In [3]:
import os
os.environ["MKL_NUM_THREADS"] = "8" 
os.environ["NUMEXPR_NUM_THREADS"] = "8" 
os.environ["OMP_NUM_THREADS"] = "8" 

from os import path, makedirs, listdir
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

# from apex import amp

from adamw import AdamW
from losses import dice_round, ComboLoss

import pandas as pd
from tqdm import tqdm
import timeit
import cv2

from zoo.models import Res34_9ch_Unet
from zoo.models import Res50_9ch_Unet



from imgaug import augmenters as iaa

from utils import *

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns

import gc

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

train_png = './wdata/train_png'
train_png2 = './wdata/train_png_5_3_0'
train_png3 = './wdata/train_png_pan_6_7'

masks_dir = './wdata/masks'

models_folder = './wdata/weights'
# val_output_folder = 'res34_9ch_val_0'

speed_bins = [15, 18.75, 20, 25, 30, 35, 45, 55, 65]

cities = [('AOI_7_Moscow', '/fs/scratch/PCON0003/osu10670/AOI_7_Moscow_train', 'train_AOI_7_Moscow_geojson_roads_speed_wkt_weighted_simp.csv'), ('AOI_8_Mumbai', '/fs/scratch/PCON0003/osu10670/AOI_8_Mumbai_train', 'train_AOI_8_Mumbai_geojson_roads_speed_wkt_weighted_simp.csv')]
          
cities_idxs = {}
for i in range(len(cities)):
    cities_idxs[cities[i][0]] = i


input_shape = (704, 704) # (384, 384)

train_files = []
for f in listdir(train_png):
    if '.png' in f:
        train_files.append(f)



class TrainData(Dataset):
    def __init__(self, train_idxs):
        super().__init__()
        self.train_idxs = train_idxs
        self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2)

    def __len__(self):
        return len(self.train_idxs)

    def __getitem__(self, idx):
        _idx = self.train_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)

        #TODO finally finetune Moscow (only?) without flips and rotations!
        

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.8)) or (random.random() > 0.9):
            img = img[::-1, ...]
            img2 = img2[::-1, ...]
            img3 = img3[::-1, ...]
            msk0 = msk0[::-1, ...]
            msk1 = msk1[::-1, ...]
            msk2 = msk2[::-1, ...]
            msk3 = msk3[::-1, ...]
            msk4 = msk4[::-1, ...]

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.8)) or (random.random() > 0.9):
            rot = random.randrange(4)
            if rot > 0:
                img = np.rot90(img, k=rot)
                img2 = np.rot90(img2, k=rot)
                img3 = np.rot90(img3, k=rot)
                msk0 = np.rot90(msk0, k=rot)
                msk1 = np.rot90(msk1, k=rot)
                msk2 = np.rot90(msk2, k=rot)
                msk3 = np.rot90(msk3, k=rot)
                msk4 = np.rot90(msk4, k=rot)
                    
        if random.random() > 0.95:
            shift_pnt = (random.randint(-320, 320), random.randint(-320, 320))
            img = shift_image(img, shift_pnt)
            img2 = shift_image(img2, shift_pnt)
            img3 = shift_image(img3, shift_pnt)
            msk0 = shift_image(msk0, shift_pnt)
            msk1 = shift_image(msk1, shift_pnt)
            msk2 = shift_image(msk2, shift_pnt)
            msk3 = shift_image(msk3, shift_pnt)
            msk4 = shift_image(msk4, shift_pnt)
            
        if random.random() > 0.95:
            rot_pnt =  (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320))
            scale = 0.9 + random.random() * 0.2
            angle = random.randint(0, 20) - 10
            if (angle != 0) or (scale != 1):
                img = rotate_image(img, angle, scale, rot_pnt)
                img2 = rotate_image(img2, angle, scale, rot_pnt)
                img3 = rotate_image(img3, angle, scale, rot_pnt)
                msk0 = rotate_image(msk0, angle, scale, rot_pnt)
                msk1 = rotate_image(msk1, angle, scale, rot_pnt)
                msk2 = rotate_image(msk2, angle, scale, rot_pnt)
                msk3 = rotate_image(msk3, angle, scale, rot_pnt)
                msk4 = rotate_image(msk4, angle, scale, rot_pnt)

        crop_size = input_shape[0]
        if random.random() > 0.95:
            crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9))

        x0 = random.randint(0, img.shape[1] - crop_size)
        y0 = random.randint(0, img.shape[0] - crop_size)

        img = img[y0:y0+crop_size, x0:x0+crop_size, :]
        img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :]
        img3 = img3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size, :]
        msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size, :]
        msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size, :]
        msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size]
        

        if crop_size != input_shape[0]:
            img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR)
            img3 = cv2.resize(img3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR)
            msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR)
            msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR)
            msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR)
            
        if random.random() > 0.97:
            img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.97:
            img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.97:
            img3 = shift_channels(img3, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.97:
            img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.97:
            if random.random() > 0.95:
                img = clahe(img)
            elif random.random() > 0.95:
                img = gauss_noise(img)
            elif random.random() > 0.95:
                img = cv2.blur(img, (3, 3))
        elif random.random() > 0.97:
            if random.random() > 0.95:
                img = saturation(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = brightness(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = contrast(img, 0.9 + random.random() * 0.2)

        if random.random() > 0.98:
            el_det = self.elastic.to_deterministic()
            img = el_det.augment_image(img)

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample


    
class ValData(Dataset):
    def __init__(self, image_idxs):
        super().__init__()
        self.image_idxs = image_idxs

    def __len__(self):
        return len(self.image_idxs)

    def __getitem__(self, idx):
        _idx = self.image_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)
        img = np.pad(img, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img2 = np.pad(img2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img3 = np.pad(img3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk0 = np.pad(msk0, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk1 = np.pad(msk1, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk2 = np.pad(msk2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk3 = np.pad(msk3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk4 = np.pad(msk4, ((6, 6), (6, 6)), mode='reflect')

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample




def validate(net, data_loader):
    dices0 = []
    dices1 = []

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            msks = sample["msk"].numpy()
            imgs = sample["img"].cuda(non_blocking=True)
            img_ids =  sample["img_id"]
            
            out = model(imgs)
            
            msk_pred = torch.sigmoid(out[:, :2, ...]).cpu().numpy()
            speed_cont_pred = out[:, 2, ...].cpu().numpy()
            speed_cont_pred[speed_cont_pred < 0] = 0
            speed_cont_pred[speed_cont_pred > 1] = 1
            msk_speed_pred = torch.softmax(out[:, 3:, ...], dim=1).cpu().numpy()

            pred = msk_pred > 0.5
            for j in range(msks.shape[0]):
                dices0.append(dice(msks[j, 0], pred[j, 0]))
                dices1.append(dice(msks[j, 1], pred[j, 1]))
                
                # pred_img0 = np.concatenate([msk_pred[j, 0, ..., np.newaxis], msk_pred[j, 1, ..., np.newaxis], np.zeros_like(msk_pred[j, 0, ..., np.newaxis])], axis=2)
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j]), (pred_img0 * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed0.png')), (msk_speed_pred[j, :3].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed1.png')), (msk_speed_pred[j, 3:6].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed2.png')), (msk_speed_pred[j, 6:9].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed_cont.png')), (speed_cont_pred[j] * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

    d0 = np.mean(dices0)
    d1 = np.mean(dices1)

    print("Val Dice: {}, {}".format(d0, d1))
    return d0



def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch):
    model = model.eval()
    d = validate(model, data_loader=data_val)

    if d > best_score:
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'best_score': d,
        }, path.join(models_folder, snapshot_name + '_best'))
        best_score = d

    print("dice: {}\tdice_best: {}".format(d, best_score))
    return best_score



def train_epoch(current_epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses3 = AverageMeter()
    losses4 = AverageMeter()

    dices = AverageMeter()

    iterator = tqdm(train_data_loader)
    model.train()
    scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["img"].cuda(non_blocking=True)
        msks = sample["msk"].cuda(non_blocking=True)
        msks_speed = sample["msk_speed"].cuda(non_blocking=True)
        lbls_speed = sample["lbl_speed"].cuda(non_blocking=True)
        msks_speed_cont = sample["msk_speed_cont"].cuda(non_blocking=True)
        

        out = model(imgs)

        loss1 = seg_loss(out[:, 0, ...], msks[:, 0, ...])
        loss2 = seg_loss(out[:, 1, ...], msks[:, 1, ...])

        loss3 = ce_loss(out[:, 3:, ...], lbls_speed)

        loss4 = mse_loss(out[:, 2:3, ...], msks_speed_cont)

        loss = 1.5 * loss1 + 0.05 * loss2 + 0.2 * loss3 + 0.1 * loss4

        for _i in range(3, 13):
            loss += 0.03 * seg_loss(out[:, _i, ...], msks_speed[:, _i-3, ...])

        with torch.no_grad():
            _probs = torch.sigmoid(out[:, 0, ...])
            dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])

        losses.update(loss.item(), imgs.size(0))
        losses1.update(loss1.item(), imgs.size(0))
        losses2.update(loss2.item(), imgs.size(0))
        losses3.update(loss3.item(), imgs.size(0))
        losses4.update(loss4.item(), imgs.size(0))

        dices.update(dice_sc, imgs.size(0))

        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Loss1 {loss1.val:.4f} ({loss1.avg:.4f}); Loss2 {loss2.val:.4f} ({loss2.avg:.4f}); Loss3 {loss3.val:.4f} ({loss3.avg:.4f}); Loss4 {loss4.val:.4f} ({loss4.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))
        
        optimizer.zero_grad()
        loss.backward()
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     scaled_loss.backward()
        optimizer.step()

#         scheduler.step()

    print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Loss1 {loss1.avg:.4f}; Loss2 {loss2.avg:.4f}; Loss3 {loss3.avg:.4f}; Loss4 {loss4.avg:.4f}; Dice {dice.avg:.4f}".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))




In [None]:
train_idxs0, val_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)

len(train_idxs0)

In [2]:
# for seed in [501,502,503,504]:


t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=True)

