<a href="https://colab.research.google.com/github/Kaua-Rbs/Deep-Learning-From-A-To-Z-With-Pytorch-And-Python/blob/main/Project_3_Parameter_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project 3: Breast Cancer Binary Classification - Parameter Tuning

## Importing Libraries

In [1]:
!pip install skorch

Collecting skorch
  Downloading skorch-1.0.0-py3-none-any.whl (239 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m239.4/239.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: skorch
Successfully installed skorch-1.0.0


In [2]:
import pandas as pd
import numpy as np
import sklearn
import skorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetBinaryClassifier

In [3]:
torch.__version__, skorch.__version__, sklearn.__version__

('2.3.0+cu121', '1.0.0', '1.2.2')

## Database

In [4]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x780cdc195d30>

In [6]:
previsores = pd.read_csv('entradas_breast.csv')
classe = pd.read_csv('saidas_breast.csv')

In [7]:
previsores = np.array(previsores, dtype = 'float32')
classe = np.array(classe, dtype = 'float32').squeeze(1)

In [8]:
previsores.shape

(569, 30)

In [9]:
classe.shape

(569,)

## Neural Network Structure's Class Definition

In [10]:
class classificador_torch(nn.Module):
  def __init__(self, activation, neurons, initializer):
    super().__init__()
    # 30 -> 16 -> 16 -> 1
    self.dense0 = nn.Linear(30, neurons)
    initializer(self.dense0.weight)
    self.activation0 = activation
    self.dense1 = nn.Linear(neurons, neurons)
    initializer(self.dense1.weight)
    self.activation1 = activation
    self.dense2 = nn.Linear(neurons, 1)
    initializer(self.dense2.weight)

  def forward(self, X):
    X = self.dense0(X)
    X = self.activation0(X)
    X = self.dense1(X)
    X = self.activation1(X)
    X = self.dense2(X)
    return X

## Skorch

In [11]:
classificador_sklearn = NeuralNetBinaryClassifier(module=classificador_torch,
                                                  lr = 0.001,
                                                  optimizer__weight_decay = 0.0001,
                                                  train_split=False)

## Parameter Tuning

In [13]:
params = {'batch_size': [10],
          'max_epochs': [100],
          'optimizer': [torch.optim.Adam, torch.optim.SGD],
          'criterion': [torch.nn.BCEWithLogitsLoss],
          'module__activation': [F.relu, F.tanh],
          'module__neurons': [8, 16],
          'module__initializer': [torch.nn.init.uniform]}

In [14]:
params

{'batch_size': [10],
 'max_epochs': [100],
 'optimizer': [torch.optim.adam.Adam, torch.optim.sgd.SGD],
 'criterion': [torch.nn.modules.loss.BCEWithLogitsLoss],
 'module__activation': [<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
  <function torch.nn.functional.tanh(input)>],
 'module__neurons': [8, 16],
 'module__initializer': [<function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>]}

In [15]:
grid_search = GridSearchCV(estimator=classificador_sklearn, param_grid=params,
                           scoring = 'accuracy', cv = 2)
grid_search = grid_search.fit(previsores, classe)

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


  epoch    train_loss     dur
-------  ------------  ------
      1    [36m27299.8736[0m  0.2129
      2    [36m24343.2034[0m  0.0460
      3    [36m21607.9790[0m  0.0422
      4    [36m19135.1196[0m  0.0409
      5    [36m16931.4971[0m  0.0419
      6    [36m14986.2345[0m  0.0417
      7    [36m13263.7452[0m  0.0428
      8    [36m11731.9724[0m  0.0439
      9    [36m10367.9490[0m  0.0415
     10     [36m9150.4826[0m  0.0527
     11     [36m8060.7256[0m  0.0422
     12     [36m7080.4180[0m  0.0415
     13     [36m6192.2901[0m  0.0469
     14     [36m5382.8819[0m  0.0422
     15     [36m4640.1386[0m  0.0422
     16     [36m3952.5591[0m  0.0460
     17     [36m3309.9036[0m  0.0510
     18     [36m2702.8657[0m  0.0430
     19     [36m2122.0785[0m  0.0444
     20     [36m1558.6171[0m  0.0425
     21     [36m1003.4027[0m  0.0452
     22      [36m478.4378[0m  0.0433
     23      [36m138.4220[0m  0.0415
     24       [36m90.7287[0m  0.0449
    

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5    [36m12690.1400[0m  0.0445
      6    [36m11149.3147[0m  0.0421
      7     [36m9765.4873[0m  0.0431
      8     [36m8523.4345[0m  0.0537
      9     [36m7408.3080[0m  0.0415
     10     [36m6405.5082[0m  0.0420
     11     [36m5502.9294[0m  0.0412
     12     [36m4687.3043[0m  0.0412
     13     [36m3944.7716[0m  0.0486
     14     [36m3263.5476[0m  0.0553
     15     [36m2631.7690[0m  0.0400
     16     [36m2037.1351[0m  0.0400
     17     [36m1466.6355[0m  0.0406
     18      [36m907.2842[0m  0.0443
     19      [36m368.1885[0m  0.0424
     20       [36m70.2290[0m  0.0409
     21       [36m64.8927[0m  0.0446
     22       [36m41.8881[0m  0.0415
     23       [36m38.2924[0m  0.0405
     24       [36m34.4357[0m  0.0391
     25       [36m30.3832[0m  0.0416
     26       [36m28.4990[0m  0.0406
     27       [36m26.7977[0m  0.0416
     28       27.0994  0.0429
     29       [36m26.0352[0m  0.0409
     30       [36m24.2246[0m  0.

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m0.6960[0m  0.0461
      6        [36m0.6956[0m  0.0439
      7        [36m0.6951[0m  0.0453
      8        [36m0.6946[0m  0.0407
      9        [36m0.6942[0m  0.0418
     10        [36m0.6937[0m  0.0432
     11        [36m0.6933[0m  0.0463
     12        [36m0.6928[0m  0.0429
     13        [36m0.6924[0m  0.0436
     14        [36m0.6920[0m  0.0569
     15        [36m0.6916[0m  0.0500
     16        [36m0.6912[0m  0.0458
     17        [36m0.6907[0m  0.0443
     18        [36m0.6903[0m  0.0396
     19        [36m0.6900[0m  0.0406
     20        [36m0.6896[0m  0.0401
     21        [36m0.6892[0m  0.0404
     22        [36m0.6888[0m  0.0427
     23        [36m0.6884[0m  0.0406
     24        [36m0.6881[0m  0.0404
     25        [36m0.6877[0m  0.0395
     26        [36m0.6873[0m  0.0406
     27        [36m0.6870[0m  0.0406
     28        [36m0.6866[0m  0.0419
     29        [36m0.6863[0m  0.0426
     30        [36m0.685

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6        [36m0.7215[0m  0.0324
      7        [36m0.7206[0m  0.0315
      8        [36m0.7197[0m  0.0324
      9        [36m0.7188[0m  0.0328
     10        [36m0.7179[0m  0.0324
     11        [36m0.7171[0m  0.0341
     12        [36m0.7162[0m  0.0316
     13        [36m0.7154[0m  0.0321
     14        [36m0.7146[0m  0.0331
     15        [36m0.7137[0m  0.0315
     16        [36m0.7129[0m  0.0329
     17        [36m0.7122[0m  0.0324
     18        [36m0.7114[0m  0.0311
     19        [36m0.7106[0m  0.0354
     20        [36m0.7099[0m  0.0351
     21        [36m0.7091[0m  0.0325
     22        [36m0.7084[0m  0.0318
     23        [36m0.7077[0m  0.0321
     24        [36m0.7070[0m  0.0322
     25        [36m0.7063[0m  0.0401
     26        [36m0.7056[0m  0.0413
     27        [36m0.7049[0m  0.0478
     28        [36m0.7042[0m  0.0325
     29        [36m0.7036[0m  0.0327
     30        [36m0.7029[0m  0.0334
     31        [36m0.702

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5    [36m40394.8678[0m  0.0488
      6    [36m34179.8085[0m  0.0421
      7    [36m28505.4586[0m  0.0420
      8    [36m23266.5479[0m  0.0420
      9    [36m18374.1998[0m  0.0567
     10    [36m13746.7589[0m  0.0527
     11     [36m9300.1381[0m  0.0436
     12     [36m4939.4393[0m  0.0495
     13     [36m1247.1749[0m  0.0438
     14      [36m178.5840[0m  0.0405
     15      191.0505  0.0420
     16      261.4461  0.0399
     17      239.5681  0.0402
     18      236.6766  0.0408
     19      238.6056  0.0404
     20      230.9262  0.0399
     21      223.7864  0.0411
     22      220.1303  0.0414
     23      216.8497  0.0400
     24      204.7691  0.0418
     25      190.7830  0.0398
     26      207.1949  0.0406
     27      196.0288  0.0461
     28      195.7723  0.0431
     29      193.2078  0.0402
     30      183.0049  0.0382
     31      187.3162  0.0487
     32      188.7900  0.0554
     33      179.1142  0.0407
     34      [36m166.8962[0m  0.0402
 

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      4    [36m57591.1555[0m  0.0596
      5    [36m50579.8031[0m  0.0522
      6    [36m44267.9887[0m  0.0541
      7    [36m38596.0440[0m  0.0514
      8    [36m33501.1193[0m  0.0510
      9    [36m28921.3914[0m  0.0483
     10    [36m24799.2250[0m  0.0547
     11    [36m21079.0999[0m  0.0553
     12    [36m17714.4012[0m  0.0544
     13    [36m14659.1779[0m  0.0553
     14    [36m11855.1358[0m  0.0595
     15     [36m9243.7219[0m  0.0596
     16     [36m6774.6028[0m  0.0591
     17     [36m4394.1975[0m  0.0621
     18     [36m2069.4815[0m  0.0569
     19      [36m403.7693[0m  0.0581
     20      [36m196.6113[0m  0.0663
     21      207.4359  0.0550
     22      [36m195.5148[0m  0.0555
     23      [36m190.5136[0m  0.0589
     24      [36m182.9846[0m  0.0589
     25      [36m175.6850[0m  0.0590
     26      [36m167.5373[0m  0.0703
     27      [36m156.4418[0m  0.0653
     28      [36m148.5057[0m  0.0699
     29      [36m145.9227[0m  0.

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m0.6716[0m  0.0402
      8        [36m0.6714[0m  0.0347
      9        [36m0.6713[0m  0.0374
     10        [36m0.6711[0m  0.0344
     11        [36m0.6710[0m  0.0335
     12        [36m0.6709[0m  0.0328
     13        [36m0.6708[0m  0.0346
     14        [36m0.6706[0m  0.0365
     15        [36m0.6705[0m  0.0431
     16        [36m0.6704[0m  0.0402
     17        [36m0.6703[0m  0.0313
     18        [36m0.6701[0m  0.0342
     19        [36m0.6700[0m  0.0320
     20        [36m0.6699[0m  0.0330
     21        [36m0.6698[0m  0.0315
     22        [36m0.6697[0m  0.0321
     23        [36m0.6696[0m  0.0317
     24        [36m0.6695[0m  0.0323
     25        [36m0.6694[0m  0.0303
     26        [36m0.6692[0m  0.0313
     27        [36m0.6691[0m  0.0306
     28        [36m0.6690[0m  0.0437
     29        [36m0.6689[0m  0.0360
     30        [36m0.6688[0m  0.0360
     31        [36m0.6687[0m  0.0353
     32        [36m0.668

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6        [36m0.7174[0m  0.0327
      7        [36m0.7165[0m  0.0302
      8        [36m0.7157[0m  0.0379
      9        [36m0.7149[0m  0.0304
     10        [36m0.7140[0m  0.0304
     11        [36m0.7132[0m  0.0298
     12        [36m0.7124[0m  0.0315
     13        [36m0.7117[0m  0.0344
     14        [36m0.7109[0m  0.0332
     15        [36m0.7101[0m  0.0340
     16        [36m0.7094[0m  0.0349
     17        [36m0.7086[0m  0.0331
     18        [36m0.7079[0m  0.0322
     19        [36m0.7072[0m  0.0316
     20        [36m0.7065[0m  0.0320
     21        [36m0.7058[0m  0.0451
     22        [36m0.7051[0m  0.0416
     23        [36m0.7044[0m  0.0374
     24        [36m0.7038[0m  0.0291
     25        [36m0.7031[0m  0.0300
     26        [36m0.7025[0m  0.0334
     27        [36m0.7019[0m  0.0308
     28        [36m0.7012[0m  0.0304
     29        [36m0.7006[0m  0.0302
     30        [36m0.7000[0m  0.0298
     31        [36m0.699

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      4        [36m1.7189[0m  0.0514
      5        [36m1.6472[0m  0.0537
      6        [36m1.5726[0m  0.0547
      7        [36m1.4901[0m  0.0565
      8        [36m1.3849[0m  0.0563
      9        [36m1.2230[0m  0.0596
     10        [36m0.9781[0m  0.0518
     11        [36m0.7669[0m  0.0520
     12        [36m0.6936[0m  0.0520
     13        [36m0.6770[0m  0.0578
     14        [36m0.6739[0m  0.0549
     15        [36m0.6731[0m  0.0536
     16        [36m0.6724[0m  0.0603
     17        [36m0.6719[0m  0.0602
     18        [36m0.6715[0m  0.0618
     19        [36m0.6712[0m  0.0681
     20        [36m0.6709[0m  0.0515
     21        [36m0.6707[0m  0.0608
     22        [36m0.6705[0m  0.0614
     23        [36m0.6704[0m  0.0524
     24        [36m0.6702[0m  0.0593
     25        [36m0.6701[0m  0.0582
     26        [36m0.6700[0m  0.0579
     27        [36m0.6699[0m  0.0566
     28        [36m0.6698[0m  0.0575
     29        [36m0.669

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m1.2859[0m  0.0447
      6        [36m1.2117[0m  0.0422
      7        [36m1.1254[0m  0.0442
      8        [36m1.0169[0m  0.0421
      9        [36m0.8868[0m  0.0447
     10        [36m0.7715[0m  0.0421
     11        [36m0.7064[0m  0.0421
     12        [36m0.6806[0m  0.0434
     13        [36m0.6718[0m  0.0434
     14        [36m0.6686[0m  0.0424
     15        [36m0.6674[0m  0.0437
     16        [36m0.6667[0m  0.0414
     17        [36m0.6664[0m  0.0557
     18        [36m0.6661[0m  0.0413
     19        [36m0.6660[0m  0.0417
     20        [36m0.6658[0m  0.0508
     21        [36m0.6657[0m  0.0531
     22        [36m0.6656[0m  0.0448
     23        [36m0.6656[0m  0.0418
     24        [36m0.6655[0m  0.0430
     25        [36m0.6654[0m  0.0432
     26        [36m0.6654[0m  0.0445
     27        [36m0.6653[0m  0.0422
     28        [36m0.6653[0m  0.0419
     29        [36m0.6653[0m  0.0417
     30        [36m0.665

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6        [36m1.4093[0m  0.0358
      7        [36m1.3770[0m  0.0316
      8        [36m1.3452[0m  0.0404
      9        [36m1.3138[0m  0.0407
     10        [36m1.2829[0m  0.0408
     11        [36m1.2525[0m  0.0405
     12        [36m1.2227[0m  0.0312
     13        [36m1.1935[0m  0.0337
     14        [36m1.1650[0m  0.0346
     15        [36m1.1372[0m  0.0321
     16        [36m1.1100[0m  0.0332
     17        [36m1.0837[0m  0.0328
     18        [36m1.0581[0m  0.0322
     19        [36m1.0333[0m  0.0310
     20        [36m1.0094[0m  0.0315
     21        [36m0.9864[0m  0.0297
     22        [36m0.9642[0m  0.0298
     23        [36m0.9430[0m  0.0304
     24        [36m0.9228[0m  0.0307
     25        [36m0.9035[0m  0.0311
     26        [36m0.8851[0m  0.0308
     27        [36m0.8677[0m  0.0306
     28        [36m0.8513[0m  0.0301
     29        [36m0.8358[0m  0.0319
     30        [36m0.8213[0m  0.0407
     31        [36m0.807

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m1.3860[0m  0.0422
      6        [36m1.3556[0m  0.0429
      7        [36m1.3257[0m  0.0423
      8        [36m1.2962[0m  0.0492
      9        [36m1.2672[0m  0.0416
     10        [36m1.2387[0m  0.0391
     11        [36m1.2107[0m  0.0449
     12        [36m1.1833[0m  0.0523
     13        [36m1.1565[0m  0.0444
     14        [36m1.1303[0m  0.0520
     15        [36m1.1048[0m  0.0420
     16        [36m1.0800[0m  0.0445
     17        [36m1.0559[0m  0.0403
     18        [36m1.0326[0m  0.0400
     19        [36m1.0100[0m  0.0451
     20        [36m0.9883[0m  0.0447
     21        [36m0.9673[0m  0.0451
     22        [36m0.9472[0m  0.0436
     23        [36m0.9280[0m  0.0431
     24        [36m0.9096[0m  0.0453
     25        [36m0.8920[0m  0.0433
     26        [36m0.8754[0m  0.0452
     27        [36m0.8596[0m  0.0473
     28        [36m0.8446[0m  0.0348
     29        [36m0.8305[0m  0.0327
     30        [36m0.817

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m2.1681[0m  0.0438
      6        [36m2.0287[0m  0.0417
      7        [36m1.8666[0m  0.0448
      8        [36m1.5984[0m  0.0503
      9        [36m1.1698[0m  0.0503
     10        [36m0.8075[0m  0.0506
     11        [36m0.6951[0m  0.0436
     12        [36m0.6748[0m  0.0393
     13        [36m0.6708[0m  0.0408
     14        [36m0.6696[0m  0.0458
     15        [36m0.6690[0m  0.0579
     16        [36m0.6688[0m  0.0522
     17        [36m0.6686[0m  0.0406
     18        [36m0.6684[0m  0.0441
     19        [36m0.6684[0m  0.0416
     20        [36m0.6683[0m  0.0491
     21        [36m0.6682[0m  0.0418
     22        [36m0.6682[0m  0.0409
     23        [36m0.6682[0m  0.0405
     24        [36m0.6681[0m  0.0404
     25        [36m0.6681[0m  0.0416
     26        [36m0.6680[0m  0.0401
     27        [36m0.6680[0m  0.0422
     28        [36m0.6680[0m  0.0424
     29        [36m0.6679[0m  0.0401
     30        [36m0.667

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m1.7417[0m  0.0449
      6        [36m1.6123[0m  0.0443
      7        [36m1.4851[0m  0.0454
      8        [36m1.3601[0m  0.0433
      9        [36m1.2281[0m  0.0485
     10        [36m1.0634[0m  0.0400
     11        [36m0.8767[0m  0.0416
     12        [36m0.7321[0m  0.0410
     13        [36m0.6801[0m  0.0437
     14        [36m0.6709[0m  0.0401
     15        [36m0.6692[0m  0.0413
     16        [36m0.6686[0m  0.0409
     17        [36m0.6683[0m  0.0415
     18        [36m0.6680[0m  0.0425
     19        [36m0.6678[0m  0.0475
     20        [36m0.6677[0m  0.0413
     21        [36m0.6676[0m  0.0439
     22        [36m0.6672[0m  0.0445
     23        [36m0.6499[0m  0.0481
     24        0.6893  0.0551
     25        0.6823  0.0437
     26        0.6777  0.0408
     27        0.6753  0.0444
     28        0.6739  0.0443
     29        0.6729  0.0417
     30        0.6721  0.0419
     31        0.6714  0.0475
     32        0.67

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m2.7880[0m  0.0426
      6        [36m2.7181[0m  0.0476
      7        [36m2.6482[0m  0.0445
      8        [36m2.5785[0m  0.0491
      9        [36m2.5088[0m  0.0472
     10        [36m2.4391[0m  0.0464
     11        [36m2.3696[0m  0.0465
     12        [36m2.3002[0m  0.0434
     13        [36m2.2310[0m  0.0328
     14        [36m2.1619[0m  0.0315
     15        [36m2.0931[0m  0.0323
     16        [36m2.0245[0m  0.0345
     17        [36m1.9562[0m  0.0313
     18        [36m1.8883[0m  0.0318
     19        [36m1.8208[0m  0.0311
     20        [36m1.7538[0m  0.0426
     21        [36m1.6875[0m  0.0415
     22        [36m1.6219[0m  0.0345
     23        [36m1.5571[0m  0.0313
     24        [36m1.4934[0m  0.0318
     25        [36m1.4308[0m  0.0324
     26        [36m1.3697[0m  0.0328
     27        [36m1.3101[0m  0.0322
     28        [36m1.2523[0m  0.0328
     29        [36m1.1966[0m  0.0319
     30        [36m1.143

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6        [36m3.8201[0m  0.0437
      7        [36m3.7531[0m  0.0363
      8        [36m3.6861[0m  0.0314
      9        [36m3.6191[0m  0.0349
     10        [36m3.5520[0m  0.0314
     11        [36m3.4850[0m  0.0333
     12        [36m3.4180[0m  0.0322
     13        [36m3.3511[0m  0.0318
     14        [36m3.2841[0m  0.0325
     15        [36m3.2171[0m  0.0309
     16        [36m3.1501[0m  0.0330
     17        [36m3.0832[0m  0.0312
     18        [36m3.0163[0m  0.0325
     19        [36m2.9494[0m  0.0311
     20        [36m2.8825[0m  0.0329
     21        [36m2.8156[0m  0.0321
     22        [36m2.7488[0m  0.0329
     23        [36m2.6820[0m  0.0410
     24        [36m2.6153[0m  0.0341
     25        [36m2.5486[0m  0.0322
     26        [36m2.4820[0m  0.0325
     27        [36m2.4155[0m  0.0323
     28        [36m2.3491[0m  0.0330
     29        [36m2.2828[0m  0.0332
     30        [36m2.2167[0m  0.0330
     31        [36m2.150

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      3    [36m51267.8822[0m  0.0858
      4    [36m38467.3745[0m  0.0922
      5    [36m28422.4937[0m  0.0852
      6    [36m20474.0159[0m  0.0862
      7    [36m14073.4524[0m  0.0989
      8     [36m8759.0203[0m  0.0805
      9     [36m4118.9056[0m  0.0838
     10      [36m574.8635[0m  0.0813
     11      [36m151.9170[0m  0.0797
     12      [36m122.1264[0m  0.0834
     13      [36m103.8878[0m  0.0819
     14       [36m87.6194[0m  0.0813
     15       [36m72.8501[0m  0.0808
     16       [36m64.1713[0m  0.0911
     17       [36m63.3306[0m  0.0862
     18       [36m55.2647[0m  0.0858
     19       58.0386  0.0874
     20       58.5939  0.0832
     21       [36m51.1864[0m  0.0835
     22       [36m46.2947[0m  0.0821
     23       [36m43.7537[0m  0.0812
     24       [36m43.3757[0m  0.0825
     25       48.3145  0.0797
     26       53.5200  0.0810
     27       47.2430  0.0914
     28       [36m36.6532[0m  0.0939
     29       38.5515  0.0827
 

In [16]:
melhores_parametros = grid_search.best_params_
melhor_precisao = grid_search.best_score_

In [17]:
melhores_parametros

{'batch_size': 10,
 'criterion': torch.nn.modules.loss.BCEWithLogitsLoss,
 'max_epochs': 100,
 'module__activation': <function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
 'module__initializer': <function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>,
 'module__neurons': 16,
 'optimizer': torch.optim.adam.Adam}

In [18]:
melhor_precisao

0.7944897454904868