In [1]:
import os
import os.path as osp
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from skimage.io import imsave, imread
from skimage import img_as_ubyte, img_as_float
import sys
import torch

%load_ext autoreload
%autoreload 2

In [2]:
def imshow_pair(im, gdt, vmin1=None, vmax1=None, vmin2=None, vmax2=None):
    f, ax = plt.subplots(1, 2, figsize=(10,5))
    np_im = np.asarray(im)
    np_gdt = np.asarray(gdt)
    if len(np_im.shape) == 2:
        if vmin1==None:
            ax[0].imshow(np_im, cmap='gray'),  ax[0].axis('off')
        else:
            ax[0].imshow(np_im, cmap='gray', vmin=vmin1, vmax=vmax1),  ax[0].axis('off')
    else:
        ax[0].imshow(np_im),  ax[0].axis('off')
    if len(np_gdt.shape) == 2:
        if vmin2==None:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray'), ax[1].axis('off')

        else:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray', vmin=vmin2, vmax=vmax2), ax[1].axis('off')            
    else:
        ax[1].imshow(np.asarray(gdt)), ax[1].axis('off')
    plt.tight_layout()
    return f

In [3]:
from tqdm import trange
from skimage.color import label2rgb

In [4]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [5]:
from monai.metrics import DiceMetric
dice_metric = DiceMetric(mutually_exclusive=True, to_onehot_y=True, reduction='none', include_background=False)

In [6]:
from sklearn.metrics import f1_score
from sklearn.metrics import matthews_corrcoef as mcc

def evaluate(logits, labels):
    all_targets = []
    all_probs_0 = []
    all_probs_1 = []
    all_probs_2 = []

    for i in range(len(logits)):
        probs = torch.nn.Softmax(dim=0)(logits[i]).detach().cpu().numpy()
        all_probs_0.extend(probs[0].ravel())
        all_probs_1.extend(probs[1].ravel())
        all_probs_2.extend(probs[2].ravel())

        target = labels[i].numpy()

        all_targets.append(target.ravel())

    all_probs_np = np.stack([all_probs_0, all_probs_1, all_probs_2], axis=1)
    all_preds_np = np.argmax(all_probs_np, axis=1)
    
    all_targets_np = np.hstack(all_targets)
    all_preds_np = 1+all_preds_np # we are predicting only three classes and ignoring background
    all_preds_np[all_targets_np==0]=0
    
    return f1_score(all_targets_np, all_preds_np,average='weighted', labels=[1,2,3]), \
            mcc(all_targets_np[all_targets_np!=0], all_preds_np[all_targets_np!=0])

In [7]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
device

device(type='cuda', index=0)

## Loading Datasets

With a data source and transforms defined we can now create a dataset object. The base class for MONAI is `Dataset`, created here to load the image files only:

In [8]:
from utils.get_loaders import get_train_val_loaders

train_loader, val_loader = get_train_val_loaders(csv_path_train='data/DRIVE/train_av.csv', 
                                                 csv_path_val='data/DRIVE/val_av.csv', batch_size=4,
                                                 tg_size=(512,512), label_values=[0, 85, 170, 255], 
                                                 num_workers=8)

In [9]:
from utils.dice_loss import SimilarityLoss
sim_loss = SimilarityLoss()

