### Wasserstein Distance Toy Example

In [1]:
from config import *
from dataset import MNIST, CIFAR10
from models.mnist_cnn import MNISTCNN
from trainer import train

%load_ext autoreload
%autoreload 2



#### Load Training & Validation Datasets \& Dataloaders

In [2]:
mnist_tri_set, mnist_val_set, mnist_tri_loader, mnist_val_loader = MNIST(64, 32)
cifar_tri_set, cifar_val_set , cifar_tri_loader, cifar_val_loader = MNIST(64, 32)
# TODO: Show dataset statistics and sample images.

#### Training a toy CNN using MNIST
We first train a toy CNN model on MNIST dataset.

In [None]:
model = MNISTCNN().to(DEVICE)
train(model = model, train_loader=mnist_tri_loader, val_loader=mnist_val_loader, num_epoch=8)

In [23]:
torch.save(model, MODEL_SAVE_PATH)

In [3]:
model = torch.load(MODEL_SAVE_PATH)

#### Wasserstein Distance

In [15]:
def cost_matrix(X, Y):
    # TODO: Change this to more generic version
    if len(X.shape) == 2:
        N,D = X.shape
        M,D = Y.shape
        return (1 - torch.eye(N, M)).to(DEVICE)
    
    if len(X.shape) == 3:
        B,N,D = X.shape
        B,M,D = Y.shape
        return torch.unsqueeze(1 - torch.eye(N, M), 0).repeat(B, 1, 1).to(DEVICE)

In [6]:
def label_2_onehot(label, C, device):
    # transform the InD labels into one-hot vector
    assert type(label) == torch.Tensor

    size = label.shape[0]
    if len(label.shape) == 1:
        label = torch.unsqueeze(label, 1)
    
    label = label % C
    
    label_onehot = torch.FloatTensor(size, C).to(device)

    label_onehot.zero_()
    label_onehot.scatter_(1, label, 1)
    return label_onehot

In [51]:
def sink_dist_test(input, target, C, device):
    
    test_label_onehot = label_2_onehot(target, C, device)
    test_label_onehot = torch.unsqueeze(test_label_onehot, -1)
    test_input = torch.unsqueeze(input, -1)
    ##Loss value for InD samples 
    #Wasserstein-1 distance
    test_loss = SamplesLoss("sinkhorn", p=2, blur=1., cost=cost_matrix)
    # ic(test_input.shape)
    # ic(test_input[:,:,0].shape)
    # ic(test_label_onehot[:,:,0].shape)
    # ic(test_label_onehot.shape)
    test_loss_value = test_loss(test_input[:,:,0], test_input, test_label_onehot[:,:,0], test_label_onehot)
    
    return test_loss_value

def sink_dist_test_v2(input, C, device):
    
    all_class = torch.LongTensor([i for i in range(C)]).to(device)
    all_class_onehot = label_2_onehot(all_class, C, device)
    ##reshape into (B,N,D)
    all_class_onehot = torch.unsqueeze(all_class_onehot, -1)
    test_input = torch.unsqueeze(input, -1)
    test_batch_size = test_input.shape[0]
    test_loss_values = torch.zeros(test_batch_size, C).to(device)
    # Approximate Wasserstein distance
    test_loss = SamplesLoss("sinkhorn", p=2, blur=1., cost = cost_matrix) 
    # ic(test_batch_size)
    for b in range(test_batch_size):
        # ic(test_input.shape)
        input_b = test_input[b:b+1,:,:].repeat(C, 1, 1)
        # ic(input_b.shape)
        # ic(input_b[0:1,:,0].shape)
        # ic(all_class_onehot[:,:,0].shape)
        # ic(all_class_onehot.shape)
        # Modified the line below
        test_loss_values[b] = torch.tensor([test_loss(input_b[c:c+1,:,0], input_b[c:c+1:,:], all_class_onehot[c:c+1,:,0], \
                                            all_class_onehot[c:c+1:,:]) for c in range(C)])
    
    return test_loss_values.min(dim=1)[0]


#### Wasserstein Distance Toy Example Sanity Check

