In [None]:
import os
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imsave, imread
from skimage import img_as_ubyte, img_as_float
import sys
import monai, torch

%load_ext autoreload
%autoreload 2

In [None]:
def imshow_pair(im, gdt, vmin1=None, vmax1=None, vmin2=None, vmax2=None):
    f, ax = plt.subplots(1, 2, figsize=(10,5))
    np_im = np.asarray(im)
    np_gdt = np.asarray(gdt)
    if len(np_im.shape) == 2:
        if vmin1==None:
            ax[0].imshow(np_im, cmap='gray'),  ax[0].axis('off')
        else:
            ax[0].imshow(np_im, cmap='gray', vmin=vmin1, vmax=vmax1),  ax[0].axis('off')
    else:
        ax[0].imshow(np_im),  ax[0].axis('off')
    if len(np_gdt.shape) == 2:
        if vmin2==None:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray', vmin=vmin2, vmax=vmax2), ax[1].axis('off')
        else:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray'), ax[1].axis('off')
    else:
        ax[1].imshow(np.asarray(gdt)), ax[1].axis('off')
    plt.tight_layout()
    return f

In [None]:
# path_masks_old = '../CURRICULUM_AV/data/DRIVE/mask'
# path_masks = 'data_new/DRIVE/AV_groundTruth/training/masks/'

# mask_names = os.listdir(path_masks_old)

# mask_names = sorted([os.path.join(path_masks_old, n) for n in mask_names])

# from skimage import io
# for m in mask_names:
#     mm=io.imread(m)
#     io.imsave(m.replace(path_masks_old, path_masks).replace('gif','png'), mm)

In [None]:
path_ims = 'data_new/DRIVE/AV_groundTruth/training/images/'
path_segs = 'data_new/DRIVE/AV_groundTruth/training/av/'
path_masks = 'data_new/DRIVE/AV_groundTruth/training/masks/'

img_names = os.listdir(path_ims)
seg_names = os.listdir(path_segs)
mask_names = os.listdir(path_masks)

img_names = sorted([os.path.join(path_ims, n) for n in img_names])
seg_names = sorted([os.path.join(path_segs, n) for n in seg_names])
mask_names = sorted([os.path.join(path_masks, n) for n in mask_names if not n.startswith('.')])

In [None]:
train_img_names = img_names[:16]
val_img_names = img_names[16:]

train_seg_names = seg_names[:16]
val_seg_names = seg_names[16:]

train_mask_names = mask_names[:16]
val_mask_names = mask_names[16:]

In [None]:
fn_keys = ('img', 'seg')  # filename keys for image and seg files
train_filenames = [{'img': x, 'seg': y, 'mask': m} for x,y,m in zip(train_img_names, train_seg_names, 
                                                                  train_mask_names)]

val_filenames = [{'img': x, 'seg': y, 'mask': m} for x,y,m in zip(val_img_names, val_seg_names, val_mask_names)]

In [None]:
from monai.transforms import MapTransform

# for type hinting at this stage we need more
from monai.config import KeysCollection
from typing import Optional, Any, Mapping, Hashable

In [None]:
from tqdm import trange
from monai.metrics import DiceMetric

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(mutually_exclusive=True, to_onehot_y=True, reduction='none', include_background=False)

In [None]:
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef as mcc

def evaluate(logits, labels):
    all_targets = []
    all_probs_0 = []
    all_probs_1 = []
    all_probs_2 = []
    all_probs_3 = []

    for i in range(len(logits)):
        probs = torch.nn.Softmax(dim=0)(logits[i]).detach().cpu().numpy()
        all_probs_0.extend(probs[0].ravel())
        all_probs_1.extend(probs[1].ravel())
        all_probs_2.extend(probs[2].ravel())
        all_probs_3.extend(probs[3].ravel())

        target = labels[i].numpy()

        all_targets.append(target.ravel())

    all_probs_np = np.stack([all_probs_0, all_probs_1, all_probs_2, all_probs_3], axis=1)
    all_preds_np = np.argmax(all_probs_np, axis=1)
    all_targets_np = np.hstack(all_targets)

    return f1_score(all_targets_np, all_preds_np,average='weighted'), mcc(all_targets_np, all_preds_np)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
device

## Loading Datasets

With a data source and transforms defined we can now create a dataset object. The base class for MONAI is `Dataset`, created here to load the image files only:

