In [209]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
import cv2
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F

In [210]:
trainframe=pd.read_csv("train_data.csv")
train_classes=np.max(trainframe['Label'])
images_per_class=20

In [130]:
# class OmniTask():
#     def __init__(self,labels,size):
#         self.labels=labels
#         self.size=size
#     def sample_data(self):
#         labelframe=trainframe.iloc[(self.label*samples_per_class):((self.label+1)*samples_per_class)]
#         labelframe.reset_index(inplace=True,drop=True)
#         rand_samples=np.random.choice(samples_per_class,replace=True,size=self.size)
#         images=[]
#         labels=[]
#         for sample in rand_samples:
#             img_path=labelframe.iloc[s]['Path']
#             img_label=labelframe.iloc[s]['Label']
#             img=cv2.imread(img_path)
#             images.append(img)
#             labels.append(img_label)
#         images=torch.stack([torch.tensor(images[i]) for i in range(len(images))])
#         images=images.float()
#         labels=torch.FloatTensor(labels)
#         miniset=torch.utils.data.TensorDataset(images,labels)
#         return miniset        

In [211]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [325]:
class Task():
    def __init__(self,sample_classes,num_instances):
        self.sample_classes=sample_classes
        self.num_instances=num_instances
    def sample_data(self):
    #num_classes=5/20 for 5/20 way classification. num_instances=k for 'k shot' learning
    #sample_classes=np.random.choice(train_classes,replace=True,size=num_classes)
    #for each of these classes, get num_instances number of samples
        label=0
        images=[]
        labels=[]
        for c in self.sample_classes:
            cframe=trainframe.iloc[(c*images_per_class):((c+1)*images_per_class)]
            cframe.reset_index(inplace=True,drop=True)
            sample_idxs=np.random.choice(images_per_class,replace=True,size=self.num_instances)
            for s in sample_idxs:
                img_path=labelframe.iloc[s]['Path']
                img_label=label
                img=cv2.imread(img_path,0)
                img=cv2.resize(img,dsize=(28,28))
                #img=np.array(img)
                img=img[...,np.newaxis]
                img=np.transpose(img)
                images.append(img)
                labels.append(img_label)
            label+=1
        images=torch.stack([torch.tensor(images[i]) for i in range(len(images))])
        images=images.float()
        labels=torch.Tensor(labels)
        labels=labels.long()
        miniset=torch.utils.data.TensorDataset(images,labels)
        return miniset

class TaskDistribution():
    def __init__(self,num_classes,num_instances):
        self.num_classes=num_classes
        self.num_instances=num_instances
    def sample_task(self):
        sample_classes=np.random.choice(train_classes,replace=True,size=self.num_classes)
        return Task(sample_classes,self.num_instances)
        

In [348]:
class ConvNet(nn.Module):
    def __init__(self,num_classes):
        super(ConvNet,self).__init__()
        self.features = nn.Sequential(OrderedDict([
                ('conv1', nn.Conv2d(1, 64, 3)),
                ('bn1', nn.BatchNorm2d(64, momentum=1, affine=True)),
                ('relu1', nn.ReLU(inplace=True)),
                ('pool1', nn.MaxPool2d(2,2)),
                ('conv2', nn.Conv2d(64,64,3)),
                ('bn2', nn.BatchNorm2d(64, momentum=1, affine=True)),
                ('relu2', nn.ReLU(inplace=True)),
                ('pool2', nn.MaxPool2d(2,2)),
                ('conv3', nn.Conv2d(64,64,3)),
                ('bn3', nn.BatchNorm2d(64, momentum=1, affine=True)),
                ('relu3', nn.ReLU(inplace=True)),
                ('pool3', nn.MaxPool2d(2,2))]))    
        self.add_module('fc', nn.Linear(64,num_classes))
    
    def forward(self,x):
        x=self.net(x)
        x=x.view(-1,64)
        x=self.fc(x)
        return x
    
    def argforward(self,x,weights):
        x = F.conv2d(x, weights['features.conv1.weight'], weights['features.conv1.bias'])
        x = F.batch_norm(x, weight = weights['features.bn1.weight'], bias = weights['features.bn1.bias'], momentum=1,
                        running_mean = torch.zeros(np.prod(np.array(x.data.size()[1]))).cuda(),
                        running_var = torch.ones(np.prod(np.array(x.data.size()[1]))).cuda())
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2) 
        x = F.conv2d(x, weights['features.conv2.weight'], weights['features.conv2.bias'])
        x = F.batch_norm(x, weight = weights['features.bn2.weight'], bias = weights['features.bn2.bias'], momentum=1,
                        running_mean = torch.zeros(np.prod(np.array(x.data.size()[1]))).cuda(),
                         running_var = torch.ones(np.prod(np.array(x.data.size()[1]))).cuda())
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2) 
        x = F.conv2d(x, weights['features.conv3.weight'], weights['features.conv3.bias'])
        x = F.batch_norm(x, weight = weights['features.bn3.weight'], bias = weights['features.bn3.bias'], momentum=1,
                        running_mean = torch.zeros(np.prod(np.array(x.data.size()[1]))).cuda(),
                         running_var = torch.ones(np.prod(np.array(x.data.size()[1]))).cuda())
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2, stride=2) 
        x = x.view(x.size(0), 64)
        x = F.linear(x, weights['fc.weight'], weights['fc.bias'])
        return x


