In [1]:
import os
os.environ["MKL_NUM_THREADS"] = "8" 
os.environ["NUMEXPR_NUM_THREADS"] = "8" 
os.environ["OMP_NUM_THREADS"] = "8" 

from os import path, makedirs, listdir
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

# from apex import amp

from adamw import AdamW
from losses import dice_round, ComboLoss

import pandas as pd
from tqdm import tqdm
import timeit
import cv2

from zoo.models import Res34_9ch_Unet
from zoo.models import Res50_9ch_Unet



from imgaug import augmenters as iaa

from utils import *

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score

import matplotlib.pyplot as plt
import seaborn as sns

import gc

cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_png = './wdata/train_png'
train_png2 = './wdata/train_png_5_3_0'
train_png3 = './wdata/train_png_pan_6_7'

masks_dir = './wdata/masks'

models_folder = './wdata/weights'
# val_output_folder = 'res34_9ch_val_0'

speed_bins = [15, 18.75, 20, 25, 30, 35, 45, 55, 65]

cities = [('AOI_7_Moscow', '/fs/scratch/PCON0003/osu10670/SN5_roads/AOI_7_Moscow_train', 'train_AOI_7_Moscow_geojson_roads_speed_wkt_weighted_simp.csv'), ('AOI_8_Mumbai', '/fs/scratch/PCON0003/osu10670/SN5_roads/AOI_8_Mumbai_train', 'train_AOI_8_Mumbai_geojson_roads_speed_wkt_weighted_simp.csv')]
          
cities_idxs = {}
for i in range(len(cities)):
    cities_idxs[cities[i][0]] = i


input_shape = (704, 704) # (384, 384)

train_files = []
for f in listdir(train_png):
    if '.png' in f:
        train_files.append(f)




In [3]:
class TrainData(Dataset):
    def __init__(self, train_idxs):
        super().__init__()
        self.train_idxs = train_idxs
        self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2)

    def __len__(self):
        return len(self.train_idxs)

    def __getitem__(self, idx):
        _idx = self.train_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)

        #TODO finally finetune Moscow (only?) without flips and rotations!
        

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.8)) or (random.random() > 0.9):
            img = img[::-1, ...]
            img2 = img2[::-1, ...]
            img3 = img3[::-1, ...]
            msk0 = msk0[::-1, ...]
            msk1 = msk1[::-1, ...]
            msk2 = msk2[::-1, ...]
            msk3 = msk3[::-1, ...]
            msk4 = msk4[::-1, ...]

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.8)) or (random.random() > 0.9):
            rot = random.randrange(4)
            if rot > 0:
                img = np.rot90(img, k=rot)
                img2 = np.rot90(img2, k=rot)
                img3 = np.rot90(img3, k=rot)
                msk0 = np.rot90(msk0, k=rot)
                msk1 = np.rot90(msk1, k=rot)
                msk2 = np.rot90(msk2, k=rot)
                msk3 = np.rot90(msk3, k=rot)
                msk4 = np.rot90(msk4, k=rot)
                    
        if random.random() > 0.95:
            shift_pnt = (random.randint(-320, 320), random.randint(-320, 320))
            img = shift_image(img, shift_pnt)
            img2 = shift_image(img2, shift_pnt)
            img3 = shift_image(img3, shift_pnt)
            msk0 = shift_image(msk0, shift_pnt)
            msk1 = shift_image(msk1, shift_pnt)
            msk2 = shift_image(msk2, shift_pnt)
            msk3 = shift_image(msk3, shift_pnt)
            msk4 = shift_image(msk4, shift_pnt)
            
        if random.random() > 0.95:
            rot_pnt =  (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320))
            scale = 0.9 + random.random() * 0.2
            angle = random.randint(0, 20) - 10
            if (angle != 0) or (scale != 1):
                img = rotate_image(img, angle, scale, rot_pnt)
                img2 = rotate_image(img2, angle, scale, rot_pnt)
                img3 = rotate_image(img3, angle, scale, rot_pnt)
                msk0 = rotate_image(msk0, angle, scale, rot_pnt)
                msk1 = rotate_image(msk1, angle, scale, rot_pnt)
                msk2 = rotate_image(msk2, angle, scale, rot_pnt)
                msk3 = rotate_image(msk3, angle, scale, rot_pnt)
                msk4 = rotate_image(msk4, angle, scale, rot_pnt)

        crop_size = input_shape[0]
        if random.random() > 0.95:
            crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9))

        x0 = random.randint(0, img.shape[1] - crop_size)
        y0 = random.randint(0, img.shape[0] - crop_size)

        img = img[y0:y0+crop_size, x0:x0+crop_size, :]
        img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :]
        img3 = img3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size, :]
        msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size, :]
        msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size, :]
        msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size]
        

        if crop_size != input_shape[0]:
            img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR)
            img3 = cv2.resize(img3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR)
            msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR)
            msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR)
            msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR)
            
        if random.random() > 0.97:
            img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.97:
            img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.97:
            img3 = shift_channels(img3, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.97:
            img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.97:
            if random.random() > 0.95:
                img = clahe(img)
            elif random.random() > 0.95:
                img = gauss_noise(img)
            elif random.random() > 0.95:
                img = cv2.blur(img, (3, 3))
        elif random.random() > 0.97:
            if random.random() > 0.95:
                img = saturation(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = brightness(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = contrast(img, 0.9 + random.random() * 0.2)

        if random.random() > 0.98:
            el_det = self.elastic.to_deterministic()
            img = el_det.augment_image(img)

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample


    
class ValData(Dataset):
    def __init__(self, image_idxs):
        super().__init__()
        self.image_idxs = image_idxs

    def __len__(self):
        return len(self.image_idxs)

    def __getitem__(self, idx):
        _idx = self.image_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)
        img = np.pad(img, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img2 = np.pad(img2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img3 = np.pad(img3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk0 = np.pad(msk0, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk1 = np.pad(msk1, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk2 = np.pad(msk2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk3 = np.pad(msk3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk4 = np.pad(msk4, ((6, 6), (6, 6)), mode='reflect')

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample




def validate(net, data_loader):
    dices0 = []
    dices1 = []

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            msks = sample["msk"].numpy()
            imgs = sample["img"].cuda(non_blocking=True)
            img_ids =  sample["img_id"]
            
            out = model(imgs)
            
            msk_pred = torch.sigmoid(out[:, :2, ...]).cpu().numpy()
            speed_cont_pred = out[:, 2, ...].cpu().numpy()
            speed_cont_pred[speed_cont_pred < 0] = 0
            speed_cont_pred[speed_cont_pred > 1] = 1
            msk_speed_pred = torch.softmax(out[:, 3:, ...], dim=1).cpu().numpy()

            pred = msk_pred > 0.5
            for j in range(msks.shape[0]):
                dices0.append(dice(msks[j, 0], pred[j, 0]))
                dices1.append(dice(msks[j, 1], pred[j, 1]))
                
                # pred_img0 = np.concatenate([msk_pred[j, 0, ..., np.newaxis], msk_pred[j, 1, ..., np.newaxis], np.zeros_like(msk_pred[j, 0, ..., np.newaxis])], axis=2)
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j]), (pred_img0 * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed0.png')), (msk_speed_pred[j, :3].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed1.png')), (msk_speed_pred[j, 3:6].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed2.png')), (msk_speed_pred[j, 6:9].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed_cont.png')), (speed_cont_pred[j] * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

    d0 = np.mean(dices0)
    d1 = np.mean(dices1)

    print("Val Dice: {}, {}".format(d0, d1))
    return d0



def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch):
    model = model.eval()
    d = validate(model, data_loader=data_val)

    if d > best_score:
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'best_score': d,
        }, path.join(models_folder, snapshot_name + '_best'))
        best_score = d

    print("dice: {}\tdice_best: {}".format(d, best_score))
    return best_score



def train_epoch(current_epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses3 = AverageMeter()
    losses4 = AverageMeter()

    dices = AverageMeter()

    iterator = tqdm(train_data_loader)
    model.train()
    scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["img"].cuda(non_blocking=True)
        msks = sample["msk"].cuda(non_blocking=True)
        msks_speed = sample["msk_speed"].cuda(non_blocking=True)
        lbls_speed = sample["lbl_speed"].cuda(non_blocking=True)
        msks_speed_cont = sample["msk_speed_cont"].cuda(non_blocking=True)
        

        out = model(imgs)

        loss1 = seg_loss(out[:, 0, ...], msks[:, 0, ...])
        loss2 = seg_loss(out[:, 1, ...], msks[:, 1, ...])

        loss3 = ce_loss(out[:, 3:, ...], lbls_speed)

        loss4 = mse_loss(out[:, 2:3, ...], msks_speed_cont)

        loss = 1.5 * loss1 + 0.05 * loss2 + 0.2 * loss3 + 0.1 * loss4

        for _i in range(3, 13):
            loss += 0.03 * seg_loss(out[:, _i, ...], msks_speed[:, _i-3, ...])

        with torch.no_grad():
            _probs = torch.sigmoid(out[:, 0, ...])
            dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])

        losses.update(loss.item(), imgs.size(0))
        losses1.update(loss1.item(), imgs.size(0))
        losses2.update(loss2.item(), imgs.size(0))
        losses3.update(loss3.item(), imgs.size(0))
        losses4.update(loss4.item(), imgs.size(0))

        dices.update(dice_sc, imgs.size(0))

        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Loss1 {loss1.val:.4f} ({loss1.avg:.4f}); Loss2 {loss2.val:.4f} ({loss2.avg:.4f}); Loss3 {loss3.val:.4f} ({loss3.avg:.4f}); Loss4 {loss4.val:.4f} ({loss4.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))
        
        optimizer.zero_grad()
        loss.backward()
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     scaled_loss.backward()
        optimizer.step()

#         scheduler.step()

    print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Loss1 {loss1.avg:.4f}; Loss2 {loss2.avg:.4f}; Loss3 {loss3.avg:.4f}; Loss4 {loss4.avg:.4f}; Dice {dice.avg:.4f}".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))


In [None]:
train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)



In [None]:
test_idxs

In [None]:
test_idxs

In [None]:
# if __name__ == '__main__':
def train(seed):
    t0 = timeit.default_timer()

    makedirs(models_folder, exist_ok=True)
    # makedirs(val_output_folder, exist_ok=True)

    # seed = int(7)
    vis_dev = '0,1,2,3'

    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

    cudnn.benchmark = True

    batch_size = 6
    val_batch_size = 4

    snapshot_name = 'res50_9ch_full_{}_0'.format(seed)
    
    train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=8)
    train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)


    np.random.seed(seed)
    random.seed(seed)

    train_idxs = []
    for i in train_idxs0:
        train_idxs.append(i)
        if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
            train_idxs.append(i)
        if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
            train_idxs.append(i)
        if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
            train_idxs.append(i)
    train_idxs = np.asarray(train_idxs)


    steps_per_epoch = len(train_idxs) // batch_size
    validation_steps = len(val_idxs) // val_batch_size

    print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

    data_train = TrainData(train_idxs)
    val_train = ValData(val_idxs)

    train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
    val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

    model = Res50_9ch_Unet() #.cuda()

    params = model.parameters()

    optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4)

    scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[4, 10, 16, 24, 28, 32], gamma=0.5)

    model = nn.DataParallel(model).cuda()


    seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
    ce_loss = nn.CrossEntropyLoss().cuda()
    mse_loss = nn.MSELoss().cuda()

    best_score = 0
    _cnt = -1
    for epoch in range(34):
        train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
        if epoch % 2 == 0:
            _cnt += 1
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_score': best_score,
            }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
            torch.cuda.empty_cache()
            best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)

    elapsed = timeit.default_timer() - t0
    print('Time: {:.3f} min'.format(elapsed / 60))


