In [1]:
import os
import pickle
import numpy as np
import random
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image

import torch
import torchnet as tnt

#modify the data root
_MINI_IMAGENET_DATASET_DIR = '../datasets/MiniImagenet'

In [None]:
def load_data(file):
    with open(file,'rb') as f:
        data=pickle.load(f,encoding='iso-8859-1')
    return data

def buildLabelIndex(labels):
    label2inds={}
    for idx,label in enumerate(labels):
        if label not in label2inds:
            label2inds[label]=[]
        label2inds[label].append(idx)
    return label2inds

In [None]:
class MiniImageNet(data.Dataset):
    def __init__(self,phase='train',do_not_use_random_transf=False):
        self.base_folder='miniImagenet'
        assert(phase=='train' or phase=='val' or phase=='test')
        self.phase=phase
        self.name='MiniImageNet_'+phase
        
        print('Loading mini ImageNet dataser - phase {0}'.format(phase))
        file_train_categories_train_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,
                                                        'miniImageNet_category_split_train_phase_train.pickle')
        file_train_categories_val_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,
                                                      'miniImageNet_category_split_train_phase_val.pickle')
        file_train_categories_test_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,
                                                       'miniImageNet_category_split_train_phase_test.pickle')
        file_val_categories_val_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,
                                                    'miniImageNet_category_split_val.pickle')
        file_test_categories_test_phase = os.path.join(_MINI_IMAGENET_DATASET_DIR,
                                                      'miniImageNet_category_split_test.pickle')
        
        if self.phase=='train':
            #during training phase we only load the training phase images of the training category
            data_train=load_data(file_train_categories_train_phase)
            self.data=data_train['data'] #array (n,84,84,3)
            self.labels=data_train['labels'] #list[n]
            
            self.label2ind=buildLabelIndex(self.labels)
            self.labelIds=sorted(self.label2ind.keys())
            self.num_cats=len(self.labelIds)
            self.labelIds_base=self.labelIds
            self.num_cats_base=len(self.labelIds_base)
        elif self.phase=='val' or self.phase=='test':
            if self.phase=='test':
                data_base=load_data(file_train_categories_test_phase)
                data_novel=load_data(file_test_categories_test_phase)
            else:
                data_base=load_data(file_train_categories_val_phase)
                data_novel=load_data(file_val_categories_val_phase)
            
            self.data=np.concatenate(
                [data_base['data'],data_novel['data']],axis=0)
            self.labels=data_base['labels']+data_novel['labels']
            
            self.label2ind=buildLabelIndex(self.labels)
            self.labelIds=sorted(self.label2ind.keys())
            self.num_cats=len(self.labelIds)
            
            self.labelIds_base=buildLabelIndex(data_base['labels']).keys()
            self.labelIds_novel=buildLabelIndex(data_novel['labels']).keys()
            self.num_cats_base=len(self.labelIds_base)
            self.num_cats_novel=len(self.labelIds_novel)
            intersection=set(self.labelIds_base) & set(self.labelIds_novel)
            assert(len(intersection)==0)
        else:
            raise ValueError('Not valid phase {0}'.fotmat(self.phase))
        
        mean_pix = [x/255.0 for x in [120.39586422,  115.59361427, 104.54012653]]
        std_pix = [x/255.0 for x in [70.68188272,  68.27635443,  72.54505529]]
        normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
        
        if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
            self.transform=transforms.Compose([
                lambda x:np.array(x),
                transforms.ToTensor(),
                normalize
            ])
        else:
            self.transform=transforms.Compose([
                transforms.RandomCrop(84,padding=8),
                transforms.RandomHorizontalFlip(),
                lambda x : np.array(x),
                transforms.ToTensor(),
                normalize
            ])
            
    def __getitem__(self,index):
        img,label=self.data[index],self.labels[index]
        #doing this so that it is consistent with all other datasets to return a PIL image
        img=Image.fromarray(img)
        if self.transform is not None:
            img=self.transform(img)
        return img,label
    
    def __len__(self):
        return len(self.data)