seed = int(501)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'res50_9ch_full_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)

np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = Res50_9ch_Unet() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 10, 16, 24, 28, 32], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 416 validation_steps 53


  cpuset_checked))
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)
epoch: 0; lr 0.0004000; Loss 1.6208 (1.7861); Loss1 0.8351 (0.9207); Loss2 0.9565 (1.0299); Loss3 0.2483 (0.3019); Loss4 0.0260 (0.0463); Dice 0.3596 (0.2464): 100%|██████████| 416/416 [04:02<00:00,  1.71it/s]   


epoch: 0; lr 0.0004000; Loss 1.7861; Loss1 0.9207; Loss2 1.0299; Loss3 0.3019; Loss4 0.0463; Dice 0.2464


100%|██████████| 54/54 [00:56<00:00,  1.06s/it]


Val Dice: 0.4389190179655486, 0.23421647231318846
dice: 0.4389190179655486	dice_best: 0.4389190179655486


epoch: 1; lr 0.0004000; Loss 1.2974 (1.5597); Loss1 0.6449 (0.8004); Loss2 0.9338 (0.9638); Loss3 0.1362 (0.2051); Loss4 0.0078 (0.0205); Dice 0.4602 (0.3656): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 1; lr 0.0004000; Loss 1.5597; Loss1 0.8004; Loss2 0.9638; Loss3 0.2051; Loss4 0.0205; Dice 0.3656


epoch: 2; lr 0.0004000; Loss 1.4044 (1.4752); Loss1 0.7116 (0.7493); Loss2 0.9157 (0.9512); Loss3 0.1522 (0.1924); Loss4 0.0111 (0.0135); Dice 0.4130 (0.4119): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 2; lr 0.0004000; Loss 1.4752; Loss1 0.7493; Loss2 0.9512; Loss3 0.1924; Loss4 0.0135; Dice 0.4119


100%|██████████| 54/54 [00:54<00:00,  1.00s/it]


Val Dice: 0.3767309159644874, 0.24459404900147355
dice: 0.3767309159644874	dice_best: 0.4389190179655486


epoch: 3; lr 0.0004000; Loss 1.3883 (1.4447); Loss1 0.7030 (0.7322); Loss2 0.9801 (0.9388); Loss3 0.1339 (0.1898); Loss4 0.0085 (0.0138); Dice 0.3966 (0.4298): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 3; lr 0.0004000; Loss 1.4447; Loss1 0.7322; Loss2 0.9388; Loss3 0.1898; Loss4 0.0138; Dice 0.4298


epoch: 4; lr 0.0001000; Loss 1.2241 (1.3523); Loss1 0.6089 (0.6787); Loss2 1.0009 (0.8995); Loss3 0.0700 (0.1755); Loss4 0.0049 (0.0089); Dice 0.4489 (0.4752): 100%|██████████| 416/416 [04:01<00:00,  1.73it/s]


epoch: 4; lr 0.0001000; Loss 1.3523; Loss1 0.6787; Loss2 0.8995; Loss3 0.1755; Loss4 0.0089; Dice 0.4752


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.5525091523367318, 0.3148542904278596
dice: 0.5525091523367318	dice_best: 0.5525091523367318


epoch: 5; lr 0.0002000; Loss 1.1759 (1.3192); Loss1 0.5811 (0.6603); Loss2 0.8554 (0.8590); Loss3 0.1261 (0.1720); Loss4 0.0093 (0.0092); Dice 0.5377 (0.4913): 100%|██████████| 416/416 [03:59<00:00,  1.73it/s]


epoch: 5; lr 0.0002000; Loss 1.3192; Loss1 0.6603; Loss2 0.8590; Loss3 0.1720; Loss4 0.0092; Dice 0.4913


epoch: 6; lr 0.0002000; Loss 1.2167 (1.2832); Loss1 0.6063 (0.6430); Loss2 0.8272 (0.7775); Loss3 0.1338 (0.1654); Loss4 0.0065 (0.0091); Dice 0.5317 (0.5071): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 6; lr 0.0002000; Loss 1.2832; Loss1 0.6430; Loss2 0.7775; Loss3 0.1654; Loss4 0.0091; Dice 0.5071


100%|██████████| 54/54 [00:56<00:00,  1.05s/it]


Val Dice: 0.6025176953108274, 0.41462543872533786
dice: 0.6025176953108274	dice_best: 0.6025176953108274


epoch: 7; lr 0.0002000; Loss 0.8711 (1.2837); Loss1 0.3903 (0.6452); Loss2 0.9972 (0.7556); Loss3 0.0352 (0.1619); Loss4 0.0033 (0.0087); Dice 0.6433 (0.5050): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 7; lr 0.0002000; Loss 1.2837; Loss1 0.6452; Loss2 0.7556; Loss3 0.1619; Loss4 0.0087; Dice 0.5050


epoch: 8; lr 0.0002000; Loss 1.0769 (1.2664); Loss1 0.5143 (0.6358); Loss2 0.7201 (0.7354); Loss3 0.1557 (0.1610); Loss4 0.0081 (0.0087); Dice 0.6170 (0.5134): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 8; lr 0.0002000; Loss 1.2664; Loss1 0.6358; Loss2 0.7354; Loss3 0.1610; Loss4 0.0087; Dice 0.5134


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6043738001190062, 0.4305285004398973
dice: 0.6043738001190062	dice_best: 0.6043738001190062


epoch: 9; lr 0.0002000; Loss 1.1019 (1.2412); Loss1 0.5413 (0.6214); Loss2 0.8011 (0.7201); Loss3 0.0784 (0.1578); Loss4 0.0043 (0.0083); Dice 0.5274 (0.5256): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 9; lr 0.0002000; Loss 1.2412; Loss1 0.6214; Loss2 0.7201; Loss3 0.1578; Loss4 0.0083; Dice 0.5256


epoch: 10; lr 0.0000500; Loss 1.1840 (1.1788); Loss1 0.5887 (0.5853); Loss2 0.8026 (0.6883); Loss3 0.1504 (0.1494); Loss4 0.0056 (0.0070); Dice 0.5761 (0.5544): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 10; lr 0.0000500; Loss 1.1788; Loss1 0.5853; Loss2 0.6883; Loss3 0.1494; Loss4 0.0070; Dice 0.5544


100%|██████████| 54/54 [00:54<00:00,  1.02s/it]


Val Dice: 0.615174081438334, 0.44197225398691836
dice: 0.615174081438334	dice_best: 0.615174081438334


epoch: 11; lr 0.0001000; Loss 1.1056 (1.1756); Loss1 0.5448 (0.5841); Loss2 0.5480 (0.6814); Loss3 0.1792 (0.1496); Loss4 0.0112 (0.0069); Dice 0.6008 (0.5569): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 11; lr 0.0001000; Loss 1.1756; Loss1 0.5841; Loss2 0.6814; Loss3 0.1496; Loss4 0.0069; Dice 0.5569


