In [3]:
import os
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imsave, imread
from skimage import img_as_ubyte, img_as_float
import sys
import torch

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
def imshow_pair(im, gdt, vmin1=None, vmax1=None, vmin2=None, vmax2=None):
    f, ax = plt.subplots(1, 2, figsize=(10,5))
    np_im = np.asarray(im)
    np_gdt = np.asarray(gdt)
    if len(np_im.shape) == 2:
        if vmin1==None:
            ax[0].imshow(np_im, cmap='gray'),  ax[0].axis('off')
        else:
            ax[0].imshow(np_im, cmap='gray', vmin=vmin1, vmax=vmax1),  ax[0].axis('off')
    else:
        ax[0].imshow(np_im),  ax[0].axis('off')
    if len(np_gdt.shape) == 2:
        if vmin2==None:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray', vmin=vmin2, vmax=vmax2), ax[1].axis('off')
        else:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray'), ax[1].axis('off')
    else:
        ax[1].imshow(np.asarray(gdt)), ax[1].axis('off')
    plt.tight_layout()
    return f

In [5]:
from tqdm import trange

In [6]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [7]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(mutually_exclusive=True, to_onehot_y=True, reduction='none')

In [8]:
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef as mcc

def evaluate(logits, labels):
    all_targets = []
    all_probs_0 = []
    all_probs_1 = []
    all_probs_2 = []
    all_probs_3 = []

    for i in range(len(logits)):
        probs = torch.nn.Softmax(dim=0)(logits[i]).detach().cpu().numpy()
        all_probs_0.extend(probs[0].ravel())
        all_probs_1.extend(probs[1].ravel())
        all_probs_2.extend(probs[2].ravel())
        all_probs_3.extend(probs[3].ravel())

        target = labels[i].numpy()

        all_targets.append(target.ravel())

    all_probs_np = np.stack([all_probs_0, all_probs_1, all_probs_2, all_probs_3], axis=1)
    all_preds_np = np.argmax(all_probs_np, axis=1)
    all_targets_np = np.hstack(all_targets)

    return f1_score(all_targets_np, all_preds_np,average='weighted'), mcc(all_targets_np, all_preds_np)

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
device

device(type='cuda', index=0)

## Loading Datasets

With a data source and transforms defined we can now create a dataset object. The base class for MONAI is `Dataset`, created here to load the image files only:

In [10]:
from utils.get_loaders import get_train_val_loaders

train_loader, val_loader = get_train_val_loaders(csv_path_train='data/DRIVE/train_av.csv', 
                                                 csv_path_val='data/DRIVE/val_av.csv', batch_size=4,
                                                 tg_size=(512,512), label_values=[0, 85, 170, 255], 
                                                 num_workers=8)

In [11]:
def run_one_epoch(loader, model, criterion, optimizer=None, scheduler=None,
                  grad_acc_steps=0, assess=False, save_plot=False, cycle=0):
    device='cuda' if next(model.parameters()).is_cuda else 'cpu'
    train = optimizer is not None  # if we are in training mode there will be an optimizer and train=True here

    if train: model.train()
    else: model.eval()
        
    if assess: dice_bck, dice_arteries, dice_veins, f1_scs, mcc_scs = 0, 0, 0, [], []
    n_elems, running_loss = 0, 0
    wnet=False
    for i_batch, batch_data in enumerate(loader):
        try:
            inputs, labels = (batch_data["img"].to(device), batch_data["seg"].to(device), )
        except:
            inputs, labels = batch_data[0].to(device), batch_data[1].unsqueeze(dim=1).to(device)
            
            
        if train:  # only in training mode               
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                labels_aux = labels!=0
                loss_aux = torch.nn.BCEWithLogitsLoss()(logits_aux, labels_aux.float())
                
            loss = criterion(logits, labels.squeeze())
            if wnet:
                loss+=loss_aux
                
            (loss / (grad_acc_steps + 1)).backward()
            if i_batch % (grad_acc_steps+1) == 0:  # for grad_acc_steps=0, this is always True
                optimizer.step()
                for _ in range(grad_acc_steps+1): scheduler.step() # for grad_acc_steps=0, this means once
                optimizer.zero_grad()
        
        else:
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                labels_aux = labels!=0
                loss_aux = torch.nn.BCEWithLogitsLoss()(logits_aux, labels_aux.float())

            loss = criterion(logits, labels.squeeze(dim=1))

            if wnet:
                loss+=loss_aux

            if assess:
                dice_bck, dice_unc, dice_arteries, dice_veins = dice_metric(logits, labels).mean(dim=0)
                if save_plot:
                    for j in range(logits.shape[0]):
                        preds = torch.argmax(logits,dim=1)
                        im_name = batch_data['img_meta_dict']['filename_or_obj'][j].split('/')[-1].split('.')[-2]
                        s_name = 'logs/displays/{}_cycle_{}.png'.format(im_name, cycle)
                        f=imshow_pair(preds[j].cpu(), labels[j].squeeze().cpu())