In [327]:
def get_weights(keys,params):
    temp=[w.clone() for w in list(params.values())]
    l=[]
    for s,t in zip(keys,temp):
        l.append((s,t))
    dic=OrderedDict((s,t) for (s,t) in l)
    return dic

In [345]:
class OmniMAML(): 
    def __init__(self,net,alpha,beta,k,num_metatasks,N):
        self.net=net
        self.alpha=alpha
        self.beta=beta
        self.k=k #"k shot" classification
        self.num_metatasks=num_metatasks
        self.weights=OrderedDict((name, param) for (name,param) in self.net.named_parameters())
        self.criterion=nn.CrossEntropyLoss()
        self.optimiser=torch.optim.Adam(list(self.weights.values()),self.beta)
        self.meta_losses=[]
        self.plot_every=10
        self.print_every=500
        self.num_metataks=num_metatasks
        self.N=N #"N way" classification

         
    def inner_loop(self,task):
        temp=[w.clone() for w in list(self.weights.values())]
        temp_weights=OrderedDict((name,param) for (name,param) in 
                                 zip(list(self.weights.keys()),temp))
        dset=task.sample_data()
        loader=DataLoader(dset,batch_size=self.k*self.N,shuffle=True)
        x,y=loader.__iter__().next()
        x=x.to(device)
        y=y.to(device)
        output=self.net.argforward(x,temp_weights)
        loss=self.criterion(output,y)/(self.k*self.N)
        grads=torch.autograd.grad(loss,list(temp_weights.values()))
        items=get_weights(keys=list(self.weights.keys()),params=temp_weights)
        temp_weights=OrderedDict((name,param-self.alpha*g) for ((name,param),g) in 
                                 zip(items.items(),grads))
        #temp_weights=OrderedDict((name,param-self.alpha*g) for ((name,param),g) in 
        #                         zip(zip(list(self.weights.keys()),temp)),grads)
        z=zip(list(self.weights.keys()),temp)
        dset=task.sample_data()
        loader=DataLoader(dset,batch_size=self.k*self.N,shuffle=True)
        x,y=loader.__iter__().next()
        x=x.to(device)
        y=y.to(device)
        output=self.net.argforward(x,temp_weights)
        metaloss=self.criterion(output,y)/(self.k*self.N)
        return metaloss
    
    def final_loop(self,num_epochs):
        total_loss=0
        for epoch in range(1,num_epochs+1):
            tasks=TaskDistribution(num_classes=self.N,num_instances=self.k)
            metaloss_sum=0
            for i in range(self.num_metatasks):
                task=tasks.sample_task()
                metaloss=self.inner_loop(task)
                metaloss_sum+=metaloss
            metagrads=torch.autograd.grad(metaloss_sum,list(self.weights.values()))
            for w,g in zip(list(self.weights.values()),metagrads):
                w.grad=g
            self.optimiser.step()
            total_loss+=metaloss_sum.item()/self.num_metatasks
            if epoch % self.print_every == 0:
                print("{}/{}. loss: {}".format(epoch, num_epochs, total_loss / self.plot_every))
            if epoch%self.plot_every==0:
                self.meta_losses.append(total_loss/self.plot_every)
                total_loss = 0
            if (epoch%100)==0:
                print("Epoch "+str(epoch)+" completed.")