In [10]:
class TvLoss(torch.nn.Module):
    def __init__(self, ignore_background=True, reduction='mean'):
        super(TvLoss, self).__init__()
        self.reduction = reduction
        self.ignore_background = ignore_background

    def compute_tv(self, logits, labels):

        probs = torch.nn.Softmax(dim=1)(logits)
        labels_oh = torch.cat([labels==0, labels==1, labels==2, labels==3], dim=1).long()

        probs = torch.mul(probs, labels_oh) # discard values outside labels

    #     foreground = torch.cat([labels!=0, labels!=0, labels!=0, labels!=0], dim=1).long()
    #     probs_filtered = torch.mul(probs, foreground) # discard values outside vessels

        tv_l = torch.abs(torch.sub(probs, torch.roll(probs, shifts=1, dims=-1)))
        tv_r = torch.abs(torch.sub(probs, torch.roll(probs, shifts=-1, dims=-1)))

        tv_u = torch.abs(torch.sub(probs, torch.roll(probs, shifts=-1, dims=-2)))
        tv_d = torch.abs(torch.sub(probs, torch.roll(probs, shifts=1, dims=-2)))
    #     tv_d = torch.clamp(tv_d, min=0, max=1)

        tv = torch.mean(torch.stack([tv_l, tv_r, tv_u, tv_d], axis=0), dim=0)

        return tv
    
    def forward(self, logits, labels):
        probs = torch.nn.Softmax(dim=1)(logits)
        labels_oh = torch.cat([labels==0, labels==1, labels==2, labels==3], dim=1).float()
        
        tv = self.compute_tv(logits, labels)
        
        perfect_tv = self.compute_tv(100*labels_oh, labels)>0
        tv[perfect_tv]=0
        
        tv = torch.div(tv, probs+1e-6)
        
       
        mean_per_elem_per_class = (tv.sum(dim=(-2, -1)) / (labels_oh.sum(dim=(-2, -1))+1e-6)  )
        mean_per_class = mean_per_elem_per_class.mean(dim=0)
        
        if self.reduction == 'mean':
            return mean_per_class[2:].mean()
        elif self.reduction == 'per_class':
            return mean_per_class[2:]
        elif self.reduction == 'per_elem_per_class':
            return mean_per_elem_per_class[:, 2:]
        elif self.reduction == 'none':
            return tv

In [11]:
tv_criterion = TvLoss(reduction='mean')

In [26]:
def run_one_epoch(loader, model, criterion, optimizer=None, scheduler=None,
                  grad_acc_steps=0, assess=False, save_plot=False, cycle=0):
    device='cuda' if next(model.parameters()).is_cuda else 'cpu'
    train = optimizer is not None  # if we are in training mode there will be an optimizer and train=True here

    if train: model.train()
    else: model.eval()
        
    if assess: dice_unc, dice_arteries, dice_veins, f1_scs, mcc_scs = 0, 0, 0, [], []
    n_elems, running_loss, tv_running_loss = 0, 0, 0
    wnet=False
    for i_batch, batch_data in enumerate(loader):
        try:
            inputs, labels = (batch_data["img"].to(device), batch_data["seg"].to(device), )
        except:
            inputs, labels = batch_data[0].to(device), batch_data[1].unsqueeze(dim=1).to(device)
            
            
        if train:  # only in training mode               
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits              
                loss_aux = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), 
                                              logits_aux], dim=1), labels.squeeze(dim=1))
            loss = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels.squeeze())
            
            tv_loss= 10*tv_criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels)
                       
            
            if wnet:
                loss+=loss_aux
            
            
            ( (loss+0.1*tv_loss) / (grad_acc_steps + 1)).backward()
#             ( (loss) / (grad_acc_steps + 1)).backward()
        
            if i_batch % (grad_acc_steps+1) == 0:  # for grad_acc_steps=0, this is always True
                optimizer.step()
                for _ in range(grad_acc_steps+1): scheduler.step() # for grad_acc_steps=0, this means once
                optimizer.zero_grad()
        
        else:
            logits = model(inputs)
            if isinstance(logits, tuple): # wnet
                wnet=True
                logits_aux, logits = logits
                loss_aux = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), 
                                              logits_aux], dim=1), labels.squeeze(dim=1))
            loss = criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels.squeeze())
            tv_loss= 10*tv_criterion(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels)
            
            if wnet:
                loss+=loss_aux

            if assess:
                dice_unc, dice_arteries, dice_veins = \
                dice_metric(torch.cat([-10*torch.ones(labels.shape).to(device), logits], dim=1), labels).mean(dim=0)
                
                if save_plot:
                    preds = torch.argmax(logits,dim=1)+1
                    back=labels==0
                    preds[back.squeeze()]=0.5
                    for j in range(logits.shape[0]):
                        from skimage.color import label2rgb
                        rgb_pred = label2rgb(preds[j].cpu().numpy(), colors=['black', 'green', 'red', 'blue'])
                        rgb_labels = label2rgb(labels[j].squeeze().cpu().numpy(), colors=['black', 'green', 'red', 'blue'])
#                         f=imshow_pair(preds[j].cpu(), labels[j].squeeze().cpu())
                        f=imshow_pair(rgb_pred, rgb_labels)