In [None]:
class FewShotDataloader():
    def __init__(self,dataset,
                nKnovel=5,#number of novel categories
                nKbase=-1,#number of base categories
                nExemplars=1,#number of training examples per novel category
                nTestNovel=15*5,#number of test examples for all novel categories
                nTestBase=15*5,#number of test examples for all base categories
                batch_size=1,#number of training episodes per batch
                num_workers=4,
                epoch_size=2000):
        self.dataset=dataset
        self.phase=self.dataset.phase
        max_possible_nKnovel=(self.dataset.num_cats_base if self.phase=='train'
                             else self.dataset.num_cats_novel)
        assert(nKnovel>=0 and nKnovel<max_possible_nKnovel)
        self.nKnovel=nKnovel
        
        max_possible_nKbase=self.dataset.num_cats_base
        nKbase=nKbase if nKbase>=0 else max_possible_nKbase
        
        if self.phase=='train' and nKbase>0:
            nKbase-=self.nKnovel
            max_possible_nKbase-=self.nKnovel
            
        assert(nKbase>=0 and nKbase <=max_possible_nKbase)
        self.nKbase=nKbase
        
        self.nExemplars=nExemplars
        self.nTestNovel=nTestNovel
        self.nTestBase=nTestBase
        self.batch_size=batch_size
        self.epoch_size=epoch_size
        self.num_workers=num_workers
        self.is_eval_mode=(self.phase=='test') or (self.phase=='val')
        
    def sampleImageIdsFrom(self, cat_id , sample_size=1):
        """
        samples 'sample_size' number of unique image ids picked from the
        category 'cat_id'
        """
        assert(cat_id in self.dataset.label2ind)
        assert(len(self.dataset.label2ind[cat_id])>=sample_size)
        #Note : random.sample samples elements without replacement.
        return random.sample(self.dataset.label2ind[cat_id],sample_size)
    
    def sampleCategories(self,cat_set,sample_size=1):
        """
        Samples 'sample_size' number of unique categories picked from the
        'cat_set' set of categories.'cat_set' can be either 'base' or 'novel'.
        """
        if cat_set=='base':
            labelIds=self.dataset.labelIds_base
        elif cat_set=='novel':
            labelIds=self.dataset.labelIds_novel
        else:
            raise ValueError('Not recognize category set {}'.format(cat_set))
            
        assert(len(labelIds)>=sample_size)
        
        return random.sample(labelIds,sample_size)
    
    def sample_base_and_novel_categories(self,nKbase , nKnovel):
        """
        Samples 'nKbase' number of base categories and 'nKnovel'  number of novel categories.
        """
        if self.is_eval_mode:
            assert(nKnovel<=self.dataset.num_cats_novel)
            
            Kbase = sorted(self.sampleCategories('base',nKbase))
            Knovel = sorted(self.sampleCategories('novel',nKnovel))
        else:
            cats_ids = self.sampleCategories('base',nKbase+nKnovel)
            assert(len(cats_ids)==(nKbase+nKnovel))
            
            random.shuffle(cats_ids)
            Knovel=sorted(cats_ids[:nKnovel])
            Kbase=sorted(cats_ids[nKnovel:])
        return Kbase,Knovel
    
    def sample_test_examples_for_base_categories(self,Kbase,nTestBase):
        """
        Sample 'nTestBase' number of images from the 'Kbase' categories.
    
        """
        Tbase=[]
        if len(Kbase)>0:
            KbaseIndices=np.random.choice(np.arange(len(Kbase)),size=nTestBase,replace=True)
            KbaseIndices,NumImagesPerCategory=np.unique(KbaseIndices,return_counts=True)
            
            for Kbase_idx,NumImages in zip(KbaseIndices,NumImagesPerCategory):
                imd_ids=self.sampleImageIdsFrom(Kbase[Kbase_idx],sample_size=NumImages)
                Tbase+=[(img_id,Kbase_idx) for img_id in imd_ids]
        assert(len(Tbase)==nTestBase)
        
        return Tbase
    
    def sample_train_and_test_examples_for_novel_categories(
        self,Knovel,nTestNovel,nExemplars,nKbase):
        """
        Samples train and test examples of the novel categories.
        
        Args:
            Knovel:a list with the ids of the novel categories
            nTestNovel:the total number of test imgs that will be sampled from all novel categories
            nExemplars:the number of training examples per novel category that will be sampled
            nKbase:the number of base categories.it's used as offset of the category index of each sampled img
            
        Returns:
            Tnovel:a list of length 'nTestNovel' with 2-element tuple.
                    (img_id , category_label)
            Exemplars: a list of length len(Knovel)*nExemplars of 2-element tuple
                    (img_id , category_label range in [nKbase,nKbase+len(Knovel)-1])
        """
        if len(Knovel)==0:
            return [],[]
        
        nKnovel=len(Knovel)
        Tnovel=[]
        Exemplars=[]
        assert((nTestNovel % nKnovel)==0)
        nEvalExamplesPerClass = int(nTestNovel/nKnovel)
        
        for Knovel_idx in range(len(Knovel)):
            imd_ids=self.sampleImageIdsFrom(Knovel[Knovel_idx],sample_size=(nEvalExamplesPerClass+nExemplars))
            imds_tnovel=imd_ids[:nEvalExamplesPerClass]
            imds_ememplars=imd_ids[nEvalExamplesPerClass:]
            
            Tnovel+=[(img_id , nKbase+Knovel_idx) for img_id in imds_tnovel]
            Exemplars+=[(img_id,nKbase+Knovel_idx) for img_id in imds_ememplars]
        
        assert(len(Tnovel)==nTestNovel)
        assert(len(Exemplars)==len(Knovel)*nExemplars)
        random.shuffle(Exemplars)
        
        return Tnovel,Exemplars
    
    def sample_episode(self):
        """
        Sample a training episode
        """
        nKnovel=self.nKnovel
        nKbase=self.nKbase
        nTestNovel=self.nTestNovel
        nTestBase=self.nTestBase
        nExemplars=self.nExemplars
        
        Kbase,Knovel = self.sample_base_and_novel_categories(nKbase,nKnovel)
        Tbase=self.sample_test_examples_for_base_categories(Kbase,nTestBase)
        Tnovel,Exemplars=self.sample_train_and_test_examples_for_novel_categories(Knovel,nTestNovel,nExemplars,nKbase)
        
        #concatenate the base and novel category examples
        Test=Tbase+Tnovel
        random.shuffle(Test)
        Kall=Kbase+Knovel
        
        return Exemplars , Test , Kall , nKbase
    
    def createExamplesTensorData(self,examples):
        """
        Create the examples image and label tensor data
        """
        images=torch.stack(
            [self.dataset[img_idx][0] for img_idx ,_ in examples],dim=0)
        labels=torch.LongTensor([label for _,label in examples])
        return images,labels
    
    def get_iterator(self,epoch=0):
        rand_seed=epoch
        random.seed(rand_seed)
        np.random.seed(rand_seed)
        
        def load_function(iter_idx):
            Exemplars,Test,Kall,nKbase = self.sample_episode()
            Xt,Yt=self.createExamplesTensorData(Test)
            Kall=torch.LongTensor(Kall)
            if len(Exemplars)>0:
                Xe,Ye=self.createExamplesTensorData(Exemplars)
                return Xe,Ye,Xt,Yt,Kall,nKbase
            else:
                return Xt,Yt,Kall,nKbase
            
        tnt_dataset=tnt.dataset.ListDataset(
            elem_list=range(self.epoch_size),load=load_function)
        data_loader=tnt_dataset.parallel(
            batch_size=self.batch_size,
            num_workers=(0 if self.is_eval_mode else self.num_workers),
            shuffle=(False if self.is_eval_mode else True))
        
        return data_loader
    def __call__(self,epoch=0):
        return self.get_iterator(epoch)
    
    def __len__(self):
        return int(self.epoch_size/self.batch_size)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