In [52]:
# Simple Sanity Check
def example_wass_loss_ind(img_id_lst):
    wass_loss = []
    for id in img_id_lst:
        test_sample, test_label = mnist_tri_set.__getitem__(id)
        # ic(test_sample.shape)
        test_logits = model(test_sample.unsqueeze(0))
        test_softmax = torch.softmax(test_logits, dim=-1)
        # ic(test_softmax.shape)
        # pred = torch.argmax(test_logits, dim=1)
        # ic(test_softmax)
        # ic(pred)
        # ic(test_label)
        # one_hot_eg = label_2_onehot(torch.tensor([test_label]), 10, DEVICE)
        sample_wass_loss = sink_dist_test(test_softmax, torch.tensor([test_label]), 10, DEVICE)
        wass_loss.append(sample_wass_loss)
    return torch.tensor(wass_loss)

In [42]:
ic(mnist_tri_set)
ic(cifar_tri_set)

ic| mnist_tri_set: Dataset MNIST
                       Number of datapoints: 60000
                       Root location: ./Datasets
                       Split: Train
                       StandardTransform
                   Transform: Compose(
                                  ToTensor()
                              )
ic| cifar_tri_set: Dataset MNIST
                       Number of datapoints: 60000
                       Root location: ./Datasets
                       Split: Train
                       StandardTransform
                   Transform: Compose(
                                  ToTensor()
                              )


Dataset MNIST
    Number of datapoints: 60000
    Root location: ./Datasets
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
           )

In [58]:
img_id_lst = torch.randint(low=0, high=60000, size=(2000,))
wass_loss_eg = example_wass_loss_ind(img_id_lst)
mean_wass_loss = torch.mean(wass_loss_eg)
ic(mean_wass_loss)

ic| mean_wass_loss: tensor(0.0061)


tensor(0.0061)

In [57]:
# Sanity Check
def example_wass_loss_ood(img_id_lst):
    wass_loss = []
    for id in img_id_lst:
        OOD_sample = cifar_tri_set.__getitem__(id)[0].mean(0, keepdim=True)
        OOD_logits = model(OOD_sample.unsqueeze(0))
        # ic(OOD_logits.shape)
        OOD_softmax = torch.softmax(OOD_logits, dim=1)
        # ic(OOD_softmax.shape)
        # pred = torch.argmax(OOD_logits, dim=1)
        # ic(OOD_softmax)
        # ic(pred)
        # Sanity check for OOD wasserstein distance
        OOD_wass_loss = sink_dist_test_v2(input=OOD_softmax, C=torch.tensor(10), device=DEVICE)
        # ic(OOD_wass_loss)
        wass_loss.append(OOD_wass_loss)
    return torch.tensor(wass_loss)    
        

img_id_lst = torch.randint(low=0, high=60000, size=(2000,))
wass_loss_eg = example_wass_loss_ood(img_id_lst)
mean_wass_loss = torch.mean(wass_loss_eg)
ic(mean_wass_loss)

ic| mean_wass_loss: tensor(0.0039)


tensor(0.0039)

ic| wass_loss_eg[0:25]: tensor([1.0641e-10, 1.6086e-05, 8.1767e-03, 2.5580e-13, 8.0821e-04, 1.1333e-01,
                                1.1584e-05, 9.1160e-08, 6.4926e-05, 2.7420e-07, 2.3192e-11, 3.2685e-13,
                                8.7994e-11, 1.3145e-12, 3.7460e-11, 8.5549e-12, 1.6371e-10, 2.7254e-06,
                                1.8332e-12, 1.8777e-06, 3.0437e-06, 7.5060e-07, 6.8261e-04, 3.8321e-07,
                                1.6147e-06])


tensor([1.0641e-10, 1.6086e-05, 8.1767e-03, 2.5580e-13, 8.0821e-04, 1.1333e-01,
        1.1584e-05, 9.1160e-08, 6.4926e-05, 2.7420e-07, 2.3192e-11, 3.2685e-13,
        8.7994e-11, 1.3145e-12, 3.7460e-11, 8.5549e-12, 1.6371e-10, 2.7254e-06,
        1.8332e-12, 1.8777e-06, 3.0437e-06, 7.5060e-07, 6.8261e-04, 3.8321e-07,
        1.6147e-06])