#                         f.savefig(s_name)
#                         plt.close(f)
                
                f1_s, mcc_s = evaluate(logits.detach().cpu(), labels.cpu())
                f1_scs.append(f1_s)
                mcc_scs.append(mcc_s)
        # Compute running loss
        running_loss += loss.item() * inputs.size(0)
        tv_running_loss += tv_loss.item() * inputs.size(0)
        n_elems += inputs.size(0)
        run_loss = running_loss / n_elems
        tv_run_loss = tv_running_loss / n_elems
            
    if assess: return dice_unc, dice_arteries, dice_veins, \
                      np.array(f1_scs).mean(), np.array(mcc_scs).mean(), run_loss, tv_run_loss
    return None, None, None, None, None, run_loss, tv_run_loss

In [13]:
def train_one_cycle(train_loader, model, criterion, optimizer=None, scheduler=None, grad_acc_steps=0, cycle=0):
    # prepare next cycle:
    # reset iteration counter
    scheduler.last_epoch = -1
    # update number of iterations

    scheduler.T_max = scheduler.cycle_lens[cycle] * len(train_loader)
    
    model.train()
    optimizer.zero_grad()
    cycle_len = scheduler.cycle_lens[cycle]
    with trange(cycle_len) as t:
        for epoch in range(cycle_len):
            if epoch == cycle_len-1: assess=True # only compute performance on last epoch
            else: assess = False

            d_bck, d_arts, d_veins, \
            f1_sc, mcc_sc, tr_loss, tr_tv_loss = run_one_epoch(train_loader, model, criterion, optimizer=optimizer,
                                                          scheduler=scheduler, grad_acc_steps=grad_acc_steps, 
                                                          assess=assess, cycle=cycle)
            t.set_postfix_str("Cycle: {}/{} Ep. {}/{} -- tr. loss={:.4f}|{:.4f} / lr={:.6f}".format(cycle+1, 
                                                                                    len(scheduler.cycle_lens),
                                                                                    epoch+1, cycle_len,
                                                                                    float(tr_loss), 
                                                                                    float(tr_tv_loss), 
                                                                                    get_lr(optimizer)))
            t.update()
    return d_bck, d_arts, d_veins, f1_sc, mcc_sc, tr_loss, tr_tv_loss

# TV LOADERS

In [27]:
n_classes=3

In [28]:
from models.res_unet_adrian import UNet as unet

class Wnet(torch.nn.Module):
    def __init__(self, n_classes=1, in_c=3, layers=(8, 16, 32), conv_bridge=True, shortcut=True, mode='train'):
        super(Wnet, self).__init__()
        self.mode=mode
        self.unet1 = unet(in_c=in_c, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)
        self.unet2 = unet(in_c=in_c+n_classes, n_classes=n_classes, layers=layers, conv_bridge=conv_bridge, shortcut=shortcut)

    def forward(self, x):
        x1 = self.unet1(x)
        x2 = self.unet2(torch.cat([x, x1], dim=1))
        if self.mode!='train':
            return x2
        return x1,x2

model = Wnet(in_c=3, n_classes=n_classes, layers=[8,16,32])
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
model.to(device);
params

68678

In [29]:
# from utils.model_saving_loading import load_model
# model, stats = load_model(model, 'experiments/BASELINE/', 'cpu')

In [30]:
cycle_lens = [20, 50]
grad_acc_steps=0
n_cycles = cycle_lens[0]
min_lr = 1e-8

In [31]:
# bb = next(iter(val_loader))

In [32]:
# x, y = bb[0], bb[1]

In [33]:
# logits = model(x.to(device))

In [34]:
# logits_aux, logits = logits

In [35]:
if len(cycle_lens)==2: # handles option of specifying cycles as pair (n_cycles, cycle_len)
    cycle_lens = cycle_lens[0]*[cycle_lens[1]]

In [36]:
optimizer = torch.optim.Adam(model.parameters(), 1e-2)
criterion = torch.nn.CrossEntropyLoss(ignore_index=0)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                              T_max=cycle_lens[0] * len(train_loader) // (grad_acc_steps + 1), 
                              eta_min=min_lr)
setattr(scheduler, 'cycle_lens', cycle_lens)

# With TV