class BasicModule(nn.Module):
    def __init__(self):
        super(BasicModule,self).__init__()
        self.model_name=str(type(self))
        
    def load(self,path):
        self.load_state_dict(torch.load(path))
        
    def save(self,path=None):
        if path is None:
            raise ValueError('Please specify the saving road!!!')
        torch.save(self.state_dict(),path)
        return path

In [None]:
def conv_block(in_channels,out_channels,use_relu=True):
    if use_relu:
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
    else:
        return nn.Sequential(
            nn.Conv2d(in_channels,out_channels,3,padding=1),
            nn.BatchNorm2d(out_channels),
            nn.MaxPool2d(2)
        )
    
class AvgBlock(BasicModule):
    def __init__(self,nFeat):
        super(AvgBlock,self).__init__()
        
    def forward(self,features_train , labels_train):
        labels_train_transposed=labels_train.transpose(1,2)
        weight_novel=torch.bmm(labels_train_transposed,features_train)
        weight_novel=weight_novel.div(
            labels_train_transposed.sum(dim=2,keepdim=True).expand_as(weight_novel))
        return weight_novel

In [None]:
class ConvNet(BasicModule):
    def __init__(self):
        super(ConvNet,self).__init__()
        self.encoder=nn.Sequential(
            conv_block(3,64),
            conv_block(64,64),
            conv_block(64,128),
            conv_block(128,128,use_relu=False),
        )
    def forward(self,x):
        out=self.encoder(x)
        out=out.view(out.size(0),-1)
        return out
    

In [None]:
class AttentionBlock(BasicModule):
    def __init__(self,nFeat,nKall,scale_att=10.0):
        super(AttentionBlock,self).__init__()
        
        self.nFeat=nFeat
        self.queryLayer=nn.Linear(nFeat,nFeat)
        self.queryLayer.weight.data.copy_(
            torch.eye(nFeat,nFeat)+torch.randn(nFeat,nFeat)*0.001)
        self.queryLayer.bias.data.zero_()
        
        self.scale_att=nn.Parameter(torch.FloatTensor(1).fill_(scale_att),requires_grad=True)
        wkeys=torch.FloatTensor(nKall,nFeat).normal_(0.0,np.sqrt(2.0/nFeat))
        self.wkeys=nn.Parameter(wkeys,requires_grad=True)
        
    def forward(self,features_train,labels_train,weight_base,Kbase):
        
        batch_size,num_train_examples,num_features=features_train.size()
        nKbase=weight_base.size(1) #[batch_size,nKbase,num_features]
        labels_train_transposed=labels_train.transpose(1,2)
        nKnovel=labels_train_transposed.size(1) #[batch_size,nKnovel,num_train_examples]
        
        features_train=features_train.view(batch_size*num_train_examples,num_features)
        Qe=self.queryLayer(features_train)
        Qe=Qe.view(batch_size,num_train_examples,self.nFeat)
        Qe=F.normalize(Qe,p=2,dim=Qe.dim()-1,eps=1e-12)
        
        wkeys=self.wkeys[Kbase.view(-1)]
        wkeys=F.normalize(wkeys,p=2,dim=wkeys.dim()-1,eps=1e-12)
        #Transpose from[batch_size,nKbase,nFeat]->[batch_size,nFeat,nKbase]
        wkeys=wkeys.view(batch_size,nKbase,self.nFeat).transpose(1,2)
        
        #Compute the attention coefficients
        #AttenCoffiencients=Qe*wkeys -> 
        #[batch_size x num_train_examples x nKbase] =[batch_size x num_train_examples x nFeat] * [batch_size x nFeat x nKbase]
        AttentionCoef=self.scale_att*torch.bmm(Qe,wkeys)
        AttentionCoef=F.softmax(AttentionCoef.view(batch_size*num_train_examples,nKbase))
        AttentionCoef=AttentionCoef.view(batch_size,num_train_examples,nKbase)
        
        #Compute the weight_novel
        #weight_novel=AttentionCoef * weight_base ->
        #[batch_size x num_train_examples x num_features] =[batch_size x num_train_examples x nKbase] * [batch_size x nKbase x num_features]
        weight_novel=torch.bmm(AttentionCoef,weight_base)
        #weight_novel=labels_train_transposed*weight_novel ->
        #[batch_size x nKnovel x num_features] = [batch_size x nKnovel x num_train_examples] * [batch_size x num_train_examples x num_features]
        weight_novel=torch.bmm(labels_train_transposed,weight_novel)
        #div K-shot ,get avg
        weight_novel=weight_novel.div(labels_train_transposed.sum(dim=2,keepdim=True).expand_as(weight_novel))
        return weight_novel