In [None]:
from monai.transforms import Compose, LambdaD, LoadImageD, ToTensorD, ScaleIntensityD

In [None]:
def to_labels(av):
    labels=np.zeros_like(av[0,:,:])
    arteries=av[0,:,:]==255
    uncertain=av[1,:,:]==255
    veins=av[2,:,:]==255 
    
    labels[uncertain]=0    
    labels[arteries]=2
    labels[veins]=1 # veins are darker

    return labels


trans = Compose(
    [
        LoadImageD(keys=('img', 'seg')),
#         ScaleIntensityD(keys=('img',)),
        LambdaD(('img',), lambda x: x/255.),  # apply gamma only on image
        LambdaD(('seg',), to_labels),  # apply gamma only on image
    ]
)

imgd = trans(train_filenames[0])
img = imgd["img"]
seg = imgd["seg"]
img.shape, seg.shape

In [None]:
# f=imshow_pair(img.transpose(1, 2, 0), seg)
# f.savefig('wtf.png')
# plt.close(f)

In [None]:
from monai.data import CacheDataset, Dataset, PersistentDataset
from monai.inferers import sliding_window_inference

In [None]:
import torch
from monai.data import Dataset, ArrayDataset

from monai.transforms import Compose, LambdaD, LoadImageD, ToTensorD, AddChannelD, AsChannelFirstD, \
                            RandSpatialCropD, RandRotated, CastToTypeD, SqueezeDimD, ResizeD, \
                            ScaleIntensityD, RandAdjustContrastD, RandRotateD, RandAffineD, \
                            Rand2DElasticD, RandFlipD, RandZoomD, CropForegroundd, ResizeWithPadOrCropD, \
                            DeleteItemsd, NormalizeIntensityD, ScaleIntensityRangeD

In [None]:
train_transforms = Compose(
    [
        LoadImageD(keys=('img', 'seg', 'mask')),
        LambdaD(('seg',), to_labels),
        AddChannelD(keys=('seg','mask')),
        CropForegroundd(keys=('img','seg'), source_key='mask'),
        DeleteItemsd(keys=('mask')),
        ResizeD(keys=('img','seg'), spatial_size=(512,512), mode=('bicubic', 'nearest'), 
                align_corners=(False,None)),                   
        RandAdjustContrastD(keys=('img',), prob=0.25, gamma=(0.75, 1.25)),
        RandRotated(keys=('img','seg'), range_x=45.0,padding_mode='zeros', prob=1.0),
        RandFlipD(keys=('img','seg'), prob=0.5, spatial_axis=(0,)), # vertical flip
        RandFlipD(keys=('img','seg'), prob=0.5, spatial_axis=(1,)), # horizontal flip
#         RandSpatialCropD(keys=('img','seg'), roi_size=(256,256), random_size=False),     
        LambdaD(('img',), lambda x: x/255.), 
#         NormalizeIntensityD(keys=('img',),channel_wise=False),
        ToTensorD(keys=('img', 'seg')),
        CastToTypeD(keys=('seg',), dtype=torch.long),
    ]
)


# train_ds = Dataset(train_filenames, train_transforms)
# train_loader_monai = torch.utils.data.DataLoader(train_ds, batch_size=2, shuffle=False, num_workers=2,)

# x, xx = train_ds[0], next(iter(train_loader_monai))
# x['img'].min(), x['img'].max(), xx['img'].min(), xx['img'].max()

In [None]:
val_transforms = Compose(
    [
        LoadImageD(keys=('img', 'seg', 'mask')),
        LambdaD(('seg',), to_labels),
        AddChannelD(keys=('seg','mask')),
        CropForegroundd(keys=('img','seg','mask'), source_key='mask'),   
        DeleteItemsd(keys=('mask')),
        ResizeD(keys=('img','seg'), spatial_size=(512,512), mode=('bicubic', 'nearest'), 
                align_corners=(False,None)),
#         ResizeWithPadOrCropD(keys=('img','seg'), spatial_size=(512,512)),     
        LambdaD(('img',), lambda x: x/255.),
#         NormalizeIntensityD(keys=('img',),channel_wise=False),
        ToTensorD(keys=('img', 'seg')),
        CastToTypeD(keys=('seg'), dtype=torch.long),
    ]        
)

val_ds = Dataset(val_filenames, val_transforms)

