<a href="https://colab.research.google.com/github/MarcelodeFreitas/udemy_deep_learning_pytorch_python/blob/main/Projeto_3_Classifica%C3%A7%C3%A3o_bin%C3%A1ria_breast_cancer_com_tuning_de_par%C3%A2metros.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Projeto 3: Classificação binária brest cancer com tuning dos parâmetros

## Etapa 1: Importação das bibliotecas

In [6]:
!pip install skorch




[notice] A new release of pip is available: 23.3.2 -> 24.0
[notice] To update, run: python.exe -m pip install --upgrade pip


In [7]:
import pandas as pd
import numpy as np
import sklearn
import skorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetBinaryClassifier

In [8]:
torch.__version__, skorch.__version__, sklearn.__version__

('1.13.1+cpu', '0.13.0', '1.0.2')

## Etapa 2: Base de dados

In [9]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x2334bd5a2f0>

In [10]:
previsores = pd.read_csv('./databases/entradas_breast.csv')
classe = pd.read_csv('./databases/saidas_breast.csv')

In [11]:
previsores = np.array(previsores, dtype = 'float32')
classe = np.array(classe, dtype = 'float32').squeeze(1)

In [12]:
previsores.shape

(569, 30)

In [13]:
classe.shape

(569,)

## Etapa 3: Classe para estrutura da rede neural

**\*\* ATUALIZAÇÃO JAN/2022 \*\*** : na versão atual do Skorch, os resultados da rede neural devem ser retornados sem ativação, ou seja, sem a camada sigmoide no final. Com isto, a função de custo deve ser `BCEWithLogitsLoss`.

In [14]:
class classificador_torch(nn.Module):
  def __init__(self, activation, neurons, initializer):
    super().__init__()
    # 30 -> 16 -> 16 -> 1
    self.dense0 = nn.Linear(30, neurons)
    initializer(self.dense0.weight)
    self.activation0 = activation
    self.dense1 = nn.Linear(neurons, neurons)
    initializer(self.dense1.weight)
    self.activation1 = activation
    self.dense2 = nn.Linear(neurons, 1)
    initializer(self.dense2.weight)
    # self.output = nn.Sigmoid() ** ATUALIZAÇÃO (ver detalhes no texto acima) **

  def forward(self, X):
    X = self.dense0(X)
    X = self.activation0(X)
    X = self.dense1(X)
    X = self.activation1(X)
    X = self.dense2(X)
    # X = self.output(X) ** ATUALIZAÇÃO (ver detalhes no texto acima) **
    return X

## Etapa 4: Skorch

In [15]:
classificador_sklearn = NeuralNetBinaryClassifier(module=classificador_torch,
                                                  lr = 0.001,
                                                  optimizer__weight_decay = 0.0001,
                                                  train_split=False)

## Etapa 5: Tuning dos parâmetros

In [16]:
params = {'batch_size': [10],
          'max_epochs': [100],
          'optimizer': [torch.optim.Adam, torch.optim.SGD],
          'criterion': [torch.nn.BCEWithLogitsLoss], #, torch.nn.HingeEmbeddingLoss], # ** ATUALIZAÇÃO **
          'module__activation': [F.relu, F.tanh],
          'module__neurons': [8, 16],
          'module__initializer': [torch.nn.init.uniform]} # _, torch.nn.init.normal_]}

In [17]:
params

{'batch_size': [10],
 'max_epochs': [100],
 'optimizer': [torch.optim.adam.Adam, torch.optim.sgd.SGD],
 'criterion': [torch.nn.modules.loss.BCEWithLogitsLoss],
 'module__activation': [<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
  <function torch.nn.functional.tanh(input)>],
 'module__neurons': [8, 16],
 'module__initializer': [<function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>]}

In [18]:
grid_search = GridSearchCV(estimator=classificador_sklearn, param_grid=params,
                           scoring = 'accuracy', cv = 2)
grid_search = grid_search.fit(previsores, classe)

  epoch    train_loss     dur