In [None]:
class LinearDiag(BasicModule):
    def __init__(self,num_features,bias=False):
        super(LinearDiag,self).__init__()
        weight=torch.FloatTensor(num_features).fill_(1)#initialize to the identity transform
        self.weight=nn.Parameter(weight,requires_grad=True)
        
        if bias:
            bias=torch.FloatTensor(num_features).fill_(0)
            self.bias=nn.Parameter(bias,requires_grad=True)
            
        else:
            self.register_parameter('bias',None)
            
    def forward(self,X):
        assert(X.dim()==2 and X.size(1)==self.weight.size(0))
        out=X*self.weight.expand_as(X)
        if self.bias is not None:
            out=out+self.bias.expand_as(out)
            
        return out
        

In [None]:
class Classifier(BasicModule):
    def __init__(self,nKall=64,nFeat=128*5*5,weight_generator_type='none'):
        super(Classifier,self).__init__()
        self.nKall=nKall
        self.nFeat=nFeat
        self.weight_generator_type=weight_generator_type
        
        weight_base=torch.FloatTensor(nKall,nFeat).normal_(
        0.0,np.sqrt(2.0/nFeat))
        self.weight_base=nn.Parameter(weight_base,requires_grad=True)
        self.bias=nn.Parameter(torch.FloatTensor(1).fill_(0),requires_grad=True)
        scale_cls=10.0
        self.scale_cls=nn.Parameter(torch.FloatTensor(1).fill_(scale_cls),requires_grad=True)
        
        if self.weight_generator_type=='none':
            #if type is none , then feature averaging is being used.
            #However,in this case the generator doesn't involve any learnable params ,thus doesn't require training
            self.favgblock=AvgBlock(nFeat)
        elif self.weight_generator_type=='attention_based':
            scale_att=10.0
            self.favgblock=AvgBlock(nFeat)
            self.attentionBlock=AttentionBlock(nFeat,nKall,scale_att=scale_att)
            
            self.wnLayerFavg=LinearDiag(nFeat)
            self.wnLayerWatt=LinearDiag(nFeat)
        else:
            raise ValueError('weight_generator_type is not supported!')
            
    def get_classification_weights(
        self,Kbase_ids,features_train=None,labels_train=None):
        """
        Args:
            Get the classification weights of the base and novel categories.
            Kbase_ids:[batch_size , nKbase],the indices of base categories that used
            features_train:[batch_size,num_train_examples(way*shot),nFeat]
            labels_train :[batch_size,num_train_examples,nKnovel(way)] one-hot of features_train
        
        return:
            cls_weights:[batch_size,nK,nFeat] 
        """
        #get the classification weights for the base categories
        batch_size,nKbase=Kbase_ids.size()
        weight_base=self.weight_base[Kbase_ids.view(-1)]
        weight_base=weight_base.view(batch_size,nKbase,-1)
        
        #if training data for novel categories are not provided,return only base_weight
        if features_train is None or labels_train is None:
            return weight_base
        
        #get classification weights for novel categories
        _,num_train_examples , num_channels=features_train.size()
        nKnovel=labels_train.size(2)
        
        #before do cosine similarity ,do L2 normalize
        features_train=F.normalize(features_train,p=2,dim=features_train.dim()-1,eps=1e-12)
        if self.weight_generator_type=='none':
            weight_novel=self.favgblock(features_train,labels_train)
            weight_novel=weight_novel.view(batch_size,nKnovel,num_channels)
        elif self.weight_generator_type=='attention_based':
            weight_novel_avg=self.favgblock(features_train,labels_train)
            weight_novel_avg=self.wnLayerFavg(weight_novel_avg.view(batch_size*nKnovel,num_channels))
            
            #do L2 for weighr_base
            weight_base_tmp=F.normalize(weight_base,p=2,dim=weight_base.dim()-1,eps=1e-12)
            
            weight_novel_att=self.attentionBlock(features_train,labels_train,weight_base_tmp,Kbase_ids)
            weight_novel_att=self.wnLayerWatt(weight_novel_att.view(batch_size*nKnovel,num_channels))
            
            weight_novel=weight_novel_avg+weight_novel_att
            weight_novel=weight_novel.view(batch_size,nKnovel,num_channels)
        else:
            raise ValueError('weight generator type is not supported!')
            
        #Concatenate the base and novel classification weights and return
        weight_both=torch.cat([weight_base,weight_novel],dim=1)#[batch_size ,nKbase+nKnovel , num_channel]
        
        return weight_both
    
    def apply_classification_weights(self,features,cls_weights):
        """
        Apply the classification weight vectors to the feature vectors
        Args:
            features:[batch_size,num_test_examples,num_channels]
            cls_weights:[batch_size,nK,num_channels]
        Return:
            cls_scores:[batch_size,num_test_examples(query set),nK]
        """
        #do L2 normalize
        features=F.normalize(features,p=2,dim=features.dim()-1,eps=1e-12)
        cls_weights=F.normalize(cls_weights,p=2,dim=cls_weights.dim()-1,eps=1e-12)
        cls_scores=self.scale_cls*torch.baddbmm(1.0,
                    self.bias.view(1,1,1),1.0,features,cls_weights.transpose(1,2))
        return cls_scores
    
    def forward(self,features_test,Kbase_ids,features_train=None,labels_train=None):
        """
        Recognize on the test examples both base and novel categories.
        Args:
            features_test:[batch_size,num_test_examples(query set),num_channels]
            Kbase_ids:[batch_size,nKbase] , the indices of base categories that are being used.
            features_train:[batch_size,num_train_examples,num_channels]
            labels_train:[batch_size,num_train_examples,nKnovel]
            
        Return:
            cls_score:[batch_size,num_test_examples,nKbase+nKnovel]
    
        """
        cls_weights=self.get_classification_weights(
            Kbase_ids,features_train,labels_train)
        cls_scores=self.apply_classification_weights(features_test,cls_weights)
        return cls_scores    