epoch: 12; lr 0.0001000; Loss 1.0186 (1.1617); Loss1 0.4982 (0.5761); Loss2 0.6591 (0.6801); Loss3 0.1402 (0.1477); Loss4 0.0086 (0.0070); Dice 0.6441 (0.5634): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 12; lr 0.0001000; Loss 1.1617; Loss1 0.5761; Loss2 0.6801; Loss3 0.1477; Loss4 0.0070; Dice 0.5634


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6056887562675264, 0.4511034569971642
dice: 0.6056887562675264	dice_best: 0.615174081438334


epoch: 13; lr 0.0001000; Loss 0.9298 (1.1437); Loss1 0.4237 (0.5649); Loss2 0.6409 (0.6690); Loss3 0.1165 (0.1467); Loss4 0.0040 (0.0067); Dice 0.6630 (0.5724): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 13; lr 0.0001000; Loss 1.1437; Loss1 0.5649; Loss2 0.6690; Loss3 0.1467; Loss4 0.0067; Dice 0.5724


epoch: 14; lr 0.0001000; Loss 1.0987 (1.1481); Loss1 0.5255 (0.5684); Loss2 0.6010 (0.6731); Loss3 0.1813 (0.1457); Loss4 0.0069 (0.0066); Dice 0.5997 (0.5698): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 14; lr 0.0001000; Loss 1.1481; Loss1 0.5684; Loss2 0.6731; Loss3 0.1457; Loss4 0.0066; Dice 0.5698


100%|██████████| 54/54 [00:56<00:00,  1.04s/it]


Val Dice: 0.6192063178805858, 0.47013670233814353
dice: 0.6192063178805858	dice_best: 0.6192063178805858


epoch: 15; lr 0.0001000; Loss 1.1979 (1.1357); Loss1 0.6031 (0.5616); Loss2 0.5730 (0.6661); Loss3 0.1669 (0.1452); Loss4 0.0085 (0.0069); Dice 0.5546 (0.5771): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 15; lr 0.0001000; Loss 1.1357; Loss1 0.5616; Loss2 0.6661; Loss3 0.1452; Loss4 0.0069; Dice 0.5771


epoch: 16; lr 0.0000250; Loss 1.3389 (1.1052); Loss1 0.6738 (0.5442); Loss2 0.8051 (0.6528); Loss3 0.1522 (0.1390); Loss4 0.0107 (0.0061); Dice 0.4365 (0.5889): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 16; lr 0.0000250; Loss 1.1052; Loss1 0.5442; Loss2 0.6528; Loss3 0.1390; Loss4 0.0061; Dice 0.5889


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6152673331286621, 0.46525743260317093
dice: 0.6152673331286621	dice_best: 0.6192063178805858


epoch: 17; lr 0.0000500; Loss 1.0640 (1.0904); Loss1 0.5172 (0.5348); Loss2 0.6900 (0.6458); Loss3 0.1869 (0.1381); Loss4 0.0100 (0.0060); Dice 0.6417 (0.5969): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 17; lr 0.0000500; Loss 1.0904; Loss1 0.5348; Loss2 0.6458; Loss3 0.1381; Loss4 0.0060; Dice 0.5969


epoch: 18; lr 0.0000500; Loss 0.9258 (1.0765); Loss1 0.4314 (0.5268); Loss2 0.5220 (0.6445); Loss3 0.1215 (0.1367); Loss4 0.0078 (0.0060); Dice 0.6378 (0.6031): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 18; lr 0.0000500; Loss 1.0765; Loss1 0.5268; Loss2 0.6445; Loss3 0.1367; Loss4 0.0060; Dice 0.6031


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6361943388693738, 0.4672155147162845
dice: 0.6361943388693738	dice_best: 0.6361943388693738


epoch: 19; lr 0.0000500; Loss 1.1760 (1.0896); Loss1 0.5842 (0.5348); Loss2 0.8547 (0.6507); Loss3 0.1542 (0.1390); Loss4 0.0067 (0.0061); Dice 0.5696 (0.5981): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]

epoch: 19; lr 0.0000500; Loss 1.0896; Loss1 0.5348; Loss2 0.6507; Loss3 0.1390; Loss4 0.0061; Dice 0.5981
Time: 89.435 min





In [4]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=True)

seed = int(502)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'res50_9ch_full_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)

np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = Res50_9ch_Unet() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 10, 16, 24, 28, 32], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 416 validation_steps 53


epoch: 0; lr 0.0004000; Loss 1.6731 (1.8108); Loss1 0.8894 (0.9350); Loss2 0.9802 (1.0446); Loss3 0.1054 (0.3004); Loss4 0.0419 (0.0730); Dice 0.2361 (0.2330): 100%|██████████| 416/416 [03:59<00:00,  1.73it/s] 


epoch: 0; lr 0.0004000; Loss 1.8108; Loss1 0.9350; Loss2 1.0446; Loss3 0.3004; Loss4 0.0730; Dice 0.2330


100%|██████████| 54/54 [00:53<00:00,  1.01it/s]


Val Dice: 0.43761595690850524, 0.19626812687767117
dice: 0.43761595690850524	dice_best: 0.43761595690850524


epoch: 1; lr 0.0004000; Loss 1.4670 (1.5626); Loss1 0.7331 (0.8033); Loss2 0.9511 (0.9697); Loss3 0.2417 (0.2030); Loss4 0.0202 (0.0165); Dice 0.4052 (0.3629): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 1; lr 0.0004000; Loss 1.5626; Loss1 0.8033; Loss2 0.9697; Loss3 0.2030; Loss4 0.0165; Dice 0.3629


epoch: 2; lr 0.0004000; Loss 1.6940 (1.4920); Loss1 0.8960 (0.7606); Loss2 0.9855 (0.9587); Loss3 0.1518 (0.1929); Loss4 0.0101 (0.0133); Dice 0.2479 (0.4028): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 2; lr 0.0004000; Loss 1.4920; Loss1 0.7606; Loss2 0.9587; Loss3 0.1929; Loss4 0.0133; Dice 0.4028


100%|██████████| 54/54 [00:53<00:00,  1.01it/s]


Val Dice: 0.5216132435859004, 0.21466523152628392
dice: 0.5216132435859004	dice_best: 0.5216132435859004


epoch: 3; lr 0.0004000; Loss 1.3833 (1.4606); Loss1 0.6851 (0.7421); Loss2 0.9709 (0.9536); Loss3 0.1900 (0.1880); Loss4 0.0122 (0.0140); Dice 0.4379 (0.4204): 100%|██████████| 416/416 [03:58<00:00,  1.74it/s]


epoch: 3; lr 0.0004000; Loss 1.4606; Loss1 0.7421; Loss2 0.9536; Loss3 0.1880; Loss4 0.0140; Dice 0.4204


epoch: 4; lr 0.0001000; Loss 1.0716 (1.3629); Loss1 0.4984 (0.6844); Loss2 0.9416 (0.9434); Loss3 0.1494 (0.1742); Loss4 0.0075 (0.0096); Dice 0.6013 (0.4708): 100%|██████████| 416/416 [03:58<00:00,  1.74it/s]


epoch: 4; lr 0.0001000; Loss 1.3629; Loss1 0.6844; Loss2 0.9434; Loss3 0.1742; Loss4 0.0096; Dice 0.4708


100%|██████████| 54/54 [00:55<00:00,  1.03s/it]


Val Dice: 0.6068701846226311, 0.27561388475186827
dice: 0.6068701846226311	dice_best: 0.6068701846226311


epoch: 5; lr 0.0002000; Loss 1.4246 (1.3299); Loss1 0.7272 (0.6659); Loss2 0.9584 (0.9412); Loss3 0.1545 (0.1703); Loss4 0.0100 (0.0095); Dice 0.4237 (0.4879): 100%|██████████| 416/416 [03:59<00:00,  1.73it/s]


epoch: 5; lr 0.0002000; Loss 1.3299; Loss1 0.6659; Loss2 0.9412; Loss3 0.1703; Loss4 0.0095; Dice 0.4879


epoch: 6; lr 0.0002000; Loss 1.2675 (1.3054); Loss1 0.6345 (0.6523); Loss2 0.9634 (0.9396); Loss3 0.0801 (0.1643); Loss4 0.0052 (0.0091); Dice 0.4356 (0.4991): 100%|██████████| 416/416 [03:58<00:00,  1.75it/s]


epoch: 6; lr 0.0002000; Loss 1.3054; Loss1 0.6523; Loss2 0.9396; Loss3 0.1643; Loss4 0.0091; Dice 0.4991


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6153639152114757, 0.28261687268035285
dice: 0.6153639152114757	dice_best: 0.6153639152114757


