In [None]:
import numpy as np
import os
import torch.optim as optim
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt
from matplotlib import rcParams
from time import time
rcParams['figure.figsize'] = (15,4)

from architectures.unet_model import UNet
from custom_dataset.BraTSDataset import BraTSDataset_old2
from My_configs import My_configs
from train import train_best
from data_loaders import my_dataloader
from data_loaders import find_dice_score
from data_loaders import dice_score_patients
#from visualize import visualize
from losses import ComboLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau
import albumentations as A
import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from train1 import Training
from losses import dice_metric

In [2]:
config = My_configs()
config.dataset_path = '/home/aubingazhibov/brain_tumor/brats_slices_final'
config.model_name = 'unet_new.pth'

config.batch_size = 5
config.criterion = ComboLoss(weights = {'bce': 1,'dice': 1, 'focal': 1})
config.epochs = 12
config.img_size = 512
config.triplets = [[0.75, 1000, 0.3], [0.75, 1000, 0.4], [0.75, 2000, 0.3], [0.75, 2000, 0.4], [0.6, 2000, 0.3], [0.6, 2000, 0.4], [0.6, 3000, 0.3], [0.6, 3000, 0.4]]
config.opt = {
    'lr': 0.0001,
    'weight_decay': 0.000005
}
config.scd = {
    'mode': 'max',
    'factor': 0.1,
     'patience': 2,
    'threshold':0.0000001,
    'min_lr':0.0000001
}#scheduler


In [3]:
train_transform1 = A.Compose(
    [   
        A.Flip(p=0.45),
        A.Affine(scale = (0.3,0.7), rotate =[-90, 90], p=0.45),
        A.ElasticTransform(p=0.5),
        A.RandomGamma(always_apply=False, p=0.45, gamma_limit=(66, 98)),
        
    
      #  A.Normalize(mean=([0.485, 0.456, 0.406, 0.406]), std=(0.229, 0.224, 0.225, 0.220)),
        ToTensorV2(),
    ]
)

val_transform1 = A.Compose(
    [  #A.Normalize(mean=(0.485, 0.456, 0.406, 0.400), std=(0.229, 0.224, 0.225, 0.220)),
        ToTensorV2(),
    ]
)
data_tr, data_val,val = my_dataloader(config.dataset_path, BraTSDataset_old2, batch_size = config.batch_size)


In [4]:
u_model = UNet(out=4).cuda()
model = torch.nn.DataParallel(u_model)
opt = optim.Adam(model.parameters(), lr = config.opt['lr'], weight_decay = config.opt['weight_decay'])

scheduler = ReduceLROnPlateau(opt, mode = config.scd['mode'], factor = config.scd['factor'], patience = config.scd['patience'], threshold = config.scd['threshold'], min_lr = config.scd['min_lr'])

In [5]:
train_validate = Training(config.criterion, opt, scheduler, config.epochs, dice_metric,triplets=config.triplets, is_4out = True)

In [6]:
model = train_validate.run_train(model, data_tr, data_val, val)

epoch: 0, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


epoch: 0, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -523.638 and validation dice is 0.040
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 1, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 1, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -364.818 and validation dice is 0.143
best_threshold is (0.6, 2000, 0.4)

 ----------------------------------------------------------------
epoch: 2, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 2, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -563.717 and validation dice is 0.046
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 3, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 3, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -728.032 and validation dice is 0.050
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 4, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 4, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -1201.264 and validation dice is 0.069
best_threshold is (0.6, 2000, 0.4)

 ----------------------------------------------------------------
epoch: 5, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 5, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -1690.200 and validation dice is 0.086
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 6, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 6, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -3587.549 and validation dice is 0.028
best_threshold is (0.6, 2000, 0.4)

 ----------------------------------------------------------------
epoch: 7, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 7, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -6858.105 and validation dice is 0.000
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 8, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 8, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -8632.379 and validation dice is 0.000
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 9, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 9, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -36760.362 and validation dice is 0.045
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
epoch: 10, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 10, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -74646.204 and validation dice is 0.022
best_threshold is (0.75, 1000, 0.4)

 ----------------------------------------------------------------
epoch: 11, training....


  0%|          | 0/3919 [00:00<?, ?it/s]

epoch: 11, validating....


  0%|          | 0/966 [00:00<?, ?it/s]

train loss is -30790.158 and validation dice is 0.005
best_threshold is (0.75, 1000, 0.3)

 ----------------------------------------------------------------
best dice score was achieved 0.143 in epoch number 1


In [7]:
config.model_name = './new_models/unet_new_oldcustom_combo111_4output_26july.pth'
torch.save(model.state_dict(), config.model_name)