### Training procedure
### training step1 : training FE and pretrain cosine-based classifier

In [None]:
#step 1 =========training Feature Extractor and pretrain cosine-based classifier
use_cuda=torch.cuda.is_available()
torch.cuda.set_device(0)
torch.manual_seed(1234)
if use_cuda:
    torch.cuda.manual_seed(1234)


In [None]:
epoch=31
lr=0.1
momentum=0.9
weight_decay=5e-4

dataset_train=MiniImageNet(phase='train')
# dataset_test=MiniImageNet(phase='val')

dloader_train=FewShotDataloader(dataset=dataset_train,
                               nKnovel=0,
                               nKbase=64,
                               nExemplars=0,
                               nTestNovel=0,
                               nTestBase=32,
                               batch_size=8,
                               num_workers=1,
                               epoch_size=8*1000)


In [None]:

if not os.path.isdir('results/trace_file'):
    os.makedirs('results/trace_file')
    os.makedirs('results/pretrain_model')
    
trace_file=os.path.join('results','trace_file','pre_train_trace.txt')
if os.path.isfile(trace_file):
    os.remove(trace_file)
    
#model
fe_model=ConvNet()
classifier=Classifier()
if use_cuda:
    fe_model.cuda()
    classifier.cuda()

#optimizer
optimizer_fe=torch.optim.SGD(fe_model.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)
optimizer_classifier=torch.optim.SGD(classifier.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)
lr_schedule_fe=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_fe,gamma=0.5,step_size=25)
lr_schedule_classifier=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_classifier,gamma=0.5,step_size=25)
criterion=torch.nn.CrossEntropyLoss()

print("----pre-train----")
for ep in range(epoch):
    train_loss=[]
    print("----epoch: %2d---- "%ep)
    fe_model.train()
    classifier.train()
    
    for batch in tqdm(dloader_train(ep)):
        assert(len(batch)==4)
        
        optimizer_fe.zero_grad()
        optimizer_classifier.zero_grad()
        
        train_data=batch[0]
        train_label=batch[1]
        k_id=batch[2]
        
        if use_cuda:
            train_data=train_data.cuda()
            train_label=train_label.cuda()
            k_id=k_id.cuda()
        
        batch_size,nTestBase,channels,width,high=train_data.size()
        train_data=train_data.view(batch_size*nTestBase,channels,width,high)
        train_data_embedding=fe_model(train_data)
        pred_result=classifier(train_data_embedding.view(batch_size,nTestBase,-1),k_id)