In [None]:
train_ds = Dataset(train_filenames, train_transforms)
val_ds = Dataset(val_filenames, val_transforms)

train_loader_monai = torch.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=4,)
val_loader_monai = torch.utils.data.DataLoader(val_ds, batch_size=4, num_workers=4)

In [None]:
# from utils.get_loaders import get_train_val_loaders

# train_loader, val_loader = get_train_val_loaders(csv_path_train='data/DRIVE/train_av.csv', 
#                                                  csv_path_val='data/DRIVE/val_av.csv', batch_size=4,
#                                                  tg_size=(512,512), label_values=[0, 85, 170, 255], 
#                                                  num_workers=8)

In [None]:
# # x = next(iter(train_loader))
# # print(x[0].min(), x[0].max())
# xx = next(iter(train_loader_monai))
# xx['img'].min(), xx['img'].max()

In [None]:
# inputs = val_ds[0]
# inputs['img'].shape, inputs['img'].dtype, inputs['seg'].shape, inputs['seg'].dtype

In [None]:
inputs = train_ds[1]
im, tg = inputs['img'], inputs['seg']
imshow_pair(im.permute(1,2,0), tg[0])
tg.shape

In [None]:
# inputs = val_ds[0]
# im, tg = inputs['img'], inputs['seg']
# imshow_pair(im.permute(1,2,0), tg[0])
# im.dtype, tg.shape

In [None]:
# directory = os.environ.get("MONAI_DATA_DIRECTORY")
# root_dir = tempfile.mkdtemp() if directory is None else directory
# print(root_dir)

In [None]:
x_t, x_v=next(iter(train_loader_monai)), next(iter(val_loader_monai))
x_t['seg'].shape,  x_v['seg'].shape, torch.unique(x_v['seg'])

In [None]:
# x_t, x_v = next(iter(train_loader)), next(iter(val_loader))
# x_t[1].shape, x_v[1].shape

In [None]:
n_classes=2

In [None]:
# from monai.networks.nets import UNet
# from monai.networks.layers import Norm

# model = UNet(
#         dimensions=2,
#         in_channels=3,
#         out_channels=n_classes,
#         channels=(8,16,32,64),
#         strides=(1, 1, 1, 1),
#         num_res_units=4,
#         norm=Norm.BATCH,
#     ).to(device)

# model_parameters = filter(lambda p: p.requires_grad, model.parameters())
# params = sum([np.prod(p.size()) for p in model_parameters])
# params

In [None]:
from models.res_unet_adrian import UNet as unet

class Wnet(torch.nn.Module):
    def __init__(self, n_classes=1, in_c=3, layers=(8, 16, 32), conv_bridge=True, shortcut=True, mode='train'):
        super(Wnet, self).__init__()
        self.mode=mode
        self.unet1 = unet(in_c=in_c, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)
        self.unet2 = unet(in_c=in_c+n_classes, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)

    def forward(self, x):
        x1 = self.unet1(x)
        x2 = self.unet2(torch.cat([x, x1], dim=1))
        if self.mode!='train':
            return x2
        return x1,x2

model = Wnet(in_c=3, n_classes=n_classes, layers=[8,16,32])
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
model.to(device);
params

In [None]:
cycle_lens = [20, 50]
grad_acc_steps=0
n_cycles = cycle_lens[0]
min_lr = 1e-8

In [None]:
if len(cycle_lens)==2: # handles option of specifying cycles as pair (n_cycles, cycle_len)
    cycle_lens = cycle_lens[0]*[cycle_lens[1]]

In [None]:
# from models.get_model import get_arch
# model = get_arch('big_wnet', in_c=3, n_classes=n_classes)
# model_parameters = filter(lambda p: p.requires_grad, model.parameters())
# params = sum([np.prod(p.size()) for p in model_parameters])
# model.to(device);
# params

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                              T_max=cycle_lens[0] * len(train_loader_monai) // (grad_acc_steps + 1), 
                              eta_min=min_lr)
setattr(scheduler, 'cycle_lens', cycle_lens)

In [None]:
x_t=next(iter(train_loader_monai))

In [None]:
logits = model(x_t['img'].to(device))
logits_aux, logits = logits[0].cpu(), logits[1].cpu()

In [None]:
labels = x_t['seg']

In [None]:
labels.shape, torch.unique(labels), logits.shape

In [None]:
logits_new = torch.cat([-10*torch.ones(labels.shape), logits], dim=1)
logits_new.shape

