In [1]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
import os
import cv2
import multiprocessing as mp
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import distance

ModuleNotFoundError: No module named 'distance'

# ENCODER CLASS 

Define a class which represents the main block of the prototypical networks


In [2]:
class Flatten(nn.Module):
    """Converts N-dimensional Tensor of shape [batch_size, d1, d2, ..., dn] to 2-dimensional Tensor
    of shape [batch_size, d1*d2*...*dn].
    # Arguments
        input: Input tensor
    """
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

In [3]:
def conv_block(in_channels, out_channels):
    return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
                         nn.BatchNorm2d(out_channels),nn.ReLU(),nn.MaxPool2d(2))

In [4]:
def load_protonet_conv(x_dim, hid_dim, z_dim):
    """
    Loads the prototypical network model
    Arg:
    x_dim (tuple): dimension of input image
    hid_dim (int): dimension of hidden layers in conv blocks
    z_dim (int): dimension of embedded image
    Returns:
    Model (Class ProtoNet)
      """


    
    encoder = nn.Sequential(conv_block(x_dim[0], hid_dim),conv_block(hid_dim, hid_dim),
                            conv_block(hid_dim, hid_dim),conv_block(hid_dim, z_dim),Flatten())
    
    return Protonet(encoder)

# PROTONET CLASS


In [9]:
class Protonet(nn.Module):
    
    def __init__(self, encoder):
        super(Protonet,self).__init__()
        self.encoder = encoder
    
    def set_forward_loss(self, sample):
        """
        Computes loss, accuracy and output for classification task
        Args:
            sample (torch.Tensor): shape (n_way, n_support+n_query, (dim)) 
        Returns:
            torch.Tensor: shape(2), loss, accuracy and y_hat
        """
        sample_images = sample['images']
        n_way = sample['n_way']
        n_support = sample['n_support']
        n_query = sample['n_query']

        x_support = sample_images[:, :n_support]
        x_query = sample_images[:, n_support:]

        #target indices are 0 ... n_way-1
        #target_inds = TENSOR WHICH HAS dimension (n_way,n_query,1), which represent the arange from 0 to n_way of a matrix 
        #of dimension (n_query,1) (c[0]= [[0,0,0]])
        target_inds = torch.arange(0, n_way).view(n_way, 1, 1).expand(n_way, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False)

        #encode images of the support and the query set
        x = torch.cat([x_support.contiguous().view(n_way * n_support, *x_support.size()[2:]),
                       x_query.contiguous().view(n_way * n_query, *x_query.size()[2:])], 0)

        z = self.encoder.forward(x)
        z_dim = z.size(-1) #usually 64
        z_proto = z[:n_way*n_support].view(n_way, n_support, z_dim).mean(1)
        z_query = z[n_way*n_support:]

        #compute distances
        dists = distance.euclidean_dist(z_query, z_proto)

        #compute probabilities
        log_p_y = F.log_softmax(-dists, dim=1).view(n_way, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        df = {'loss': loss_val.item(),'acc': acc_val.item(),'y_hat': y_hat}
        return loss_val, df
