In [1]:
import model, data

import torch
from torch.utils.data import DataLoader
from torch import optim, nn

import copy
import optuna

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 200
PATIENCE = 5

In [3]:
TRAIN = data.Lung_Dataset('train', verbose = 0)
VAL = data.Lung_Dataset('test', verbose = 0)

In [4]:
def train(model, device, loss_criterion, optimizer, train_loader, val_loader, epochs, patience):
    best_loss = float("inf")
    early_stop = 0
    best_weights = None
    
    for i in range(epochs):
        train_loss = train_epoch(model, device, loss_criterion, optimizer, train_loader)
        val_loss = validate(model, device, loss_criterion, val_loader)
        
        """
        Early Stopping 
        """
        if val_loss < best_loss:
            early_stop = 0
            best_loss = val_loss
            best_weights = copy.deepcopy(model.state_dict())
        else:
            early_stop += 1
                
        if early_stop == patience:
            model.load_state_dict(best_weights)
            break
    return best_loss

def train_epoch(model, device, loss_criterion, optimizer, train_loader):
    model.to(device)
    model.train()
    
    running_loss = 0
    
    counter = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model.forward(data)
        
        target = target.argmax(dim=1, keepdim=True).float()
        
        loss = loss_criterion(output, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        counter += 1

    return (running_loss / counter)

def validate(model, device, loss_criterion, val_loader):
    model.to(device)
    model.eval()

    correct = 0
    val_loss = 0

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            target = target.argmax(dim=1, keepdim=True).float()
            
            output = model.forward(data)
            val_loss += loss_criterion(output, target).item()
            
            pred = torch.round(output)
            equal_data = torch.sum(target.data == pred).item()
            correct += equal_data
    
    return (val_loss / len(val_loader))

In [5]:
def define_model(trial):
    p = trial.suggest_float("dp", 0, 0.2)
    print(p)
    return model.CNN(dropout=p)

def objective(trial):
    cnn = define_model(trial).to(DEVICE)
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    wd = trial.suggest_float("wd", 0, 0.2)
    o = optim.AdamW(cnn.parameters(), lr=lr, weight_decay=wd)
    bs = trial.suggest_categorical("bs", [16, 32, 64])
    train_loader = DataLoader(TRAIN, batch_size=bs, shuffle=True)
    val_loader = DataLoader(VAL, batch_size=bs, shuffle=True)
    print(lr, wd, bs)
    return train(cnn, DEVICE, nn.BCELoss(), o, train_loader, val_loader, EPOCHS, PATIENCE)
    

In [6]:
study = optuna.create_study()
study.optimize(objective, n_trials=30, timeout=3600)
print("Best trial:")
trial = study.best_trial
print("  Value: ", trial.value)
print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))

