In [0]:
import os
import numpy as np
import random
import torch
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data.sampler import Sampler
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import scipy as sp
import scipy.stats
import math
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 


%matplotlib inline

In [0]:
def get_data_folders():
    ''' This Function returns the training and testing folders containing the respective classes'''

    train_path =  '/content/drive/My Drive/Few Shot Painting Classification/Train'
    test_path = '/content/drive/My Drive/Few Shot Painting Classification/Test'

    train_folders = [os.path.join(train_path,label)  for label in os.listdir(train_path) if os.path.isdir(os.path.join(train_path, label))]

    test_folders = [os.path.join(test_path,label)  for label in os.listdir(test_path) if os.path.isdir(os.path.join(test_path, label)) ]

    random.seed(42)
    random.shuffle(train_folders)
    random.shuffle(test_folders)

    return train_folders , test_folders


In [0]:
class Task_PaintingsData(object):
    
    ''' This function samples the training and testing files along with the corresponding labels.
    Arguments:-
    class_folders : The folders containing your training or testing classes.
    num_classes : The number of classes you want to sample from all the classes i.e N-way
    train_num : The number of total samples in your training set i.e sample set (K-shot)
    test_num : The number of samples in test set ie. Query set'''

    def __init__(self , class_folders , num_classes,train_num,test_num):

        
        self.class_folders = class_folders
        self.num_classes = num_classes
        self.train_num = train_num
        self.test_num = test_num
        #self.classes =  [x.split('/')[-1] for x in self.class_folders]

        way_class_folders = random.sample(self.class_folders , self.num_classes)
        labels = np.array(range(len(way_class_folders)))
        labels = dict(zip(way_class_folders , labels))
        samples = dict()

        self.train_roots , self.test_roots = [] , []

        for c in way_class_folders:

            files = [os.path.join(c,x) for x in os.listdir(c)]
            samples[c] = random.sample(files , len(files))
            random.shuffle(samples[c])

            self.train_roots += samples[c][:train_num]
            self.test_roots += samples[c][train_num : train_num + test_num]
        #print(labels)
        self.train_labels = [labels[self.GetClass(x)] for x in self.train_roots]
        self.test_labels  = [labels[self.GetClass(x)] for x in self.test_roots]

    def GetClass(self, sample):

        return os.path.join('/'.join(sample.split('/')[:-1]))

In [0]:
class FewShotData(torch.utils.data.Dataset):

    ''' This function is custom dataset class which will be required by dataloader.
        The class returns the image and the corresponding label.
        Argumnets:-
        Task : It is the Task_PaintingsData object.
        split : Whether to sample from test or train set.
        transform : Transformations for the images
        target_transform : Transformations for the labels'''

    def __init__(self,task,split = 'train', transform = None, target_transform = None):

        self.task = task
        self.transform = transform
        self.split = split
        self.target_transform = target_transform
        self.image_roots = self.task.train_roots if self.split == 'train' else self.task.test_roots
        self.image_labels = self.task.train_labels if self.split == 'train' else self.task.test_labels

    def __len__(self):

        return len(self.image_roots)

    def __getitem__(self,idx):

        image_root = self.image_roots[idx]
        img = Image.open(image_root)
        if self.transform is not None:
            img = self.transform(img)
        label = self.image_labels[idx]
        if self.target_transform is not None:
            label = self.target_transform(label)
        
        return img , label

    


In [0]:
class FewShotBatchSampler(Sampler):

    ''' This class provides a custom sampler which is required to construct the few shot episode.
        Arguments:-
        num_per_class : The number of samples per class.
        num_instances : Total instances for that task.
        num_class : The number of classes from which the examples are to be taken.
        shuffle : Whether to shuffle the files or not.'''
        
    def __init__(self,num_per_class , num_instances , num_classes , shuffle = True):

        self.num_per_class = num_per_class
        self.num_instances = num_instances
        self.num_classes = num_classes
        self.shuffle = shuffle

    def __iter__(self):

        if self.shuffle:

            batch = [[i+j*self.num_instances for i in torch.randperm(self.num_instances)[:self.num_per_class]] for j in range(self.num_classes)]
        
        else:
            batch = [[i+j*self.num_instances for i in range(self.num_instances)[:self.num_per_class]] for j in range(self.num_classes)]

        batch = [item for sublist in batch for item in sublist]

        if self.shuffle:
            random.shuffle(batch)
        return iter(batch)

    def __len__(self):
        return 1