In [None]:
for cycle in range(10):

    _, _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        _, _, _, tr_f1, tr_mcc, tr_loss, tr_tv_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        _, _, _, vl_f1, vl_mcc, vl_loss, vl_tv_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)

        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'Train/Val TV Loss: {:.4f}/{:.4f}'.format(tr_loss, vl_loss, tr_tv_loss, vl_tv_loss))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        

 98%|█████████▊| 49/50 [00:54<00:01,  1.11s/it, Cycle: 1/20 Ep. 49/50 -- tr. loss=0.7305|0.1039 / lr=0.000015]Mean of empty slice.
100%|██████████| 50/50 [00:55<00:00,  1.12s/it, Cycle: 1/20 Ep. 50/50 -- tr. loss=0.7124|0.0955 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7550/0.8098 -- Train/Val TV Loss: 0.1010/0.1276
Train/Val F1|MCC: 0.8635/0.8461 | 0.7363/0.6992


100%|██████████| 50/50 [00:55<00:00,  1.11s/it, Cycle: 2/20 Ep. 50/50 -- tr. loss=0.7048|0.0942 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7452/0.7951 -- Train/Val TV Loss: 0.1018/0.1197
Train/Val F1|MCC: 0.8697/0.8502 | 0.7486/0.7084


100%|██████████| 50/50 [00:55<00:00,  1.12s/it, Cycle: 3/20 Ep. 50/50 -- tr. loss=0.6708|0.0832 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.6913/0.7891 -- Train/Val TV Loss: 0.0894/0.1208
Train/Val F1|MCC: 0.8838/0.8528 | 0.7752/0.7128


100%|██████████| 50/50 [00:55<00:00,  1.12s/it, Cycle: 4/20 Ep. 50/50 -- tr. loss=0.6952|0.0929 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7814/0.8404 -- Train/Val TV Loss: 0.1095/0.1312
Train/Val F1|MCC: 0.8631/0.8385 | 0.7357/0.6849


 24%|██▍       | 12/50 [00:13<00:42,  1.11s/it, Cycle: 5/20 Ep. 12/50 -- tr. loss=0.7390|0.1089 / lr=0.008698]

# Without TV

In [25]:
for cycle in range(10):

    _, _, _, _, _, _, _ = train_one_cycle(train_loader,model, criterion, optimizer,scheduler,cycle=cycle)

    save_plot = (cycle+1)%5==0
    save_plot=False
    with torch.no_grad():
        _, _, _, tr_f1, tr_mcc, tr_loss, tr_tv_loss = run_one_epoch(train_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)
        
        _, _, _, vl_f1, vl_mcc, vl_loss, vl_tv_loss = run_one_epoch(val_loader, model, criterion, 
                                                                 optimizer=None, scheduler=None,
                                                                 grad_acc_steps=0, assess=True, 
                                                                 save_plot=save_plot, cycle=cycle)

        
        print('Train/Val Loss: {:.4f}/{:.4f} -- '\
              'Train/Val TV Loss: {:.4f}/{:.4f}'.format(tr_loss, vl_loss, tr_tv_loss, vl_tv_loss))
        print('Train/Val F1|MCC: {:.4f}/{:.4f} | {:.4f}/{:.4f}'.format(tr_f1, vl_f1, tr_mcc, vl_mcc))        

 98%|█████████▊| 49/50 [00:52<00:01,  1.09s/it, Cycle: 1/20 Ep. 49/50 -- tr. loss=0.8626|0.1655 / lr=0.000015]Mean of empty slice.
100%|██████████| 50/50 [00:53<00:00,  1.08s/it, Cycle: 1/20 Ep. 50/50 -- tr. loss=0.8471|0.1588 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.8622/0.8072 -- Train/Val TV Loss: 0.1621/0.1526
Train/Val F1|MCC: 0.8260/0.8462 | 0.6641/0.7008


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 2/20 Ep. 50/50 -- tr. loss=0.8239|0.1589 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.8419/0.8351 -- Train/Val TV Loss: 0.1620/0.1724
Train/Val F1|MCC: 0.8326/0.8396 | 0.6759/0.6873


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 3/20 Ep. 50/50 -- tr. loss=0.8370|0.1538 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.8641/0.7777 -- Train/Val TV Loss: 0.1689/0.1483
Train/Val F1|MCC: 0.8316/0.8526 | 0.6747/0.7128


100%|██████████| 50/50 [00:54<00:00,  1.08s/it, Cycle: 4/20 Ep. 50/50 -- tr. loss=0.8166|0.1454 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.8048/0.7932 -- Train/Val TV Loss: 0.1495/0.1570
Train/Val F1|MCC: 0.8474/0.8483 | 0.7067/0.7044


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 5/20 Ep. 50/50 -- tr. loss=0.8119|0.1604 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7490/0.7821 -- Train/Val TV Loss: 0.1308/0.1585
Train/Val F1|MCC: 0.8625/0.8591 | 0.7342/0.7252


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 6/20 Ep. 50/50 -- tr. loss=0.7867|0.1469 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.8137/0.7965 -- Train/Val TV Loss: 0.1491/0.1682
Train/Val F1|MCC: 0.8458/0.8526 | 0.7012/0.7123


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 7/20 Ep. 50/50 -- tr. loss=0.7472|0.1373 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7764/0.8032 -- Train/Val TV Loss: 0.1424/0.1645
Train/Val F1|MCC: 0.8546/0.8516 | 0.7194/0.7107


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 8/20 Ep. 50/50 -- tr. loss=0.7233|0.1230 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7572/0.7748 -- Train/Val TV Loss: 0.1294/0.1550
Train/Val F1|MCC: 0.8584/0.8586 | 0.7257/0.7239


100%|██████████| 50/50 [00:54<00:00,  1.08s/it, Cycle: 9/20 Ep. 50/50 -- tr. loss=0.7039|0.1233 / lr=0.000001]
  0%|          | 0/50 [00:00<?, ?it/s]

Train/Val Loss: 0.7411/0.7729 -- Train/Val TV Loss: 0.1326/0.1563
Train/Val F1|MCC: 0.8668/0.8565 | 0.7426/0.7199


100%|██████████| 50/50 [00:54<00:00,  1.09s/it, Cycle: 10/20 Ep. 50/50 -- tr. loss=0.6871|0.1181 / lr=0.000001]


Train/Val Loss: 0.7350/0.7941 -- Train/Val TV Loss: 0.1340/0.1638
Train/Val F1|MCC: 0.8621/0.8532 | 0.7329/0.7133


In [None]:
with torch.no_grad():
    vl_d_unc, vl_d_arts, vl_d_veins, vl_f1, vl_mcc, tr_loss = run_one_epoch(val_loader, model, criterion, 
                                                             optimizer=None, scheduler=None,
                                                             grad_acc_steps=0, assess=True, 
                                                             save_plot=True, cycle=0)

In [None]:
vl_f1

In [None]:
x,labels = next(iter(train_loader))
labels=labels.unsqueeze(dim=1)
x.shape, labels.shape

In [None]:
with torch.no_grad():
    logits = model(x.to(device))
del x
logits_aux, logits_pre = logits
logits_aux, logits_pre = logits_aux.cpu(), logits_pre.cpu()

In [None]:
logits_pre.is_cuda

In [None]:
class TvLoss(torch.nn.Module):
    def __init__(self, ignore_background=True, reduction='mean'):
        super(TvLoss, self).__init__()
        self.reduction = reduction
        self.ignore_background = ignore_background

    def compute_tv(self, logits, labels):

        probs = torch.nn.Softmax(dim=1)(logits)
        labels_oh = torch.cat([labels==0, labels==1, labels==2, labels==3], dim=1).long()

        probs = torch.mul(probs, labels_oh) # discard values outside labels

    #     foreground = torch.cat([labels!=0, labels!=0, labels!=0, labels!=0], dim=1).long()
    #     probs_filtered = torch.mul(probs, foreground) # discard values outside vessels

        tv_l = torch.abs(torch.sub(probs, torch.roll(probs, shifts=1, dims=-1)))
        tv_r = torch.abs(torch.sub(probs, torch.roll(probs, shifts=-1, dims=-1)))

        tv_u = torch.abs(torch.sub(probs, torch.roll(probs, shifts=-1, dims=-2)))
        tv_d = torch.abs(torch.sub(probs, torch.roll(probs, shifts=1, dims=-2)))
    #     tv_d = torch.clamp(tv_d, min=0, max=1)

        tv = torch.mean(torch.stack([tv_l, tv_r, tv_u, tv_d], axis=0), dim=0)

        return tv
    
    def forward(self, logits, labels):
        probs = torch.nn.Softmax(dim=1)(logits)
        labels_oh = torch.cat([labels==0, labels==1, labels==2, labels==3], dim=1).float()
        
        tv = self.compute_tv(logits, labels)
        
        perfect_tv = compute_tv(100*labels_oh, labels)>0
        tv[perfect_tv]=0
        
        tv = torch.div(tv, probs+1e-6)
        
       
        mean_per_elem_per_class = (tv.sum(dim=(-2, -1)) / (labels_oh.sum(dim=(-2, -1))+1e-6)  )
        mean_per_class = mean_per_elem_per_class.mean(dim=0)
        
        if self.reduction == 'mean':
            return mean_per_class[1:].mean()
        elif self.reduction == 'per_class':
            return mean_per_class[1:]
        elif self.reduction == 'per_elem_per_class':
            return mean_per_elem_per_class[:, 1:]
        elif self.reduction == 'none':
            return tv

In [None]:
tv_loss = TvLoss(reduction='none')
tv_loss_r = TvLoss(reduction='per_elem_per_class')

In [None]:
labels.shape

In [None]:
tv = tv_loss(logits, labels)
tv.max()

In [None]:
bb=1
imshow_pair(labels_oh[bb,2,50:300,10+250:10+400], labels_oh[bb,3,50:300,10+250:10+400], vmin1=0,vmax1=1);
imshow_pair(probs_filtered[bb,2,50:300,10+250:10+400], probs_filtered[bb,3,50:300,10+250:10+400], vmin1=0,vmax1=1, vmin2=0,vmax2=1);
imshow_pair(tv[bb,2,50:300,10+250:10+400], tv[bb,3,50:300,10+250:10+400]);
tt=tv_loss_r(logits[:,:,50:300,10+250:10+400], labels[:,:,50:300,10+250:10+400])
tt

In [None]:
imshow_pair(tv[1,2], probs[1,2]);

In [None]:
tv = compute_tv(logits, labels)

In [None]:
perfect_tv = compute_tv(100*labels_oh.float(), labels)
tv[perfect_tv>0]=0

In [None]:
tv = torch.div(tv, probs+1e-6)
tv.max()

In [None]:
imshow_pair(tv[1,2], probs[1,2]);

In [None]:
logits = torch.cat([-100*torch.ones(labels.shape), logits_pre], dim=1)
probs = torch.nn.Softmax(dim=1)(logits)
labels_oh = torch.cat([labels==0, labels==1, labels==2, labels==3], dim=1).long()
complement = 1-labels_oh

foreground = torch.cat([labels!=0, labels!=0, labels!=0, labels!=0], dim=1).long()
background = torch.cat([labels==0, labels==0, labels==0, labels==0], dim=1).long()

probs.shape, labels.shape, labels_oh.shape

In [None]:
probs_filtered = torch.mul(probs, labels_oh)
# probs_filtered = torch.mul(probs, foreground)

In [None]:
imshow_pair(labels_oh[1,2], labels_oh[1,3], vmin1=0,vmax1=1);

In [None]:
imshow_pair(probs[1,2], probs_filtered[1,2], vmin1=0,vmax1=1);

In [None]:
imshow_pair(probs[1,3], probs_filtered[1,3], vmin1=0,vmax1=1);

In [None]:
imshow_pair(labels_oh[1,2,0:150,250:400], labels_oh[1,3,0:150,250:400], vmin1=0,vmax1=1);
imshow_pair(probs_filtered[1,2,0:150,250:400], probs_filtered[1,3,0:150,250:400], vmin1=0,vmax1=1, vmin2=0,vmax2=1);
imshow_pair(tv[1,2,0:150,250:400], tv[1,3,0:150,250:400], vmin1=0,vmax1=1, vmin2=0,vmax2=1);

In [None]:
imshow_pair(probs_filtered[1,2,250:450,250:500],labels_oh[1,2,250:450,250:500]);
imshow_pair(probs_filtered[1,2,250:450,250:500],tv[1,2,250:450,250:500]);

In [None]:
imshow_pair(probs_filtered[1,3,250:450,250:500],labels_oh[1,3,250:450,250:500],vmin1=0,vmax1=1,vmin2=0,vmax2=1);
imshow_pair(probs_filtered[1,3,250:450,250:500],tv[1,3,250:450,250:500],vmin1=0,vmax1=1,vmin2=0,vmax2=1);