epoch: 7; lr 0.0002000; Loss 1.0962 (1.2928); Loss1 0.5258 (0.6452); Loss2 0.9135 (0.9355); Loss3 0.1308 (0.1636); Loss4 0.0064 (0.0094); Dice 0.5971 (0.5051): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 7; lr 0.0002000; Loss 1.2928; Loss1 0.6452; Loss2 0.9355; Loss3 0.1636; Loss4 0.0094; Dice 0.5051


epoch: 8; lr 0.0002000; Loss 1.1780 (1.2685); Loss1 0.5707 (0.6316); Loss2 0.9019 (0.9234); Loss3 0.1677 (0.1590); Loss4 0.0116 (0.0105); Dice 0.5517 (0.5165): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 8; lr 0.0002000; Loss 1.2685; Loss1 0.6316; Loss2 0.9234; Loss3 0.1590; Loss4 0.0105; Dice 0.5165


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.5710571711303104, 0.30411800024750857
dice: 0.5710571711303104	dice_best: 0.6153639152114757


epoch: 9; lr 0.0002000; Loss 1.3487 (1.2478); Loss1 0.6911 (0.6223); Loss2 0.6705 (0.8411); Loss3 0.0961 (0.1578); Loss4 0.0032 (0.0089); Dice 0.4557 (0.5241): 100%|██████████| 416/416 [03:58<00:00,  1.75it/s]


epoch: 9; lr 0.0002000; Loss 1.2478; Loss1 0.6223; Loss2 0.8411; Loss3 0.1578; Loss4 0.0089; Dice 0.5241


epoch: 10; lr 0.0000500; Loss 1.1434 (1.1913); Loss1 0.5607 (0.5916); Loss2 0.7538 (0.7523); Loss3 0.0790 (0.1515); Loss4 0.0061 (0.0081); Dice 0.5200 (0.5496): 100%|██████████| 416/416 [03:58<00:00,  1.74it/s]


epoch: 10; lr 0.0000500; Loss 1.1913; Loss1 0.5916; Loss2 0.7523; Loss3 0.1515; Loss4 0.0081; Dice 0.5496


100%|██████████| 54/54 [00:56<00:00,  1.05s/it]


Val Dice: 0.6369380109202774, 0.4351324383229635
dice: 0.6369380109202774	dice_best: 0.6369380109202774


epoch: 11; lr 0.0001000; Loss 1.4823 (1.1788); Loss1 0.7636 (0.5848); Loss2 0.8855 (0.7315); Loss3 0.2567 (0.1507); Loss4 0.0137 (0.0078); Dice 0.5002 (0.5574): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 11; lr 0.0001000; Loss 1.1788; Loss1 0.5848; Loss2 0.7315; Loss3 0.1507; Loss4 0.0078; Dice 0.5574


epoch: 12; lr 0.0001000; Loss 1.2184 (1.1618); Loss1 0.5882 (0.5750); Loss2 0.7420 (0.7104); Loss3 0.1889 (0.1475); Loss4 0.0084 (0.0076); Dice 0.5425 (0.5655): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 12; lr 0.0001000; Loss 1.1618; Loss1 0.5750; Loss2 0.7104; Loss3 0.1475; Loss4 0.0076; Dice 0.5655


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.646448685625396, 0.4457171132331898
dice: 0.646448685625396	dice_best: 0.646448685625396


epoch: 13; lr 0.0001000; Loss 1.0886 (1.1662); Loss1 0.5295 (0.5786); Loss2 0.6192 (0.7106); Loss3 0.1518 (0.1461); Loss4 0.0071 (0.0074); Dice 0.5801 (0.5600): 100%|██████████| 416/416 [03:58<00:00,  1.75it/s]


epoch: 13; lr 0.0001000; Loss 1.1662; Loss1 0.5786; Loss2 0.7106; Loss3 0.1461; Loss4 0.0074; Dice 0.5600


epoch: 14; lr 0.0001000; Loss 1.1211 (1.1478); Loss1 0.5431 (0.5681); Loss2 0.7083 (0.7012); Loss3 0.1620 (0.1442); Loss4 0.0057 (0.0072); Dice 0.5949 (0.5686): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 14; lr 0.0001000; Loss 1.1478; Loss1 0.5681; Loss2 0.7012; Loss3 0.1442; Loss4 0.0072; Dice 0.5686


100%|██████████| 54/54 [00:55<00:00,  1.03s/it]


Val Dice: 0.627742261404646, 0.4278515354713907
dice: 0.627742261404646	dice_best: 0.646448685625396


epoch: 15; lr 0.0001000; Loss 1.2016 (1.1284); Loss1 0.6003 (0.5568); Loss2 0.6399 (0.6895); Loss3 0.1891 (0.1423); Loss4 0.0094 (0.0081); Dice 0.5625 (0.5806): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 15; lr 0.0001000; Loss 1.1284; Loss1 0.5568; Loss2 0.6895; Loss3 0.1423; Loss4 0.0081; Dice 0.5806


epoch: 16; lr 0.0000250; Loss 0.9083 (1.1028); Loss1 0.4306 (0.5428); Loss2 0.6009 (0.6699); Loss3 0.0972 (0.1389); Loss4 0.0063 (0.0066); Dice 0.6532 (0.5909): 100%|██████████| 416/416 [03:59<00:00,  1.73it/s]


epoch: 16; lr 0.0000250; Loss 1.1028; Loss1 0.5428; Loss2 0.6699; Loss3 0.1389; Loss4 0.0066; Dice 0.5909


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6549078880839965, 0.4683904080882349
dice: 0.6549078880839965	dice_best: 0.6549078880839965


epoch: 17; lr 0.0000500; Loss 1.3537 (1.0942); Loss1 0.6974 (0.5372); Loss2 0.5348 (0.6700); Loss3 0.1597 (0.1373); Loss4 0.0070 (0.0066); Dice 0.4442 (0.5948): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]


epoch: 17; lr 0.0000500; Loss 1.0942; Loss1 0.5372; Loss2 0.6700; Loss3 0.1373; Loss4 0.0066; Dice 0.5948


epoch: 18; lr 0.0000500; Loss 1.1416 (1.0858); Loss1 0.5626 (0.5329); Loss2 0.7129 (0.6584); Loss3 0.1274 (0.1357); Loss4 0.0062 (0.0065); Dice 0.5698 (0.5989): 100%|██████████| 416/416 [04:00<00:00,  1.73it/s]


epoch: 18; lr 0.0000500; Loss 1.0858; Loss1 0.5329; Loss2 0.6584; Loss3 0.1357; Loss4 0.0065; Dice 0.5989


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6559400626543052, 0.4682354903433511
dice: 0.6559400626543052	dice_best: 0.6559400626543052


epoch: 19; lr 0.0000500; Loss 1.1050 (1.0784); Loss1 0.5519 (0.5281); Loss2 0.5892 (0.6652); Loss3 0.0927 (0.1369); Loss4 0.0082 (0.0065); Dice 0.5647 (0.6031): 100%|██████████| 416/416 [03:59<00:00,  1.74it/s]

epoch: 19; lr 0.0000500; Loss 1.0784; Loss1 0.5281; Loss2 0.6652; Loss3 0.1369; Loss4 0.0065; Dice 0.6031
Time: 89.025 min





In [5]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=True)

seed = int(503)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'res50_9ch_full_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)

np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = Res50_9ch_Unet() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 10, 16, 24, 28, 32], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 414 validation_steps 53


epoch: 0; lr 0.0004000; Loss 1.5822 (1.7759); Loss1 0.8152 (0.9219); Loss2 0.9837 (1.0106); Loss3 0.1988 (0.2495); Loss4 0.0092 (0.0956); Dice 0.3619 (0.2461): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 0; lr 0.0004000; Loss 1.7759; Loss1 0.9219; Loss2 1.0106; Loss3 0.2495; Loss4 0.0956; Dice 0.2461


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.32071871370708704, 0.2523400584067027
dice: 0.32071871370708704	dice_best: 0.32071871370708704


epoch: 1; lr 0.0004000; Loss 1.4757 (1.5630); Loss1 0.7463 (0.8040); Loss2 0.9607 (0.9654); Loss3 0.2070 (0.2017); Loss4 0.0328 (0.0167); Dice 0.3955 (0.3629): 100%|██████████| 414/414 [03:56<00:00,  1.75it/s]


epoch: 1; lr 0.0004000; Loss 1.5630; Loss1 0.8040; Loss2 0.9654; Loss3 0.2017; Loss4 0.0167; Dice 0.3629


