In [6]:
import numpy as np
import os
import torch.optim as optim
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from matplotlib import rcParams
from time import time
rcParams['figure.figsize'] = (20,14)

from architectures.ternausnets import AlbuNet, unet11 
from custom_dataset.BraTSDataset import BraTSDataset_old2
from My_configs import My_configs
from train import train_best
from data_loaders import my_dataloader
from data_loaders import find_dice_score
from data_loaders import dice_score_patients
#from visualize import visualize

from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import ComboLoss
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2

import warnings
warnings.filterwarnings("ignore")

from torchvision import transforms
from losses import dice_metric

from tqdm import tqdm
from tqdm import trange
from train1 import Training

In [2]:
config = My_configs()
config.dataset_path = '/home/aubingazhibov/brain_tumor/brats_slices_final'
config.model_name = './ternausent_july25.pth'

config.batch_size = 5
config.criterion = ComboLoss(weights = {'bce': 1,'dice': 1, 'focal': 1})
config.epochs = 15
config.img_size = 512
config.triplets = [[0.75, 1000, 0.3], [0.75, 1000, 0.4], [0.75, 2000, 0.3], [0.75, 2000, 0.4], [0.6, 2000, 0.3], [0.6, 2000, 0.4], [0.6, 3000, 0.3], [0.6, 3000, 0.4]]
config.opt = {
    'lr': 0.0001,
    'weight_decay': 0.000005
}
config.scd = {
    'mode': 'max',
    'factor': 0.1,
     'patience': 2,
    'threshold':0.0000001,
    'min_lr':0.0000001
}#scheduler


In [3]:
train_transform1 =  A.Compose([
    A.HorizontalFlip(),
    A.OneOf([
        A.RandomContrast(),
        A.RandomGamma(),
        A.RandomBrightness(),
        ], p=0.3),
    A.ShiftScaleRotate(),
    A.Resize(config.img_size,config.img_size,always_apply=True),
   # ToTensorV2(),
])
val_transform1 = A.Compose(
    [ # A.Normalize(mean=(0.485, 0.456, 0.406, 0.400), std=(0.229, 0.224, 0.225, 0.220)),
        A.Resize(config.img_size,config.img_size,always_apply=True),
    #    ToTensorV2(),
    ]
)
data_tr, data_val,val = my_dataloader(config.dataset_path, BraTSDataset_old2,batch_size = config.batch_size, transform1 = train_transform1,val_transform1=val_transform1, resize = config.img_size)
#data_tr, data_val,val = my_dataloader(config.dataset_path, BraTSDataset, train_size = config.train_size, batch_size = config.batch_size, transform1 = None,val_transform1=None)

In [11]:
model = AlbuNet(pretrained=True)
model = model.cuda()
model = torch.nn.DataParallel(model)

opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])

scheduler = ReduceLROnPlateau(opt, mode = config.scd['mode'], factor = config.scd['factor'], patience = config.scd['patience'], threshold = config.scd['threshold'], min_lr = config.scd['min_lr'])

In [5]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric,triplets=config.triplets, is_medic = False)

In [None]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 0, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.192 and validation dice is 0.689

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 1, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.130 and validation dice is 0.649

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 2, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.111 and validation dice is 0.662

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 3, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.102 and validation dice is 0.693

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 4, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.092 and validation dice is 0.706

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 5, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.085 and validation dice is 0.688

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 6, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.080 and validation dice is 0.678

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 7, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.076 and validation dice is 0.678

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 8, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.066 and validation dice is 0.667

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 9, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.063 and validation dice is 0.678

 ----------------------------------------------------------------
epoch: 10, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 10, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.061 and validation dice is 0.685

 ----------------------------------------------------------------
epoch: 11, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 11, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.060 and validation dice is 0.696

 ----------------------------------------------------------------
epoch: 12, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 12, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.060 and validation dice is 0.668

 ----------------------------------------------------------------
epoch: 13, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 13, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.059 and validation dice is 0.667

 ----------------------------------------------------------------
epoch: 14, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

In [None]:
config.model_name = './new_generation_AlbuNet_july25_combo111_oldbrats_stage1.pth'
torch.save(model.state_dict(), config.model_name)
#model.load_state_dict(torch.load(config.model_name))
#model.eval()

# Stage 2 

In [6]:
from torch.optim.lr_scheduler import CosineAnnealingLR