In [None]:
x = torch.tensor([0, 1, 2, 3, 4])
torch.save(x, 'tensor.pt')
# >>> # Save to io.BytesIO buffer
# >>> buffer = io.BytesIO()
# >>> torch.save(x, buffer)


## SeResnet50+Unet

In [4]:
import os
os.environ["MKL_NUM_THREADS"] = "8" 
os.environ["NUMEXPR_NUM_THREADS"] = "8" 
os.environ["OMP_NUM_THREADS"] = "8" 

from os import path, makedirs, listdir
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.optim.lr_scheduler as lr_scheduler

# from apex import amp

from adamw import AdamW
from losses import dice_round, ComboLoss

import pandas as pd
from tqdm import tqdm
import timeit
import cv2

from zoo.models import SeResNext50_Unet_9ch

from imgaug import augmenters as iaa

from utils import *

from sklearn.model_selection import train_test_split

from sklearn.metrics import accuracy_score
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)

train_png = './wdata/train_png'
train_png2 = './wdata/train_png_5_3_0'
train_png3 = './wdata/train_png_pan_6_7'

masks_dir = './wdata/masks'

models_folder = './wdata/weights'

speed_bins = [15, 18.75, 20, 25, 30, 35, 45, 55, 65]

cities = [('AOI_7_Moscow', '/fs/scratch/PCON0003/osu10670/train/AOI_7_Moscow', 'train_AOI_7_Moscow_geojson_roads_speed_wkt_weighted_simp.csv'),
          ('AOI_8_Mumbai', '/fs/scratch/PCON0003/osu10670/train/AOI_8_Mumbai', 'train_AOI_8_Mumbai_geojson_roads_speed_wkt_weighted_simp.csv'),]

cities_idxs = {}
for i in range(len(cities)):
    cities_idxs[cities[i][0]] = i


input_shape = (640, 640) # (384, 384)

train_files = []
for f in listdir(train_png):
    if '.png' in f:
        train_files.append(f)


