In [1]:
import model, data

import torch
from torch.utils.data import DataLoader
from torch import optim, nn

import copy
import optuna

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 200
PATIENCE = 5

In [3]:
TRAIN = data.Lung_Dataset('train', verbose = 2)
VAL = data.Lung_Dataset('test', verbose = 2)

In [4]:
def train(model, device, loss_criterion, optimizer, train_loader, val_loader, epochs, patience):
    best_loss = float("inf")
    early_stop = 0
    best_weights = None
    
    for i in range(epochs):
        train_loss = train_epoch(model, device, loss_criterion, optimizer, train_loader)
        val_loss = validate(model, device, loss_criterion, val_loader)
        
        """
        Early Stopping 
        """
        if val_loss < best_loss:
            early_stop = 0
            best_loss = val_loss
            best_weights = copy.deepcopy(model.state_dict())
        else:
            early_stop += 1
                
        if early_stop == patience:
            model.load_state_dict(best_weights)
            break
    return best_loss

def train_epoch(model, device, loss_criterion, optimizer, train_loader):
    model.to(device)
    model.train()
    
    running_loss = 0
    
    counter = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model.forward(data)
        
        target = target.argmax(dim=1, keepdim=True).float()
        
        loss = loss_criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        counter += 1

    return (running_loss / counter)

def validate(model, device, loss_criterion, val_loader):
    model.to(device)
    model.eval()

    correct = 0
    val_loss = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            target = target.argmax(dim=1, keepdim=True).float()
            
            output = model.forward(data)
            val_loss += loss_criterion(output, target).item()
            
            pred = torch.round(output)
            equal_data = torch.sum(target.data == pred).item()
            correct += equal_data
    
    return (val_loss / len(val_loader))

In [5]:
def define_model(trial):
    p = trial.suggest_float("dp", 0, 0.2)
    print(p)
    return model.CNN(dropout=p)

def objective(trial):
    cnn = define_model(trial).to(DEVICE)
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    wd = trial.suggest_float("wd", 0, 0.2)
    o = optim.AdamW(cnn.parameters(), lr=lr, weight_decay=wd)
    bs = trial.suggest_categorical("bs", [16, 32, 64])
    train_loader = DataLoader(TRAIN, batch_size=bs, shuffle=True)
    val_loader = DataLoader(VAL, batch_size=bs, shuffle=True)
    print(lr, wd, bs)
    return train(cnn, DEVICE, nn.BCELoss(), o, train_loader, val_loader, EPOCHS, PATIENCE)
    