In [346]:
N=5
net=ConvNet(num_classes=N)
net.to(device)
maml=OmniMAML(net,alpha=0.1,beta=0.001,k=3,num_metatasks=32,N=N)

In [347]:
maml.final_loop(num_epochs=1)

In [335]:
net=ConvNet(5)
weights=OrderedDict((name, param) for (name,param) in net.named_parameters())
weights.items()

odict_items([('features.conv1.weight', Parameter containing:
tensor([[[[ 1.9403e-01, -2.3679e-01,  1.5356e-02],
          [ 1.7192e-01, -1.7619e-01, -5.1181e-02],
          [ 2.6803e-01,  1.4945e-01, -2.5490e-01]]],


        [[[ 1.5718e-01,  2.5654e-01, -6.4612e-02],
          [ 1.2119e-01, -1.5504e-01, -2.9659e-01],
          [ 1.9620e-02, -1.9034e-01,  9.5991e-03]]],


        [[[ 2.3084e-01,  1.7164e-01,  2.7283e-02],
          [ 2.6726e-01,  7.3758e-02,  5.6239e-02],
          [ 1.4174e-01,  1.0790e-02,  9.6042e-02]]],


        [[[-1.6255e-01, -3.1555e-01, -1.8513e-01],
          [-1.3194e-01, -3.2913e-01,  2.7115e-01],
          [-2.6304e-02, -5.0627e-02,  1.5131e-01]]],


        [[[-1.6226e-01,  2.4513e-02, -2.6430e-01],
          [-5.9958e-02,  8.8895e-02,  1.0743e-01],
          [-2.3213e-01,  1.1048e-01,  2.3489e-01]]],


        [[[ 7.4756e-03,  1.4024e-01, -2.9164e-01],
          [-1.3479e-01, -1.9207e-01, -2.0314e-01],
          [-2.7353e-01,  2.6142e-01, -2.3706e-01]]],

In [338]:
temp=[w.clone() for w in list(weights.values())]
l=[]
for s,t in zip(weights.keys(),temp):
    l.append((s,t))
dic=OrderedDict((s,t) for (s,t) in l)
dic.items()

odict_items([('features.conv1.weight', tensor([[[[ 1.9403e-01, -2.3679e-01,  1.5356e-02],
          [ 1.7192e-01, -1.7619e-01, -5.1181e-02],
          [ 2.6803e-01,  1.4945e-01, -2.5490e-01]]],


        [[[ 1.5718e-01,  2.5654e-01, -6.4612e-02],
          [ 1.2119e-01, -1.5504e-01, -2.9659e-01],
          [ 1.9620e-02, -1.9034e-01,  9.5991e-03]]],


        [[[ 2.3084e-01,  1.7164e-01,  2.7283e-02],
          [ 2.6726e-01,  7.3758e-02,  5.6239e-02],
          [ 1.4174e-01,  1.0790e-02,  9.6042e-02]]],


        [[[-1.6255e-01, -3.1555e-01, -1.8513e-01],
          [-1.3194e-01, -3.2913e-01,  2.7115e-01],
          [-2.6304e-02, -5.0627e-02,  1.5131e-01]]],


        [[[-1.6226e-01,  2.4513e-02, -2.6430e-01],
          [-5.9958e-02,  8.8895e-02,  1.0743e-01],
          [-2.3213e-01,  1.1048e-01,  2.3489e-01]]],


        [[[ 7.4756e-03,  1.4024e-01, -2.9164e-01],
          [-1.3479e-01, -1.9207e-01, -2.0314e-01],
          [-2.7353e-01,  2.6142e-01, -2.3706e-01]]],


        [[[ 2.6504e