In [0]:
def get_few_shot_dataloaders(task,num_per_class = 1 , shuffle = False , split = 'train'):

    '''This functions returns the dataloader which will be required for iterating throught the dataset'''

    transform = transforms.Compose([transforms.Resize((128,128) , interpolation = Image.NEAREST) ,
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426]) ,
                                    ])
    
    data = FewShotData(task , split, transform)

    if split == 'train':

        sampler = FewShotBatchSampler(num_per_class,
                                      num_classes = task.num_classes,
                                      num_instances =task.train_num ,
                                      shuffle = shuffle)
    else:

        sampler = FewShotBatchSampler(num_per_class,
                                      num_classes = task.num_classes,
                                      num_instances = task.test_num ,
                                      shuffle = shuffle)
    
    loader = torch.utils.data.DataLoader(data , batch_size = num_per_class*task.num_classes, sampler = sampler)
    return loader


    

In [0]:
#Defining the models

class EmbeddingModule(nn.Module):
    ''' This class implements the embedding module'''

    def __init__(self):

        super(EmbeddingModule , self).__init__()

        self.layer1 = nn.Sequential(nn.Conv2d(3,64, kernel_size = 3 , padding = 0),
                                    nn.BatchNorm2d(64 , momentum = 1 , affine = True),
                                    nn.ReLU(),
                                    nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(nn.Conv2d(64,64, kernel_size = 3, padding = 0),
                                    nn.BatchNorm2d(64, momentum = 1, affine = True),
                                    nn.ReLU(),
                                    nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(nn.Conv2d(64,64,kernel_size = 3, padding = 1),
                                    nn.BatchNorm2d(64, momentum = 1, affine = True),
                                    nn.ReLU())
        
        self.layer4 = nn.Sequential(nn.Conv2d(64,64,kernel_size = 3, padding = 1),
                                    nn.BatchNorm2d(64 , momentum = 1, affine = True),
                                    nn.ReLU())
    
    def forward(self,x):

        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)

        return out

        

In [0]:
class RelationNet(nn.Module):
    ''' This class implements the relation module'''

    def __init__(self , input_size , hidden_size):

        super(RelationNet , self).__init__()

        self.layer1 = nn.Sequential(nn.Conv2d(128,64,kernel_size = 3 , padding = 0),
                                    nn.BatchNorm2d(64),
                                    nn.ReLU(),
                                    nn.MaxPool2d(2))
        
        self.layer2 = nn.Sequential(nn.Conv2d(64,64,kernel_size = 3 , padding = 0),
                                    nn.BatchNorm2d(64),
                                    nn.ReLU(),
                                    nn.MaxPool2d(2))


        self.fc1 = nn.Linear(input_size*6*6 , hidden_size)
        self.fc2 = nn.Linear(hidden_size , 1)
    
    def forward(self , x):

        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0),-1)
        out = F.relu(self.fc1(out))
        out = F.sigmoid(self.fc2(out))
        return out

In [0]:
def weights_init(m):
    '''Weights initializer for the networks'''

    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0*np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
    return m,h