In [6]:
study = optuna.create_study()
study.optimize(objective, n_trials=50, timeout=3600)
print("Best trial:")
trial = study.best_trial
print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[32m[I 2021-03-20 10:29:09,780][0m A new study created in memory with name: no-name-a41b37b7-57b7-47a0-a077-56244a033fe0[0m


0.06057165318859972
0.0013520931046449997 0.11347722692407791 16


[32m[I 2021-03-20 10:30:24,255][0m Trial 0 finished with value: 0.3333252854645252 and parameters: {'dp': 0.06057165318859972, 'lr': 0.0013520931046449997, 'wd': 0.11347722692407791, 'bs': 16}. Best is trial 0 with value: 0.3333252854645252.[0m


0.1522951009670115
2.3306972611332212e-05 0.13034043863064404 64


[32m[I 2021-03-20 10:31:14,021][0m Trial 1 finished with value: 0.5825112164020538 and parameters: {'dp': 0.1522951009670115, 'lr': 2.3306972611332212e-05, 'wd': 0.13034043863064404, 'bs': 64}. Best is trial 0 with value: 0.3333252854645252.[0m


0.10685277468376463
0.00011111424403325367 0.022163514510014838 64


[32m[I 2021-03-20 10:32:14,624][0m Trial 2 finished with value: 0.36531396210193634 and parameters: {'dp': 0.10685277468376463, 'lr': 0.00011111424403325367, 'wd': 0.022163514510014838, 'bs': 64}. Best is trial 0 with value: 0.3333252854645252.[0m


0.02019899633953295
0.0001779956475560491 0.13440266698592862 64


[32m[I 2021-03-20 10:33:59,497][0m Trial 3 finished with value: 0.2999621133009593 and parameters: {'dp': 0.02019899633953295, 'lr': 0.0001779956475560491, 'wd': 0.13440266698592862, 'bs': 64}. Best is trial 3 with value: 0.2999621133009593.[0m


0.07216065625585566
0.00022040717935498526 0.08267768486777832 64


[32m[I 2021-03-20 10:35:49,913][0m Trial 4 finished with value: 0.27932196110486984 and parameters: {'dp': 0.07216065625585566, 'lr': 0.00022040717935498526, 'wd': 0.08267768486777832, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.12808882415451758
0.0009187683331167611 0.030708369334321618 16


[32m[I 2021-03-20 10:37:12,047][0m Trial 5 finished with value: 0.31964680925011635 and parameters: {'dp': 0.12808882415451758, 'lr': 0.0009187683331167611, 'wd': 0.030708369334321618, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.1944411572945465
0.0001902598948985696 0.027793391185003458 16


[32m[I 2021-03-20 10:40:13,228][0m Trial 6 finished with value: 0.31035185636331636 and parameters: {'dp': 0.1944411572945465, 'lr': 0.0001902598948985696, 'wd': 0.027793391185003458, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.049379175134508624
4.709053259516035e-05 0.08116132550855143 64


[32m[I 2021-03-20 10:41:52,392][0m Trial 7 finished with value: 0.3349460909763972 and parameters: {'dp': 0.049379175134508624, 'lr': 4.709053259516035e-05, 'wd': 0.08116132550855143, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.06368417183913193
0.000681851099055944 0.19212139090443808 16


[32m[I 2021-03-20 10:43:14,725][0m Trial 8 finished with value: 0.2942549393822749 and parameters: {'dp': 0.06368417183913193, 'lr': 0.000681851099055944, 'wd': 0.19212139090443808, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.07692404735891667
0.05646183215795199 0.19322731285235528 16


[32m[I 2021-03-20 10:44:25,401][0m Trial 9 finished with value: 0.5773775577545166 and parameters: {'dp': 0.07692404735891667, 'lr': 0.05646183215795199, 'wd': 0.19322731285235528, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.001170578262552291
0.01443454289045443 0.06859704531443177 32


[32m[I 2021-03-20 10:45:05,630][0m Trial 10 finished with value: 0.5203975414236387 and parameters: {'dp': 0.001170578262552291, 'lr': 0.01443454289045443, 'wd': 0.06859704531443177, 'bs': 32}. Best is trial 4 with value: 0.27932196110486984.[0m


0.03458257847689693
0.0011738154931952983 0.18434187580613104 32


[32m[I 2021-03-20 10:46:14,777][0m Trial 11 finished with value: 0.34633087863524753 and parameters: {'dp': 0.03458257847689693, 'lr': 0.0011738154931952983, 'wd': 0.18434187580613104, 'bs': 32}. Best is trial 4 with value: 0.27932196110486984.[0m


0.08693234840969014
0.004810554586855008 0.15987126992129191 64


[32m[I 2021-03-20 10:48:14,413][0m Trial 12 finished with value: 0.3580191433429718 and parameters: {'dp': 0.08693234840969014, 'lr': 0.004810554586855008, 'wd': 0.15987126992129191, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.10167809408410662
1.0041853298257173e-05 0.06852387115957422 16


[32m[I 2021-03-20 10:50:57,338][0m Trial 13 finished with value: 0.3587395070741574 and parameters: {'dp': 0.10167809408410662, 'lr': 1.0041853298257173e-05, 'wd': 0.06852387115957422, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.0667346939564297
0.0004536092613777286 0.09449274643284844 16


[32m[I 2021-03-20 10:52:24,965][0m Trial 14 finished with value: 0.29124442177514237 and parameters: {'dp': 0.0667346939564297, 'lr': 0.0004536092613777286, 'wd': 0.09449274643284844, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.13169742252445413
0.005242674157198616 0.0927586095975866 32


[32m[I 2021-03-20 10:53:20,917][0m Trial 15 finished with value: 0.38984114676713943 and parameters: {'dp': 0.13169742252445413, 'lr': 0.005242674157198616, 'wd': 0.0927586095975866, 'bs': 32}. Best is trial 4 with value: 0.27932196110486984.[0m


0.003270953366463089
0.00039791067130735277 0.04776987858499752 64


[32m[I 2021-03-20 10:54:10,023][0m Trial 16 finished with value: 0.3481379598379135 and parameters: {'dp': 0.003270953366463089, 'lr': 0.00039791067130735277, 'wd': 0.04776987858499752, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.03796985227476649
5.297636836200823e-05 0.11897406624194262 16


[32m[I 2021-03-20 10:55:30,640][0m Trial 17 finished with value: 0.33227191989620525 and parameters: {'dp': 0.03796985227476649, 'lr': 5.297636836200823e-05, 'wd': 0.11897406624194262, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.07997327678918566
0.003353974191500774 0.051742753271104766 64


[32m[I 2021-03-20 10:56:19,751][0m Trial 18 finished with value: 0.3462867836157481 and parameters: {'dp': 0.07997327678918566, 'lr': 0.003353974191500774, 'wd': 0.051742753271104766, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.12058782077651486
0.00040622927705186553 0.10109170243724203 64


[32m[I 2021-03-20 10:57:52,413][0m Trial 19 finished with value: 0.28673263142506283 and parameters: {'dp': 0.12058782077651486, 'lr': 0.00040622927705186553, 'wd': 0.10109170243724203, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.17422375282017383
1.5208760650365768e-05 0.15486235055100878 64


[32m[I 2021-03-20 10:59:57,377][0m Trial 20 finished with value: 0.4554918756087621 and parameters: {'dp': 0.17422375282017383, 'lr': 1.5208760650365768e-05, 'wd': 0.15486235055100878, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.11426529558933218
0.00034525434140349475 0.10188887454787803 64


[32m[I 2021-03-20 11:01:35,342][0m Trial 21 finished with value: 0.31297074258327484 and parameters: {'dp': 0.11426529558933218, 'lr': 0.00034525434140349475, 'wd': 0.10188887454787803, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.09263586646860812
8.453370991215703e-05 2.8914725989565015e-05 64


[32m[I 2021-03-20 11:02:29,368][0m Trial 22 finished with value: 0.3968365043401718 and parameters: {'dp': 0.09263586646860812, 'lr': 8.453370991215703e-05, 'wd': 2.8914725989565015e-05, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.13969490331970516
0.0004065652008039802 0.09035338883128065 64


[32m[I 2021-03-20 11:03:51,533][0m Trial 23 finished with value: 0.3199833979209264 and parameters: {'dp': 0.13969490331970516, 'lr': 0.0004065652008039802, 'wd': 0.09035338883128065, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.0698028827005007
0.0023061507359635144 0.07550077852138813 16


[32m[I 2021-03-20 11:04:38,096][0m Trial 24 finished with value: 0.3499660063534975 and parameters: {'dp': 0.0698028827005007, 'lr': 0.0023061507359635144, 'wd': 0.07550077852138813, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.16055942443586732
0.0002286398814574078 0.10928019458397432 32


[32m[I 2021-03-20 11:05:46,347][0m Trial 25 finished with value: 0.3889438845217228 and parameters: {'dp': 0.16055942443586732, 'lr': 0.0002286398814574078, 'wd': 0.10928019458397432, 'bs': 32}. Best is trial 4 with value: 0.27932196110486984.[0m


0.11565493196236894
0.000593927175748908 0.05160830262045148 64


[32m[I 2021-03-20 11:07:09,451][0m Trial 26 finished with value: 0.3183664530515671 and parameters: {'dp': 0.11565493196236894, 'lr': 0.000593927175748908, 'wd': 0.05160830262045148, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.05314747040012667
0.012581846753651395 0.13856288077316015 64


[32m[I 2021-03-20 11:08:04,410][0m Trial 27 finished with value: 0.4394279619057973 and parameters: {'dp': 0.05314747040012667, 'lr': 0.012581846753651395, 'wd': 0.13856288077316015, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.0916757201284986
0.0001038962091949819 0.08822813187454809 16


[32m[I 2021-03-20 11:09:36,383][0m Trial 28 finished with value: 0.30898371525108814 and parameters: {'dp': 0.0916757201284986, 'lr': 0.0001038962091949819, 'wd': 0.08822813187454809, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.02146427903295997
0.0018416920510633704 0.11751457708398019 16


[32m[I 2021-03-20 11:10:27,904][0m Trial 29 finished with value: 0.328363382567962 and parameters: {'dp': 0.02146427903295997, 'lr': 0.0018416920510633704, 'wd': 0.11751457708398019, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.05878091949376712
5.495822140663257e-05 0.10590096574482363 64


[32m[I 2021-03-20 11:11:43,571][0m Trial 30 finished with value: 0.3453686734040578 and parameters: {'dp': 0.05878091949376712, 'lr': 5.495822140663257e-05, 'wd': 0.10590096574482363, 'bs': 64}. Best is trial 4 with value: 0.27932196110486984.[0m


0.06351793640513609
0.0005831615420914658 0.17276232083595472 16


[32m[I 2021-03-20 11:12:47,234][0m Trial 31 finished with value: 0.31891682371497154 and parameters: {'dp': 0.06351793640513609, 'lr': 0.0005831615420914658, 'wd': 0.17276232083595472, 'bs': 16}. Best is trial 4 with value: 0.27932196110486984.[0m


0.07025320973373538
0.00088507324324128 0.05975621480375387 16


KeyboardInterrupt: 