#                         f.savefig(s_name)
#                         plt.close(f)
                
                f1_s, mcc_s = evaluate(logits.detach().cpu(), labels.cpu())
                f1_scs.append(f1_s)
                mcc_scs.append(mcc_s)
        # Compute running loss
        running_loss += loss.item() * inputs.size(0)
        n_elems += inputs.size(0)
        run_loss = running_loss / n_elems
            
    if assess: return dice_bck, dice_arteries, dice_veins, \
                      np.array(f1_scs).mean(), np.array(mcc_scs).mean(), run_loss
    return None, None, None, None, None, run_loss

In [12]:
def train_one_cycle(train_loader, model, criterion, optimizer=None, scheduler=None, grad_acc_steps=0, cycle=0):
    # prepare next cycle:
    # reset iteration counter
    scheduler.last_epoch = -1
    # update number of iterations

    scheduler.T_max = scheduler.cycle_lens[cycle] * len(train_loader)
    
    model.train()
    optimizer.zero_grad()
    cycle_len = scheduler.cycle_lens[cycle]
    with trange(cycle_len) as t:
        for epoch in range(cycle_len):
            if epoch == cycle_len-1: assess=True # only compute performance on last epoch
            else: assess = False
                
            d_bck, d_arts, d_veins, \
            f1_sc, mcc_sc, tr_loss = run_one_epoch(train_loader, model, criterion, optimizer=optimizer,
                                                          scheduler=scheduler, grad_acc_steps=grad_acc_steps, 
                                                          assess=assess, cycle=cycle)
            t.set_postfix_str("Cycle: {}/{} Ep. {}/{} -- tr. loss={:.4f} / lr={:.6f}".format(cycle+1, 
                                                                                    len(scheduler.cycle_lens),
                                                                                    epoch+1, cycle_len,
                                                                                    float(tr_loss), 
                                                                                    get_lr(optimizer)))
            t.update()
    return d_bck, d_arts, d_veins, f1_sc, mcc_sc, tr_loss

# TV LOADERS

In [13]:
n_classes=4

In [22]:
from models.res_unet_adrian import UNet as unet

class Wnet(torch.nn.Module):
    def __init__(self, n_classes=1, in_c=3, layers=(8, 16, 32), conv_bridge=True, shortcut=True, mode='train'):
        super(Wnet, self).__init__()
        self.mode=mode
        self.unet1 = unet(in_c=in_c, n_classes=1, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)
        self.unet2 = unet(in_c=in_c, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)

    def forward(self, x):
        x1 = self.unet1(x)
        # multiply input by vessel predictions
        x2 = self.unet2(torch.mul(x, torch.stack(3*[x1.squeeze(dim=1)], axis=1)))
        if self.mode!='train':
            return x2
        return x1,x2

model = Wnet(in_c=3, n_classes=n_classes, layers=[8,16,32,64])
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
model.to(device);
params

279117

In [23]:
x,y = next(iter(train_loader))
logits = model(x.to(device))
del logits,x,y

In [24]:
cycle_lens = [20, 50]
grad_acc_steps=0
n_cycles = cycle_lens[0]
min_lr = 1e-8

In [25]:
if len(cycle_lens)==2: # handles option of specifying cycles as pair (n_cycles, cycle_len)
    cycle_lens = cycle_lens[0]*[cycle_lens[1]]

In [26]:
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                              T_max=cycle_lens[0] * len(train_loader) // (grad_acc_steps + 1), 
                              eta_min=min_lr)
setattr(scheduler, 'cycle_lens', cycle_lens)

In [None]:
for cycle in range(20):
    
    _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        vl_d_bck, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, vl_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)

        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'per-class Train/Val DICE: {:.4f}/{:.4f} | {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_loss, vl_loss,
                                                                                     tr_d_bck, vl_d_bck,
                                                                                     tr_d_arts, vl_d_arts,
                                                                                     tr_d_veins, vl_d_veins))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        

 98%|█████████▊| 49/50 [00:55<00:01,  1.13s/it, Cycle: 1/20 Ep. 49/50 -- tr. loss=0.3036 / lr=0.000015]Mean of empty slice.
