In [1]:
import os
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imsave, imread
from skimage import img_as_ubyte, img_as_float
import sys
import torch

%load_ext autoreload
%autoreload 2

In [2]:
def imshow_pair(im, gdt, vmin1=None, vmax1=None, vmin2=None, vmax2=None):
    f, ax = plt.subplots(1, 2, figsize=(10,5))
    np_im = np.asarray(im)
    np_gdt = np.asarray(gdt)
    if len(np_im.shape) == 2:
        if vmin1==None:
            ax[0].imshow(np_im, cmap='gray'),  ax[0].axis('off')
        else:
            ax[0].imshow(np_im, cmap='gray', vmin=vmin1, vmax=vmax1),  ax[0].axis('off')
    else:
        ax[0].imshow(np_im),  ax[0].axis('off')
    if len(np_gdt.shape) == 2:
        if vmin2==None:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray', vmin=vmin2, vmax=vmax2), ax[1].axis('off')
        else:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray'), ax[1].axis('off')
    else:
        ax[1].imshow(np.asarray(gdt)), ax[1].axis('off')
    plt.tight_layout()
    return f

In [3]:
from tqdm import trange

In [4]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [6]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(mutually_exclusive=True, to_onehot_y=True, reduction='none')

In [7]:
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef as mcc

def evaluate(logits, labels):
    all_targets = []
    all_probs_0 = []
    all_probs_1 = []
    all_probs_2 = []
    all_probs_3 = []

    for i in range(len(logits)):
        probs = torch.nn.Softmax(dim=0)(logits[i]).detach().cpu().numpy()
        all_probs_0.extend(probs[0].ravel())
        all_probs_1.extend(probs[1].ravel())
        all_probs_2.extend(probs[2].ravel())
        all_probs_3.extend(probs[3].ravel())

        target = labels[i].numpy()

        all_targets.append(target.ravel())

    all_probs_np = np.stack([all_probs_0, all_probs_1, all_probs_2, all_probs_3], axis=1)
    all_preds_np = np.argmax(all_probs_np, axis=1)
    all_targets_np = np.hstack(all_targets)

    return f1_score(all_targets_np, all_preds_np,average='weighted'), mcc(all_targets_np, all_preds_np)

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
device

device(type='cuda', index=0)

## Loading Datasets

With a data source and transforms defined we can now create a dataset object. The base class for MONAI is `Dataset`, created here to load the image files only:

In [9]:
from utils.get_loaders import get_train_val_loaders

train_loader, val_loader = get_train_val_loaders(csv_path_train='data/DRIVE/train_av.csv', 
                                                 csv_path_val='data/DRIVE/val_av.csv', batch_size=4,
                                                 tg_size=(512,512), label_values=[0, 85, 170, 255], 
                                                 num_workers=8)

In [15]:
def run_one_epoch(loader, model, criterion, optimizer=None, scheduler=None,
                  grad_acc_steps=0, assess=False, save_plot=False, cycle=0):
    device='cuda' if next(model.parameters()).is_cuda else 'cpu'
    train = optimizer is not None  # if we are in training mode there will be an optimizer and train=True here

    if train: model.train()
    else: model.eval()
        
    if assess: dice_bck, dice_arteries, dice_veins, f1_scs, mcc_scs = 0, 0, 0, [], []
    n_elems, running_loss = 0, 0
    wnet=False
    for i_batch, batch_data in enumerate(loader):
        try:
            inputs, labels = (batch_data["img"].to(device), batch_data["seg"].to(device), )
        except:
            inputs, labels = batch_data[0].to(device), batch_data[1].unsqueeze(dim=1).to(device)
            
            
        if train:  # only in training mode               
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                loss_aux = criterion(logits_aux, labels.squeeze())
            loss = criterion(logits, labels.squeeze())
            if wnet:
                loss+=loss_aux
                
            (loss / (grad_acc_steps + 1)).backward()
            if i_batch % (grad_acc_steps+1) == 0:  # for grad_acc_steps=0, this is always True
                optimizer.step()
                for _ in range(grad_acc_steps+1): scheduler.step() # for grad_acc_steps=0, this means once
                optimizer.zero_grad()
        
        else:
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                loss_aux = criterion(logits_aux, labels.squeeze())
            loss = criterion(logits, labels.squeeze(dim=1))

            if wnet:
                loss+=loss_aux

            if assess:
                dice_bck, dice_unc, dice_arteries, dice_veins = dice_metric(logits, labels).mean(dim=0)
                if save_plot:
                    for j in range(logits.shape[0]):
                        preds = torch.argmax(logits,dim=1)
                        im_name = batch_data['img_meta_dict']['filename_or_obj'][j].split('/')[-1].split('.')[-2]
                        s_name = 'logs/displays/{}_cycle_{}.png'.format(im_name, cycle)
                        f=imshow_pair(preds[j].cpu(), labels[j].squeeze().cpu())
