In [None]:
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from scipy import ndimage
import cv2
import multiprocessing as mp
import torch
import torch.nn as nn
import torchvision
from torch.autograd import Variable
import random
import os
import torch.nn.functional as F
from tqdm import tqdm

# DATA 

In [None]:
def read_alphabets(alphabet_directory_path, alphabet_directory_name):
    """
    Reads all the characters from a given alphabet_directory
    """
    datax = []
    datay = []
    characters = os.listdir(alphabet_directory_path)
    for character in characters:
        images = os.listdir(alphabet_directory_path + character + '/')
        for img in images:
            print(alphabet_directory_path + character + '/' + img)
            image = cv2.resize(cv2.imread(alphabet_directory_path + character + '/' + img),(500,500))
            #rotations of image
            rotated_90 = ndimage.rotate(image, 90)
            rotated_180 = ndimage.rotate(image, 180)
            rotated_270 = ndimage.rotate(image, 270)
            # crop image 

            print(image.shape)
            #start_y = random.randint(0, image.shape[0])
            #dim_y = random.randint(0,image.shape[0])
            #start_x = random.randint(0,image.shape[1])
            #dim_x = random.randint(0,image.shape[1])
            #temp = image[start_y:min(start_y+dim_y,image.shape[0] - 1),start_x:min(start_x+dim_x,image.shape[1]-1),:]
            #temp = cv2.resize(temp,(image.shape[1],image.shape[0]),interpolation=cv2.INTER_CUBIC)
            # end crop image
            datax.extend((image, rotated_90, rotated_180, rotated_270))
            datay.extend((character,character,character,character))
    return np.array(datax), np.array(datay)

def read_images(base_directory):
    """
    Reads all the alphabets from the base_directory
    """
    datax = None
    datay = None
    results = []
    for directory in os.listdir(base_directory):
        results.append(read_alphabets(base_directory + '/' + directory + '/',directory))
    for result in results:
        if datax is None:
            datax = result[0]
            datay = result[1]
        else:
            datax = np.vstack([datax, result[0]])
            datay = np.concatenate([datay, result[1]])
    return datax, datay




def extract_sample(n_way, n_support, n_query, datax, datay):
    sample = []
    unique_y = np.unique(datay)
    K = np.random.choice(unique_y, n_way, replace = False)
    for cls in K:
        datax_cls = datax[datay==cls]
        perm = np.random.permutation(datax_cls)
        sample_cls = perm[:(n_support + n_query)]
        sample.append(sample_cls)
    # sample in the end will be a matrix of dimension  k X n_support + n_query
    
    sample = np.array(sample) #become a np array of array (matrix)
    sample = torch.from_numpy(sample).float() # become a tensor
    sample = sample.permute(0,1,4,2,3) 
    df = {'images': sample, 'n_way': n_way,'n_support': n_support,'n_query': n_query}
    return df


def display_sample(sample):
    """
    Displays sample in a grid
    Args:
    sample (torch.Tensor): sample of images to display
    """
    #need 4D tensor to create grid, currently 5D
    sample_4D = sample.view(sample.shape[0]*sample.shape[1],*sample.shape[2:])
    #make a grid
    out = torchvision.utils.make_grid(sample_4D, nrow=sample.shape[1])
    plt.figure(figsize = (16,7))
    plt.imshow(out.permute(1, 2, 0))



In [None]:
def euclidean_dist(x, y):
    """
    Computes euclidean distance btw x and y
    Args:
        x (torch.Tensor): shape (n, d). n usually n_way*n_query
        y (torch.Tensor): shape (m, d). m usually n_way
    Returns:
        torch.Tensor: shape(n, m). For each query, the distances to each centroid
    """
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    assert d == y.size(1)
    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)
    return torch.pow(x - y, 2).sum(2)

# TRAIN AND TEST

In [None]:
def train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch,epoch_size,PATH = "model/protonet.pt"):
    scheduler = optim.lr_scheduler.StepLR(optimizer,1,gamma = 0.5, last_epoch = -1)
    epoch = 0
    while(epoch < max_epoch):
        running_loss = 0.0
        running_acc = 0.0
        
        for episode in tqdm(range(epoch_size)):
            sample = extract_sample(n_way, n_support, n_query, train_x, train_y)
            optimizer.zero_grad()
            loss, output = model.set_forward_loss(sample)
            running_loss += output['loss']
            running_acc += output['acc']
            loss.backward()
            optimizer.step()
        epoch_loss = running_loss / epoch_size
        epoch_acc = running_acc / epoch_size
        print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))
        epoch += 1
        scheduler.step()

        

        
        