epoch: 2; lr 0.0004000; Loss 1.3751 (1.4955); Loss1 0.6854 (0.7630); Loss2 0.9244 (0.9556); Loss3 0.2077 (0.1950); Loss4 0.0119 (0.0158); Dice 0.5286 (0.4035): 100%|██████████| 414/414 [03:56<00:00,  1.75it/s]


epoch: 2; lr 0.0004000; Loss 1.4955; Loss1 0.7630; Loss2 0.9556; Loss3 0.1950; Loss4 0.0158; Dice 0.4035


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.5571165340403977, 0.27961309835932546
dice: 0.5571165340403977	dice_best: 0.5571165340403977


epoch: 3; lr 0.0004000; Loss 1.4141 (1.4363); Loss1 0.7124 (0.7288); Loss2 0.9117 (0.9387); Loss3 0.1988 (0.1851); Loss4 0.0099 (0.0161); Dice 0.4484 (0.4335): 100%|██████████| 414/414 [03:56<00:00,  1.75it/s]


epoch: 3; lr 0.0004000; Loss 1.4363; Loss1 0.7288; Loss2 0.9387; Loss3 0.1851; Loss4 0.0161; Dice 0.4335


epoch: 4; lr 0.0001000; Loss 1.0787 (1.3608); Loss1 0.5185 (0.6853); Loss2 0.8662 (0.9301); Loss3 0.1143 (0.1737); Loss4 0.0052 (0.0093); Dice 0.5651 (0.4696): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 4; lr 0.0001000; Loss 1.3608; Loss1 0.6853; Loss2 0.9301; Loss3 0.1737; Loss4 0.0093; Dice 0.4696


100%|██████████| 54/54 [00:54<00:00,  1.02s/it]


Val Dice: 0.5835590051844379, 0.30205118957856103
dice: 0.5835590051844379	dice_best: 0.5835590051844379


epoch: 5; lr 0.0002000; Loss 1.2540 (1.3178); Loss1 0.6272 (0.6598); Loss2 0.9132 (0.9245); Loss3 0.1538 (0.1681); Loss4 0.0060 (0.0087); Dice 0.4946 (0.4917): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 5; lr 0.0002000; Loss 1.3178; Loss1 0.6598; Loss2 0.9245; Loss3 0.1681; Loss4 0.0087; Dice 0.4917


epoch: 6; lr 0.0002000; Loss 1.3723 (1.3011); Loss1 0.6921 (0.6508); Loss2 0.8517 (0.9155); Loss3 0.2234 (0.1657); Loss4 0.0177 (0.0093); Dice 0.4928 (0.5005): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 6; lr 0.0002000; Loss 1.3011; Loss1 0.6508; Loss2 0.9155; Loss3 0.1657; Loss4 0.0093; Dice 0.5005


100%|██████████| 54/54 [00:53<00:00,  1.00it/s]


Val Dice: 0.5688629440892038, 0.3088218330997515
dice: 0.5688629440892038	dice_best: 0.5835590051844379


epoch: 7; lr 0.0002000; Loss 1.3320 (1.2740); Loss1 0.6742 (0.6375); Loss2 0.7441 (0.8219); Loss3 0.1675 (0.1631); Loss4 0.0060 (0.0084); Dice 0.4667 (0.5111): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 7; lr 0.0002000; Loss 1.2740; Loss1 0.6375; Loss2 0.8219; Loss3 0.1631; Loss4 0.0084; Dice 0.5111


epoch: 8; lr 0.0002000; Loss 1.2986 (1.2574); Loss1 0.6530 (0.6305); Loss2 0.7097 (0.7553); Loss3 0.1781 (0.1597); Loss4 0.0078 (0.0087); Dice 0.4844 (0.5189): 100%|██████████| 414/414 [03:56<00:00,  1.75it/s]


epoch: 8; lr 0.0002000; Loss 1.2574; Loss1 0.6305; Loss2 0.7553; Loss3 0.1597; Loss4 0.0087; Dice 0.5189


100%|██████████| 54/54 [00:53<00:00,  1.00it/s]


Val Dice: 0.6137425071211107, 0.43462767129133933
dice: 0.6137425071211107	dice_best: 0.6137425071211107


epoch: 9; lr 0.0002000; Loss 1.2598 (1.2456); Loss1 0.6310 (0.6236); Loss2 0.6635 (0.7357); Loss3 0.2004 (0.1600); Loss4 0.0105 (0.0080); Dice 0.5367 (0.5236): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 9; lr 0.0002000; Loss 1.2456; Loss1 0.6236; Loss2 0.7357; Loss3 0.1600; Loss4 0.0080; Dice 0.5236


epoch: 10; lr 0.0000500; Loss 1.0692 (1.1864); Loss1 0.5155 (0.5896); Loss2 0.6799 (0.7086); Loss3 0.1903 (0.1504); Loss4 0.0078 (0.0071); Dice 0.6146 (0.5520): 100%|██████████| 414/414 [03:56<00:00,  1.75it/s]


epoch: 10; lr 0.0000500; Loss 1.1864; Loss1 0.5896; Loss2 0.7086; Loss3 0.1504; Loss4 0.0071; Dice 0.5520


100%|██████████| 54/54 [00:54<00:00,  1.00s/it]


Val Dice: 0.6379511497781952, 0.466400320991314
dice: 0.6379511497781952	dice_best: 0.6379511497781952


epoch: 11; lr 0.0001000; Loss 1.0058 (1.1751); Loss1 0.4810 (0.5832); Loss2 0.7592 (0.7000); Loss3 0.1270 (0.1498); Loss4 0.0091 (0.0070); Dice 0.6308 (0.5586): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 11; lr 0.0001000; Loss 1.1751; Loss1 0.5832; Loss2 0.7000; Loss3 0.1498; Loss4 0.0070; Dice 0.5586


epoch: 12; lr 0.0001000; Loss 1.1607 (1.1734); Loss1 0.5738 (0.5828); Loss2 0.7354 (0.6990); Loss3 0.1944 (0.1486); Loss4 0.0073 (0.0067); Dice 0.6016 (0.5576): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 12; lr 0.0001000; Loss 1.1734; Loss1 0.5828; Loss2 0.6990; Loss3 0.1486; Loss4 0.0067; Dice 0.5576


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6255479187664259, 0.46562806969371695
dice: 0.6255479187664259	dice_best: 0.6379511497781952


epoch: 13; lr 0.0001000; Loss 0.9824 (1.1451); Loss1 0.4770 (0.5666); Loss2 0.4280 (0.6720); Loss3 0.0657 (0.1445); Loss4 0.0027 (0.0064); Dice 0.5828 (0.5711): 100%|██████████| 414/414 [03:58<00:00,  1.74it/s]


epoch: 13; lr 0.0001000; Loss 1.1451; Loss1 0.5666; Loss2 0.6720; Loss3 0.1445; Loss4 0.0064; Dice 0.5711


epoch: 14; lr 0.0001000; Loss 1.2075 (1.1413); Loss1 0.5999 (0.5639); Loss2 0.6704 (0.6735); Loss3 0.2255 (0.1466); Loss4 0.0091 (0.0068); Dice 0.5741 (0.5745): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 14; lr 0.0001000; Loss 1.1413; Loss1 0.5639; Loss2 0.6735; Loss3 0.1466; Loss4 0.0068; Dice 0.5745


100%|██████████| 54/54 [00:53<00:00,  1.00it/s]


Val Dice: 0.6359615470797666, 0.47401787545197177
dice: 0.6359615470797666	dice_best: 0.6379511497781952


epoch: 15; lr 0.0001000; Loss 1.2846 (1.1298); Loss1 0.6507 (0.5582); Loss2 0.6686 (0.6681); Loss3 0.1625 (0.1444); Loss4 0.0072 (0.0066); Dice 0.5162 (0.5797): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 15; lr 0.0001000; Loss 1.1298; Loss1 0.5582; Loss2 0.6681; Loss3 0.1444; Loss4 0.0066; Dice 0.5797


epoch: 16; lr 0.0000250; Loss 1.0148 (1.0980); Loss1 0.5062 (0.5394); Loss2 0.5830 (0.6586); Loss3 0.0806 (0.1393); Loss4 0.0037 (0.0061); Dice 0.5958 (0.5948): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 16; lr 0.0000250; Loss 1.0980; Loss1 0.5394; Loss2 0.6586; Loss3 0.1393; Loss4 0.0061; Dice 0.5948