In [None]:
logits_new.shape, labels.shape

In [None]:
criterion(logits_new, labels.squeeze())

In [None]:
# tt = labels[labels!=0]-1
# logits[labels!=0].shape, tt.shape
# torch.unique(tt)

In [None]:
# torch.nn.BCEWithLogitsLoss()(tt.float()-1,logits[labels!=0])

In [None]:
# torch.nn.BCEWithLogitsLoss()(labels[labels!=0].float()-1,logits[labels!=0])

In [None]:
# labels.shape, torch.unique(labels), logits.shape

In [None]:
# tt = labels[labels!=0]-1
# pp = logits[labels!=0]

In [None]:
# -(tt*torch.nn.functional.logsigmoid(pp) + (1-tt)*torch.nn.functional.logsigmoid(1-pp)).mean()

In [None]:
# import torch.nn.functional as F

In [None]:
# def compute_bce_no_back(logits, labels):
#     tt = labels[labels!=0]-1
#     pp = logits[labels!=0]
#     return -(tt*torch.nn.functional.logsigmoid(pp) + (1-tt)*torch.nn.functional.logsigmoid(1-pp)).mean()

In [None]:
# compute_bce_no_back(logits, labels)

In [None]:
# criterion = compute_bce_no_back

## What would be better:
Predict two classes, add a channel that has all -100 in the first place, it acts as our prediction of the background. Then we use torch.nn.CrossEntropy(ignore_index=0) and can safely use Softmax and monai.dice with ignore_background.

In [None]:
def run_one_epoch(loader, model, criterion, optimizer=None, scheduler=None,
                  grad_acc_steps=0, assess=False, save_plot=False, cycle=0):
    device='cuda' if next(model.parameters()).is_cuda else 'cpu'
    train = optimizer is not None  # if we are in training mode there will be an optimizer and train=True here

    if train: model.train()
    else: 
        model.eval()
        model.mode='val'
        
    if assess: dice, auc = 0, 0
    n_elems, running_loss = 0, 0
    wnet=False
    for i_batch, batch_data in enumerate(loader):
        try:
            inputs, labels = (batch_data["img"].to(device), batch_data["seg"].to(device), )
        except:
            inputs, labels = batch_data[0].to(device), batch_data[1].unsqueeze(dim=1).to(device)
            
        if train:  # only in training mode               
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                loss_aux = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits_aux], dim=1), 
                                     labels.squeeze(dim=1))                
#             loss = criterion(logits, labels.squeeze(dim=1))
#             loss = compute_bce_no_back(logits, labels)
            loss = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), 
                             labels.squeeze(dim=1))

        
        
            if wnet:
                loss+=loss_aux
                
            (loss / (grad_acc_steps + 1)).backward()
            if i_batch % (grad_acc_steps+1) == 0:  # for grad_acc_steps=0, this is always True
                optimizer.step()
                for _ in range(grad_acc_steps+1): scheduler.step() # for grad_acc_steps=0, this means once
                optimizer.zero_grad()
        
        else:
#             logits = sliding_window_inference(inputs, roi_size=(256,256), 
#                                               sw_batch_size=loader.batch_size, 
#                                               predictor=model)
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
#                 loss_aux = criterion(logits_aux, labels.squeeze(dim=1))
                loss_aux = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits_aux], dim=1), 
                                     labels)
#             loss = criterion(logits, labels.squeeze(dim=1))
            loss = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), 
                             labels.squeeze(dim=1))
            if wnet:
                loss+=loss_aux

            if assess:
                
                dice = dice_metric(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels)
                
                auc=dice[:,1].mean()
                dice=dice[:,0].mean()
                
                if save_plot:
                    preds = logits.sigmoid().squeeze()
                    back=labels==0
                    preds[back.squeeze()]=0.5
                    for j in range(logits.shape[0]):
                        im_name = batch_data['img_meta_dict']['filename_or_obj'][j].split('/')[-1].split('.')[-2]
                        s_name = 'logs/displays/{}_cycle_{}.png'.format(im_name, cycle)
                        f=imshow_pair(preds[j].cpu(), labels[j].squeeze().cpu())
#                         f.savefig(s_name)
#                         plt.close(f)

        # Compute running loss
        running_loss += loss.item() * inputs.size(0)
        n_elems += inputs.size(0)
        run_loss = running_loss / n_elems
            
    if assess: return dice, auc, run_loss
    return None, None, run_loss

