### Wasserstein Distance Toy Example

In [1]:
from config import *
from dataset import MNIST, CIFAR10
from models.mnist_cnn import MNISTCNN
from trainer import train

%load_ext autoreload
%autoreload 2



#### Load Training & Validation Datasets \& Dataloaders

In [2]:
mnist_tri_set, mnist_val_set, mnist_tri_loader, mnist_val_loader = MNIST(64, 32)
cifar_tri_set, cifar_val_set , cifar_tri_loader, cifar_val_loader = MNIST(64, 32)
# TODO: Show dataset statistics and sample images.

#### Training a toy CNN using MNIST
We first train a toy CNN model on MNIST dataset.

In [3]:
model = MNISTCNN().to(DEVICE)
train(model = model, train_loader=mnist_tri_loader, val_loader=mnist_val_loader, num_epoch=8)

  0%|          | 0/8 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [23]:
torch.save(model, MODEL_SAVE_PATH)

In [4]:
model = torch.load(MODEL_SAVE_PATH)

In [5]:
# Simple Sanity Check
test_sample, test_label = mnist_tri_set.__getitem__(7)
ic(test_sample.shape)
test_logits = model(test_sample.unsqueeze(0))
test_softmax = torch.softmax(test_logits, dim=-1)
ic(test_softmax.shape)
pred = torch.argmax(test_logits, dim=1)
ic(test_softmax)
ic(pred)
ic(test_label)

ic| test_sample.shape: torch.Size([1, 28, 28])
ic| test_softmax.shape: torch.Size([1, 10])
ic| test_softmax: tensor([[1.2200e-13, 2.2747e-13, 2.1682e-09, 1.0000e+00, 1.5332e-16, 3.3098e-12,
                           2.3926e-20, 2.2021e-10, 1.2905e-08, 5.4976e-08]],
                         grad_fn=<SoftmaxBackward0>)
ic| pred: tensor([3])
ic| test_label: 3


3

#### Wasserstein Distance

In [6]:
def cost_matrix(X, Y):
    # TODO: Change this to more generic version
    if len(X.shape) == 2:
        N,D = X.shape
        M,D = Y.shape
        return (1 - torch.eye(N, M)).to(DEVICE)
    
    if len(X.shape) == 3:
        B,N,D = X.shape
        B,M,D = Y.shape
        return torch.unsqueeze(1 - torch.eye(N, M), 0).repeat(B, 1, 1).to(DEVICE)

In [7]:
def label_2_onehot(label, C, device):
    # transform the InD labels into one-hot vector
    assert type(label) == torch.Tensor

    size = label.shape[0]
    if len(label.shape) == 1:
        label = torch.unsqueeze(label, 1)
    
    label = label % C
    
    label_onehot = torch.FloatTensor(size, C).to(device)

    label_onehot.zero_()
    label_onehot.scatter_(1, label, 1)
    return label_onehot

In [8]:
def sink_dist_test(input, target, C, device):
    
    test_label_onehot = label_2_onehot(target, C, device)
    test_label_onehot = torch.unsqueeze(test_label_onehot, -1)
    test_input = torch.unsqueeze(input, -1)
    ##Loss value for InD samples
    test_loss = SamplesLoss("sinkhorn", p=2, blur=1., cost=cost_matrix) #Wasserstein-1 distance
    ic(test_input.shape)
    ic(test_input[:,:,0].shape)
    test_loss_value = test_loss(test_input[:,:,0], test_input, test_label_onehot[:,:,0], test_label_onehot)
    
    return test_loss_value

def sink_dist_test_v2(input, C, device):
    
    all_class = torch.LongTensor([i for i in range(C)]).to(device)
    all_class_onehot = label_2_onehot(all_class, C, device)
    ##reshape into (B,N,D)
    all_class_onehot = torch.unsqueeze(all_class_onehot, -1)
    test_input = torch.unsqueeze(input, -1)
    test_batch_size = test_input.shape[0]
    test_loss_values = torch.zeros(test_batch_size, C).to(device)
    test_loss = SamplesLoss("sinkhorn", p=2, blur=1., cost = cost_matrix) #Wasserstein-1 distance
    # ic(test_batch_size)
    for b in range(test_batch_size):
        input_b = test_input[b:b+1,:,:].repeat(C, 1, 1)
        # ic(input_b.shape)
        # ic(input_b[:,:,0].shape)
        # ic(all_class_onehot.shape)
        test_loss_values[b] = test_loss(input_b[:,:,0], input_b, all_class_onehot[:,:,0], all_class_onehot)
    
    return test_loss_values.min(dim=1)[0]

# Sanity Check
one_hot_eg = label_2_onehot(torch.tensor([test_label]), 10, DEVICE)
sample_wass_loss = sink_dist_test(test_softmax, torch.tensor([test_label]), 10, DEVICE)

ic| test_input.shape: torch.Size([1, 10, 1])
ic| test_input[:,:,0].shape: torch.Size([1, 10])


In [9]:
ic(sample_wass_loss)

ic| sample_wass_loss: tensor([3.1974e-14], grad_fn=<AddBackward0>)


tensor([3.1974e-14], grad_fn=<AddBackward0>)

In [10]:
# Sanity Check
OOD_sample = cifar_tri_set.__getitem__(0)[0].mean(0, keepdim=True)
OOD_logits = model(OOD_sample.unsqueeze(0))
# ic(OOD_logits.shape)
OOD_softmax = torch.softmax(OOD_logits, dim=-1)
# ic(OOD_softmax.shape)
pred = torch.argmax(OOD_logits, dim=1)
# ic(OOD_softmax)
ic(pred)
# Sanity check for OOD wasserstein distance
OOD_wass_loss = sink_dist_test_v2(input=OOD_softmax, C=10, device=DEVICE)
ic(OOD_wass_loss)

ic| pred: tensor([5])


RuntimeError: The size of tensor a (10) must match the size of tensor b (100) at non-singleton dimension 1