100%|██████████| 54/54 [00:56<00:00,  1.04s/it]


Val Dice: 0.6429544971655257, 0.4942235078059394
dice: 0.6429544971655257	dice_best: 0.6429544971655257


epoch: 17; lr 0.0000500; Loss 1.1706 (1.0935); Loss1 0.5842 (0.5369); Loss2 0.5977 (0.6505); Loss3 0.1864 (0.1390); Loss4 0.0072 (0.0059); Dice 0.6006 (0.5962): 100%|██████████| 414/414 [03:58<00:00,  1.73it/s]


epoch: 17; lr 0.0000500; Loss 1.0935; Loss1 0.5369; Loss2 0.6505; Loss3 0.1390; Loss4 0.0059; Dice 0.5962


epoch: 18; lr 0.0000500; Loss 1.1741 (1.0784); Loss1 0.5818 (0.5277); Loss2 0.6398 (0.6535); Loss3 0.1660 (0.1378); Loss4 0.0078 (0.0060); Dice 0.5674 (0.6037): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]


epoch: 18; lr 0.0000500; Loss 1.0784; Loss1 0.5277; Loss2 0.6535; Loss3 0.1378; Loss4 0.0060; Dice 0.6037


100%|██████████| 54/54 [00:54<00:00,  1.02s/it]


Val Dice: 0.6395814348821032, 0.492765815457997
dice: 0.6395814348821032	dice_best: 0.6429544971655257


epoch: 19; lr 0.0000500; Loss 1.1141 (1.0742); Loss1 0.5338 (0.5253); Loss2 0.7310 (0.6471); Loss3 0.2169 (0.1368); Loss4 0.0088 (0.0060); Dice 0.6251 (0.6058): 100%|██████████| 414/414 [03:57<00:00,  1.74it/s]

epoch: 19; lr 0.0000500; Loss 1.0742; Loss1 0.5253; Loss2 0.6471; Loss3 0.1368; Loss4 0.0060; Dice 0.6058
Time: 88.370 min





In [6]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=True)

seed = int(504)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'res50_9ch_full_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)

np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = Res50_9ch_Unet() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4)

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 10, 16, 24, 28, 32], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 409 validation_steps 53


epoch: 0; lr 0.0004000; Loss 1.5126 (1.7288); Loss1 0.7763 (0.8970); Loss2 0.9271 (1.0003); Loss3 0.1760 (0.2516); Loss4 0.0118 (0.0279); Dice 0.4100 (0.2721): 100%|██████████| 409/409 [03:57<00:00,  1.72it/s]


epoch: 0; lr 0.0004000; Loss 1.7288; Loss1 0.8970; Loss2 1.0003; Loss3 0.2516; Loss4 0.0279; Dice 0.2721


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.4581043478452984, 0.2383177570093458
dice: 0.4581043478452984	dice_best: 0.4581043478452984


epoch: 1; lr 0.0004000; Loss 1.3567 (1.5553); Loss1 0.6745 (0.7974); Loss2 0.9470 (0.9666); Loss3 0.1792 (0.2058); Loss4 0.0117 (0.0175); Dice 0.4696 (0.3721): 100%|██████████| 409/409 [03:56<00:00,  1.73it/s]


epoch: 1; lr 0.0004000; Loss 1.5553; Loss1 0.7974; Loss2 0.9666; Loss3 0.2058; Loss4 0.0175; Dice 0.3721


epoch: 2; lr 0.0004000; Loss 1.5063 (1.4749); Loss1 0.7725 (0.7485); Loss2 0.9442 (0.9573); Loss3 0.1833 (0.1964); Loss4 0.0062 (0.0123); Dice 0.4499 (0.4163): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 2; lr 0.0004000; Loss 1.4749; Loss1 0.7485; Loss2 0.9573; Loss3 0.1964; Loss4 0.0123; Dice 0.4163


100%|██████████| 54/54 [00:54<00:00,  1.02s/it]


Val Dice: 0.522144313987339, 0.277726236911255
dice: 0.522144313987339	dice_best: 0.522144313987339


epoch: 3; lr 0.0004000; Loss 1.3921 (1.4278); Loss1 0.6938 (0.7226); Loss2 0.9287 (0.9405); Loss3 0.2221 (0.1854); Loss4 0.0108 (0.0129); Dice 0.4843 (0.4391): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 3; lr 0.0004000; Loss 1.4278; Loss1 0.7226; Loss2 0.9405; Loss3 0.1854; Loss4 0.0129; Dice 0.4391


epoch: 4; lr 0.0001000; Loss 1.2874 (1.3414); Loss1 0.6403 (0.6716); Loss2 1.0080 (0.9211); Loss3 0.0901 (0.1749); Loss4 0.0041 (0.0091); Dice 0.4530 (0.4845): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 4; lr 0.0001000; Loss 1.3414; Loss1 0.6716; Loss2 0.9211; Loss3 0.1749; Loss4 0.0091; Dice 0.4845


100%|██████████| 54/54 [00:55<00:00,  1.03s/it]


Val Dice: 0.5626612460204592, 0.29323147849156395
dice: 0.5626612460204592	dice_best: 0.5626612460204592


epoch: 5; lr 0.0002000; Loss 1.1908 (1.3108); Loss1 0.5924 (0.6540); Loss2 0.8835 (0.9169); Loss3 0.0846 (0.1708); Loss4 0.0050 (0.0089); Dice 0.4702 (0.5002): 100%|██████████| 409/409 [03:54<00:00,  1.75it/s]


epoch: 5; lr 0.0002000; Loss 1.3108; Loss1 0.6540; Loss2 0.9169; Loss3 0.1708; Loss4 0.0089; Dice 0.5002


epoch: 6; lr 0.0002000; Loss 1.4012 (1.3057); Loss1 0.7128 (0.6515); Loss2 0.8823 (0.9251); Loss3 0.2057 (0.1695); Loss4 0.0108 (0.0085); Dice 0.4877 (0.5023): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 6; lr 0.0002000; Loss 1.3057; Loss1 0.6515; Loss2 0.9251; Loss3 0.1695; Loss4 0.0085; Dice 0.5023


100%|██████████| 54/54 [00:53<00:00,  1.01it/s]


Val Dice: 0.6419060473176655, 0.33962190732787545
dice: 0.6419060473176655	dice_best: 0.6419060473176655


epoch: 7; lr 0.0002000; Loss 1.1305 (1.2907); Loss1 0.5465 (0.6431); Loss2 0.8701 (0.9204); Loss3 0.1112 (0.1664); Loss4 0.0083 (0.0088); Dice 0.5351 (0.5093): 100%|██████████| 409/409 [03:53<00:00,  1.75it/s]


epoch: 7; lr 0.0002000; Loss 1.2907; Loss1 0.6431; Loss2 0.9204; Loss3 0.1664; Loss4 0.0088; Dice 0.5093


epoch: 8; lr 0.0002000; Loss 1.2252 (1.2734); Loss1 0.5942 (0.6328); Loss2 0.9088 (0.9141); Loss3 0.1954 (0.1636); Loss4 0.0074 (0.0081); Dice 0.5555 (0.5170): 100%|██████████| 409/409 [03:54<00:00,  1.74it/s]


epoch: 8; lr 0.0002000; Loss 1.2734; Loss1 0.6328; Loss2 0.9141; Loss3 0.1636; Loss4 0.0081; Dice 0.5170


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6373036278806558, 0.32227955244066436
dice: 0.6373036278806558	dice_best: 0.6419060473176655


epoch: 9; lr 0.0002000; Loss 1.2735 (1.2496); Loss1 0.6166 (0.6215); Loss2 0.8146 (0.8311); Loss3 0.2457 (0.1611); Loss4 0.0158 (0.0087); Dice 0.5144 (0.5267): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 9; lr 0.0002000; Loss 1.2496; Loss1 0.6215; Loss2 0.8311; Loss3 0.1611; Loss4 0.0087; Dice 0.5267


epoch: 10; lr 0.0000500; Loss 1.1913 (1.1986); Loss1 0.5939 (0.5946); Loss2 0.7898 (0.7365); Loss3 0.1475 (0.1532); Loss4 0.0098 (0.0077); Dice 0.5439 (0.5484): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 10; lr 0.0000500; Loss 1.1986; Loss1 0.5946; Loss2 0.7365; Loss3 0.1532; Loss4 0.0077; Dice 0.5484


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6609810432390342, 0.44307608720210173
dice: 0.6609810432390342	dice_best: 0.6609810432390342


