In [None]:
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Created by: BoyuanJiang
# College of Information Science & Electronic Engineering,ZheJiang University
# Email: ginger188@gmail.com
# Copyright (c) 2017

# @Time    :17-8-29 22:26
# @FILE    :mainOmniglot.py
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++


from data_loader import TreeBankDataset
from OmniglotBuilder import OmniglotBuilder
import tqdm

In [None]:
# Experiment setup
batch_size = 2
# what is fce/full context embedding?
# fce = True
classes_per_set = 5
samples_per_class = 1
channels = 1
# Training setup
total_epochs = 100
total_train_batches = 1000
total_val_batches = 250
total_test_batches = 500
best_val_acc = 0.0

In [None]:
data = TreeBankDataset(batch_size=batch_size, classes_per_set=classes_per_set,
                            samples_per_class=samples_per_class, seed=2017, shuffle=True, use_cache=True)

In [None]:
import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
def convLayer(in_channels, out_channels, keep_prob=0.0):
    """3*3 convolution with padding,ever time call it the output size become half"""
    cnn_seq = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        nn.ReLU(True),
        nn.BatchNorm2d(out_channels),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Dropout(keep_prob)
    )
    return cnn_seq

class Classifier(nn.Module):
    def __init__(self, hidden_size, num_layers, vector_dim, output_size, use_cuda, batch_size=1):
        super(Classifier, self).__init__()
        """
        Initial a muti-layer Bidirectional LSTM
        :param layer_size: a list of each layer'size
        :param batch_size: 
        :param vector_dim: 
        """
        self.batch_size = batch_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vector_dim = vector_dim
        self.num_layers = num_layers
        self.use_cuda = use_cuda
        self.lstm = nn.LSTM(input_size=self.vector_dim, num_layers=self.num_layers, hidden_size=self.hidden_size, batch_first=True)
        self.linear = nn.Linear(self.hidden_size, output_size)
        self.hidden = self.init_hidden(self.use_cuda)

    def init_hidden(self,use_cuda):
        if use_cuda:
            return (Variable(torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size)).cuda(),
                    Variable(torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size)).cuda())
        else:
            return (Variable(torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size)),
                    Variable(torch.zeros(self.lstm.num_layers, self.batch_size, self.lstm.hidden_size)))

    def repackage_hidden(self,h):
        """Wraps hidden states in new Variables, to detach them from their history."""
        if type(h) == Variable:
            return Variable(h.data)
        else:
            return tuple(self.repackage_hidden(v) for v in h)

    def forward(self, inputs):
        self.hidden = self.repackage_hidden(self.hidden)
        #lstm input is in the shape of (sequence len, batch_size, input_dim)
#         for i in range(inputs.size()[1]):
#             #should stop when found padding char...
#             output, self.hidden = self.lstm(inputs[:,i:i+1,:], self.hidden)
#         #return the last hidden state
#         last_hidden = output[:,0,:]
        output, self.hidden = self.lstm(inputs, self.hidden)
        last_hidden = output[:,-1,:]
        return self.linear(last_hidden)

class AttentionalClassify(nn.Module):
    def __init__(self):
        super(AttentionalClassify, self).__init__()

    def forward(self, similarities, support_set_y):
        """
        Products pdfs over the support set classes for the target set image.
        :param similarities: A tensor with cosine similarites of size[batch_size,sequence_length]
        :param support_set_y:[batch_size,sequence_length,classes_num]
        :return: Softmax pdf shape[batch_size,classes_num]
        """
        softmax = nn.Softmax()
        softmax_similarities = softmax(similarities)
        preds = softmax_similarities.unsqueeze(1).bmm(support_set_y).squeeze()
        return preds