In [None]:
from monai.metrics import compute_roc_auc, compute_meandice

In [None]:
def train_one_cycle(train_loader, model, criterion, optimizer=None, scheduler=None, grad_acc_steps=0, cycle=0):
    # prepare next cycle:
    # reset iteration counter
    scheduler.last_epoch = -1
    # update number of iterations

    scheduler.T_max = scheduler.cycle_lens[cycle] * len(train_loader)
    
    model.train()
    optimizer.zero_grad()
    cycle_len = scheduler.cycle_lens[cycle]
    with trange(cycle_len) as t:
        for epoch in range(cycle_len):
            if epoch == cycle_len-1: assess=True # only compute performance on last epoch
            else: assess = False
                
            dice, auc, tr_loss = run_one_epoch(train_loader, model, criterion, optimizer=optimizer,
                                                          scheduler=scheduler, grad_acc_steps=grad_acc_steps, 
                                                          assess=assess, cycle=cycle)
            t.set_postfix_str("Cycle: {}/{} Ep. {}/{} -- tr. loss={:.4f} / lr={:.6f}".format(cycle+1, 
                                                                                    len(scheduler.cycle_lens),
                                                                                    epoch+1, cycle_len,
                                                                                    float(tr_loss), 
                                                                                    get_lr(optimizer)))
            t.update()
            
    return dice, auc, tr_loss

# MONAI LOADER

In [None]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(sigmoid=True, 
                         logit_thresh=0.5,
                         to_onehot_y=True,
                         reduction='none', 
                         include_background=False)

In [None]:
for cycle in range(12):
    
    _, _, _  = train_one_cycle(train_loader_monai,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        vl_dice, vl_auc, vl_loss = run_one_epoch(val_loader_monai, model, 
                                                                  criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
#         # Assess also on traininig data but with val transforms
#         val_data = val_loader.dataset.data.copy()
#         val_loader.dataset.data = train_loader.dataset.data
#         tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(val_loader, model, criterion, 
#                                                                  optimizer=None, scheduler=None,
#                                                                  grad_acc_steps=0, assess=True, 
#                                                                  save_plot=save_plot, cycle=cycle)
#         val_loader.dataset.data = val_data

        tr_dice, tr_auc, tr_loss = run_one_epoch(train_loader_monai, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        
        print('Train/Val Loss: {:.4f}/{:.4f} -- DICE|AUC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.\
              format(tr_loss, vl_loss,tr_dice, vl_dice,tr_auc, vl_auc))
        

In [None]:
with torch.no_grad():
#     vl_dice, vl_auc, vl_loss
    logits, labels = run_one_epoch(val_loader_monai, model, criterion, optimizer=None, scheduler=None,
                                                             grad_acc_steps=0, assess=True, save_plot=True, cycle=0)

In [None]:
torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1).shape

In [None]:
logits.shape, labels.shape

In [None]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(sigmoid=True, 
                         logit_thresh=0.5,
                         to_onehot_y=True,
                         reduction='none', 
                         include_background=False)

In [None]:
dd = dice_metric(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels)

In [None]:
dd.shape

In [None]:
dd[:,0].mean(), dd[:,1].mean()

# TV LOADERS

In [None]:
model = get_arch('wnet', in_c=3, n_classes=n_classes)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
model.to(device);
params

In [None]:
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                              T_max=cycle_lens[0] * len(train_loader) // (grad_acc_steps + 1), 
                              eta_min=min_lr)
setattr(scheduler, 'cycle_lens', cycle_lens)

In [None]:
for cycle in range(10):
    
    _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        vl_d_bck, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, vl_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
#         # Assess also on traininig data but with val transforms
#         val_data = val_loader.dataset.data.copy()
#         val_loader.dataset.data = train_loader.dataset.data
#         tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(val_loader, model, criterion, 
#                                                                  optimizer=None, scheduler=None,
#                                                                  grad_acc_steps=0, assess=True, 
#                                                                  save_plot=save_plot, cycle=cycle)
#         val_loader.dataset.data = val_data

        tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'per-class Train/Val DICE: {:.4f}/{:.4f} | {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_loss, vl_loss,
                                                                                     tr_d_bck, vl_d_bck,
                                                                                     tr_d_arts, vl_d_arts,
                                                                                     tr_d_veins, vl_d_veins))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        