class TrainData(Dataset):
    def __init__(self, train_idxs):
        super().__init__()
        self.train_idxs = train_idxs
        self.elastic = iaa.ElasticTransformation(alpha=(0.25, 1.2), sigma=0.2)

    def __len__(self):
        return len(self.train_idxs)

    def __getitem__(self, idx):
        _idx = self.train_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.5)) or (random.random() > 0.75):
            img = img[::-1, ...]
            img2 = img2[::-1, ...]
            img3 = img3[::-1, ...]
            msk0 = msk0[::-1, ...]
            msk1 = msk1[::-1, ...]
            msk2 = msk2[::-1, ...]
            msk3 = msk3[::-1, ...]
            msk4 = msk4[::-1, ...]

        if (('Moscow' not in img_id) and ('Mumbai' not in img_id) and (random.random() > 0.5)) or (random.random() > 0.75):
            rot = random.randrange(4)
            if rot > 0:
                img = np.rot90(img, k=rot)
                img2 = np.rot90(img2, k=rot)
                img3 = np.rot90(img3, k=rot)
                msk0 = np.rot90(msk0, k=rot)
                msk1 = np.rot90(msk1, k=rot)
                msk2 = np.rot90(msk2, k=rot)
                msk3 = np.rot90(msk3, k=rot)
                msk4 = np.rot90(msk4, k=rot)
                    
        if random.random() > 0.9:
            shift_pnt = (random.randint(-320, 320), random.randint(-320, 320))
            img = shift_image(img, shift_pnt)
            img2 = shift_image(img2, shift_pnt)
            img3 = shift_image(img3, shift_pnt)
            msk0 = shift_image(msk0, shift_pnt)
            msk1 = shift_image(msk1, shift_pnt)
            msk2 = shift_image(msk2, shift_pnt)
            msk3 = shift_image(msk3, shift_pnt)
            msk4 = shift_image(msk4, shift_pnt)
            
        if random.random() > 0.9:
            rot_pnt =  (img.shape[0] // 2 + random.randint(-320, 320), img.shape[1] // 2 + random.randint(-320, 320))
            scale = 0.9 + random.random() * 0.2
            angle = random.randint(0, 20) - 10
            if (angle != 0) or (scale != 1):
                img = rotate_image(img, angle, scale, rot_pnt)
                img2 = rotate_image(img2, angle, scale, rot_pnt)
                img3 = rotate_image(img3, angle, scale, rot_pnt)
                msk0 = rotate_image(msk0, angle, scale, rot_pnt)
                msk1 = rotate_image(msk1, angle, scale, rot_pnt)
                msk2 = rotate_image(msk2, angle, scale, rot_pnt)
                msk3 = rotate_image(msk3, angle, scale, rot_pnt)
                msk4 = rotate_image(msk4, angle, scale, rot_pnt)

        crop_size = input_shape[0]
        if random.random() > 0.9:
            crop_size = random.randint(int(input_shape[0] / 1.1), int(input_shape[0] / 0.9))

        x0 = random.randint(0, img.shape[1] - crop_size)
        y0 = random.randint(0, img.shape[0] - crop_size)

        img = img[y0:y0+crop_size, x0:x0+crop_size, :]
        img2 = img2[y0:y0+crop_size, x0:x0+crop_size, :]
        img3 = img3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk0 = msk0[y0:y0+crop_size, x0:x0+crop_size, :]
        msk1 = msk1[y0:y0+crop_size, x0:x0+crop_size, :]
        msk2 = msk2[y0:y0+crop_size, x0:x0+crop_size, :]
        msk3 = msk3[y0:y0+crop_size, x0:x0+crop_size, :]
        msk4 = msk4[y0:y0+crop_size, x0:x0+crop_size]
        

        if crop_size != input_shape[0]:
            img = cv2.resize(img, input_shape, interpolation=cv2.INTER_LINEAR)
            img2 = cv2.resize(img2, input_shape, interpolation=cv2.INTER_LINEAR)
            img3 = cv2.resize(img3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk0 = cv2.resize(msk0, input_shape, interpolation=cv2.INTER_LINEAR)
            msk1 = cv2.resize(msk1, input_shape, interpolation=cv2.INTER_LINEAR)
            msk2 = cv2.resize(msk2, input_shape, interpolation=cv2.INTER_LINEAR)
            msk3 = cv2.resize(msk3, input_shape, interpolation=cv2.INTER_LINEAR)
            msk4 = cv2.resize(msk4, input_shape, interpolation=cv2.INTER_LINEAR)
            
        if random.random() > 0.95:
            img = shift_channels(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.95:
            img2 = shift_channels(img2, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))
        if random.random() > 0.95:
            img3 = shift_channels(img3, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.95:
            img = change_hsv(img, random.randint(-5, 5), random.randint(-5, 5), random.randint(-5, 5))

        if random.random() > 0.95:
            if random.random() > 0.95:
                img = clahe(img)
            elif random.random() > 0.95:
                img = gauss_noise(img)
            elif random.random() > 0.95:
                img = cv2.blur(img, (3, 3))
        elif random.random() > 0.95:
            if random.random() > 0.95:
                img = saturation(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = brightness(img, 0.9 + random.random() * 0.2)
            elif random.random() > 0.95:
                img = contrast(img, 0.9 + random.random() * 0.2)

        if random.random() > 0.97:
            el_det = self.elastic.to_deterministic()
            img = el_det.augment_image(img)

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample


    
class ValData(Dataset):
    def __init__(self, image_idxs):
        super().__init__()
        self.image_idxs = image_idxs

    def __len__(self):
        return len(self.image_idxs)

    def __getitem__(self, idx):
        _idx = self.image_idxs[idx]

        img_id = train_files[_idx]

        img = cv2.imread(path.join(train_png, img_id), cv2.IMREAD_COLOR)
        img2 = cv2.imread(path.join(train_png2, img_id), cv2.IMREAD_COLOR)
        img3 = cv2.imread(path.join(train_png3, img_id), cv2.IMREAD_COLOR)

        msk0 = cv2.imread(path.join(masks_dir, img_id), cv2.IMREAD_COLOR)
        msk1 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed0.png')), cv2.IMREAD_COLOR)
        msk2 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed1.png')), cv2.IMREAD_COLOR)
        msk3 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed2.png')), cv2.IMREAD_COLOR)
        msk4 = cv2.imread(path.join(masks_dir, img_id.replace('.png', '_speed_cont.png')), cv2.IMREAD_UNCHANGED)
        img = np.pad(img, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img2 = np.pad(img2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        img3 = np.pad(img3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk0 = np.pad(msk0, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk1 = np.pad(msk1, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk2 = np.pad(msk2, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk3 = np.pad(msk3, ((6, 6), (6, 6), (0, 0)), mode='reflect')
        msk4 = np.pad(msk4, ((6, 6), (6, 6)), mode='reflect')

        msk = (msk0[..., :2] > 127) * 1
        bkg_msk = (np.ones_like(msk[..., :1]) - msk[..., :1]) * 255
        msk_speed = np.concatenate([msk1, msk2, msk3, bkg_msk], axis=2)
        msk_speed = (msk_speed > 127) * 1
        for i in range(9):
            for j in range(i + 1, 10):
                msk_speed[msk_speed[..., 9-i] > 0, 9-j] = 0
        lbl_speed = msk_speed.argmax(axis=2)

        msk_speed_cont = msk4 / 255
        msk_speed_cont = msk_speed_cont[..., np.newaxis]

        img = np.concatenate([img, img2, img3], axis=2)
        img = preprocess_inputs(img)

        img = torch.from_numpy(img.transpose((2, 0, 1))).float()
        msk = torch.from_numpy(msk.transpose((2, 0, 1))).long()
        msk_speed = torch.from_numpy(msk_speed.transpose((2, 0, 1))).long()
        msk_speed_cont = torch.from_numpy(msk_speed_cont.transpose((2, 0, 1))).float()
        lbl_speed = torch.from_numpy(lbl_speed.copy()).long()

        sample = {'img': img, 'msk': msk, 'msk_speed': msk_speed, 'lbl_speed': lbl_speed, 'msk_speed_cont': msk_speed_cont, 'img_id': img_id}
        return sample




def validate(net, data_loader):
    dices0 = []
    dices1 = []

    with torch.no_grad():
        for i, sample in enumerate(tqdm(data_loader)):
            msks = sample["msk"].numpy()
            imgs = sample["img"].cuda(non_blocking=True)
            img_ids =  sample["img_id"]
            
            out = model(imgs)
            
            msk_pred = torch.sigmoid(out[:, :2, ...]).cpu().numpy()
            speed_cont_pred = out[:, 2, ...].cpu().numpy()
            speed_cont_pred[speed_cont_pred < 0] = 0
            speed_cont_pred[speed_cont_pred > 1] = 1
            msk_speed_pred = torch.softmax(out[:, 3:, ...], dim=1).cpu().numpy()

            pred = msk_pred > 0.5
            for j in range(msks.shape[0]):
                dices0.append(dice(msks[j, 0], pred[j, 0]))
                dices1.append(dice(msks[j, 1], pred[j, 1]))
                
                # pred_img0 = np.concatenate([msk_pred[j, 0, ..., np.newaxis], msk_pred[j, 1, ..., np.newaxis], np.zeros_like(msk_pred[j, 0, ..., np.newaxis])], axis=2)
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j]), (pred_img0 * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed0.png')), (msk_speed_pred[j, :3].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed1.png')), (msk_speed_pred[j, 3:6].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed2.png')), (msk_speed_pred[j, 6:9].transpose(1, 2, 0) * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])
                # cv2.imwrite(path.join(val_output_folder,  img_ids[j].replace('.png', '_speed_cont.png')), (speed_cont_pred[j] * 255).astype('uint8'), [cv2.IMWRITE_PNG_COMPRESSION, 9])

    d0 = np.mean(dices0)
    d1 = np.mean(dices1)

    print("Val Dice: {}, {}".format(d0, d1))
    return d0



def evaluate_val(data_val, best_score, model, snapshot_name, current_epoch):
    model = model.eval()
    d = validate(model, data_loader=data_val)

    if d > best_score:
        torch.save({
            'epoch': current_epoch + 1,
            'state_dict': model.state_dict(),
            'best_score': d,
        }, path.join(models_folder, snapshot_name + '_best'))
        best_score = d

    print("dice: {}\tdice_best: {}".format(d, best_score))
    return best_score



def train_epoch(current_epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader):
    losses = AverageMeter()
    losses1 = AverageMeter()
    losses2 = AverageMeter()
    losses3 = AverageMeter()
    losses4 = AverageMeter()

    dices = AverageMeter()

    iterator = tqdm(train_data_loader)
    model.train()
    scheduler.step(current_epoch)
    for i, sample in enumerate(iterator):
        imgs = sample["img"].cuda(non_blocking=True)
        msks = sample["msk"].cuda(non_blocking=True)
        msks_speed = sample["msk_speed"].cuda(non_blocking=True)
        lbls_speed = sample["lbl_speed"].cuda(non_blocking=True)
        msks_speed_cont = sample["msk_speed_cont"].cuda(non_blocking=True)
        
        out = model(imgs)

        loss1 = seg_loss(out[:, 0, ...], msks[:, 0, ...])
        loss2 = seg_loss(out[:, 1, ...], msks[:, 1, ...])

        loss3 = ce_loss(out[:, 3:, ...], lbls_speed)

        loss4 = mse_loss(out[:, 2:3, ...], msks_speed_cont)

        loss = 1.2 * loss1 + 0.05 * loss2 + 0.2 * loss3 + 0.1 * loss4

        for _i in range(3, 13):
            loss += 0.03 * seg_loss(out[:, _i, ...], msks_speed[:, _i-3, ...])

        with torch.no_grad():
            _probs = torch.sigmoid(out[:, 0, ...])
            dice_sc = 1 - dice_round(_probs, msks[:, 0, ...])

        losses.update(loss.item(), imgs.size(0))
        losses1.update(loss1.item(), imgs.size(0))
        losses2.update(loss2.item(), imgs.size(0))
        losses3.update(loss3.item(), imgs.size(0))
        losses4.update(loss4.item(), imgs.size(0))

        dices.update(dice_sc, imgs.size(0))

        iterator.set_description(
            "epoch: {}; lr {:.7f}; Loss {loss.val:.4f} ({loss.avg:.4f}); Loss1 {loss1.val:.4f} ({loss1.avg:.4f}); Loss2 {loss2.val:.4f} ({loss2.avg:.4f}); Loss3 {loss3.val:.4f} ({loss3.avg:.4f}); Loss4 {loss4.val:.4f} ({loss4.avg:.4f}); Dice {dice.val:.4f} ({dice.avg:.4f})".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))
        
        optimizer.zero_grad()
        loss.backward()
        # with amp.scale_loss(loss, optimizer) as scaled_loss:
        #     scaled_loss.backward()
        optimizer.step()

#         scheduler.step()

    print("epoch: {}; lr {:.7f}; Loss {loss.avg:.4f}; Loss1 {loss1.avg:.4f}; Loss2 {loss2.avg:.4f}; Loss3 {loss3.avg:.4f}; Loss4 {loss4.avg:.4f}; Dice {dice.avg:.4f}".format(
                current_epoch, scheduler.get_lr()[-1], loss=losses, loss1=losses1, loss2=losses2, loss3=losses3, loss4=losses4, dice=dices))






In [2]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=Tru
seed = int(601)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'seres50_9ch_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=seed)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)


np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = SeResNext50_Unet_9ch() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4) 

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6, 12, 18, 24, 26], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))



steps_per_epoch 413 validation_steps 53


  cpuset_checked))
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)
epoch: 0; lr 0.0004000; Loss 1.2436 (1.5331); Loss1 0.7559 (0.9427); Loss2 0.9769 (0.9956); Loss3 0.0930 (0.2648); Loss4 0.0474 (0.1653); Dice 0.3213 (0.2361): 100%|██████████| 413/413 [04:27<00:00,  1.54it/s] 


epoch: 0; lr 0.0004000; Loss 1.5331; Loss1 0.9427; Loss2 0.9956; Loss3 0.2648; Loss4 0.1653; Dice 0.2361


100%|██████████| 54/54 [01:01<00:00,  1.14s/it]


Val Dice: 0.5244245339401153, 0.20301426330421043
dice: 0.5244245339401153	dice_best: 0.5244245339401153