In [13]:
opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])
scheduler = CosineAnnealingLR(opt, T_max = 8, eta_min=0.0000001 )
config.epochs = 12

config.opt = {
    'lr': 0.00001,
    'weight_decay': 0.000005
}
config.batch_size = 2


In [14]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric,triplets=config.triplets, is_medic = False)

In [15]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 0, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.098 and validation dice is 0.705
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 1, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.088 and validation dice is 0.717
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 2, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.082 and validation dice is 0.699
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 3, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.077 and validation dice is 0.715
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 4, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.073 and validation dice is 0.728
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 5, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.071 and validation dice is 0.735
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 6, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.068 and validation dice is 0.714
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 7, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.066 and validation dice is 0.723
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 8, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.065 and validation dice is 0.690
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 9, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.064 and validation dice is 0.730
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 10, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 10, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.062 and validation dice is 0.742
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 11, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 11, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.061 and validation dice is 0.698
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
best dice score was achieved 0.742 in epoch number 10


In [16]:
config.model_name = './new_generation_AlbuNet_combo111_oldbrats_stage2.pth'
torch.save(model.state_dict(), config.model_name)

In [17]:
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
#hr - 80, lr - 40
for i in range(len(thresholds)):
    
    mean, std = find_dice_score(model, val, thresholds[i])
    print(f'threshold = {thresholds[i]}, mean = {mean}, std = {std}')
    #threshold += 0.1

threshold = 0.1, mean = 0.7897882866319307, std = 0.2721513147117413
threshold = 0.2, mean = 0.7881597962104523, std = 0.27334071475133137
threshold = 0.3, mean = 0.7872788922956843, std = 0.27387003815580513
threshold = 0.4, mean = 0.786644179711099, std = 0.2741829363646536
threshold = 0.5, mean = 0.7861720350521597, std = 0.2743425093764282
threshold = 0.6, mean = 0.7857832737722713, std = 0.274459921558303
threshold = 0.7, mean = 0.7854315863201072, std = 0.2745918266463566
threshold = 0.8, mean = 0.785152970444831, std = 0.27465512755973454
threshold = 0.9, mean = 0.7849072001588442, std = 0.27471211017331115


# Stage 3

In [18]:
opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])
scheduler = CosineAnnealingLR(opt, T_max = 8, eta_min=0.0000001 )
config.epochs = 10

config.opt = {
    'lr': 0.00001,
    'weight_decay': 0.000005
}
config.batch_size = 2
config.criterion = ComboLoss(weights = {'bce': 1,'dice': 1})

In [19]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric,triplets=config.triplets, is_medic = False)

In [20]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 0, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.054 and validation dice is 0.739
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 1, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.052 and validation dice is 0.723
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 2, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.051 and validation dice is 0.736
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 3, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.050 and validation dice is 0.740
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 4, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.050 and validation dice is 0.716
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 5, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.049 and validation dice is 0.715
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 6, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.049 and validation dice is 0.740
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 7, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.049 and validation dice is 0.738
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 8, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.048 and validation dice is 0.734
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 9, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.048 and validation dice is 0.721
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
best dice score was achieved 0.740 in epoch number 6


In [21]:
config.model_name = './new_generation_AlbuNet_combo111_oldbrats_stage3.pth'
torch.save(model.state_dict(), config.model_name)

In [22]:
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
#hr - 80, lr - 40
for i in range(len(thresholds)):
    
    mean, std = find_dice_score(model, val, thresholds[i])
    print(f'threshold = {thresholds[i]}, mean = {mean}, std = {std}')
    #threshold += 0.1

threshold = 0.1, mean = 0.7873946149943822, std = 0.27433860596463844
threshold = 0.2, mean = 0.7861451316754571, std = 0.27488714878363707
threshold = 0.3, mean = 0.785540608501538, std = 0.2751106301261827
threshold = 0.4, mean = 0.7851192139307206, std = 0.27523818153416263
threshold = 0.5, mean = 0.7848104021589042, std = 0.2752558214666667
threshold = 0.6, mean = 0.7845433888146955, std = 0.27528855727458795
threshold = 0.7, mean = 0.7843074028909659, std = 0.275319256702124
threshold = 0.8, mean = 0.7841206240524379, std = 0.27531063094727526
threshold = 0.9, mean = 0.7839479267437647, std = 0.27532252546435926