epoch: 11; lr 0.0001000; Loss 1.1873 (1.1728); Loss1 0.5887 (0.5796); Loss2 0.6388 (0.7148); Loss3 0.1528 (0.1515); Loss4 0.0069 (0.0075); Dice 0.5504 (0.5614): 100%|██████████| 409/409 [03:56<00:00,  1.73it/s]


epoch: 11; lr 0.0001000; Loss 1.1728; Loss1 0.5796; Loss2 0.7148; Loss3 0.1515; Loss4 0.0075; Dice 0.5614


epoch: 12; lr 0.0001000; Loss 1.2358 (1.1631); Loss1 0.6172 (0.5744); Loss2 0.6749 (0.6962); Loss3 0.1138 (0.1509); Loss4 0.0050 (0.0073); Dice 0.4958 (0.5663): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 12; lr 0.0001000; Loss 1.1631; Loss1 0.5744; Loss2 0.6962; Loss3 0.1509; Loss4 0.0073; Dice 0.5663


100%|██████████| 54/54 [00:53<00:00,  1.00it/s]


Val Dice: 0.6561545628230413, 0.4607794062360531
dice: 0.6561545628230413	dice_best: 0.6609810432390342


epoch: 13; lr 0.0001000; Loss 1.1459 (1.1431); Loss1 0.5811 (0.5632); Loss2 0.7038 (0.6888); Loss3 0.0792 (0.1473); Loss4 0.0041 (0.0072); Dice 0.5120 (0.5754): 100%|██████████| 409/409 [03:54<00:00,  1.74it/s]


epoch: 13; lr 0.0001000; Loss 1.1431; Loss1 0.5632; Loss2 0.6888; Loss3 0.1473; Loss4 0.0072; Dice 0.5754


epoch: 14; lr 0.0001000; Loss 1.1027 (1.1490); Loss1 0.5469 (0.5679); Loss2 0.6585 (0.6824); Loss3 0.1249 (0.1471); Loss4 0.0069 (0.0070); Dice 0.5614 (0.5702): 100%|██████████| 409/409 [03:56<00:00,  1.73it/s]


epoch: 14; lr 0.0001000; Loss 1.1490; Loss1 0.5679; Loss2 0.6824; Loss3 0.1471; Loss4 0.0070; Dice 0.5702


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6590617338679889, 0.4710930282980724
dice: 0.6590617338679889	dice_best: 0.6609810432390342


epoch: 15; lr 0.0001000; Loss 1.2629 (1.1245); Loss1 0.6181 (0.5536); Loss2 0.9142 (0.6662); Loss3 0.1654 (0.1450); Loss4 0.0083 (0.0069); Dice 0.5276 (0.5841): 100%|██████████| 409/409 [03:54<00:00,  1.74it/s]


epoch: 15; lr 0.0001000; Loss 1.1245; Loss1 0.5536; Loss2 0.6662; Loss3 0.1450; Loss4 0.0069; Dice 0.5841


epoch: 16; lr 0.0000250; Loss 0.8369 (1.0967); Loss1 0.3749 (0.5370); Loss2 0.5671 (0.6558); Loss3 0.0705 (0.1419); Loss4 0.0031 (0.0064); Dice 0.6892 (0.5963): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]


epoch: 16; lr 0.0000250; Loss 1.0967; Loss1 0.5370; Loss2 0.6558; Loss3 0.1419; Loss4 0.0064; Dice 0.5963


100%|██████████| 54/54 [00:54<00:00,  1.01s/it]


Val Dice: 0.6682537701492888, 0.4840664393110517
dice: 0.6682537701492888	dice_best: 0.6682537701492888


epoch: 17; lr 0.0000500; Loss 1.1395 (1.0852); Loss1 0.5641 (0.5307); Loss2 0.7766 (0.6497); Loss3 0.1428 (0.1395); Loss4 0.0072 (0.0063); Dice 0.5739 (0.6016): 100%|██████████| 409/409 [03:56<00:00,  1.73it/s]


epoch: 17; lr 0.0000500; Loss 1.0852; Loss1 0.5307; Loss2 0.6497; Loss3 0.1395; Loss4 0.0063; Dice 0.6016


epoch: 18; lr 0.0000500; Loss 0.9619 (1.0774); Loss1 0.4542 (0.5259); Loss2 0.5511 (0.6504); Loss3 0.1534 (0.1385); Loss4 0.0080 (0.0062); Dice 0.6921 (0.6052): 100%|██████████| 409/409 [03:55<00:00,  1.73it/s]


epoch: 18; lr 0.0000500; Loss 1.0774; Loss1 0.5259; Loss2 0.6504; Loss3 0.1385; Loss4 0.0062; Dice 0.6052


100%|██████████| 54/54 [00:55<00:00,  1.02s/it]


Val Dice: 0.6757377565407933, 0.4774658058923596
dice: 0.6757377565407933	dice_best: 0.6757377565407933


epoch: 19; lr 0.0000500; Loss 1.1672 (1.0685); Loss1 0.5835 (0.5212); Loss2 0.5915 (0.6439); Loss3 0.1995 (0.1378); Loss4 0.0071 (0.0061); Dice 0.6042 (0.6095): 100%|██████████| 409/409 [03:55<00:00,  1.74it/s]

epoch: 19; lr 0.0000500; Loss 1.0685; Loss1 0.5212; Loss2 0.6439; Loss3 0.1378; Loss4 0.0061; Dice 0.6095
Time: 87.679 min





In [None]:
import os
os.environ["MKL_NUM_THREADS"] = "8" 
os.environ["NUMEXPR_NUM_THREADS"] = "8" 
os.environ["OMP_NUM_THREADS"] = "8" 

from os import path, makedirs, listdir
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

# from apex import amp

from adamw import AdamW
from losses import dice_round, ComboLoss

import pandas as pd
from tqdm import tqdm
import timeit
import cv2

from zoo.models import SeResNext50_Unet_9ch

from imgaug import augmenters as iaa

from utils import *

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

train_png = './wdata/train_png'
train_png2 = './wdata/train_png_5_3_0'
train_png3 = './wdata/train_png_pan_6_7'

masks_dir = './wdata/masks'

models_folder = './wdata/weights'

speed_bins = [15, 18.75, 20, 25, 30, 35, 45, 55, 65]

cities = [('AOI_7_Moscow', '/fs/scratch/PCON0003/osu10670/train/AOI_7_Moscow', 'train_AOI_7_Moscow_geojson_roads_speed_wkt_weighted_simp.csv'),
          ('AOI_8_Mumbai', '/fs/scratch/PCON0003/osu10670/train/AOI_8_Mumbai', 'train_AOI_8_Mumbai_geojson_roads_speed_wkt_weighted_simp.csv'),]

cities_idxs = {}
for i in range(len(cities)):
    cities_idxs[cities[i][0]] = i


input_shape = (640, 640) # (384, 384)

train_files = []
for f in listdir(train_png):
    if '.png' in f:
        train_files.append(f)