def test(model, test_x, test_y, n_way, n_support, n_query, test_episode):
    """
    Tests the protonet
    Args:
        model: trained model
        test_x (np.array): images of testing set
        test_y (np.array): labels of testing set
        n_way (int): number of classes in a classification task
        n_support (int): number of labeled examples per class in the support set
        n_query (int): number of labeled examples per class in the query set
        test_episode (int): number of episodes to test on
      """
    running_loss = 0.0
    running_acc = 0.0
    for episode in tqdm(range(test_episode)):
        sample = extract_sample(n_way, n_support, n_query, test_x, test_y)
        loss, output = model.set_forward_loss(sample)
        running_loss += output['loss']
        running_acc += output['acc']
    avg_loss = running_loss / test_episode
    avg_acc = running_acc / test_episode
    print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))
    

# PROTONET CLASS

In [None]:
class Flatten(nn.Module):
    """Converts N-dimensional Tensor of shape [batch_size, d1, d2, ..., dn] to 2-dimensional Tensor
    of shape [batch_size, d1*d2*...*dn].
    # Arguments
        input: Input tensor
    """
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        return x.view(x.size(0), -1)

    
    
def conv_block(in_channels, out_channels):
    return nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
                         nn.BatchNorm2d(out_channels),nn.ReLU(),nn.MaxPool2d(2))

def load_protonet_conv(x_dim, hid_dim, z_dim):
    """
    Loads the prototypical network model
    Arg:
    x_dim (tuple): dimension of input image
    hid_dim (int): dimension of hidden layers in conv blocks
    z_dim (int): dimension of embedded image
    Returns:
    Model (Class ProtoNet)
      """


    
    encoder = nn.Sequential(conv_block(x_dim[0], hid_dim),conv_block(hid_dim, hid_dim),
                            conv_block(hid_dim, hid_dim),conv_block(hid_dim, z_dim),Flatten())
    
    return Protonet(encoder)




class Protonet(nn.Module):
    
    def __init__(self, encoder):
        super(Protonet,self).__init__()
        self.encoder = encoder
    
    def set_forward_loss(self, sample):
        """
        Computes loss, accuracy and output for classification task
        Args:
            sample (torch.Tensor): shape (n_way, n_support+n_query, (dim)) 
        Returns:
            torch.Tensor: shape(2), loss, accuracy and y_hat
        """
        sample_images = sample['images']
        n_way = sample['n_way']
        n_support = sample['n_support']
        n_query = sample['n_query']

        x_support = sample_images[:, :n_support]
        x_query = sample_images[:, n_support:]

        #target indices are 0 ... n_way-1
        #target_inds = TENSOR WHICH HAS dimension (n_way,n_query,1), which represent the arange from 0 to n_way of a matrix 
        #of dimension (n_query,1) (c[0]= [[0,0,0]])
        target_inds = torch.arange(0, n_way).view(n_way, 1, 1).expand(n_way, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False)

        #encode images of the support and the query set
        x = torch.cat([x_support.contiguous().view(n_way * n_support, *x_support.size()[2:]),
                       x_query.contiguous().view(n_way * n_query, *x_query.size()[2:])], 0)

        z = self.encoder.forward(x)
        z_dim = z.size(-1) #usually 64
        z_proto = z[:n_way*n_support].view(n_way, n_support, z_dim).mean(1)
        z_query = z[n_way*n_support:]

        #compute distances
        dists = euclidean_dist(z_query, z_proto)

        #compute probabilities
        log_p_y = F.log_softmax(-dists, dim=1).view(n_way, n_query, -1)

        loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
        _, y_hat = log_p_y.max(2)
        acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()

        df = {'loss': loss_val.item(),'acc': acc_val.item(),'y_hat': y_hat}
        return loss_val, df


In [None]:
%%time 
trainx, trainy = read_images('../Licheni')
testx, testy = read_images('../lich')

In [None]:
%%time
model = load_protonet_conv(x_dim=(3,28,28),hid_dim=64,z_dim=64)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
n_way = 3
n_support = 2
n_query = 2

train_x = trainx
train_y = trainy

max_epoch = 5
epoch_size = 100

train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size)

# TEST

In [None]:
n_way = 3
n_support = 2
n_query = 2
testx, testy = read_images('../lich')
test_x = testx
test_y = testy

test_episode = 100

test(model, test_x, test_y, n_way, n_support, n_query, test_episode)

# TEST ON SPECIFIC DATA

In [None]:
n_way = 3
n_support = 4
n_query = 4
my_sample = extract_sample(n_way, n_support, n_query, test_x, test_y)
display_sample(my_sample['images'])


In [None]:
%%time
my_loss, my_output = model.set_forward_loss(my_sample)