In [4]:
import numpy as np
import os
import torch.optim as optim
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from matplotlib import rcParams
from time import time
rcParams['figure.figsize'] = (15,4)

from architectures.segnet_depth import SegNet
from custom_dataset.BraTSDataset import BraTSDataset_old2
from My_configs import My_configs
from train import train_best
from data_loaders import my_dataloader
from data_loaders import find_dice_score
from data_loaders import dice_score_patients
#from visualize import visualize
from losses import ComboLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from train1 import Training
from losses import dice_metric

In [7]:
config = My_configs()
config.dataset_path = '/home/aubingazhibov/brain_tumor/brats_slices_final'
config.model_name = 'segnet.pth'

config.batch_size = 5
config.criterion = ComboLoss(weights = {'bce': 1,'dice': 1})
config.epochs = 10
config.img_size = 512
config.triplets = [[0.75, 1000, 0.3], [0.75, 1000, 0.4], [0.75, 2000, 0.3], [0.75, 2000, 0.4], [0.6, 2000, 0.3], [0.6, 2000, 0.4], [0.6, 3000, 0.3], [0.6, 3000, 0.4]]
config.opt = {
    'lr': 0.0001,
    'weight_decay': 0.000005
}
config.scd = {
    'mode': 'max',
    'factor': 0.1,
     'patience': 1,
    'threshold':0.0000001,
    'min_lr':0.0000001
}#scheduler


In [None]:
data_tr, data_val,val = my_dataloader(config.dataset_path, BraTSDataset_old2, batch_size = config.batch_size)

In [8]:
model = SegNet().cuda()
model = torch.nn.DataParallel(model)
opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])

scheduler = ReduceLROnPlateau(opt, mode = config.scd['mode'], factor = config.scd['factor'], patience = config.scd['patience'], threshold = config.scd['threshold'], min_lr = config.scd['min_lr'])

In [9]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric,triplets=config.triplets)

In [10]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

  return torch._C._nn.max_pool2d_with_indices(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch: 0, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.339 and validation dice is 0.472
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 1, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.131 and validation dice is 0.525
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 2, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.109 and validation dice is 0.483
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 3, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.097 and validation dice is 0.532
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 4, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.088 and validation dice is 0.503
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 5, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.082 and validation dice is 0.533
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 6, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.077 and validation dice is 0.513
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 7, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.073 and validation dice is 0.532
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 8, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.065 and validation dice is 0.516
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 9, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is 0.061 and validation dice is 0.532
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
best dice score was achieved 0.533 in epoch number 5


In [12]:
config.model_name = 'depthwise_segnet_bce1dice1_output1.pth'
torch.save(model.state_dict(), config.model_name)

In [13]:
thresholds = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
#hr - 80, lr - 40
for i in range(len(thresholds)):
    
    mean, std = find_dice_score(model, val, thresholds[i])
    print(f'threshold = {thresholds[i]}, mean = {mean}, std = {std}')
    #threshold += 0.1

  pred = np.exp(pred.to('cpu').detach().numpy()) > threshold


threshold = 0.1, mean = 0.6711878864707541, std = 0.29619316746881175
threshold = 0.2, mean = 0.6783360401459718, std = 0.297296843617315
threshold = 0.3, mean = 0.6819524428486657, std = 0.29794716086418543
threshold = 0.4, mean = 0.68437346577421, std = 0.29840014507453894
threshold = 0.5, mean = 0.6861864917210333, std = 0.2986262535244146
threshold = 0.6, mean = 0.6876371156430984, std = 0.29879043161441277
threshold = 0.7, mean = 0.6887740079255501, std = 0.29901863012959756
threshold = 0.8, mean = 0.6897248524062194, std = 0.2991671468718147
threshold = 0.9, mean = 0.6905786596894856, std = 0.29922573271124364