#         print("pred_result.size",pred_result.size())
        loss=criterion(pred_result.view(batch_size*nTestBase,-1),train_label.view(batch_size*nTestBase))
        loss.backward()
        optimizer_fe.step()
        optimizer_classifier.step()
        train_loss.append(float(loss))
    lr_schedule_fe.step()
    lr_schedule_classifier.step()
    
    avg_loss=np.mean(train_loss)
    print("epoch %2d training end : avg_loss = %.4f"%(ep,avg_loss))
    with open(trace_file,'a') as f:
        f.write('epoch:{:2d} training end：avg_loss:{:.4f}'.format(ep,avg_loss))
        f.write('\n')
    if ep==epoch-1:
        p1='results/pretrain_model/fe_%s.pth'%(str(ep))
        p2='results/pretrain_model/classifier_%s.pth'%(str(ep))
        m1=fe_model.save(path=p1)
        m2=classifier.save(path=p2)
        print("Epoch %2d model successfully saved!"%(ep))

### training step2 : continue to train classifier and attention-based weight generator

In [None]:
#step 2 
path_fe='results/pretrain_model/fe_30.pth'
path_classifier='results/pretrain_model/classifier_30.pth'

#load pretrain model
fe_model=ConvNet()
classifier=Classifier(weight_generator_type='attention_based')
pre_train_classifier=torch.load(path_classifier)

fe_model.load(path_fe)

for pname , param in classifier.named_parameters():
    if pname in pre_train_classifier:
        param.data.copy_(pre_train_classifier[pname])
        

In [None]:
#load training data
epoch=60
lr=0.1
momentum=0.9
weight_decay=5e-4

dataset_train=MiniImageNet(phase='train')
dataset_test=MiniImageNet(phase='val')

dloader_train=FewShotDataloader(dataset=dataset_train,
                               nKnovel=5,
                               nKbase=-1,
                               nExemplars=1,
                               nTestNovel=5*3,
                               nTestBase=5*3,
                               batch_size=8,
                               num_workers=1,
                               epoch_size=8*1000)#8*1000
dloader_test = FewShotDataloader(
    dataset=dataset_test,
    nKnovel=5,
    nKbase=64,
    nExemplars=1, # num training examples per novel category
    nTestNovel=15*5, # num test examples for all the novel categories
    nTestBase=15*5, # num test examples for all the base categories
    batch_size=1,
    num_workers=0,
    epoch_size=2000, #2000 num of batches per epoch
)

In [None]:
def get_labels_train_one_hot(labels_train,num_classes):
    res=[]
    batch_size,num=labels_train.size()
    for i in range(batch_size):
        min_value=torch.min(labels_train[i])
        labels=labels_train[i]-min_value
        one_hot=torch.zeros((num,num_classes))
        for i in range(len(labels)):
            one_hot[i][labels[i]]=1
        res.append(one_hot)
    return torch.cat(res).view(batch_size,num,num_classes)
        
def get_acc(pred,labels):
    _,pred_inds=pred.max(dim=1)
    pred_inds=pred_inds.view(-1)
    labels=labels.view(-1)
    acc=100*pred_inds.eq(labels).float().mean()
    return acc

In [None]:
if not os.path.isdir('results/stage_2_model'):
    os.makedirs('results/stage_2_model')

trace_file=os.path.join('results','trace_file','train_stage_2_trace.txt')
if os.path.isfile(trace_file):
    os.remove(trace_file)
    
if use_cuda:
    fe_model.cuda()
    classifier.cuda()

#optimizer
# optimizer_fe=torch.optim.SGD(fe_model.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)
optimizer_classifier=torch.optim.SGD(classifier.parameters(),lr=lr,nesterov=True , momentum=momentum,weight_decay=weight_decay)
lr_schedule_classifier=torch.optim.lr_scheduler.StepLR(optimizer=optimizer_classifier,gamma=0.5,step_size=25)
criterion=torch.nn.CrossEntropyLoss()

print("---- train-stage-2 ----")
best_acc_both=0.0
best_acc_novel=0.0
for ep in range(epoch):
    train_loss=[]
    acc_both=[]
    acc_base=[]
    acc_novel=[]
    print("----epoch: %2d---- "%ep)
    fe_model.train()
    classifier.train()
    
    for batch in tqdm(dloader_train(ep)):
        assert(len(batch)==6) #images_train, labels_train, images_test, labels_test, K, nKbase
        
#         optimizer_fe.zero_grad()
        optimizer_classifier.zero_grad()
        
        train_data=batch[0]
        train_label=batch[1]
        test_data=batch[2]
        test_label=batch[3]
        k_id=batch[4]
        nKbase=batch[5]
        KbaseId=k_id[:,:nKbase[0]]
        labels_train_one_hot=get_labels_train_one_hot(train_label,dloader_train.nKnovel)
        
        if use_cuda:
            train_data=train_data.cuda()
            train_label=train_label.cuda()
            test_data=test_data.cuda()
            test_label=test_label.cuda()
            k_id=k_id.cuda()
            nKbase=nKbase.cuda()
            KbaseId=KbaseId.cuda()
            labels_train_one_hot=labels_train_one_hot.cuda()
        
        batch_size,nExamples,channels,width,high=train_data.size()
        nTest=test_data.size(1)
        
        train_data=train_data.view(batch_size*nExamples,channels,width,high)
        test_data=test_data.view(batch_size*nTest,channels,width,high)
        
        train_data_embedding=fe_model(train_data)
        test_data_embedding=fe_model(test_data)
        
        pred_result=classifier(features_test=test_data_embedding.view(batch_size,nTest,-1),Kbase_ids=KbaseId,
                               features_train=train_data_embedding.view(batch_size,nExamples,-1),labels_train=labels_train_one_hot)