epoch: 1; lr 0.0004000; Loss 1.2858 (1.3113); Loss1 0.7929 (0.7943); Loss2 0.9625 (0.9631); Loss3 0.0998 (0.1997); Loss4 0.0501 (0.0366); Dice 0.3078 (0.3753): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 1; lr 0.0004000; Loss 1.3113; Loss1 0.7943; Loss2 0.9631; Loss3 0.1997; Loss4 0.0366; Dice 0.3753


epoch: 2; lr 0.0004000; Loss 1.4469 (1.2400); Loss1 0.9126 (0.7427); Loss2 0.9594 (0.9535); Loss3 0.1507 (0.1883); Loss4 0.0133 (0.0229); Dice 0.2534 (0.4242): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 2; lr 0.0004000; Loss 1.2400; Loss1 0.7427; Loss2 0.9535; Loss3 0.1883; Loss4 0.0229; Dice 0.4242


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.593599020946935, 0.25397520647997707
dice: 0.593599020946935	dice_best: 0.593599020946935


epoch: 3; lr 0.0004000; Loss 1.0472 (1.2082); Loss1 0.6007 (0.7228); Loss2 0.9185 (0.9502); Loss3 0.1429 (0.1810); Loss4 0.0129 (0.0177); Dice 0.4930 (0.4424): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 3; lr 0.0004000; Loss 1.2082; Loss1 0.7228; Loss2 0.9502; Loss3 0.1810; Loss4 0.0177; Dice 0.4424


epoch: 4; lr 0.0004000; Loss 1.3344 (1.1764); Loss1 0.8195 (0.7001); Loss2 0.9752 (0.9443); Loss3 0.1806 (0.1765); Loss4 0.0141 (0.0150); Dice 0.3961 (0.4613): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 4; lr 0.0004000; Loss 1.1764; Loss1 0.7001; Loss2 0.9443; Loss3 0.1765; Loss4 0.0150; Dice 0.4613


100%|██████████| 54/54 [00:59<00:00,  1.09s/it]


Val Dice: 0.5911689670953252, 0.31241360324542383
dice: 0.5911689670953252	dice_best: 0.593599020946935


epoch: 5; lr 0.0004000; Loss 1.1118 (1.1588); Loss1 0.6646 (0.6909); Loss2 0.7479 (0.8738); Loss3 0.1276 (0.1728); Loss4 0.0083 (0.0131); Dice 0.4518 (0.4703): 100%|██████████| 413/413 [04:24<00:00,  1.56it/s]


epoch: 5; lr 0.0004000; Loss 1.1588; Loss1 0.6909; Loss2 0.8738; Loss3 0.1728; Loss4 0.0131; Dice 0.4703


epoch: 6; lr 0.0001000; Loss 1.0246 (1.0997); Loss1 0.5980 (0.6514); Loss2 0.7187 (0.7787); Loss3 0.1680 (0.1640); Loss4 0.0118 (0.0106); Dice 0.5369 (0.5011): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 6; lr 0.0001000; Loss 1.0997; Loss1 0.6514; Loss2 0.7787; Loss3 0.1640; Loss4 0.0106; Dice 0.5011


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.6391810284479796, 0.42502934039770635
dice: 0.6391810284479796	dice_best: 0.6391810284479796


epoch: 7; lr 0.0002000; Loss 1.1188 (1.0840); Loss1 0.6607 (0.6416); Loss2 0.7451 (0.7557); Loss3 0.2122 (0.1585); Loss4 0.0130 (0.0099); Dice 0.5479 (0.5084): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 7; lr 0.0002000; Loss 1.0840; Loss1 0.6416; Loss2 0.7557; Loss3 0.1585; Loss4 0.0099; Dice 0.5084


epoch: 8; lr 0.0002000; Loss 1.1493 (1.0707); Loss1 0.6887 (0.6325); Loss2 0.8725 (0.7374); Loss3 0.1547 (0.1570); Loss4 0.0123 (0.0094); Dice 0.4466 (0.5171): 100%|██████████| 413/413 [04:22<00:00,  1.57it/s]


epoch: 8; lr 0.0002000; Loss 1.0707; Loss1 0.6325; Loss2 0.7374; Loss3 0.1570; Loss4 0.0094; Dice 0.5171


100%|██████████| 54/54 [00:58<00:00,  1.08s/it]


Val Dice: 0.6356971400452542, 0.4438779789547244
dice: 0.6356971400452542	dice_best: 0.6391810284479796


epoch: 9; lr 0.0002000; Loss 1.2463 (1.0573); Loss1 0.7745 (0.6228); Loss2 0.7724 (0.7271); Loss3 0.2170 (0.1573); Loss4 0.0093 (0.0092); Dice 0.5565 (0.5249): 100%|██████████| 413/413 [04:23<00:00,  1.56it/s]


epoch: 9; lr 0.0002000; Loss 1.0573; Loss1 0.6228; Loss2 0.7271; Loss3 0.1573; Loss4 0.0092; Dice 0.5249


epoch: 10; lr 0.0002000; Loss 1.1394 (1.0430); Loss1 0.6792 (0.6138); Loss2 0.8876 (0.7165); Loss3 0.1296 (0.1538); Loss4 0.0061 (0.0088); Dice 0.4668 (0.5328): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 10; lr 0.0002000; Loss 1.0430; Loss1 0.6138; Loss2 0.7165; Loss3 0.1538; Loss4 0.0088; Dice 0.5328


100%|██████████| 54/54 [00:59<00:00,  1.11s/it]


Val Dice: 0.6312822186420104, 0.4609079650381816
dice: 0.6312822186420104	dice_best: 0.6391810284479796


epoch: 11; lr 0.0002000; Loss 1.2127 (1.0354); Loss1 0.7451 (0.6093); Loss2 0.7763 (0.7069); Loss3 0.1294 (0.1511); Loss4 0.0092 (0.0089); Dice 0.4037 (0.5356): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 11; lr 0.0002000; Loss 1.0354; Loss1 0.6093; Loss2 0.7069; Loss3 0.1511; Loss4 0.0089; Dice 0.5356


epoch: 12; lr 0.0000500; Loss 1.2705 (1.0083); Loss1 0.7760 (0.5901); Loss2 0.7521 (0.6900); Loss3 0.3407 (0.1476); Loss4 0.0135 (0.0079); Dice 0.5964 (0.5501): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 12; lr 0.0000500; Loss 1.0083; Loss1 0.5901; Loss2 0.6900; Loss3 0.1476; Loss4 0.0079; Dice 0.5501


100%|██████████| 54/54 [00:59<00:00,  1.10s/it]


Val Dice: 0.6603041109045621, 0.47013189673987027
dice: 0.6603041109045621	dice_best: 0.6603041109045621


epoch: 13; lr 0.0001000; Loss 0.8568 (0.9886); Loss1 0.4970 (0.5762); Loss2 0.5865 (0.6819); Loss3 0.0603 (0.1447); Loss4 0.0054 (0.0076); Dice 0.5569 (0.5617): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 13; lr 0.0001000; Loss 0.9886; Loss1 0.5762; Loss2 0.6819; Loss3 0.1447; Loss4 0.0076; Dice 0.5617


epoch: 14; lr 0.0001000; Loss 0.8327 (0.9880); Loss1 0.4605 (0.5766); Loss2 0.7225 (0.6752); Loss3 0.0966 (0.1458); Loss4 0.0065 (0.0076); Dice 0.6170 (0.5640): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 14; lr 0.0001000; Loss 0.9880; Loss1 0.5766; Loss2 0.6752; Loss3 0.1458; Loss4 0.0076; Dice 0.5640


100%|██████████| 54/54 [00:59<00:00,  1.10s/it]


Val Dice: 0.655557930788707, 0.4709074733124863
dice: 0.655557930788707	dice_best: 0.6603041109045621


epoch: 15; lr 0.0001000; Loss 0.8010 (0.9833); Loss1 0.4310 (0.5731); Loss2 0.8401 (0.6777); Loss3 0.0987 (0.1449); Loss4 0.0059 (0.0074); Dice 0.6494 (0.5672): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 15; lr 0.0001000; Loss 0.9833; Loss1 0.5731; Loss2 0.6777; Loss3 0.1449; Loss4 0.0074; Dice 0.5672


epoch: 16; lr 0.0001000; Loss 0.9674 (0.9746); Loss1 0.5781 (0.5676); Loss2 0.7838 (0.6732); Loss3 0.0582 (0.1427); Loss4 0.0028 (0.0073); Dice 0.5006 (0.5700): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 16; lr 0.0001000; Loss 0.9746; Loss1 0.5676; Loss2 0.6732; Loss3 0.1427; Loss4 0.0073; Dice 0.5700


100%|██████████| 54/54 [00:58<00:00,  1.09s/it]


Val Dice: 0.63616476157735, 0.4700723428264238
dice: 0.63616476157735	dice_best: 0.6603041109045621


epoch: 17; lr 0.0001000; Loss 0.9390 (0.9680); Loss1 0.5285 (0.5632); Loss2 0.6551 (0.6654); Loss3 0.2092 (0.1430); Loss4 0.0107 (0.0071); Dice 0.6385 (0.5751): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]


epoch: 17; lr 0.0001000; Loss 0.9680; Loss1 0.5632; Loss2 0.6654; Loss3 0.1430; Loss4 0.0071; Dice 0.5751