100%|██████████| 50/50 [00:57<00:00,  1.14s/it, Cycle: 1/20 Ep. 50/50 -- tr. loss=0.3140 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3093/0.2961 -- per-class Train/Val DICE: 0.9756/0.9763 | 0.3717/0.4454 | 0.5393/0.5622
Train/Val F1|MCC: 0.9273/0.9310 | 0.5882/0.6061


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 2/20 Ep. 50/50 -- tr. loss=0.2605 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2649/0.2638 -- per-class Train/Val DICE: 0.9750/0.9778 | 0.5383/0.5507 | 0.6318/0.6495
Train/Val F1|MCC: 0.9402/0.9413 | 0.6603/0.6694


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 3/20 Ep. 50/50 -- tr. loss=0.2604 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2524/0.2466 -- per-class Train/Val DICE: 0.9772/0.9781 | 0.6027/0.5991 | 0.6904/0.6796
Train/Val F1|MCC: 0.9443/0.9456 | 0.6897/0.6916


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 4/20 Ep. 50/50 -- tr. loss=0.2420 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2299/0.2377 -- per-class Train/Val DICE: 0.9797/0.9784 | 0.6765/0.6343 | 0.7412/0.7057
Train/Val F1|MCC: 0.9506/0.9487 | 0.7182/0.7100


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 5/20 Ep. 50/50 -- tr. loss=0.2242 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2430/0.2323 -- per-class Train/Val DICE: 0.9783/0.9786 | 0.6768/0.6543 | 0.7338/0.7163
Train/Val F1|MCC: 0.9480/0.9503 | 0.7167/0.7199


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 6/20 Ep. 50/50 -- tr. loss=0.2285 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2297/0.2298 -- per-class Train/Val DICE: 0.9780/0.9785 | 0.7047/0.6660 | 0.7422/0.7223
Train/Val F1|MCC: 0.9520/0.9512 | 0.7354/0.7248


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 7/20 Ep. 50/50 -- tr. loss=0.2133 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2195/0.2277 -- per-class Train/Val DICE: 0.9779/0.9787 | 0.7289/0.6757 | 0.7662/0.7278
Train/Val F1|MCC: 0.9546/0.9521 | 0.7486/0.7301


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 8/20 Ep. 50/50 -- tr. loss=0.2229 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2195/0.2260 -- per-class Train/Val DICE: 0.9776/0.9788 | 0.6960/0.6754 | 0.7476/0.7289
Train/Val F1|MCC: 0.9548/0.9523 | 0.7567/0.7312


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 9/20 Ep. 50/50 -- tr. loss=0.2095 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2102/0.2225 -- per-class Train/Val DICE: 0.9788/0.9789 | 0.7414/0.6842 | 0.7715/0.7319
Train/Val F1|MCC: 0.9565/0.9529 | 0.7591/0.7342


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 10/20 Ep. 50/50 -- tr. loss=0.2097 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2182/0.2233 -- per-class Train/Val DICE: 0.9804/0.9789 | 0.6991/0.6816 | 0.7317/0.7308
Train/Val F1|MCC: 0.9549/0.9526 | 0.7556/0.7332


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 11/20 Ep. 50/50 -- tr. loss=0.2104 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2038/0.2221 -- per-class Train/Val DICE: 0.9788/0.9789 | 0.7433/0.6891 | 0.7672/0.7341
Train/Val F1|MCC: 0.9579/0.9531 | 0.7590/0.7368


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 12/20 Ep. 50/50 -- tr. loss=0.2032 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2103/0.2203 -- per-class Train/Val DICE: 0.9829/0.9790 | 0.7084/0.6829 | 0.7487/0.7324
Train/Val F1|MCC: 0.9564/0.9529 | 0.7565/0.7351


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 13/20 Ep. 50/50 -- tr. loss=0.2135 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2127/0.2206 -- per-class Train/Val DICE: 0.9757/0.9789 | 0.7409/0.6851 | 0.7646/0.7321
Train/Val F1|MCC: 0.9565/0.9529 | 0.7626/0.7355


100%|██████████| 50/50 [00:56<00:00,  1.13s/it, Cycle: 14/20 Ep. 50/50 -- tr. loss=0.2142 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2033/0.2195 -- per-class Train/Val DICE: 0.9817/0.9790 | 0.7267/0.6824 | 0.7763/0.7331
Train/Val F1|MCC: 0.9584/0.9530 | 0.7687/0.7353


 50%|█████     | 25/50 [00:28<00:28,  1.13s/it, Cycle: 15/20 Ep. 25/50 -- tr. loss=0.2001 / lr=0.005079]

In [None]:
vl_d_bck, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, vl_loss = run_one_epoch(val_loader, model, criterion, 
                                                         optimizer=None, scheduler=None,
                                                         grad_acc_steps=0, assess=True, 
                                                         save_plot=True, cycle=cycle)

In [None]:
for cycle in range(20):
    
    _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        vl_d_bck, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, vl_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)

        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'per-class Train/Val DICE: {:.4f}/{:.4f} | {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_loss, vl_loss,
                                                                                     tr_d_bck, vl_d_bck,
                                                                                     tr_d_arts, vl_d_arts,
                                                                                     tr_d_veins, vl_d_veins))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        

 98%|█████████▊| 49/50 [00:51<00:01,  1.05s/it, Cycle: 1/20 Ep. 49/50 -- tr. loss=0.3225 / lr=0.000015]Mean of empty slice.