#         print("pred_result.size",pred_result.size())
        pred_result = pred_result.view(batch_size*nTest,-1)
        test_label = test_label.view(batch_size*nTest)
    
        loss=criterion(pred_result,test_label)
        loss.backward()
#         optimizer_fe.step()
        optimizer_classifier.step()
    
        train_loss.append(float(loss))
        
        accuracy_both=get_acc(pred_result,test_label)
        acc_both.append(float(accuracy_both))
        
        base_ids=torch.nonzero(test_label < nKbase[0]).view(-1)
        novel_ids=torch.nonzero(test_label >= nKbase[0]).view(-1)
        
        pred_base = pred_result[base_ids,:]
        pred_novel =pred_result[novel_ids,:]
        
        accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])
        accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))
        
        acc_base.append(float(accuracy_base))
        acc_novel.append(float(accuracy_novel))
        
    
    lr_schedule_classifier.step()
    #------------------------------------------------------
    #validation stage
    print("----begin validation----")
    fe_model.eval()
    classifier.eval()
    
    val_loss=[]
    val_acc_both=[]
    val_acc_base=[]
    val_acc_novel=[]
    for batch in tqdm(dloader_test(ep)):
        assert(len(batch)==6)
        train_data=batch[0]
        train_label=batch[1]
        test_data=batch[2]
        test_label=batch[3]
        k_id=batch[4]
        nKbase=batch[5]
        KbaseId=k_id[:,:nKbase[0]]
        labels_train_one_hot=get_labels_train_one_hot(train_label,dloader_test.nKnovel)
        
        if use_cuda:
            train_data=train_data.cuda()
            train_label=train_label.cuda()
            test_data=test_data.cuda()
            test_label=test_label.cuda()
            k_id=k_id.cuda()
            nKbase=nKbase.cuda()
            KbaseId=KbaseId.cuda()
            labels_train_one_hot=labels_train_one_hot.cuda()
        
        batch_size,nExamples,channels,width,high=train_data.size()
        nTest=test_data.size(1)
        
        train_data=train_data.view(batch_size*nExamples,channels,width,high)
        test_data=test_data.view(batch_size*nTest,channels,width,high)
        
        train_data_embedding=fe_model(train_data)
        test_data_embedding=fe_model(test_data)
        
        pred_result=classifier(features_test=test_data_embedding.view(batch_size,nTest,-1),Kbase_ids=KbaseId,
                               features_train=train_data_embedding.view(batch_size,nExamples,-1),labels_train=labels_train_one_hot)