epoch: 18; lr 0.0000250; Loss 0.9587 (0.9445); Loss1 0.5444 (0.5463); Loss2 0.5807 (0.6513); Loss3 0.1201 (0.1393); Loss4 0.0042 (0.0069); Dice 0.5750 (0.5875): 100%|██████████| 413/413 [04:22<00:00,  1.57it/s]


epoch: 18; lr 0.0000250; Loss 0.9445; Loss1 0.5463; Loss2 0.6513; Loss3 0.1393; Loss4 0.0069; Dice 0.5875


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6572303331792065, 0.47641752884362537
dice: 0.6572303331792065	dice_best: 0.6603041109045621


epoch: 19; lr 0.0000500; Loss 1.0280 (0.9489); Loss1 0.6204 (0.5501); Loss2 0.6197 (0.6543); Loss3 0.1711 (0.1398); Loss4 0.0073 (0.0069); Dice 0.5823 (0.5858): 100%|██████████| 413/413 [04:23<00:00,  1.57it/s]

epoch: 19; lr 0.0000500; Loss 0.9489; Loss1 0.5501; Loss2 0.6543; Loss3 0.1398; Loss4 0.0069; Dice 0.5858
Time: 97.943 min





In [3]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=Tru
seed = int(602)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'seres50_9ch_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=seed)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)


np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = SeResNext50_Unet_9ch() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4) 

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6, 12, 18, 24, 26], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 414 validation_steps 53


epoch: 0; lr 0.0004000; Loss 1.2703 (1.5175); Loss1 0.7657 (0.9287); Loss2 0.9833 (1.0649); Loss3 0.1762 (0.2679); Loss4 0.0234 (0.1305); Dice 0.3875 (0.2429): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]  


epoch: 0; lr 0.0004000; Loss 1.5175; Loss1 0.9287; Loss2 1.0649; Loss3 0.2679; Loss4 0.1305; Dice 0.2429


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.5253211359637356, 0.20648777868111934
dice: 0.5253211359637356	dice_best: 0.5253211359637356


epoch: 1; lr 0.0004000; Loss 1.3388 (1.3090); Loss1 0.8061 (0.7908); Loss2 0.9560 (0.9650); Loss3 0.2556 (0.2078); Loss4 0.0200 (0.0272); Dice 0.3928 (0.3793): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 1; lr 0.0004000; Loss 1.3090; Loss1 0.7908; Loss2 0.9650; Loss3 0.2078; Loss4 0.0272; Dice 0.3793


epoch: 2; lr 0.0004000; Loss 1.2936 (1.2545); Loss1 0.7592 (0.7507); Loss2 0.9641 (0.9591); Loss3 0.3163 (0.1959); Loss4 0.0209 (0.0206); Dice 0.4813 (0.4140): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 2; lr 0.0004000; Loss 1.2545; Loss1 0.7507; Loss2 0.9591; Loss3 0.1959; Loss4 0.0206; Dice 0.4140


100%|██████████| 54/54 [00:58<00:00,  1.07s/it]


Val Dice: 0.5974775506608855, 0.1982115898569185
dice: 0.5974775506608855	dice_best: 0.5974775506608855


epoch: 3; lr 0.0004000; Loss 1.0908 (1.2072); Loss1 0.6320 (0.7167); Loss2 0.9540 (0.9496); Loss3 0.1434 (0.1863); Loss4 0.0129 (0.0173); Dice 0.4508 (0.4431): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 3; lr 0.0004000; Loss 1.2072; Loss1 0.7167; Loss2 0.9496; Loss3 0.1863; Loss4 0.0173; Dice 0.4431


epoch: 4; lr 0.0004000; Loss 1.1014 (1.1804); Loss1 0.6396 (0.7004); Loss2 0.9761 (0.9457); Loss3 0.1786 (0.1790); Loss4 0.0173 (0.0144); Dice 0.4709 (0.4581): 100%|██████████| 414/414 [04:24<00:00,  1.56it/s]


epoch: 4; lr 0.0004000; Loss 1.1804; Loss1 0.7004; Loss2 0.9457; Loss3 0.1790; Loss4 0.0144; Dice 0.4581


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.5979837529883048, 0.28345852246964925
dice: 0.5979837529883048	dice_best: 0.5979837529883048


epoch: 5; lr 0.0004000; Loss 1.1828 (1.1516); Loss1 0.6739 (0.6824); Loss2 0.9114 (0.9064); Loss3 0.2838 (0.1744); Loss4 0.0790 (0.0174); Dice 0.4946 (0.4735): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 5; lr 0.0004000; Loss 1.1516; Loss1 0.6824; Loss2 0.9064; Loss3 0.1744; Loss4 0.0174; Dice 0.4735


epoch: 6; lr 0.0001000; Loss 0.9307 (1.0914); Loss1 0.5309 (0.6441); Loss2 0.6302 (0.7842); Loss3 0.1365 (0.1646); Loss4 0.0079 (0.0122); Dice 0.5921 (0.5064): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 6; lr 0.0001000; Loss 1.0914; Loss1 0.6441; Loss2 0.7842; Loss3 0.1646; Loss4 0.0122; Dice 0.5064


100%|██████████| 54/54 [01:00<00:00,  1.11s/it]


Val Dice: 0.6415818573066489, 0.428344056206719
dice: 0.6415818573066489	dice_best: 0.6415818573066489


epoch: 7; lr 0.0002000; Loss 0.8051 (1.0750); Loss1 0.4414 (0.6338); Loss2 0.4689 (0.7566); Loss3 0.0862 (0.1613); Loss4 0.0070 (0.0105); Dice 0.6318 (0.5163): 100%|██████████| 414/414 [04:24<00:00,  1.57it/s]


epoch: 7; lr 0.0002000; Loss 1.0750; Loss1 0.6338; Loss2 0.7566; Loss3 0.1613; Loss4 0.0105; Dice 0.5163


epoch: 8; lr 0.0002000; Loss 1.1771 (1.0617); Loss1 0.7096 (0.6258); Loss2 0.8546 (0.7343); Loss3 0.1579 (0.1581); Loss4 0.0110 (0.0103); Dice 0.4421 (0.5215): 100%|██████████| 414/414 [04:24<00:00,  1.57it/s]


epoch: 8; lr 0.0002000; Loss 1.0617; Loss1 0.6258; Loss2 0.7343; Loss3 0.1581; Loss4 0.0103; Dice 0.5215


100%|██████████| 54/54 [01:00<00:00,  1.13s/it]


Val Dice: 0.6458427887668826, 0.4573299073654081
dice: 0.6458427887668826	dice_best: 0.6458427887668826


epoch: 9; lr 0.0002000; Loss 1.1124 (1.0513); Loss1 0.6598 (0.6195); Loss2 0.8318 (0.7265); Loss3 0.1600 (0.1543); Loss4 0.0092 (0.0100); Dice 0.4821 (0.5269): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 9; lr 0.0002000; Loss 1.0513; Loss1 0.6195; Loss2 0.7265; Loss3 0.1543; Loss4 0.0100; Dice 0.5269


epoch: 10; lr 0.0002000; Loss 0.8988 (1.0372); Loss1 0.5115 (0.6095); Loss2 0.6850 (0.7223); Loss3 0.1255 (0.1533); Loss4 0.0094 (0.0098); Dice 0.5916 (0.5351): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 10; lr 0.0002000; Loss 1.0372; Loss1 0.6095; Loss2 0.7223; Loss3 0.1533; Loss4 0.0098; Dice 0.5351


100%|██████████| 54/54 [01:02<00:00,  1.15s/it]


Val Dice: 0.6528836346284987, 0.45924602883823396
dice: 0.6528836346284987	dice_best: 0.6528836346284987


epoch: 11; lr 0.0002000; Loss 0.8424 (1.0280); Loss1 0.4812 (0.6042); Loss2 0.6513 (0.7048); Loss3 0.0853 (0.1517); Loss4 0.0049 (0.0101); Dice 0.6356 (0.5416): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 11; lr 0.0002000; Loss 1.0280; Loss1 0.6042; Loss2 0.7048; Loss3 0.1517; Loss4 0.0101; Dice 0.5416


epoch: 12; lr 0.0000500; Loss 0.8754 (1.0020); Loss1 0.4932 (0.5854); Loss2 0.6445 (0.6934); Loss3 0.1086 (0.1487); Loss4 0.0062 (0.0091); Dice 0.6070 (0.5559): 100%|██████████| 414/414 [04:24<00:00,  1.57it/s]


epoch: 12; lr 0.0000500; Loss 1.0020; Loss1 0.5854; Loss2 0.6934; Loss3 0.1487; Loss4 0.0091; Dice 0.5559


100%|██████████| 54/54 [01:01<00:00,  1.15s/it]


Val Dice: 0.6622485999308911, 0.47465031818966097
dice: 0.6622485999308911	dice_best: 0.6622485999308911


epoch: 13; lr 0.0001000; Loss 1.1188 (0.9883); Loss1 0.6729 (0.5766); Loss2 0.7610 (0.6812); Loss3 0.1033 (0.1450); Loss4 0.0045 (0.0086); Dice 0.4435 (0.5623): 100%|██████████| 414/414 [04:24<00:00,  1.57it/s]


epoch: 13; lr 0.0001000; Loss 0.9883; Loss1 0.5766; Loss2 0.6812; Loss3 0.1450; Loss4 0.0086; Dice 0.5623