#                         f.savefig(s_name)
#                         plt.close(f)
                
                f1_s, mcc_s = evaluate(logits.detach().cpu(), labels.cpu())
                f1_scs.append(f1_s)
                mcc_scs.append(mcc_s)
        # Compute running loss
        running_loss += loss.item() * inputs.size(0)
        n_elems += inputs.size(0)
        run_loss = running_loss / n_elems
            
    if assess: return dice_bck, dice_arteries, dice_veins, \
                      np.array(f1_scs).mean(), np.array(mcc_scs).mean(), run_loss
    return None, None, None, None, None, run_loss

In [16]:
def train_one_cycle(train_loader, model, criterion, optimizer=None, scheduler=None, grad_acc_steps=0, cycle=0):
    # prepare next cycle:
    # reset iteration counter
    scheduler.last_epoch = -1
    # update number of iterations

    scheduler.T_max = scheduler.cycle_lens[cycle] * len(train_loader)
    
    model.train()
    optimizer.zero_grad()
    cycle_len = scheduler.cycle_lens[cycle]
    with trange(cycle_len) as t:
        for epoch in range(cycle_len):
            if epoch == cycle_len-1: assess=True # only compute performance on last epoch
            else: assess = False
                
            d_bck, d_arts, d_veins, \
            f1_sc, mcc_sc, tr_loss = run_one_epoch(train_loader, model, criterion, optimizer=optimizer,
                                                          scheduler=scheduler, grad_acc_steps=grad_acc_steps, 
                                                          assess=assess, cycle=cycle)
            t.set_postfix_str("Cycle: {}/{} Ep. {}/{} -- tr. loss={:.4f} / lr={:.6f}".format(cycle+1, 
                                                                                    len(scheduler.cycle_lens),
                                                                                    epoch+1, cycle_len,
                                                                                    float(tr_loss), 
                                                                                    get_lr(optimizer)))
            t.update()
    return d_bck, d_arts, d_veins, f1_sc, mcc_sc, tr_loss

# TV LOADERS

In [18]:
n_classes=4

In [19]:
from models.res_unet_adrian import UNet as unet

class Wnet(torch.nn.Module):
    def __init__(self, n_classes=1, in_c=3, layers=(8, 16, 32), conv_bridge=True, shortcut=True, mode='train'):
        super(Wnet, self).__init__()
        self.mode=mode
        self.unet1 = unet(in_c=in_c, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)
        self.unet2 = unet(in_c=in_c+n_classes, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)

    def forward(self, x):
        x1 = self.unet1(x)
        x2 = self.unet2(torch.cat([x, x1], dim=1))
        if self.mode!='train':
            return x2
        return x1,x2

model = Wnet(in_c=3, n_classes=n_classes, layers=[8,16,32])
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
model.to(device);
params

68776

In [20]:
cycle_lens = [20, 50]
grad_acc_steps=0
n_cycles = cycle_lens[0]
min_lr = 1e-8

In [21]:
if len(cycle_lens)==2: # handles option of specifying cycles as pair (n_cycles, cycle_len)
    cycle_lens = cycle_lens[0]*[cycle_lens[1]]

In [23]:
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.CrossEntropyLoss()

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                              T_max=cycle_lens[0] * len(train_loader) // (grad_acc_steps + 1), 
                              eta_min=min_lr)
setattr(scheduler, 'cycle_lens', cycle_lens)

In [24]:
for cycle in range(20):
    
    _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        tr_d_bck, tr_d_arts, tr_d_veins, tr_f1, tr_mcc, tr_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        vl_d_bck, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, vl_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)

        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'per-class Train/Val DICE: {:.4f}/{:.4f} | {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_loss, vl_loss,
                                                                                     tr_d_bck, vl_d_bck,
                                                                                     tr_d_arts, vl_d_arts,
                                                                                     tr_d_veins, vl_d_veins))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        

 98%|█████████▊| 49/50 [00:51<00:01,  1.04s/it, Cycle: 1/20 Ep. 49/50 -- tr. loss=0.4306 / lr=0.000015]Mean of empty slice.