class DistanceNetwork(nn.Module):
    """
    This model calculates the cosine distance between each of the support set embeddings and the target image embeddings.
    """

    def __init__(self):
        super(DistanceNetwork, self).__init__()

    def forward(self, support_set, input_image):
        """
        forward implement
        :param support_set:the embeddings of the support set images.shape[sequence_length,batch_size,64]
        :param input_image: the embedding of the target image,shape[batch_size,64]
        :return:shape[batch_size,sequence_length]
        """
        eps = 1e-10
        similarities = []
        for support_image in support_set:
            sum_support = torch.sum(torch.pow(support_image, 2), 1)
            support_manitude = sum_support.clamp(eps, float("inf")).rsqrt()
            dot_product = input_image.unsqueeze(1).bmm(support_image.unsqueeze(2)).squeeze()
            cosine_similarity = dot_product * support_manitude
            similarities.append(cosine_similarity)
        similarities = torch.stack(similarities)
        return similarities.t()

class MatchingNetwork(nn.Module):
    def __init__(self, batch_size=32, num_lstm_hidden=100, sequence_embedding_size=100, learning_rate=1e-3, num_classes_per_set=5, \
                 num_samples_per_class=1, input_embedding_dim=300, use_cuda=True):
        """
        This is our main network
        :param batch_size:
        :param num_channels:
        :param learning_rate:
        :param fce: Flag indicating whether to use full context embeddings(i.e. apply an LSTM on the CNN embeddings)
        :param num_classes_per_set:
        :param num_samples_per_class:
        :param image_size:
        """
        super(MatchingNetwork, self).__init__()
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.num_classes_per_set = num_classes_per_set
        self.num_samples_per_class = num_samples_per_class
        #todo: customize number of layers
        self.g = Classifier(hidden_size=num_lstm_hidden, num_layers=1, vector_dim=input_embedding_dim, output_size=sequence_embedding_size, use_cuda=use_cuda)
        self.dn = DistanceNetwork()
        self.classify = AttentionalClassify()

    def forward(self, support_set_images, support_set_y_one_hot, target_image, target_y):
        """
        Main process of the network
        :param support_set_images: shape[batch_size,sequence_length,num_channels,image_size,image_size]
        :param support_set_y_one_hot: shape[batch_size,sequence_length,num_classes_per_set]
        :param target_image: shape[batch_size,num_channels,image_size,image_size]
        :param target_y:
        :return:
        """
        # produce embeddings for support set images
        encoded_images = []
        for i in np.arange(support_set_images.size(1)):
            gen_encode = self.g(support_set_images[:, i, :])
            encoded_images.append(gen_encode)

        # produce embeddings for target images
        gen_encode = self.g(target_image)
        encoded_images.append(gen_encode)
        output = torch.stack(encoded_images)

        # get similarities between support set embeddings and target
        similarities = self.dn(support_set=output[:-1], input_image=output[-1])

        # produce predictions for target probabilities
        preds = self.classify(similarities, support_set_y=support_set_y_one_hot)

        # calculate the accuracy
        values, indices = preds.max(1)
        accuracy = torch.mean((indices.squeeze() == target_y).float())
        crossentropy_loss = F.cross_entropy(preds, target_y.long())

        return accuracy, crossentropy_loss

In [None]:
net = MatchingNetwork(batch_size=1, use_cuda=False)

In [None]:
#TODO !!! PROPER embedding
embed = nn.Embedding(len(data.word_to_idx), 300)

In [None]:
def one_hot(y, classes):
    b = np.zeros([*y.shape,classes])
    for i in range(len(b)):
        b[i, np.arange(y[i].shape[0]), y[i]] = 1
    return b

In [None]:
support_x, support_y, target_x, target_y =  data.get_train_batch()

shape = support_x.shape

embed_support_x = embed(Variable(torch.from_numpy(support_x.reshape(shape[0]*shape[1],-1)).long()))

embed_support_x = embed_support_x.resize(shape[0],shape[1],shape[2],300)

embed_target_x = embed(Variable(torch.from_numpy(target_x).long()))

target_y = Variable(torch.from_numpy(target_y), requires_grad=False).squeeze().long()

support_y = Variable(torch.from_numpy(one_hot(support_y, classes_per_set)).float())

In [None]:
net(embed_support_x, support_y, embed_target_x, target_y)