In [1]:
import os
import os.path as osp
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
%load_ext autoreload
%autoreload 2
from skimage.io import imsave, imread
from skimage import img_as_ubyte, img_as_float
import sys
# sys.path.insert(0, '../utils')


In [2]:
def imshow_pair(im, gdt, vmin1=None, vmax1=None, vmin2=None, vmax2=None):
    f, ax = plt.subplots(1, 2, figsize=(12,6))
    np_im = np.asarray(im)
    np_gdt = np.asarray(gdt)
    if len(np_im.shape) == 2:
        if vmin1==None:
            ax[0].imshow(np_im, cmap='gray'),  ax[0].axis('off')
        else:
            ax[0].imshow(np_im, cmap='gray', vmin=vmin1, vmax=vmax1),  ax[0].axis('off')
    else:
        ax[0].imshow(np_im),  ax[0].axis('off')
    if len(np_gdt.shape) == 2:
        if vmin2==None:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray', vmin=vmin2, vmax=vmax2), ax[1].axis('off')
        else:
            ax[1].imshow(np.asarray(gdt), cmap = 'gray'), ax[1].axis('off')
    else:
        ax[1].imshow(np.asarray(gdt)), ax[1].axis('off')
    plt.tight_layout()

In [3]:
import torch
import torchvision.transforms as tr
from utils.get_loaders import get_train_val_loaders
from models.get_model import get_arch

In [4]:
csv_train = 'data/DRIVE/train_av.csv'
csv_val = csv_train.replace('train', 'val')

In [5]:
n_classes=4
label_values=[0, 85, 170, 255]

In [6]:
train_loader, val_loader = get_train_val_loaders(csv_path_train=csv_train, 
                                                 csv_path_val=csv_val, 
                                                 batch_size=2, tg_size=(512,512), 
                                                 label_values=label_values, num_workers=8)

In [7]:
print('* Instantiating a {} model'.format('wnet'))
model = get_arch('big_wnet', n_classes=n_classes)

* Instantiating a wnet model


In [8]:
from utils.model_saving_loading import load_model
model, stats = load_model(model, 'experiments/big_wnet_drive_av/', 'cpu')

In [9]:
(inputs, labels) = next(iter(val_loader))

In [10]:
from torch.nn import functional as F

In [11]:
logits_aux, logits = model(inputs)
probs = torch.nn.Softmax(dim=1)(logits).detach()

In [18]:
probs[:,2][labels==2].shape

torch.Size([21441])

In [19]:
torch.roll(probs[:,2][labels==2], shifts=1, dims=0).shape

torch.Size([21441])

In [23]:
sim_v=torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], torch.roll(probs[:,2][labels==2], shifts=1, dims=0))
sim_a=torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], torch.roll(probs[:,3][labels==3], shifts=1, dims=0))
sim_v, sim_a

(tensor(0.9409), tensor(0.9442))

In [24]:
sim_v=torch.nn.CosineSimilarity(dim=0)(logits[:,2][labels==2], torch.roll(logits[:,2][labels==2], shifts=1, dims=0))
sim_a=torch.nn.CosineSimilarity(dim=0)(logits[:,3][labels==3], torch.roll(logits[:,3][labels==3], shifts=1, dims=0))
sim_v, sim_a

(tensor(0.8823, grad_fn=<DivBackward0>),
 tensor(0.9408, grad_fn=<DivBackward0>))

In [30]:
probs.shape

torch.Size([2, 4, 512, 512])

In [37]:
_, preds = torch.max(probs, dim=1)

In [58]:
(preds[labels==2] == labels[labels==2]).sum()/torch.sum(labels==2).float()

tensor(0.6897)

In [59]:
(preds[labels==3] == labels[labels==3]).sum()/torch.sum(labels==3).float()

tensor(0.7481)

In [12]:
labels_map = torch.stack([labels==0,labels==1,labels==2,labels==3], dim=1)

In [13]:
similarity_veins, similarity_arteries = [], []

translated = torch.roll(probs, shifts=1, dims=3)
sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

similarity_veins.append(sim_along_veins)
similarity_arteries.append(sim_along_arteries)

translated = torch.roll(probs, shifts=-1, dims=3)
sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

similarity_veins.append(sim_along_veins)
similarity_arteries.append(sim_along_arteries)

translated = torch.roll(probs, shifts=1, dims=2)
sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

similarity_veins.append(sim_along_veins)
similarity_arteries.append(sim_along_arteries)

translated = torch.roll(probs, shifts=-1, dims=2)
sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

similarity_veins.append(sim_along_veins)
similarity_arteries.append(sim_along_arteries)

In [14]:
sim_veins = torch.mean(torch.stack(similarity_veins, dim=0), dim=0)
sim_arts = torch.mean(torch.stack(similarity_arteries, dim=0), dim=0)

In [17]:
sim_veins, sim_arts

(tensor(0.9613), tensor(0.9655))

In [18]:
class SimilarityLoss(torch.nn.Module):
    def __init__(self, with_probs=True, reduction='mean'):
        super(SimilarityLoss, self).__init__()
        self.with_probs=with_probs
        self.reduction=reduction

    def forward(self, logits, labels, **kwargs):
        # assumes logits is bs x n_classes H x W, 
        #         labels is bs x H x W containing integer values in [0,...,n_classes-1]
        if self.with_probs:
            probs = torch.nn.Softmax(dim=1)(logits).detach()
        else:
            probs=logits
            
        similarity_veins, similarity_arteries = [], []

        translated = torch.roll(probs, shifts=1, dims=3)
        sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
        sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

        similarity_veins.append(sim_along_veins)
        similarity_arteries.append(sim_along_arteries)

        translated = torch.roll(probs, shifts=-1, dims=3)
        sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
        sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

        similarity_veins.append(sim_along_veins)
        similarity_arteries.append(sim_along_arteries)

        translated = torch.roll(probs, shifts=1, dims=2)
        sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
        sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

        similarity_veins.append(sim_along_veins)
        similarity_arteries.append(sim_along_arteries)

        translated = torch.roll(probs, shifts=-1, dims=2)
        sim_along_veins = torch.nn.CosineSimilarity(dim=0)(probs[:,2][labels==2], translated[:,2][labels==2])
        sim_along_arteries = torch.nn.CosineSimilarity(dim=0)(probs[:,3][labels==3], translated[:,3][labels==3])

        similarity_veins.append(sim_along_veins)
        similarity_arteries.append(sim_along_arteries)
        
        sim_veins = torch.mean(torch.stack(similarity_veins, dim=0), dim=0)
        sim_arts = torch.mean(torch.stack(similarity_arteries, dim=0), dim=0)
        
        if self.reduction=='mean': # 1 value for the entire batch
            return torch.mean(torch.stack([sim_veins, sim_arts], dim=0), dim=0)
        elif self.reduction=='none': # n_classes values per element in batch
            return [sim_veins, sim_arts]
        else: sys.exit('not a valid reduction scheme')

In [24]:
criterion = SimilarityLoss(reduction='mean', with_probs=False)

In [28]:
l = criterion(logits, labels)
l

tensor(0.9598, grad_fn=<MeanBackward1>)

In [29]:
model = get_arch('big_wnet', n_classes=n_classes)
logits_aux, logits = model(inputs)

In [30]:
l = criterion(logits, labels)
l

tensor(0.2505, grad_fn=<MeanBackward1>)