In [1]:
import numpy as np
import os
import torch.optim as optim
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from matplotlib import rcParams
from time import time
rcParams['figure.figsize'] = (20,14)

from architectures.segnet_model import SegNet
from custom_dataset.BraTSDataset import BraTSDataset_old2
from My_configs import My_configs
from train import train_best
from data_loaders import my_dataloader
from data_loaders import find_dice_score
from data_loaders import dice_score_patients
#from visualize import visualize

import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
from IPython.display import clear_output

from train1 import Training
from torch.optim.lr_scheduler import ReduceLROnPlateau
from losses import ComboLoss, dice_metric


In [2]:
config = My_configs()
config = My_configs()
config.dataset_path = '/home/aubingazhibov/brain_tumor/brats_slices_final'
config.model_name = './segnet_new.pth'

config.train_size = 0.8
config.batch_size = 6
config.criterion = ComboLoss(weights = {'bce': 1,'dice': 1})
config.epochs = 15

config.opt = {
    'lr': 0.0001,
    'weight_decay': 0.000005
}
config.scd = {
    'mode': 'max',
    'factor': 0.1,
     'patience': 2,
    'threshold':0.0000001,
    'min_lr':0.0000001
}#scheduler


In [3]:
train_transform1 =  A.Compose([
    A.HorizontalFlip(),
    A.OneOf([
        A.RandomContrast(),
        A.RandomGamma(),
        A.RandomBrightness(),
        ], p=0.3),
    A.OneOf([
        A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        A.GridDistortion(),
        A.OpticalDistortion(distort_limit=2, shift_limit=0.5),
        ], p=0.3),
    A.ShiftScaleRotate(),
#    A.Resize(img_size,img_size,always_apply=True),
#    ToTensorV2(),
])
val_transform1 = A.Compose(
    [  #A.Normalize(mean=(0.485, 0.456, 0.406, 0.400), std=(0.229, 0.224, 0.225, 0.220)),
  #      ToTensorV2(),
    ]
)
data_tr, data_val,val = my_dataloader(config.dataset_path, BraTSDataset_old2, train_size = config.train_size, batch_size = config.batch_size, transform1 = train_transform1,val_transform1=val_transform1)



In [4]:
segnet_model = SegNet().cuda()
model = torch.nn.DataParallel(segnet_model)

opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])

scheduler = ReduceLROnPlateau(opt, mode = config.scd['mode'], factor = config.scd['factor'], patience = config.scd['patience'], threshold = config.scd['threshold'], min_lr = config.scd['min_lr'])
#segnet_model, best_epoch, losses_segnet_bce = train_best(model, optimizer, criterion, config.epochs, dataset_train, dataset_val, val)

In [5]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric, is_medic = False)

In [6]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

  return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch: 0, validating....
train loss is 1.441 and validation dice is 0.047

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 1, validating....
train loss is 1.277 and validation dice is 0.406

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 2, validating....
train loss is 1.142 and validation dice is 0.549

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 3, validating....
train loss is 1.022 and validation dice is 0.657

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 4, validating....
train loss is 0.911 and validation dice is 0.482

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 5, validating....
train loss is 0.804 and validation dice is 0.591

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 6, validating....
train loss is 0.702 and validation dice is 0.574

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 7, validating....
train loss is 0.639 and validation dice is 0.538

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 8, validating....
train loss is 0.628 and validation dice is 0.525

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 9, validating....
train loss is 0.616 and validation dice is 0.499

 ----------------------------------------------------------------
epoch: 10, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 10, validating....
train loss is 0.610 and validation dice is 0.537

 ----------------------------------------------------------------
epoch: 11, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 11, validating....
train loss is 0.608 and validation dice is 0.538

 ----------------------------------------------------------------
epoch: 12, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 12, validating....
train loss is 0.607 and validation dice is 0.583

 ----------------------------------------------------------------
epoch: 13, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 13, validating....
train loss is 0.606 and validation dice is 0.567

 ----------------------------------------------------------------
epoch: 14, training....


  0%|          | 0/3301 [00:00<?, ?it/s]

epoch: 14, validating....
train loss is 0.606 and validation dice is 0.579

 ----------------------------------------------------------------
best dice score was achieved 0.657 in epoch number 3


In [7]:
config.model_name = './new_models/segnet_oldcustom_bce1_dice1_july24.pth'
torch.save(model.state_dict(), config.model_name)

In [8]:
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
#hr - 80, lr - 40
for i in range(len(thresholds)):
    
    mean, std = find_dice_score(model, val, thresholds[i])
    print(f'threshold = {thresholds[i]}, mean = {mean}, std = {std}')
    #threshold += 0.1

threshold = 0.1, mean = 0.04692013080447829, std = 0.036951138055910146
threshold = 0.2, mean = 0.33696691762550113, std = 0.20312894525694036
threshold = 0.3, mean = 0.6186078579122186, std = 0.27776449056662567
threshold = 0.4, mean = 0.6439476236944268, std = 0.2888193901971688
threshold = 0.5, mean = 0.6571208599928245, std = 0.29489760931090636
threshold = 0.6, mean = 0.6654312791369531, std = 0.29941033904939857
threshold = 0.7, mean = 0.6714883016553373, std = 0.3024714118272068
threshold = 0.8, mean = 0.6760000735990711, std = 0.3049649415693655
threshold = 0.9, mean = 0.6796409010043564, std = 0.3070082172603632