-------  ------------  ------
      1    [36m27299.8738[0m  0.0884
      2    [36m24343.2038[0m  0.0677


  
  if __name__ == "__main__":
  if sys.path[0] == "":


      3    [36m21607.9801[0m  0.0455
      4    [36m19135.1209[0m  0.0457
      5    [36m16931.4981[0m  0.0511
      6    [36m14986.2359[0m  0.0418
      7    [36m13263.7468[0m  0.0407


      8    [36m11731.9746[0m  0.0429
      9    [36m10367.9507[0m  0.0395
     10     [36m9150.4845[0m  0.0452
     11     [36m8060.7275[0m  0.0664
     12     [36m7080.4200[0m  0.0448
     13     [36m6192.2921[0m  0.0422
     14     [36m5382.8836[0m  0.0353
     15     [36m4640.1406[0m  0.0366
     16     [36m3952.5612[0m  0.0357
     17     [36m3309.9056[0m  0.0357
     18     [36m2702.8676[0m  0.0350
     19     [36m2122.0807[0m  0.0341
     20     [36m1558.6194[0m  0.0336
     21     [36m1003.4049[0m  0.0363
     22      [36m478.4398[0m  0.0338
     23      [36m138.4234[0m  0.0333
     24       [36m90.7292[0m  0.0341
     25       [36m88.3919[0m  0.0339
     26       [36m77.7052[0m  0.0302
     27       [36m70.4368[0m  0.0351
     28       [36m64.3089[0m  0.0308
     29       [36m58.6982[0m  0.0351
     30       [36m54.5392[0m  0.0333
     31       [36m52.0839[0m  0.0356
     32       [36m48.6802[0m  0.0351
     33       [36m46.166

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      6    [36m11149.3161[0m  0.0349
      7     [36m9765.4890[0m  0.0354
      8     [36m8523.4365[0m  0.0315
      9     [36m7408.3099[0m  0.0300
     10     [36m6405.5098[0m  0.0403
     11     [36m5502.9313[0m  0.0355
     12     [36m4687.3060[0m  0.0367
     13     [36m3944.7731[0m  0.0368
     14     [36m3263.5494[0m  0.0380
     15     [36m2631.7707[0m  0.0337
     16     [36m2037.1368[0m  0.0359
     17     [36m1466.6373[0m  0.0354
     18      [36m907.2860[0m  0.0327
     19      [36m368.1899[0m  0.0328
     20       [36m70.2289[0m  0.0353
     21       [36m64.8930[0m  0.0341
     22       [36m41.8885[0m  0.0349
     23       [36m38.2927[0m  0.0312
     24       [36m34.4355[0m  0.0336
     25       [36m30.3832[0m  0.0336
     26       [36m28.4991[0m  0.0344
     27       [36m26.7977[0m  0.0324
     28       27.0994  0.0349
     29       [36m26.0353[0m  0.0330
     30       [36m24.2245[0m  0.0338
     31       [36m23.3312[0m  0.

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      9        [36m0.6942[0m  0.0220
     10        [36m0.6937[0m  0.0226
     11        [36m0.6933[0m  0.0228
     12        [36m0.6928[0m  0.0236
     13        [36m0.6924[0m  0.0221
     14        [36m0.6920[0m  0.0200
     15        [36m0.6916[0m  0.0216
     16        [36m0.6912[0m  0.0211
     17        [36m0.6907[0m  0.0239
     18        [36m0.6903[0m  0.0241
     19        [36m0.6900[0m  0.0226
     20        [36m0.6896[0m  0.0226
     21        [36m0.6892[0m  0.0241
     22        [36m0.6888[0m  0.0244
     23        [36m0.6884[0m  0.0211
     24        [36m0.6881[0m  0.0223
     25        [36m0.6877[0m  0.0230
     26        [36m0.6873[0m  0.0222
     27        [36m0.6870[0m  0.0197
     28        [36m0.6866[0m  0.0220
     29        [36m0.6863[0m  0.0225
     30        [36m0.6859[0m  0.0216
     31        [36m0.6856[0m  0.0216
     32        [36m0.6853[0m  0.0221
     33        [36m0.6850[0m  0.0206
     34        [36m0.684

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      9        [36m0.7188[0m  0.0216
     10        [36m0.7179[0m  0.0221
     11        [36m0.7171[0m  0.0218
     12        [36m0.7162[0m  0.0220
     13        [36m0.7154[0m  0.0208
     14        [36m0.7146[0m  0.0206
     15        [36m0.7137[0m  0.0221
     16        [36m0.7129[0m  0.0228
     17        [36m0.7122[0m  0.0221
     18        [36m0.7114[0m  0.0220
     19        [36m0.7106[0m  0.0230
     20        [36m0.7099[0m  0.0211
     21        [36m0.7091[0m  0.0220
     22        [36m0.7084[0m  0.0220
     23        [36m0.7077[0m  0.0216
     24        [36m0.7070[0m  0.0222
     25        [36m0.7063[0m  0.0211
     26        [36m0.7056[0m  0.0226
     27        [36m0.7049[0m  0.0248
     28        [36m0.7042[0m  0.0241
     29        [36m0.7036[0m  0.0206
     30        [36m0.7029[0m  0.0216
     31        [36m0.7023[0m  0.0221
     32        [36m0.7016[0m  0.0217
     33        [36m0.7010[0m  0.0216
     34        [36m0.700

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      6    [36m34179.8145[0m  0.0352
      7    [36m28505.4649[0m  0.0321
      8    [36m23266.5548[0m  0.0341
      9    [36m18374.2068[0m  0.0327
     10    [36m13746.7661[0m  0.0341
     11     [36m9300.1456[0m  0.0316
     12     [36m4939.4480[0m  0.0337
     13     [36m1247.1757[0m  0.0344
     14      [36m178.5844[0m  0.0321
     15      191.0514  0.0329
     16      261.4456  0.0337
     17      239.5682  0.0336
     18      236.6767  0.0341
     19      238.6057  0.0306
     20      230.9263  0.0336
     21      223.7866  0.0322
     22      220.1303  0.0311
     23      216.8508  0.0337
     24      204.7688  0.0401
     25      190.7912  0.0330
     26      207.1829  0.0345
     27      196.0347  0.0349
     28      195.7878  0.0336
     29      194.2897  0.0337
     30      190.3566  0.0336
     31      182.2654  0.0353
     32      187.3828  0.0351
     33      188.6760  0.0338
     34      178.7014  0.0321
     35      [36m164.6644[0m  0.0331
     36   

  
  if __name__ == "__main__":
  if sys.path[0] == "":



      7    [36m38596.0505[0m  0.0350
      8    [36m33501.1263[0m  0.0308
      9    [36m28921.3997[0m  0.0327
     10    [36m24799.2336[0m  0.0311
     11    [36m21079.1082[0m  0.0307
     12    [36m17714.4094[0m  0.0322
     13    [36m14659.1855[0m  0.0314
     14    [36m11855.1441[0m  0.0328
     15     [36m9243.7297[0m  0.0321
     16     [36m6774.6108[0m  0.0326
     17     [36m4394.2052[0m  0.0332
     18     [36m2069.4897[0m  0.0336
     19      [36m403.7687[0m  0.0316
     20      [36m196.6150[0m  0.0333
     21      207.4304  0.0331
     22      [36m195.5365[0m  0.0306
     23      [36m190.5083[0m  0.0331
     24      [36m182.9861[0m  0.0332
     25      [36m175.7252[0m  0.0333
     26      [36m167.1133[0m  0.0301
     27      [36m157.1833[0m  0.0325
     28      [36m146.7859[0m  0.0326
     29      [36m142.2308[0m  0.0318
     30      [36m141.3955[0m  0.0297
     31      [36m138.5383[0m  0.0342
     32      [36m129.9394[0m  0

  
  if __name__ == "__main__":
  if sys.path[0] == "":



      9        [36m0.6713[0m  0.0228
     10        [36m0.6711[0m  0.0225
     11        [36m0.6710[0m  0.0231
     12        [36m0.6709[0m  0.0214
     13        [36m0.6708[0m  0.0207
     14        [36m0.6706[0m  0.0216
     15        [36m0.6705[0m  0.0221
     16        [36m0.6704[0m  0.0211
     17        [36m0.6703[0m  0.0231
     18        [36m0.6701[0m  0.0231
     19        [36m0.6700[0m  0.0222
     20        [36m0.6699[0m  0.0211
     21        [36m0.6698[0m  0.0210
     22        [36m0.6697[0m  0.0222
     23        [36m0.6696[0m  0.0220
     24        [36m0.6695[0m  0.0220
     25        [36m0.6694[0m  0.0225
     26        [36m0.6692[0m  0.0230
     27        [36m0.6691[0m  0.0217
     28        [36m0.6690[0m  0.0212
     29        [36m0.6689[0m  0.0211
     30        [36m0.6688[0m  0.0221
     31        [36m0.6687[0m  0.0196
     32        [36m0.6686[0m  0.0218
     33        [36m0.6685[0m  0.0230
     34        [36m0.66

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      9        [36m0.7149[0m  0.0222
     10        [36m0.7140[0m  0.0231
     11        [36m0.7132[0m  0.0221
     12        [36m0.7124[0m  0.0216
     13        [36m0.7117[0m  0.0220
     14        [36m0.7109[0m  0.0241
     15        [36m0.7101[0m  0.0216
     16        [36m0.7094[0m  0.0197
     17        [36m0.7086[0m  0.0201
     18        [36m0.7079[0m  0.0230
     19        [36m0.7072[0m  0.0239
     20        [36m0.7065[0m  0.0222
     21        [36m0.7058[0m  0.0210
     22        [36m0.7051[0m  0.0221
     23        [36m0.7044[0m  0.0206
     24        [36m0.7038[0m  0.0201
     25        [36m0.7031[0m  0.0226
     26        [36m0.7025[0m  0.0224
     27        [36m0.7019[0m  0.0202
     28        [36m0.7012[0m  0.0228
     29        [36m0.7006[0m  0.0211
     30        [36m0.7000[0m  0.0222
     31        [36m0.6994[0m  0.0221
     32        [36m0.6988[0m  0.0226
     33        [36m0.6982[0m  0.0211
     34        [36m0.697

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      6        [36m1.5726[0m  0.0311
      7        [36m1.4901[0m  0.0333
      8        [36m1.3849[0m  0.0336
      9        [36m1.2230[0m  0.0323
     10        [36m0.9781[0m  0.0306
     11        [36m0.7669[0m  0.0306
     12        [36m0.6936[0m  0.0322
     13        [36m0.6770[0m  0.0320
     14        [36m0.6739[0m  0.0306
     15        [36m0.6731[0m  0.0336
     16        [36m0.6724[0m  0.0327
     17        [36m0.6719[0m  0.0331
     18        [36m0.6715[0m  0.0316
     19        [36m0.6712[0m  0.0302
     20        [36m0.6709[0m  0.0327
     21        [36m0.6707[0m  0.0341
     22        [36m0.6705[0m  0.0311
     23        [36m0.6704[0m  0.0321
     24        [36m0.6702[0m  0.0312
     25        [36m0.6701[0m  0.0321
     26        [36m0.6700[0m  0.0311
     27        [36m0.6699[0m  0.0322
     28        [36m0.6698[0m  0.0312
     29        [36m0.6697[0m  0.0311
     30        [36m0.6697[0m  0.0335
     31        [36m0.669

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      7        [36m1.1254[0m  0.0327
      8        [36m1.0169[0m  0.0318
      9        [36m0.8868[0m  0.0333
     10        [36m0.7715[0m  0.0342
     11        [36m0.7064[0m  0.0311
     12        [36m0.6806[0m  0.0336
     13        [36m0.6718[0m  0.0312
     14        [36m0.6686[0m  0.0328
     15        [36m0.6674[0m  0.0340
     16        [36m0.6667[0m  0.0321
     17        [36m0.6664[0m  0.0321
     18        [36m0.6661[0m  0.0337
     19        [36m0.6660[0m  0.0313
     20        [36m0.6658[0m  0.0308
     21        [36m0.6657[0m  0.0331
     22        [36m0.6656[0m  0.0336
     23        [36m0.6656[0m  0.0331
     24        [36m0.6655[0m  0.0307
     25        [36m0.6654[0m  0.0335
     26        [36m0.6654[0m  0.0321
     27        [36m0.6653[0m  0.0311
     28        [36m0.6653[0m  0.0326
     29        [36m0.6653[0m  0.0316
     30        [36m0.6652[0m  0.0327
     31        [36m0.6652[0m  0.0317
     32        [36m0.665

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      9        [36m1.3138[0m  0.0226
     10        [36m1.2829[0m  0.0241
     11        [36m1.2525[0m  0.0233
     12        [36m1.2227[0m  0.0215
     13        [36m1.1935[0m  0.0232
     14        [36m1.1650[0m  0.0211
     15        [36m1.1372[0m  0.0193
     16        [36m1.1100[0m  0.0236
     17        [36m1.0837[0m  0.0227
     18        [36m1.0581[0m  0.0225
     19        [36m1.0333[0m  0.0201
     20        [36m1.0094[0m  0.0211
     21        [36m0.9864[0m  0.0211
     22        [36m0.9642[0m  0.0215
     23        [36m0.9430[0m  0.0201
     24        [36m0.9228[0m  0.0201
     25        [36m0.9035[0m  0.0251
     26        [36m0.8851[0m  0.0236
     27        [36m0.8677[0m  0.0231
     28        [36m0.8513[0m  0.0211
     29        [36m0.8358[0m  0.0222
     30        [36m0.8213[0m  0.0225
     31        [36m0.8077[0m  0.0226
     32        [36m0.7949[0m  0.0206
     33        [36m0.7831[0m  0.0226
     34        [36m0.772

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      9        [36m1.2672[0m  0.0236
     10        [36m1.2387[0m  0.0211
     11        [36m1.2107[0m  0.0221
     12        [36m1.1833[0m  0.0206
     13        [36m1.1565[0m  0.0219
     14        [36m1.1303[0m  0.0226
     15        [36m1.1048[0m  0.0209
     16        [36m1.0800[0m  0.0238
     17        [36m1.0559[0m  0.0215
     18        [36m1.0326[0m  0.0217
     19        [36m1.0100[0m  0.0236
     20        [36m0.9883[0m  0.0216
     21        [36m0.9673[0m  0.0231
     22        [36m0.9472[0m  0.0209
     23        [36m0.9280[0m  0.0221
     24        [36m0.9096[0m  0.0216
     25        [36m0.8920[0m  0.0216
     26        [36m0.8754[0m  0.0239
     27        [36m0.8596[0m  0.0220
     28        [36m0.8446[0m  0.0241
     29        [36m0.8305[0m  0.0196
     30        [36m0.8173[0m  0.0206
     31        [36m0.8048[0m  0.0216
     32        [36m0.7932[0m  0.0210
     33        [36m0.7823[0m  0.0219
     34        [36m0.772

  
  if __name__ == "__main__":
  if sys.path[0] == "":



      7        [36m1.8666[0m  0.0323
      8        [36m1.5984[0m  0.0326
      9        [36m1.1698[0m  0.0331
     10        [36m0.8075[0m  0.0326
     11        [36m0.6951[0m  0.0331
     12        [36m0.6748[0m  0.0326
     13        [36m0.6708[0m  0.0337
     14        [36m0.6696[0m  0.0332
     15        [36m0.6690[0m  0.0332
     16        [36m0.6688[0m  0.0316
     17        [36m0.6686[0m  0.0343
     18        [36m0.6684[0m  0.0316
     19        [36m0.6684[0m  0.0341
     20        [36m0.6683[0m  0.0301
     21        [36m0.6682[0m  0.0311
     22        [36m0.6682[0m  0.0309
     23        [36m0.6682[0m  0.0328
     24        [36m0.6681[0m  0.0301
     25        [36m0.6681[0m  0.0326
     26        [36m0.6680[0m  0.0318
     27        [36m0.6680[0m  0.0321
     28        [36m0.6680[0m  0.0321
     29        [36m0.6679[0m  0.0344
     30        [36m0.6679[0m  0.0326
     31        [36m0.6678[0m  0.0342
     32        [36m0.66

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      6        [36m1.6123[0m  0.0331
      7        [36m1.4851[0m  0.0320
      8        [36m1.3601[0m  0.0322
      9        [36m1.2281[0m  0.0325
     10        [36m1.0634[0m  0.0331
     11        [36m0.8767[0m  0.0334
     12        [36m0.7321[0m  0.0321
     13        [36m0.6801[0m  0.0332
     14        [36m0.6709[0m  0.0314
     15        [36m0.6692[0m  0.0326
     16        [36m0.6686[0m  0.0557
     17        [36m0.6683[0m  0.0337
     18        [36m0.6680[0m  0.0317
     19        [36m0.6678[0m  0.0321
     20        [36m0.6677[0m  0.0334
     21        [36m0.6676[0m  0.0331
     22        [36m0.6672[0m  0.0316
     23        [36m0.6499[0m  0.0321
     24        0.6893  0.0310
     25        0.6823  0.0327
     26        0.6777  0.0328
     27        0.6753  0.0336
     28        0.6739  0.0312
     29        0.6729  0.0316
     30        0.6721  0.0331
     31        0.6714  0.0331
     32        0.6707  0.0341
     33        0.6700  0.030

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      8        [36m2.5785[0m  0.0200
      9        [36m2.5088[0m  0.0286
     10        [36m2.4391[0m  0.0265
     11        [36m2.3696[0m  0.0211
     12        [36m2.3002[0m  0.0221
     13        [36m2.2310[0m  0.0231
     14        [36m2.1619[0m  0.0220
     15        [36m2.0931[0m  0.0220
     16        [36m2.0245[0m  0.0246
     17        [36m1.9562[0m  0.0230
     18        [36m1.8883[0m  0.0236
     19        [36m1.8208[0m  0.0232
     20        [36m1.7538[0m  0.0226
     21        [36m1.6875[0m  0.0230
     22        [36m1.6219[0m  0.0222
     23        [36m1.5571[0m  0.0214
     24        [36m1.4934[0m  0.0227
     25        [36m1.4308[0m  0.0230
     26        [36m1.3697[0m  0.0264
     27        [36m1.3101[0m  0.0237
     28        [36m1.2523[0m  0.0231
     29        [36m1.1966[0m  0.0427
     30        [36m1.1432[0m  0.0291
     31        [36m1.0923[0m  0.0236
     32        [36m1.0442[0m  0.0221
     33        [36m0.999

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      8        [36m3.6861[0m  0.0221
      9        [36m3.6191[0m  0.0230
     10        [36m3.5520[0m  0.0231
     11        [36m3.4850[0m  0.0236
     12        [36m3.4180[0m  0.0252
     13        [36m3.3511[0m  0.0206
     14        [36m3.2841[0m  0.0229
     15        [36m3.2171[0m  0.0227
     16        [36m3.1501[0m  0.0244
     17        [36m3.0832[0m  0.0201
     18        [36m3.0163[0m  0.0231
     19        [36m2.9494[0m  0.0225
     20        [36m2.8825[0m  0.0222
     21        [36m2.8156[0m  0.0226
     22        [36m2.7488[0m  0.0211
     23        [36m2.6820[0m  0.0223
     24        [36m2.6153[0m  0.0236
     25        [36m2.5486[0m  0.0206
     26        [36m2.4820[0m  0.0237
     27        [36m2.4155[0m  0.0205
     28        [36m2.3491[0m  0.0241
     29        [36m2.2828[0m  0.0217
     30        [36m2.2167[0m  0.0231
     31        [36m2.1507[0m  0.0221
     32        [36m2.0849[0m  0.0216
     33        [36m2.019

  
  if __name__ == "__main__":
  if sys.path[0] == "":


      4    [36m38467.3815[0m  0.0640
      5    [36m28422.5020[0m  0.0643
      6    [36m20474.0236[0m  0.0633
      7    [36m14073.4598[0m  0.0631
      8     [36m8759.0273[0m  0.0669
      9     [36m4118.9129[0m  0.0609
     10      [36m574.8676[0m  0.0650
     11      [36m151.9106[0m  0.0612
     12      [36m122.1280[0m  0.0659
     13      [36m103.7969[0m  0.0649
     14       [36m87.6232[0m  0.0659
     15       [36m72.9820[0m  0.0644
     16       [36m64.6275[0m  0.0641
     17       [36m64.3240[0m  0.0648
     18       [36m55.6824[0m  0.0674
     19       56.6579  0.0668
     20       [36m54.0908[0m  0.0661
     21       60.3754  0.0650
     22       61.2469  0.0614
     23       57.6553  0.0675
     24       [36m46.2579[0m  0.0632
     25       [36m41.6352[0m  0.0628
     26       [36m38.4564[0m  0.0641
     27       [36m37.3743[0m  0.0645
     28       40.1999  0.0638
     29       48.5139  0.0650
     30       54.3047  0.0662
     31   

In [19]:
melhores_parametros = grid_search.best_params_
melhor_precisao = grid_search.best_score_

In [20]:
melhores_parametros

{'batch_size': 10,
 'criterion': torch.nn.modules.loss.BCEWithLogitsLoss,
 'max_epochs': 100,
 'module__activation': <function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
 'module__initializer': <function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>,
 'module__neurons': 16,
 'optimizer': torch.optim.adam.Adam}

In [21]:
melhor_precisao

0.8383308623671856