epoch: 14; lr 0.0001000; Loss 0.9455 (0.9743); Loss1 0.5408 (0.5658); Loss2 0.6927 (0.6805); Loss3 0.1538 (0.1428); Loss4 0.0100 (0.0085); Dice 0.5908 (0.5721): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 14; lr 0.0001000; Loss 0.9743; Loss1 0.5658; Loss2 0.6805; Loss3 0.1428; Loss4 0.0085; Dice 0.5721


100%|██████████| 54/54 [01:01<00:00,  1.14s/it]


Val Dice: 0.6567234311689959, 0.4755478865760513
dice: 0.6567234311689959	dice_best: 0.6622485999308911


epoch: 15; lr 0.0001000; Loss 1.1498 (0.9685); Loss1 0.6816 (0.5631); Loss2 0.6966 (0.6757); Loss3 0.1630 (0.1411); Loss4 0.0054 (0.0082); Dice 0.4828 (0.5733): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 15; lr 0.0001000; Loss 0.9685; Loss1 0.5631; Loss2 0.6757; Loss3 0.1411; Loss4 0.0082; Dice 0.5733


epoch: 16; lr 0.0001000; Loss 1.0350 (0.9596); Loss1 0.6064 (0.5570); Loss2 0.8458 (0.6650); Loss3 0.1581 (0.1397); Loss4 0.0070 (0.0080); Dice 0.5487 (0.5776): 100%|██████████| 414/414 [04:24<00:00,  1.56it/s]


epoch: 16; lr 0.0001000; Loss 0.9596; Loss1 0.5570; Loss2 0.6650; Loss3 0.1397; Loss4 0.0080; Dice 0.5776


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.6563406344004566, 0.4855618185824465
dice: 0.6563406344004566	dice_best: 0.6622485999308911


epoch: 17; lr 0.0001000; Loss 0.8563 (0.9559); Loss1 0.4732 (0.5547); Loss2 0.6766 (0.6670); Loss3 0.1668 (0.1403); Loss4 0.0129 (0.0083); Dice 0.6370 (0.5823): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 17; lr 0.0001000; Loss 0.9559; Loss1 0.5547; Loss2 0.6670; Loss3 0.1403; Loss4 0.0083; Dice 0.5823


epoch: 18; lr 0.0000250; Loss 1.0611 (0.9437); Loss1 0.6289 (0.5472); Loss2 0.8278 (0.6518); Loss3 0.1203 (0.1364); Loss4 0.0050 (0.0075); Dice 0.5620 (0.5856): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]


epoch: 18; lr 0.0000250; Loss 0.9437; Loss1 0.5472; Loss2 0.6518; Loss3 0.1364; Loss4 0.0075; Dice 0.5856


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.6661219221847593, 0.48767787087927067
dice: 0.6661219221847593	dice_best: 0.6661219221847593


epoch: 19; lr 0.0000500; Loss 0.9159 (0.9405); Loss1 0.5270 (0.5449); Loss2 0.5135 (0.6547); Loss3 0.1390 (0.1378); Loss4 0.0073 (0.0075); Dice 0.5870 (0.5889): 100%|██████████| 414/414 [04:23<00:00,  1.57it/s]

epoch: 19; lr 0.0000500; Loss 0.9405; Loss1 0.5449; Loss2 0.6547; Loss3 0.1378; Loss4 0.0075; Dice 0.5889
Time: 98.160 min





In [5]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=Tru
seed = int(603)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'seres50_9ch_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=seed)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)


np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = SeResNext50_Unet_9ch() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4) 

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6, 12, 18, 24, 26], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 408 validation_steps 53


  cpuset_checked))
	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1050.)
  exp_avg.mul_(beta1).add_(1 - beta1, grad)
epoch: 0; lr 0.0004000; Loss 1.5481 (1.6551); Loss1 0.9645 (0.9903); Loss2 0.9704 (0.9960); Loss3 0.2842 (0.3446); Loss4 0.0409 (0.3997); Dice 0.3147 (0.2243): 100%|██████████| 408/408 [04:25<00:00,  1.53it/s]    


epoch: 0; lr 0.0004000; Loss 1.6551; Loss1 0.9903; Loss2 0.9960; Loss3 0.3446; Loss4 0.3997; Dice 0.2243


100%|██████████| 54/54 [01:03<00:00,  1.18s/it]


Val Dice: 0.5221253225566205, 0.20120070611807095
dice: 0.5221253225566205	dice_best: 0.5221253225566205


epoch: 1; lr 0.0004000; Loss 1.2057 (1.3304); Loss1 0.6936 (0.8064); Loss2 0.9471 (0.9635); Loss3 0.2663 (0.2077); Loss4 0.0261 (0.0365); Dice 0.5176 (0.3624): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 1; lr 0.0004000; Loss 1.3304; Loss1 0.8064; Loss2 0.9635; Loss3 0.2077; Loss4 0.0365; Dice 0.3624


epoch: 2; lr 0.0004000; Loss 1.3013 (1.2677); Loss1 0.8006 (0.7609); Loss2 0.9675 (0.9547); Loss3 0.1411 (0.1976); Loss4 0.0266 (0.0299); Dice 0.3361 (0.4091): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 2; lr 0.0004000; Loss 1.2677; Loss1 0.7609; Loss2 0.9547; Loss3 0.1976; Loss4 0.0299; Dice 0.4091


100%|██████████| 54/54 [00:59<00:00,  1.10s/it]


Val Dice: 0.5801201248271162, 0.25110093212902784
dice: 0.5801201248271162	dice_best: 0.5801201248271162


epoch: 3; lr 0.0004000; Loss 1.1639 (1.2185); Loss1 0.6855 (0.7252); Loss2 0.9483 (0.9512); Loss3 0.1966 (0.1892); Loss4 0.0220 (0.0246); Dice 0.5160 (0.4399): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 3; lr 0.0004000; Loss 1.2185; Loss1 0.7252; Loss2 0.9512; Loss3 0.1892; Loss4 0.0246; Dice 0.4399


epoch: 4; lr 0.0004000; Loss 1.2112 (1.1901); Loss1 0.7114 (0.7080); Loss2 0.9349 (0.9489); Loss3 0.2478 (0.1788); Loss4 0.0256 (0.0207); Dice 0.4652 (0.4517): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 4; lr 0.0004000; Loss 1.1901; Loss1 0.7080; Loss2 0.9489; Loss3 0.1788; Loss4 0.0207; Dice 0.4517


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.5925955779073669, 0.2632817955954498
dice: 0.5925955779073669	dice_best: 0.5925955779073669


epoch: 5; lr 0.0004000; Loss 1.0778 (1.1691); Loss1 0.6224 (0.6939); Loss2 0.9364 (0.9348); Loss3 0.1772 (0.1774); Loss4 0.0153 (0.0183); Dice 0.5125 (0.4657): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 5; lr 0.0004000; Loss 1.1691; Loss1 0.6939; Loss2 0.9348; Loss3 0.1774; Loss4 0.0183; Dice 0.4657


epoch: 6; lr 0.0001000; Loss 1.3135 (1.1051); Loss1 0.8229 (0.6508); Loss2 0.8315 (0.8598); Loss3 0.2162 (0.1659); Loss4 0.0130 (0.0143); Dice 0.4723 (0.5017): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 6; lr 0.0001000; Loss 1.1051; Loss1 0.6508; Loss2 0.8598; Loss3 0.1659; Loss4 0.0143; Dice 0.5017


100%|██████████| 54/54 [00:57<00:00,  1.07s/it]


Val Dice: 0.6183132673145547, 0.40697037375065415
dice: 0.6183132673145547	dice_best: 0.6183132673145547


epoch: 7; lr 0.0002000; Loss 0.9086 (1.0823); Loss1 0.4983 (0.6367); Loss2 0.9973 (0.7776); Loss3 0.0509 (0.1635); Loss4 0.0087 (0.0140); Dice 0.5313 (0.5138): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 7; lr 0.0002000; Loss 1.0823; Loss1 0.6367; Loss2 0.7776; Loss3 0.1635; Loss4 0.0140; Dice 0.5138


epoch: 8; lr 0.0002000; Loss 1.1322 (1.0660); Loss1 0.6801 (0.6265); Loss2 0.7167 (0.7554); Loss3 0.1534 (0.1604); Loss4 0.0105 (0.0132); Dice 0.4437 (0.5213): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 8; lr 0.0002000; Loss 1.0660; Loss1 0.6265; Loss2 0.7554; Loss3 0.1604; Loss4 0.0132; Dice 0.5213


100%|██████████| 54/54 [01:00<00:00,  1.13s/it]


Val Dice: 0.6137988471539872, 0.4185343805106934
dice: 0.6137988471539872	dice_best: 0.6183132673145547


epoch: 9; lr 0.0002000; Loss 0.9783 (1.0599); Loss1 0.5673 (0.6228); Loss2 0.7118 (0.7446); Loss3 0.1434 (0.1590); Loss4 0.0122 (0.0132); Dice 0.5488 (0.5259): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 9; lr 0.0002000; Loss 1.0599; Loss1 0.6228; Loss2 0.7446; Loss3 0.1590; Loss4 0.0132; Dice 0.5259


epoch: 10; lr 0.0002000; Loss 1.4154 (1.0534); Loss1 0.8992 (0.6185); Loss2 0.9209 (0.7286); Loss3 0.1930 (0.1605); Loss4 0.0105 (0.0119); Dice 0.3716 (0.5301): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 10; lr 0.0002000; Loss 1.0534; Loss1 0.6185; Loss2 0.7286; Loss3 0.1605; Loss4 0.0119; Dice 0.5301