class TrainData(Dataset):
    def __init__(self, train_idxs):
        super().__init__()
        self.train_idxs = train_idxs
        self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2)

    def __len__(self):
        return len(self.train_idxs)

    def __getitem__(self, idx):
        _idx = self.train_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.5)) or (random.random() > 0.75):
            img = img[::-1, ...]
            img2 = img2[::-1, ...]
            img3 = img3[::-1, ...]
            msk0 = msk0[::-1, ...]
            msk1 = msk1[::-1, ...]
            msk2 = msk2[::-1, ...]
            msk3 = msk3[::-1, ...]
            msk4 = msk4[::-1, ...]

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.5)) or (random.random() > 0.75):
            rot = random.randrange(4)
            if rot > 0:
                img = np.rot90(img, k=rot)
                img2 = np.rot90(img2, k=rot)
                img3 = np.rot90(img3, k=rot)
                msk0 = np.rot90(msk0, k=rot)
                msk1 = np.rot90(msk1, k=rot)
                msk2 = np.rot90(msk2, k=rot)
                msk3 = np.rot90(msk3, k=rot)
                msk4 = np.rot90(msk4, k=rot)
                    
        if random.random() > 0.9:
            shift_pnt = (random.randint(-320, 320), random.randint(-320, 320))
            img = shift_image(img, shift_pnt)
            img2 = shift_image(img2, shift_pnt)
            img3 = shift_image(img3, shift_pnt)
            msk0 = shift_image(msk0, shift_pnt)
            msk1 = shift_image(msk1, shift_pnt)
            msk2 = shift_image(msk2, shift_pnt)
            msk3 = shift_image(msk3, shift_pnt)
            msk4 = shift_image(msk4, shift_pnt)
            
        if random.random() > 0.9:
            rot_pnt =  (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320))
            scale = 0.9 + random.random() * 0.2
            angle = random.randint(0, 20) - 10
            if (angle != 0) or (scale != 1):
                img = rotate_image(img, angle, scale, rot_pnt)
                img2 = rotate_image(img2, angle, scale, rot_pnt)
                img3 = rotate_image(img3, angle, scale, rot_pnt)
                msk0 = rotate_image(msk0, angle, scale, rot_pnt)
                msk1 = rotate_image(msk1, angle, scale, rot_pnt)
                msk2 = rotate_image(msk2, angle, scale, rot_pnt)
                msk3 = rotate_image(msk3, angle, scale, rot_pnt)
                msk4 = rotate_image(msk4, angle, scale, rot_pnt)

        crop_size = input_shape[0]
        if random.random() > 0.9:
            crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9))

        x0 = random.randint(0, img.shape[1] - crop_size)
        y0 = random.randint(0, img.shape[0] - crop_size)

        img = img[y0:y0+crop_size, x0:x0+crop_size, :]
        img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :]
        img3 = img3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size, :]
        msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size, :]
        msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size, :]
        msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size]
        

        if crop_size != input_shape[0]:
            img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR)
            img3 = cv2.resize(img3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR)
            msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR)
            msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR)
            msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR)
            
        if random.random() > 0.95:
            img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.95:
            img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.95:
            img3 = shift_channels(img3, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.95:
            img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.95:
            if random.random() > 0.95:
                img = clahe(img)
            elif random.random() > 0.95:
                img = gauss_noise(img)
            elif random.random() > 0.95:
                img = cv2.blur(img, (3, 3))
        elif random.random() > 0.95:
            if random.random() > 0.95:
                img = saturation(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = brightness(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = contrast(img, 0.9 + random.random() * 0.2)

        if random.random() > 0.97:
            el_det = self.elastic.to_deterministic()
            img = el_det.augment_image(img)

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample


    
class ValData(Dataset):
    def __init__(self, image_idxs):
        super().__init__()
        self.image_idxs = image_idxs

    def __len__(self):
        return len(self.image_idxs)

    def __getitem__(self, idx):
        _idx = self.image_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)
        img = np.pad(img, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img2 = np.pad(img2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img3 = np.pad(img3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk0 = np.pad(msk0, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk1 = np.pad(msk1, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk2 = np.pad(msk2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk3 = np.pad(msk3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk4 = np.pad(msk4, ((6, 6), (6, 6)), mode='reflect')

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample




def validate(net, data_loader):
    dices0 = []
    dices1 = []

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            msks = sample["msk"].numpy()
            imgs = sample["img"].cuda(non_blocking=True)
            img_ids =  sample["img_id"]
            
            out = model(imgs)
            
            msk_pred = torch.sigmoid(out[:, :2, ...]).cpu().numpy()
            speed_cont_pred = out[:, 2, ...].cpu().numpy()
            speed_cont_pred[speed_cont_pred < 0] = 0
            speed_cont_pred[speed_cont_pred > 1] = 1
            msk_speed_pred = torch.softmax(out[:, 3:, ...], dim=1).cpu().numpy()

            pred = msk_pred > 0.5
            for j in range(msks.shape[0]):
                dices0.append(dice(msks[j, 0], pred[j, 0]))
                dices1.append(dice(msks[j, 1], pred[j, 1]))
                
                # pred_img0 = np.concatenate([msk_pred[j, 0, ..., np.newaxis], msk_pred[j, 1, ..., np.newaxis], np.zeros_like(msk_pred[j, 0, ..., np.newaxis])], axis=2)
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j]), (pred_img0 * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed0.png')), (msk_speed_pred[j, :3].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed1.png')), (msk_speed_pred[j, 3:6].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed2.png')), (msk_speed_pred[j, 6:9].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed_cont.png')), (speed_cont_pred[j] * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

    d0 = np.mean(dices0)
    d1 = np.mean(dices1)

    print("Val Dice: {}, {}".format(d0, d1))
    return d0



def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch):
    model = model.eval()
    d = validate(model, data_loader=data_val)

    if d > best_score:
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'best_score': d,
        }, path.join(models_folder, snapshot_name + '_best'))
        best_score = d

    print("dice: {}\tdice_best: {}".format(d, best_score))
    return best_score



def train_epoch(current_epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses3 = AverageMeter()
    losses4 = AverageMeter()

    dices = AverageMeter()

    iterator = tqdm(train_data_loader)
    model.train()
    scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["img"].cuda(non_blocking=True)
        msks = sample["msk"].cuda(non_blocking=True)
        msks_speed = sample["msk_speed"].cuda(non_blocking=True)
        lbls_speed = sample["lbl_speed"].cuda(non_blocking=True)
        msks_speed_cont = sample["msk_speed_cont"].cuda(non_blocking=True)
        
        out = model(imgs)

        loss1 = seg_loss(out[:, 0, ...], msks[:, 0, ...])
        loss2 = seg_loss(out[:, 1, ...], msks[:, 1, ...])

        loss3 = ce_loss(out[:, 3:, ...], lbls_speed)

        loss4 = mse_loss(out[:, 2:3, ...], msks_speed_cont)

        loss = 1.2 * loss1 + 0.05 * loss2 + 0.2 * loss3 + 0.1 * loss4

        for _i in range(3, 13):
            loss += 0.03 * seg_loss(out[:, _i, ...], msks_speed[:, _i-3, ...])

        with torch.no_grad():
            _probs = torch.sigmoid(out[:, 0, ...])
            dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])

        losses.update(loss.item(), imgs.size(0))
        losses1.update(loss1.item(), imgs.size(0))
        losses2.update(loss2.item(), imgs.size(0))
        losses3.update(loss3.item(), imgs.size(0))
        losses4.update(loss4.item(), imgs.size(0))

        dices.update(dice_sc, imgs.size(0))

        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Loss1 {loss1.val:.4f} ({loss1.avg:.4f}); Loss2 {loss2.val:.4f} ({loss2.avg:.4f}); Loss3 {loss3.val:.4f} ({loss3.avg:.4f}); Loss4 {loss4.val:.4f} ({loss4.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))
        
        optimizer.zero_grad()
        loss.backward()
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     scaled_loss.backward()
        optimizer.step()

#         scheduler.step()

    print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Loss1 {loss1.avg:.4f}; Loss2 {loss2.avg:.4f}; Loss3 {loss3.avg:.4f}; Loss4 {loss4.avg:.4f}; Dice {dice.avg:.4f}".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))




if __name__ == '__main__':
    t0 = timeit.default_timer()

    makedirs(models_folder, exist_ok=True)
    # makedirs(val_output_folder, exist_ok=True)
    
    seed = int(3)
    vis_dev = '0,1,2,3'

    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

    cudnn.benchmark = True

    batch_size = 6
    val_batch_size = 4

    snapshot_name = 'seres50_9ch_{}_0'.format(seed)

    train_idxs0, val_idxs = train_test_split(np.arange(len(train_files)), test_size=0.05, random_state=seed)

    np.random.seed(seed)
    random.seed(seed)

    train_idxs = []
    for i in train_idxs0:
        train_idxs.append(i)
        if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
            train_idxs.append(i)
        if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
            train_idxs.append(i)
        if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
            train_idxs.append(i)
    train_idxs = np.asarray(train_idxs)


    steps_per_epoch = len(train_idxs) // batch_size
    validation_steps = len(val_idxs) // val_batch_size

    print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

    data_train = TrainData(train_idxs)
    val_train = ValData(val_idxs)

    train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=16, shuffle=True, pin_memory=True, drop_last=True)
    val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=16, shuffle=False, pin_memory=False)

    model = SeResNext50_Unet_9ch() #.cuda()

    params = model.parameters()

    optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4) 
    
    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6, 12, 18, 24, 26], gamma=0.5)

    model = nn.DataParallel(model).cuda()


    seg_loss = ComboLoss({'dice': 1.0, 'focal': 4.0}, per_image=True).cuda()
    ce_loss = nn.CrossEntropyLoss().cuda()
    mse_loss = nn.MSELoss().cuda()

    best_score = 0
    _cnt = -1
    for epoch in range(27):
        train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
        if epoch % 2 == 0:
            _cnt += 1
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_score': best_score,
            }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
            torch.cuda.empty_cache()
            best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)

    elapsed = timeit.default_timer() - t0
    print('Time: {:.3f} min'.format(elapsed / 60))