In [0]:
def train_networks():

    FEATURE_DIM = 64
    RELATION_DIM = 8
    CLASS_NUM = 5
    SAMPLE_NUM_PER_CLASS = 1
    BATCH_NUM_PER_CLASS = 15
    EPISODE = 5000
    TEST_EPISODE = 200
    LEARNING_RATE = 0.001
    
    
    
    train_folders , test_folders = get_data_folders()
    feature_enc = EmbeddingModule()
    relation_model = RelationNet(FEATURE_DIM , RELATION_DIM)

    feature_enc.apply(weights_init)
    relation_model.apply(weights_init)

    if torch.cuda.is_available:
        #torch.cuda.empty_cache()
        feature_enc.cuda()
        relation_model.cuda()

    feature_enc_optim = torch.optim.Adam(feature_enc.parameters() , lr = LEARNING_RATE)
    feature_enc_scheduler = StepLR(feature_enc_optim , step_size = 200 , gamma = 0.5)
    relation_model_optim = torch.optim.Adam(relation_model.parameters() , lr = LEARNING_RATE)
    relation_model_scheduler = StepLR(relation_model_optim , step_size = 200 , gamma = 0.5)
    print("Training...")
    last_acc = 0.0
    train_loss = []
    for ep in range(EPISODE):

        feature_enc_scheduler.step(ep)
        relation_model_scheduler.step(ep)

        task = Task_PaintingsData(train_folders , CLASS_NUM, SAMPLE_NUM_PER_CLASS , BATCH_NUM_PER_CLASS)

        sample_dataloader = get_few_shot_dataloaders(task , num_per_class = SAMPLE_NUM_PER_CLASS , shuffle = False, split = 'train')
        batch_dataloader = get_few_shot_dataloaders(task , num_per_class = BATCH_NUM_PER_CLASS , shuffle = True , split = 'test')

        samples , sample_labels = sample_dataloader.__iter__().next()
        #print(samples.shape)

        batches , batch_labels = batch_dataloader.__iter__().next()
        #print(batches.shape)
        
        sample_features = feature_enc(Variable(samples).cuda())
        #print("SAMPLE FEATURES" ,sample_features.shape)
        batch_features = feature_enc(Variable(batches).cuda())
        #print('BATCH_FEATURES',batch_features.shape)
        #print('Concatenating Features')
        ext_sample_features = sample_features.unsqueeze(0).repeat(BATCH_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
        #print('EXT_SAMPLE_FEATURES', ext_sample_features.shape)
        ext_batch_features = batch_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1)
        #print("Befor Trans",ext_batch_features.shape)
        ext_batch_features = torch.transpose(ext_batch_features,0,1)
        #print("EXTENDED BATCH FEATURES" ,ext_batch_features.shape)
        #print(ext_batch_features[0])

        

        relation_pairs = torch.cat((ext_sample_features , ext_batch_features) , 2).view(-1,FEATURE_DIM*2,30,30)
        #print(relation_pairs.shape)
        #break
        relations = relation_model(relation_pairs).view(-1,CLASS_NUM*SAMPLE_NUM_PER_CLASS)
        #print(relations.shape)
        
        mse = nn.MSELoss()

        one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1,1), 1)).cuda()
        loss = mse(relations, one_hot_labels)

        feature_enc.zero_grad()
        relation_model.zero_grad()

        loss.backward()
        train_loss.append(loss.item())

        torch.nn.utils.clip_grad_norm_(feature_enc.parameters(),0.5)
        torch.nn.utils.clip_grad_norm_(relation_model.parameters(),0.5)

        feature_enc_optim.step()
        relation_model_optim.step()

        if (ep + 1)%10 == 0:

            print(f'EPISODE :{ep + 1} || LOSS : {loss.item()}')

        if (ep+1)%50 == 0:

            print('Testing')
            accuracies = []
            for i in range(TEST_EPISODE):

                total_rewards = 0
                counter = 0
                task = Task_PaintingsData(test_folders,CLASS_NUM,1,10)
                sample_dataloader = get_few_shot_dataloaders(task,num_per_class=1,split="train",shuffle=False)

                num_per_class = 3
                test_dataloader = get_few_shot_dataloaders(task,num_per_class=num_per_class,split="test",shuffle=True)
                sample_images,sample_labels = sample_dataloader.__iter__().next()
                for test_images,test_labels in test_dataloader:
                    batch_size = test_labels.shape[0]
                    # calculate features
                    sample_features = feature_enc(Variable(sample_images).cuda()) # 5x64
                    test_features = feature_enc(Variable(test_images).cuda()) # 20x64

                    
                    sample_features_ext = sample_features.unsqueeze(0).repeat(batch_size,1,1,1,1)
                    test_features_ext = test_features.unsqueeze(0).repeat(1*CLASS_NUM,1,1,1,1)
                    test_features_ext = torch.transpose(test_features_ext,0,1)
                    relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,30,30)
                    relations = relation_model(relation_pairs).view(-1,CLASS_NUM)

                    _,predict_labels = torch.max(relations.data,1)

                    rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(batch_size)]

                    total_rewards += np.sum(rewards)
                    counter += batch_size
                accuracy = total_rewards/1.0/counter
                accuracies.append(accuracy)

            test_accuracy,_= mean_confidence_interval(accuracies)

            print("Test Accuracy:",test_accuracy)

            if test_accuracy > last_acc:

                # save networks
                torch.save(feature_enc.state_dict(),str("paintings_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))
                torch.save(relation_model.state_dict(),str("paintings_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))

                print("save networks for episode:",ep)

                last_acc = test_accuracy
        
        

    return train_loss,accuracies


In [0]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()
loss,accuracy = train_networks()

Training...




EPISODE :10 || LOSS : 0.24950967729091644
EPISODE :20 || LOSS : 0.16387146711349487
EPISODE :30 || LOSS : 0.15532521903514862
EPISODE :40 || LOSS : 0.15910230576992035
EPISODE :50 || LOSS : 0.14458002150058746
Testing
Test Accuracy: 0.34166666666666673
save networks for episode: 49
EPISODE :60 || LOSS : 0.13951340317726135
EPISODE :70 || LOSS : 0.1550045609474182
EPISODE :80 || LOSS : 0.15388785302639008
EPISODE :90 || LOSS : 0.13562358915805817
EPISODE :100 || LOSS : 0.1595572978258133
Testing
Test Accuracy: 0.336
EPISODE :110 || LOSS : 0.16376039385795593
EPISODE :120 || LOSS : 0.15094415843486786
EPISODE :130 || LOSS : 0.1493663489818573
EPISODE :140 || LOSS : 0.14814560115337372
EPISODE :150 || LOSS : 0.14552101492881775
Testing
Test Accuracy: 0.3516666666666666
save networks for episode: 149
EPISODE :160 || LOSS : 0.14619728922843933
EPISODE :170 || LOSS : 0.1598415970802307
EPISODE :180 || LOSS : 0.1365499496459961
EPISODE :190 || LOSS : 0.16595694422721863
EPISODE :200 || LOSS :