100%|██████████| 54/54 [00:58<00:00,  1.09s/it]


Val Dice: 0.6315202148059101, 0.42673309280737
dice: 0.6315202148059101	dice_best: 0.6315202148059101


epoch: 11; lr 0.0002000; Loss 0.8935 (1.0422); Loss1 0.5007 (0.6114); Loss2 0.5339 (0.7224); Loss3 0.1620 (0.1577); Loss4 0.0129 (0.0112); Dice 0.6029 (0.5356): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 11; lr 0.0002000; Loss 1.0422; Loss1 0.6114; Loss2 0.7224; Loss3 0.1577; Loss4 0.0112; Dice 0.5356


epoch: 12; lr 0.0000500; Loss 1.0262 (1.0020); Loss1 0.5790 (0.5836); Loss2 0.7520 (0.6977); Loss3 0.1896 (0.1499); Loss4 0.0090 (0.0096); Dice 0.5649 (0.5564): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 12; lr 0.0000500; Loss 1.0020; Loss1 0.5836; Loss2 0.6977; Loss3 0.1499; Loss4 0.0096; Dice 0.5564


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6420164935390115, 0.44402841233054446
dice: 0.6420164935390115	dice_best: 0.6420164935390115


epoch: 13; lr 0.0001000; Loss 1.0424 (0.9825); Loss1 0.6138 (0.5698); Loss2 0.6587 (0.6862); Loss3 0.1797 (0.1465); Loss4 0.0143 (0.0091); Dice 0.5539 (0.5679): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 13; lr 0.0001000; Loss 0.9825; Loss1 0.5698; Loss2 0.6862; Loss3 0.1465; Loss4 0.0091; Dice 0.5679


epoch: 14; lr 0.0001000; Loss 0.9760 (0.9901); Loss1 0.5703 (0.5764); Loss2 0.8316 (0.6785); Loss3 0.1330 (0.1461); Loss4 0.0089 (0.0092); Dice 0.5634 (0.5630): 100%|██████████| 408/408 [04:21<00:00,  1.56it/s]


epoch: 14; lr 0.0001000; Loss 0.9901; Loss1 0.5764; Loss2 0.6785; Loss3 0.1461; Loss4 0.0092; Dice 0.5630


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6330359200726039, 0.4512476532393957
dice: 0.6330359200726039	dice_best: 0.6420164935390115


epoch: 15; lr 0.0001000; Loss 1.0755 (0.9734); Loss1 0.6388 (0.5644); Loss2 0.7123 (0.6758); Loss3 0.1836 (0.1435); Loss4 0.0094 (0.0089); Dice 0.5537 (0.5722): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 15; lr 0.0001000; Loss 0.9734; Loss1 0.5644; Loss2 0.6758; Loss3 0.1435; Loss4 0.0089; Dice 0.5722


epoch: 16; lr 0.0001000; Loss 1.1922 (0.9777); Loss1 0.7357 (0.5678); Loss2 0.7447 (0.6830); Loss3 0.1355 (0.1444); Loss4 0.0065 (0.0084); Dice 0.4753 (0.5697): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 16; lr 0.0001000; Loss 0.9777; Loss1 0.5678; Loss2 0.6830; Loss3 0.1444; Loss4 0.0084; Dice 0.5697


100%|██████████| 54/54 [00:58<00:00,  1.09s/it]


Val Dice: 0.6528475216624295, 0.4450857646962223
dice: 0.6528475216624295	dice_best: 0.6528475216624295


epoch: 17; lr 0.0001000; Loss 1.0938 (0.9588); Loss1 0.6472 (0.5542); Loss2 0.6568 (0.6665); Loss3 0.2275 (0.1433); Loss4 0.0126 (0.0085); Dice 0.5708 (0.5813): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]


epoch: 17; lr 0.0001000; Loss 0.9588; Loss1 0.5542; Loss2 0.6665; Loss3 0.1433; Loss4 0.0085; Dice 0.5813


epoch: 18; lr 0.0000250; Loss 0.8631 (0.9431); Loss1 0.4842 (0.5434); Loss2 0.6724 (0.6632); Loss3 0.1005 (0.1399); Loss4 0.0064 (0.0079); Dice 0.6135 (0.5885): 100%|██████████| 408/408 [04:20<00:00,  1.57it/s]


epoch: 18; lr 0.0000250; Loss 0.9431; Loss1 0.5434; Loss2 0.6632; Loss3 0.1399; Loss4 0.0079; Dice 0.5885


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6602857792150646, 0.4649975176958552
dice: 0.6602857792150646	dice_best: 0.6602857792150646


epoch: 19; lr 0.0000500; Loss 0.7700 (0.9401); Loss1 0.4158 (0.5422); Loss2 0.7822 (0.6555); Loss3 0.0347 (0.1377); Loss4 0.0028 (0.0078); Dice 0.6269 (0.5908): 100%|██████████| 408/408 [04:19<00:00,  1.57it/s]

epoch: 19; lr 0.0000500; Loss 0.9401; Loss1 0.5422; Loss2 0.6555; Loss3 0.1377; Loss4 0.0078; Dice 0.5908
Time: 97.027 min





In [6]:
t0 = timeit.default_timer()

makedirs(models_folder, exist_ok=True)
# makedirs(val_output_folder, exist_ok=Tru
seed = int(604)
vis_dev = '0,1,2,3'

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = vis_dev

cudnn.benchmark = True

batch_size = 6
val_batch_size = 4

snapshot_name = 'seres50_9ch_{}_0'.format(seed)

train_idxs0, test_idxs = train_test_split(np.arange(len(train_files)), test_size=0.1, random_state=seed)
train_idxs0, val_idxs = train_test_split(np.arange(len(train_idxs0)), test_size=0.1, random_state=seed)


np.random.seed(seed)
random.seed(seed)

train_idxs = []
for i in train_idxs0:
    train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Paris' in train_files[i]) or ('Khartoum' in train_files[i])) and random.random() > 0.15:
        train_idxs.append(i)
    if (('Mumbai' in train_files[i]) or ('Moscow' in train_files[i])) and random.random() > 0.7:
        train_idxs.append(i)
train_idxs = np.asarray(train_idxs)


steps_per_epoch = len(train_idxs) // batch_size
validation_steps = len(val_idxs) // val_batch_size

print('steps_per_epoch', steps_per_epoch, 'validation_steps', validation_steps)

data_train = TrainData(train_idxs)
val_train = ValData(val_idxs)

train_data_loader = DataLoader(data_train, batch_size=batch_size, num_workers=10, shuffle=True, pin_memory=True, drop_last=True)
val_data_loader = DataLoader(val_train, batch_size=val_batch_size, num_workers=10, shuffle=False, pin_memory=False)

model = SeResNext50_Unet_9ch() #.cuda()

params = model.parameters()

optimizer = AdamW(params, lr=0.0004, weight_decay=1e-4) 

scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[6, 12, 18, 24, 26], gamma=0.5)

model = nn.DataParallel(model).cuda()


seg_loss = ComboLoss({'dice': 1.0, 'focal': 3.0}, per_image=True).cuda()
ce_loss = nn.CrossEntropyLoss().cuda()
mse_loss = nn.MSELoss().cuda()

best_score = 0
_cnt = -1
for epoch in range(20):
    train_epoch(epoch, seg_loss, ce_loss, mse_loss, model, optimizer, scheduler, train_data_loader)
    if epoch % 2 == 0:
        _cnt += 1
#             torch.save({
#                 'epoch': epoch + 1,
#                 'state_dict': model.state_dict(),
#                 'best_score': best_score,
#             }, path.join(models_folder, snapshot_name + '_{}'.format(_cnt % 3)))
        best_score = evaluate_val(val_data_loader, best_score, model, snapshot_name, epoch)
        torch.cuda.empty_cache()

elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))




steps_per_epoch 416 validation_steps 53


epoch: 0; lr 0.0004000; Loss 1.2479 (1.7522); Loss1 0.7535 (1.0110); Loss2 0.9802 (1.0452); Loss3 0.1075 (0.3530); Loss4 0.0654 (1.0517); Dice 0.3230 (0.2027): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]        


epoch: 0; lr 0.0004000; Loss 1.7522; Loss1 1.0110; Loss2 1.0452; Loss3 0.3530; Loss4 1.0517; Dice 0.2027


100%|██████████| 54/54 [00:59<00:00,  1.10s/it]


Val Dice: 0.528439287882199, 0.19665232144307263
dice: 0.528439287882199	dice_best: 0.528439287882199


epoch: 1; lr 0.0004000; Loss 1.3107 (1.3336); Loss1 0.7824 (0.8079); Loss2 0.9317 (0.9631); Loss3 0.2404 (0.2059); Loss4 0.0676 (0.0611); Dice 0.4330 (0.3634): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 1; lr 0.0004000; Loss 1.3336; Loss1 0.8079; Loss2 0.9631; Loss3 0.2059; Loss4 0.0611; Dice 0.3634


epoch: 2; lr 0.0004000; Loss 1.3547 (1.2588); Loss1 0.8371 (0.7544); Loss2 0.9431 (0.9468); Loss3 0.1689 (0.1928); Loss4 0.0304 (0.0402); Dice 0.3344 (0.4117): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 2; lr 0.0004000; Loss 1.2588; Loss1 0.7544; Loss2 0.9468; Loss3 0.1928; Loss4 0.0402; Dice 0.4117


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6028581967787431, 0.2862234124898856
dice: 0.6028581967787431	dice_best: 0.6028581967787431