invalid value encountered in double_scalars
100%|██████████| 50/50 [00:53<00:00,  1.06s/it, Cycle: 1/20 Ep. 50/50 -- tr. loss=0.4272 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.4391/0.3960 -- per-class Train/Val DICE: 0.9667/0.9749 | 0.0000/0.0000 | 0.5838/0.5859
Train/Val F1|MCC: 0.9051/0.9133 | 0.5334/0.5634


100%|██████████| 50/50 [00:52<00:00,  1.04s/it, Cycle: 2/20 Ep. 50/50 -- tr. loss=0.3776 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3818/0.3619 -- per-class Train/Val DICE: 0.9735/0.9773 | 0.3946/0.4435 | 0.5802/0.5912
Train/Val F1|MCC: 0.9276/0.9335 | 0.6016/0.6242


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 3/20 Ep. 50/50 -- tr. loss=0.3439 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3675/0.3369 -- per-class Train/Val DICE: 0.9794/0.9782 | 0.5161/0.4978 | 0.5998/0.6298
Train/Val F1|MCC: 0.9308/0.9389 | 0.6309/0.6558


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 4/20 Ep. 50/50 -- tr. loss=0.3223 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3248/0.3187 -- per-class Train/Val DICE: 0.9780/0.9780 | 0.5565/0.5619 | 0.6752/0.6716
Train/Val F1|MCC: 0.9417/0.9437 | 0.6731/0.6837


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 5/20 Ep. 50/50 -- tr. loss=0.3182 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3140/0.3038 -- per-class Train/Val DICE: 0.9800/0.9784 | 0.5400/0.6019 | 0.6782/0.6958
Train/Val F1|MCC: 0.9435/0.9469 | 0.6828/0.7010


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 6/20 Ep. 50/50 -- tr. loss=0.3165 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3121/0.2969 -- per-class Train/Val DICE: 0.9750/0.9783 | 0.5893/0.6305 | 0.6774/0.7066
Train/Val F1|MCC: 0.9443/0.9487 | 0.6889/0.7119


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 7/20 Ep. 50/50 -- tr. loss=0.2959 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2942/0.2906 -- per-class Train/Val DICE: 0.9776/0.9784 | 0.6209/0.6452 | 0.6979/0.7100
Train/Val F1|MCC: 0.9480/0.9496 | 0.7059/0.7154


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 8/20 Ep. 50/50 -- tr. loss=0.3061 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3045/0.2840 -- per-class Train/Val DICE: 0.9785/0.9786 | 0.6202/0.6548 | 0.7001/0.7157
Train/Val F1|MCC: 0.9463/0.9505 | 0.7083/0.7199


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 9/20 Ep. 50/50 -- tr. loss=0.2937 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3058/0.2820 -- per-class Train/Val DICE: 0.9775/0.9787 | 0.6405/0.6555 | 0.7007/0.7173
Train/Val F1|MCC: 0.9457/0.9506 | 0.7045/0.7214


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 10/20 Ep. 50/50 -- tr. loss=0.2857 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2893/0.2809 -- per-class Train/Val DICE: 0.9788/0.9785 | 0.6525/0.6582 | 0.7107/0.7169
Train/Val F1|MCC: 0.9493/0.9506 | 0.7140/0.7219


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 11/20 Ep. 50/50 -- tr. loss=0.2815 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.3073/0.2751 -- per-class Train/Val DICE: 0.9759/0.9788 | 0.6805/0.6701 | 0.7104/0.7268
Train/Val F1|MCC: 0.9462/0.9519 | 0.7201/0.7285


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 12/20 Ep. 50/50 -- tr. loss=0.2758 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2767/0.2727 -- per-class Train/Val DICE: 0.9822/0.9787 | 0.6803/0.6745 | 0.7016/0.7284
Train/Val F1|MCC: 0.9513/0.9521 | 0.7206/0.7298


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 13/20 Ep. 50/50 -- tr. loss=0.2841 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2801/0.2719 -- per-class Train/Val DICE: 0.9808/0.9788 | 0.6169/0.6775 | 0.7051/0.7266
Train/Val F1|MCC: 0.9507/0.9522 | 0.7213/0.7304


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 14/20 Ep. 50/50 -- tr. loss=0.2698 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2867/0.2719 -- per-class Train/Val DICE: 0.9789/0.9784 | 0.6937/0.6710 | 0.7468/0.7270
Train/Val F1|MCC: 0.9496/0.9516 | 0.7267/0.7273


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 15/20 Ep. 50/50 -- tr. loss=0.2945 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2669/0.2694 -- per-class Train/Val DICE: 0.9783/0.9785 | 0.7037/0.6828 | 0.7582/0.7313
Train/Val F1|MCC: 0.9534/0.9524 | 0.7400/0.7323


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 16/20 Ep. 50/50 -- tr. loss=0.2807 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2742/0.2682 -- per-class Train/Val DICE: 0.9810/0.9787 | 0.6791/0.6796 | 0.7280/0.7342
Train/Val F1|MCC: 0.9522/0.9526 | 0.7364/0.7338


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 17/20 Ep. 50/50 -- tr. loss=0.2727 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2704/0.2700 -- per-class Train/Val DICE: 0.9758/0.9784 | 0.6679/0.6788 | 0.7296/0.7293
Train/Val F1|MCC: 0.9528/0.9520 | 0.7420/0.7306


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 18/20 Ep. 50/50 -- tr. loss=0.2663 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2710/0.2666 -- per-class Train/Val DICE: 0.9784/0.9785 | 0.7200/0.6809 | 0.7525/0.7343
Train/Val F1|MCC: 0.9530/0.9525 | 0.7449/0.7326


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 19/20 Ep. 50/50 -- tr. loss=0.2685 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.2735/0.2615 -- per-class Train/Val DICE: 0.9768/0.9789 | 0.6702/0.6844 | 0.7362/0.7381
Train/Val F1|MCC: 0.9520/0.9532 | 0.7277/0.7359


100%|██████████| 50/50 [00:52<00:00,  1.05s/it, Cycle: 20/20 Ep. 50/50 -- tr. loss=0.2664 / lr=0.000001]


Train/Val Loss: 0.2669/0.2655 -- per-class Train/Val DICE: 0.9789/0.9788 | 0.7134/0.6797 | 0.7437/0.7336
Train/Val F1|MCC: 0.9534/0.9527 | 0.7463/0.7338