invalid value encountered in double_scalars
100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 1/20 Ep. 50/50 -- tr. loss=0.3189 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3265/0.3051 -- per-class Train/Val DICE: 0.9778/0.9763 | 0.3692/0.3363 | 0.5620/0.5728
Train/Val F1|MCC: 0.9225/0.9273 | 0.5745/0.5941


100%|██████████| 50/50 [00:51<00:00,  1.04s/it, Cycle: 2/20 Ep. 50/50 -- tr. loss=0.2856 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2862/0.2796 -- per-class Train/Val DICE: 0.9713/0.9776 | 0.4608/0.4627 | 0.5473/0.5937
Train/Val F1|MCC: 0.9314/0.9347 | 0.6185/0.6319


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 3/20 Ep. 50/50 -- tr. loss=0.2781 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2925/0.2644 -- per-class Train/Val DICE: 0.9739/0.9779 | 0.5308/0.5137 | 0.6210/0.6380
Train/Val F1|MCC: 0.9318/0.9402 | 0.6339/0.6615


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 4/20 Ep. 50/50 -- tr. loss=0.2775 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2685/0.2569 -- per-class Train/Val DICE: 0.9712/0.9783 | 0.5264/0.5289 | 0.6125/0.6604
Train/Val F1|MCC: 0.9382/0.9422 | 0.6604/0.6745


100%|██████████| 50/50 [00:51<00:00,  1.04s/it, Cycle: 5/20 Ep. 50/50 -- tr. loss=0.2577 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2676/0.2511 -- per-class Train/Val DICE: 0.9757/0.9784 | 0.5198/0.5659 | 0.6493/0.6711
Train/Val F1|MCC: 0.9395/0.9443 | 0.6725/0.6858


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 6/20 Ep. 50/50 -- tr. loss=0.2521 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2448/0.2470 -- per-class Train/Val DICE: 0.9817/0.9786 | 0.6238/0.5859 | 0.6888/0.6784
Train/Val F1|MCC: 0.9447/0.9457 | 0.6905/0.6934


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 7/20 Ep. 50/50 -- tr. loss=0.2441 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2529/0.2419 -- per-class Train/Val DICE: 0.9777/0.9787 | 0.6360/0.6078 | 0.6797/0.6914
Train/Val F1|MCC: 0.9433/0.9474 | 0.6919/0.7017


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 8/20 Ep. 50/50 -- tr. loss=0.2472 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2359/0.2385 -- per-class Train/Val DICE: 0.9798/0.9791 | 0.6193/0.6140 | 0.7017/0.6941
Train/Val F1|MCC: 0.9469/0.9480 | 0.6988/0.7059


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 9/20 Ep. 50/50 -- tr. loss=0.2328 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2338/0.2381 -- per-class Train/Val DICE: 0.9803/0.9788 | 0.6326/0.6251 | 0.6924/0.7019
Train/Val F1|MCC: 0.9479/0.9487 | 0.7047/0.7107


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 10/20 Ep. 50/50 -- tr. loss=0.2389 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2310/0.2322 -- per-class Train/Val DICE: 0.9785/0.9791 | 0.6696/0.6393 | 0.7086/0.7125
Train/Val F1|MCC: 0.9492/0.9502 | 0.7166/0.7185


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 11/20 Ep. 50/50 -- tr. loss=0.2409 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2418/0.2325 -- per-class Train/Val DICE: 0.9720/0.9791 | 0.6227/0.6426 | 0.6844/0.7115
Train/Val F1|MCC: 0.9464/0.9502 | 0.7117/0.7185


 22%|██▏       | 11/50 [00:12<00:40,  1.04s/it, Cycle: 12/20 Ep. 12/50 -- tr. loss=0.2481 / lr=0.008698]