epoch: 3; lr 0.0004000; Loss 1.0366 (1.2227); Loss1 0.5917 (0.7298); Loss2 0.9432 (0.9364); Loss3 0.1357 (0.1869); Loss4 0.0249 (0.0358); Dice 0.4630 (0.4352): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 3; lr 0.0004000; Loss 1.2227; Loss1 0.7298; Loss2 0.9364; Loss3 0.1869; Loss4 0.0358; Dice 0.4352


epoch: 4; lr 0.0004000; Loss 1.2731 (1.1867); Loss1 0.7725 (0.7084); Loss2 0.9047 (0.8749); Loss3 0.2204 (0.1765); Loss4 0.0260 (0.0295); Dice 0.4397 (0.4521): 100%|██████████| 416/416 [04:23<00:00,  1.58it/s]


epoch: 4; lr 0.0004000; Loss 1.1867; Loss1 0.7084; Loss2 0.8749; Loss3 0.1765; Loss4 0.0295; Dice 0.4521


100%|██████████| 54/54 [01:01<00:00,  1.13s/it]


Val Dice: 0.5855061177846318, 0.417724580600863
dice: 0.5855061177846318	dice_best: 0.6028581967787431


epoch: 5; lr 0.0004000; Loss 1.0357 (1.1599); Loss1 0.5942 (0.6919); Loss2 0.6551 (0.8000); Loss3 0.2017 (0.1753); Loss4 0.0247 (0.0303); Dice 0.5000 (0.4685): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 5; lr 0.0004000; Loss 1.1599; Loss1 0.6919; Loss2 0.8000; Loss3 0.1753; Loss4 0.0303; Dice 0.4685


epoch: 6; lr 0.0001000; Loss 1.2442 (1.1034); Loss1 0.7573 (0.6536); Loss2 0.7686 (0.7515); Loss3 0.2045 (0.1651); Loss4 0.0219 (0.0197); Dice 0.4168 (0.5010): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 6; lr 0.0001000; Loss 1.1034; Loss1 0.6536; Loss2 0.7515; Loss3 0.1651; Loss4 0.0197; Dice 0.5010


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6445820457095203, 0.46342992785562237
dice: 0.6445820457095203	dice_best: 0.6445820457095203


epoch: 7; lr 0.0002000; Loss 1.0591 (1.0756); Loss1 0.6090 (0.6343); Loss2 0.7296 (0.7368); Loss3 0.1841 (0.1605); Loss4 0.0207 (0.0187); Dice 0.5772 (0.5155): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 7; lr 0.0002000; Loss 1.0756; Loss1 0.6343; Loss2 0.7368; Loss3 0.1605; Loss4 0.0187; Dice 0.5155


epoch: 8; lr 0.0002000; Loss 1.1363 (1.0661); Loss1 0.6858 (0.6287); Loss2 0.6928 (0.7288); Loss3 0.1557 (0.1573); Loss4 0.0198 (0.0183); Dice 0.4970 (0.5202): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 8; lr 0.0002000; Loss 1.0661; Loss1 0.6287; Loss2 0.7288; Loss3 0.1573; Loss4 0.0183; Dice 0.5202


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6470165019283992, 0.4715477964495092
dice: 0.6470165019283992	dice_best: 0.6470165019283992


epoch: 9; lr 0.0002000; Loss 1.0924 (1.0606); Loss1 0.6561 (0.6250); Loss2 0.6944 (0.7157); Loss3 0.1597 (0.1583); Loss4 0.0162 (0.0165); Dice 0.4905 (0.5238): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 9; lr 0.0002000; Loss 1.0606; Loss1 0.6250; Loss2 0.7157; Loss3 0.1583; Loss4 0.0165; Dice 0.5238


epoch: 10; lr 0.0002000; Loss 1.1164 (1.0404); Loss1 0.6806 (0.6106); Loss2 0.8399 (0.7116); Loss3 0.1175 (0.1542); Loss4 0.0083 (0.0184); Dice 0.4943 (0.5367): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 10; lr 0.0002000; Loss 1.0404; Loss1 0.6106; Loss2 0.7116; Loss3 0.1542; Loss4 0.0184; Dice 0.5367


100%|██████████| 54/54 [01:00<00:00,  1.11s/it]


Val Dice: 0.6434778353624712, 0.48245963164832745
dice: 0.6434778353624712	dice_best: 0.6470165019283992


epoch: 11; lr 0.0002000; Loss 0.9290 (1.0385); Loss1 0.5111 (0.6103); Loss2 0.6852 (0.7078); Loss3 0.1374 (0.1533); Loss4 0.0159 (0.0155); Dice 0.5701 (0.5356): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 11; lr 0.0002000; Loss 1.0385; Loss1 0.6103; Loss2 0.7078; Loss3 0.1533; Loss4 0.0155; Dice 0.5356


epoch: 12; lr 0.0000500; Loss 1.0434 (1.0020); Loss1 0.6110 (0.5849); Loss2 0.6924 (0.6860); Loss3 0.1241 (0.1474); Loss4 0.0112 (0.0136); Dice 0.5102 (0.5570): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 12; lr 0.0000500; Loss 1.0020; Loss1 0.5849; Loss2 0.6860; Loss3 0.1474; Loss4 0.0136; Dice 0.5570


100%|██████████| 54/54 [00:58<00:00,  1.09s/it]


Val Dice: 0.6644455513145963, 0.5052817505451429
dice: 0.6644455513145963	dice_best: 0.6644455513145963


epoch: 13; lr 0.0001000; Loss 0.8648 (0.9783); Loss1 0.4802 (0.5684); Loss2 0.5621 (0.6741); Loss3 0.1662 (0.1455); Loss4 0.0138 (0.0130); Dice 0.6405 (0.5710): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 13; lr 0.0001000; Loss 0.9783; Loss1 0.5684; Loss2 0.6741; Loss3 0.1455; Loss4 0.0130; Dice 0.5710


epoch: 14; lr 0.0001000; Loss 0.9504 (0.9673); Loss1 0.5483 (0.5611); Loss2 0.5945 (0.6726); Loss3 0.1408 (0.1424); Loss4 0.0121 (0.0128); Dice 0.5740 (0.5761): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 14; lr 0.0001000; Loss 0.9673; Loss1 0.5611; Loss2 0.6726; Loss3 0.1424; Loss4 0.0128; Dice 0.5761


100%|██████████| 54/54 [01:00<00:00,  1.13s/it]


Val Dice: 0.6640546025580968, 0.5033948278246377
dice: 0.6640546025580968	dice_best: 0.6644455513145963


epoch: 15; lr 0.0001000; Loss 0.9225 (0.9621); Loss1 0.5132 (0.5575); Loss2 0.5939 (0.6604); Loss3 0.1986 (0.1438); Loss4 0.0202 (0.0129); Dice 0.5888 (0.5811): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 15; lr 0.0001000; Loss 0.9621; Loss1 0.5575; Loss2 0.6604; Loss3 0.1438; Loss4 0.0129; Dice 0.5811


epoch: 16; lr 0.0001000; Loss 1.2165 (0.9697); Loss1 0.7483 (0.5632); Loss2 0.7414 (0.6670); Loss3 0.2535 (0.1426); Loss4 0.0184 (0.0124); Dice 0.5357 (0.5743): 100%|██████████| 416/416 [04:25<00:00,  1.57it/s]


epoch: 16; lr 0.0001000; Loss 0.9697; Loss1 0.5632; Loss2 0.6670; Loss3 0.1426; Loss4 0.0124; Dice 0.5743


100%|██████████| 54/54 [00:59<00:00,  1.11s/it]


Val Dice: 0.6711234528793424, 0.5089755970421583
dice: 0.6711234528793424	dice_best: 0.6711234528793424


epoch: 17; lr 0.0001000; Loss 1.0158 (0.9519); Loss1 0.5866 (0.5516); Loss2 0.7096 (0.6586); Loss3 0.2278 (0.1406); Loss4 0.0178 (0.0119); Dice 0.6010 (0.5849): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 17; lr 0.0001000; Loss 0.9519; Loss1 0.5516; Loss2 0.6586; Loss3 0.1406; Loss4 0.0119; Dice 0.5849


epoch: 18; lr 0.0000250; Loss 1.0548 (0.9326); Loss1 0.6123 (0.5366); Loss2 0.8397 (0.6460); Loss3 0.1923 (0.1393); Loss4 0.0127 (0.0110); Dice 0.5428 (0.5971): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]


epoch: 18; lr 0.0000250; Loss 0.9326; Loss1 0.5366; Loss2 0.6460; Loss3 0.1393; Loss4 0.0110; Dice 0.5971


100%|██████████| 54/54 [01:00<00:00,  1.12s/it]


Val Dice: 0.6708383618855287, 0.5057970439882192
dice: 0.6708383618855287	dice_best: 0.6711234528793424


epoch: 19; lr 0.0000500; Loss 0.7563 (0.9239); Loss1 0.3900 (0.5310); Loss2 0.5200 (0.6444); Loss3 0.1000 (0.1369); Loss4 0.0066 (0.0111); Dice 0.7009 (0.5995): 100%|██████████| 416/416 [04:24<00:00,  1.57it/s]

epoch: 19; lr 0.0000500; Loss 0.9239; Loss1 0.5310; Loss2 0.6444; Loss3 0.1369; Loss4 0.0111; Dice 0.5995
Time: 98.397 min