[32m[I 2021-03-20 11:15:21,025][0m A new study created in memory with name: no-name-d735bf5e-fc0b-4136-8f91-825b3f19d1ba[0m


0.02863673616575504
5.875856331832141e-05 0.05691017668957554 16


[32m[I 2021-03-20 11:16:29,451][0m Trial 0 finished with value: 0.4603615009631866 and parameters: {'dp': 0.02863673616575504, 'lr': 5.875856331832141e-05, 'wd': 0.05691017668957554, 'bs': 16}. Best is trial 0 with value: 0.4603615009631866.[0m


0.14718755545973602
2.9968317993352165e-05 0.1893069040078719 64


[32m[I 2021-03-20 11:17:29,874][0m Trial 1 finished with value: 0.9431668758392334 and parameters: {'dp': 0.14718755545973602, 'lr': 2.9968317993352165e-05, 'wd': 0.1893069040078719, 'bs': 64}. Best is trial 0 with value: 0.4603615009631866.[0m


0.0005182289117860783
0.0003192106921708776 0.044362220373424854 16


[32m[I 2021-03-20 11:18:32,986][0m Trial 2 finished with value: 0.40608567056747585 and parameters: {'dp': 0.0005182289117860783, 'lr': 0.0003192106921708776, 'wd': 0.044362220373424854, 'bs': 16}. Best is trial 2 with value: 0.40608567056747585.[0m


0.017122097315603723
8.157242180777415e-05 0.15416891306210762 32


[32m[I 2021-03-20 11:19:35,120][0m Trial 3 finished with value: 0.4256924793124199 and parameters: {'dp': 0.017122097315603723, 'lr': 8.157242180777415e-05, 'wd': 0.15416891306210762, 'bs': 32}. Best is trial 2 with value: 0.40608567056747585.[0m


0.024159180990209018
0.00011712667280260496 0.06134448921417926 32


[32m[I 2021-03-20 11:20:37,461][0m Trial 4 finished with value: 0.3945358067750931 and parameters: {'dp': 0.024159180990209018, 'lr': 0.00011712667280260496, 'wd': 0.06134448921417926, 'bs': 32}. Best is trial 4 with value: 0.3945358067750931.[0m


0.021972629776838937
0.0002191241977420577 0.13013517866242505 64


[32m[I 2021-03-20 11:21:38,443][0m Trial 5 finished with value: 0.5403364717960357 and parameters: {'dp': 0.021972629776838937, 'lr': 0.0002191241977420577, 'wd': 0.13013517866242505, 'bs': 64}. Best is trial 4 with value: 0.3945358067750931.[0m


0.06956077121664701
0.0005640688389531045 0.15507254649624194 32


[32m[I 2021-03-20 11:22:51,772][0m Trial 6 finished with value: 0.5581732735037803 and parameters: {'dp': 0.06956077121664701, 'lr': 0.0005640688389531045, 'wd': 0.15507254649624194, 'bs': 32}. Best is trial 4 with value: 0.3945358067750931.[0m


0.02927520585746244
0.07673384538359111 0.04213293133867402 16


[32m[I 2021-03-20 11:24:06,355][0m Trial 7 finished with value: 0.4698717040129197 and parameters: {'dp': 0.02927520585746244, 'lr': 0.07673384538359111, 'wd': 0.04213293133867402, 'bs': 16}. Best is trial 4 with value: 0.3945358067750931.[0m


0.19701600499475488
0.0022194757533465477 0.001836296581867658 16


[32m[I 2021-03-20 11:25:42,528][0m Trial 8 finished with value: 0.7057867566935527 and parameters: {'dp': 0.19701600499475488, 'lr': 0.0022194757533465477, 'wd': 0.001836296581867658, 'bs': 16}. Best is trial 4 with value: 0.3945358067750931.[0m


0.028975499497177617
1.3573440931125469e-05 0.1365910471421945 32


[32m[I 2021-03-20 11:27:04,257][0m Trial 9 finished with value: 0.4070018738508224 and parameters: {'dp': 0.028975499497177617, 'lr': 1.3573440931125469e-05, 'wd': 0.1365910471421945, 'bs': 32}. Best is trial 4 with value: 0.3945358067750931.[0m


0.09694049107825271
0.0028034807238918823 0.08580511217391977 32


[32m[I 2021-03-20 11:28:15,288][0m Trial 10 finished with value: 0.45641486570239065 and parameters: {'dp': 0.09694049107825271, 'lr': 0.0028034807238918823, 'wd': 0.08580511217391977, 'bs': 32}. Best is trial 4 with value: 0.3945358067750931.[0m


0.06416951376806293
0.012682796608526496 0.011269753749619117 16


[32m[I 2021-03-20 11:30:40,466][0m Trial 11 finished with value: 0.37823703770454115 and parameters: {'dp': 0.06416951376806293, 'lr': 0.012682796608526496, 'wd': 0.011269753749619117, 'bs': 16}. Best is trial 11 with value: 0.37823703770454115.[0m


0.06948542695803014
0.024519809479165464 0.007287191757360016 32


[32m[I 2021-03-20 11:32:41,645][0m Trial 12 finished with value: 0.761397248506546 and parameters: {'dp': 0.06948542695803014, 'lr': 0.024519809479165464, 'wd': 0.007287191757360016, 'bs': 32}. Best is trial 11 with value: 0.37823703770454115.[0m


0.0631493558075571
0.010242354993970394 0.07826587202947975 16


[32m[I 2021-03-20 11:33:43,525][0m Trial 13 finished with value: 0.45941986028964704 and parameters: {'dp': 0.0631493558075571, 'lr': 0.010242354993970394, 'wd': 0.07826587202947975, 'bs': 16}. Best is trial 11 with value: 0.37823703770454115.[0m


0.1258685286441622
0.015482051772781362 0.018373534139734887 16


[32m[I 2021-03-20 11:35:06,005][0m Trial 14 finished with value: 0.3751146300480916 and parameters: {'dp': 0.1258685286441622, 'lr': 0.015482051772781362, 'wd': 0.018373534139734887, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.13613474872507092
0.07972257851487732 0.02011317972585161 16


[32m[I 2021-03-20 11:37:00,561][0m Trial 15 finished with value: 0.9141607674268576 and parameters: {'dp': 0.13613474872507092, 'lr': 0.07972257851487732, 'wd': 0.02011317972585161, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.13583953012216418
0.013085963801613415 0.022005069484355766 16


[32m[I 2021-03-20 11:39:34,531][0m Trial 16 finished with value: 0.4212121016417558 and parameters: {'dp': 0.13583953012216418, 'lr': 0.013085963801613415, 'wd': 0.022005069484355766, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.16624228569027288
0.003965009391404885 0.002498360543027465 16


[32m[I 2021-03-20 11:40:46,415][0m Trial 17 finished with value: 0.5434792391382731 and parameters: {'dp': 0.16624228569027288, 'lr': 0.003965009391404885, 'wd': 0.002498360543027465, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.10291013515309992
0.037460296988378466 0.02858741041079732 64


[32m[I 2021-03-20 11:42:34,970][0m Trial 18 finished with value: 0.7660128951072693 and parameters: {'dp': 0.10291013515309992, 'lr': 0.037460296988378466, 'wd': 0.02858741041079732, 'bs': 64}. Best is trial 14 with value: 0.3751146300480916.[0m


0.1034559963045821
0.0012801550958711743 0.10140412806929389 16


[32m[I 2021-03-20 11:44:08,455][0m Trial 19 finished with value: 0.4342569473844308 and parameters: {'dp': 0.1034559963045821, 'lr': 0.0012801550958711743, 'wd': 0.10140412806929389, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.08846383030255146
0.007140839395617085 0.10709204930366126 16


[32m[I 2021-03-20 11:45:10,595][0m Trial 20 finished with value: 1.2529820960301619 and parameters: {'dp': 0.08846383030255146, 'lr': 0.007140839395617085, 'wd': 0.10709204930366126, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.054609489475031824
0.028761704426845128 0.060784839069045915 32


[32m[I 2021-03-20 11:46:21,322][0m Trial 21 finished with value: 0.4032908737659454 and parameters: {'dp': 0.054609489475031824, 'lr': 0.028761704426845128, 'wd': 0.060784839069045915, 'bs': 32}. Best is trial 14 with value: 0.3751146300480916.[0m


0.11977527345677501
0.006166420323605612 0.06871934025181445 32


[32m[I 2021-03-20 11:47:42,174][0m Trial 22 finished with value: 0.5211186058819294 and parameters: {'dp': 0.11977527345677501, 'lr': 0.006166420323605612, 'wd': 0.06871934025181445, 'bs': 32}. Best is trial 14 with value: 0.3751146300480916.[0m


0.04567777732046928
0.0001472934501278466 0.03776440429253959 16


[32m[I 2021-03-20 11:48:44,148][0m Trial 23 finished with value: 0.5904580094875433 and parameters: {'dp': 0.04567777732046928, 'lr': 0.0001472934501278466, 'wd': 0.03776440429253959, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.08208767560874652
0.0009742508193369938 0.008979260656670263 32


[32m[I 2021-03-20 11:49:44,694][0m Trial 24 finished with value: 0.5195557847619057 and parameters: {'dp': 0.08208767560874652, 'lr': 0.0009742508193369938, 'wd': 0.008979260656670263, 'bs': 32}. Best is trial 14 with value: 0.3751146300480916.[0m


0.002914289658770372
0.01463826681543892 0.04537506438550996 64


[32m[I 2021-03-20 11:51:12,627][0m Trial 25 finished with value: 0.5335478723049164 and parameters: {'dp': 0.002914289658770372, 'lr': 0.01463826681543892, 'wd': 0.04537506438550996, 'bs': 64}. Best is trial 14 with value: 0.3751146300480916.[0m


0.04621749488485537
1.637138436554386e-05 0.02145817807776044 16


[32m[I 2021-03-20 11:52:14,988][0m Trial 26 finished with value: 0.4664513980730986 and parameters: {'dp': 0.04621749488485537, 'lr': 1.637138436554386e-05, 'wd': 0.02145817807776044, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.1751367957978145
0.08606776284599904 0.08251385336766463 16


[32m[I 2021-03-20 11:53:37,187][0m Trial 27 finished with value: 0.8378248933034066 and parameters: {'dp': 0.1751367957978145, 'lr': 0.08606776284599904, 'wd': 0.08251385336766463, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


0.11706793975033641
0.0005290277608111185 0.05999985766448707 32


[32m[I 2021-03-20 11:54:57,359][0m Trial 28 finished with value: 0.8473708957433701 and parameters: {'dp': 0.11706793975033641, 'lr': 0.0005290277608111185, 'wd': 0.05999985766448707, 'bs': 32}. Best is trial 14 with value: 0.3751146300480916.[0m


0.08291644051905296
3.493761837360377e-05 0.05248346245796763 16


[32m[I 2021-03-20 11:55:59,724][0m Trial 29 finished with value: 0.5296053443199549 and parameters: {'dp': 0.08291644051905296, 'lr': 3.493761837360377e-05, 'wd': 0.05248346245796763, 'bs': 16}. Best is trial 14 with value: 0.3751146300480916.[0m


Best trial:
  Value:  0.3751146300480916
  Params: 
    dp: 0.1258685286441622
    lr: 0.015482051772781362
    wd: 0.018373534139734887
    bs: 16