#         print("pred_result.size",pred_result.size())
        pred_result = pred_result.view(batch_size*nTest,-1)
        test_label = test_label.view(batch_size*nTest)
        
        loss=criterion(pred_result,test_label)
        val_loss.append(float(loss))
        
        accuracy_both=get_acc(pred_result,test_label)
        val_acc_both.append(float(accuracy_both))
        
        base_ids=torch.nonzero(test_label < nKbase[0]).view(-1)
        novel_ids=torch.nonzero(test_label >= nKbase[0]).view(-1)
        
        pred_base = pred_result[base_ids,:]
        pred_novel =pred_result[novel_ids,:]
        
        accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])
        accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))
        
        val_acc_base.append(float(accuracy_base))
        val_acc_novel.append(float(accuracy_novel))
    avg_loss=np.mean(train_loss)
    avg_acc_both=np.mean(acc_both)
    avg_acc_base=np.mean(acc_base)
    avg_acc_novel=np.mean(acc_novel)
    
    val_avg_loss=np.mean(val_loss)
    val_avg_acc_both=np.mean(val_acc_both)
    val_avg_acc_base=np.mean(val_acc_base)
    val_avg_acc_novel=np.mean(val_acc_novel)
    
    print("epoch %2d training end : training ---- avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f "%(ep,avg_loss,avg_acc_both,avg_acc_base,avg_acc_novel))
    print("epoch %2d training end : validation ---- avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f "%(ep,val_avg_loss,val_avg_acc_both,val_avg_acc_base,val_avg_acc_novel))
    with open(trace_file,'a') as f:
        f.write('epoch:{:2d}  training ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(ep,avg_loss,avg_acc_both,avg_acc_base,avg_acc_novel))
        f.write('\n')
        f.write('epoch:{:2d}  validation ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(ep,val_avg_loss,val_avg_acc_both,val_avg_acc_base,val_avg_acc_novel))
        f.write('\n')
    if best_acc_both<val_avg_acc_both:
        print("produce best both_acc model，saving------")
        
        p1='results/stage_2_model/fe_best_both.pth'
        p2='results/stage_2_model/classifier_best_both.pth'
        m1=fe_model.save(path=p1)
        m2=classifier.save(path=p2)
        best_acc_both=avg_acc_both
        print("successfully saving current best both_acc model----")
    if best_acc_novel<val_avg_acc_novel:
        print("produce best novel_acc model，saving------")
        
        p1='results/stage_2_model/fe_best_novel.pth'
        p2='results/stage_2_model/classifier_best_novel.pth'
        m1=fe_model.save(path=p1)
        m2=classifier.save(path=p2)
        best_acc_novel=avg_acc_novel
        print("succewssfully saving current best novel_acc model----")

### test 

In [None]:
#test stage
path_fe='results/stage_2_model/fe_best_novel.pth'
path_classifier='results/stage_2_model/classifier_best_novel.pth'

#load model
fe_model=ConvNet()
classifier=Classifier(weight_generator_type='attention_based')

fe_model.load(path_fe)
classifier.load(path_classifier)

In [None]:
#load test data

dataset_test=MiniImageNet(phase='test')


dloader_test = FewShotDataloader(
    dataset=dataset_test,
    nKnovel=5,
    nKbase=64,
    nExemplars=1, # num training examples per novel category
    nTestNovel=15*5, # num test examples for all the novel categories
    nTestBase=15*5, # num test examples for all the base categories
    batch_size=1,
    num_workers=0,
    epoch_size=600, 
)

In [None]:
trace_file=os.path.join('results','trace_file','test_trace.txt')
if os.path.isfile(trace_file):
    os.remove(trace_file)
    
if use_cuda:
    fe_model.cuda()
    classifier.cuda()

criterion=torch.nn.CrossEntropyLoss()

print("---- test-stage ----")
fe_model.eval()
classifier.eval()
test_loss=[]
test_acc_both=[]
test_acc_base=[]
test_acc_novel=[]
for batch in tqdm(dloader_test()):
    assert(len(batch)==6)
    train_data=batch[0]
    train_label=batch[1]
    test_data=batch[2]
    test_label=batch[3]
    k_id=batch[4]
    nKbase=batch[5]
    KbaseId=k_id[:,:nKbase[0]]
    labels_train_one_hot=get_labels_train_one_hot(train_label,dloader_test.nKnovel)

    if use_cuda:
        train_data=train_data.cuda()
        train_label=train_label.cuda()
        test_data=test_data.cuda()
        test_label=test_label.cuda()
        k_id=k_id.cuda()
        nKbase=nKbase.cuda()
        KbaseId=KbaseId.cuda()
        labels_train_one_hot=labels_train_one_hot.cuda()

    batch_size,nExamples,channels,width,high=train_data.size()
    nTest=test_data.size(1)

    train_data=train_data.view(batch_size*nExamples,channels,width,high)
    test_data=test_data.view(batch_size*nTest,channels,width,high)

    train_data_embedding=fe_model(train_data)
    test_data_embedding=fe_model(test_data)

    pred_result=classifier(features_test=test_data_embedding.view(batch_size,nTest,-1),Kbase_ids=KbaseId,
                           features_train=train_data_embedding.view(batch_size,nExamples,-1),labels_train=labels_train_one_hot)
#         print("pred_result.size",pred_result.size())
    pred_result = pred_result.view(batch_size*nTest,-1)
    test_label = test_label.view(batch_size*nTest)

    loss=criterion(pred_result,test_label)
    test_loss.append(float(loss))

    accuracy_both=get_acc(pred_result,test_label)
    test_acc_both.append(float(accuracy_both))

    base_ids=torch.nonzero(test_label < nKbase[0]).view(-1)
    novel_ids=torch.nonzero(test_label >= nKbase[0]).view(-1)

    pred_base = pred_result[base_ids,:]
    pred_novel =pred_result[novel_ids,:]

    accuracy_base=get_acc(pred_base[:,:nKbase[0]],test_label[base_ids])
    accuracy_novel=get_acc(pred_novel[:,nKbase[0]:],(test_label[novel_ids]-nKbase[0]))

    test_acc_base.append(float(accuracy_base))
    test_acc_novel.append(float(accuracy_novel))

test_avg_loss=np.mean(test_loss)
test_avg_acc_both=np.mean(test_acc_both)
test_avg_acc_base=np.mean(test_acc_base)
test_avg_acc_novel=np.mean(test_acc_novel)

print("%2d batch test end :  avg_loss = %.4f , avg_acc_both = %.2f , avg_acc_base = %.2f , avg_acc_novel = %.2f "%(dloader_test.epoch_size,test_avg_loss,test_avg_acc_both,test_avg_acc_base,test_avg_acc_novel))
with open(trace_file,'a') as f:
    f.write('batch_size:{:2d}  test ---- avg_loss:{:.4f} , avg_acc_both:{:.2f} , avg_acc_base:{:.2f} , avg_acc_novel:{:.2f}'.format(dloader_test.epoch_size,test_avg_loss,test_avg_acc_both,test_avg_acc_base,test_avg_acc_novel))
    f.write('\n')