In [5]:
"""
This code is reimplement of "CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning"(Zero-shot part)

Download the AWA2 dataset(AWA2-base.zip, AWA2-features.zip) at "https://cvml.ist.ac.at/AwA2/"

"""

'\nThis code is reimplement of "CVPR 2018 paper: Learning to Compare: Relation Network for Few-Shot Learning"(Zero-shot part)\n\nDownload the AWA2 dataset at "https://cvml.ist.ac.at/AwA2/"\n'

In [225]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as transform
import numpy as np

In [226]:
#data about animals' attribute
class Attribute:
    def __init__(self, root):
        """
        parameters:
            root : location of the file
        """
        self.idx_to_cls = {}
        f = open(root + "classes.txt", 'r')
        lines = f.readlines()
        for line in lines:
            self.idx_to_cls[int(line.split()[0])-1] = line.split()[1]
        f.close()

        f = open(root + "predicate-matrix-continuous.txt", 'r')
        attribute=[]
        lines = f.readlines()
        for line in lines:
            attribute.append(list(map(float, line.split())))
        f.close()
        self.attribute=torch.tensor(attribute)
        
    def __getitem__(self, idx):
        return self.attribute[idx]
        
#
def load_split(root, split_num=10):
    """
    parameters:
        root : location of the file
        split_num : num of class to test
    """
    print("Loading data starts..")
    #all data is data of query set
    train_x = []
    train_y = []
    test_x = []
    test_y = []

    test_cls = np.random.choice(50, split_num, replace=False)

    f1 = open(root + "AwA2-features.txt", 'r')
    f2 = open(root + "AwA2-labels.txt", 'r')

    features = f1.readlines()
    labels = f2.readlines()

    for feature, label in zip(features, labels):
        label = int(label)-1
        if label in test_cls:
            test_x.append(list(map(float, feature.split())))
            test_y.append(label)
        else:
            train_x.append(list(map(float, feature.split())))
            train_y.append(label)   
    f1.close()
    f2.close()

    train_set = torch.utils.data.TensorDataset(torch.tensor(train_x), torch.LongTensor(train_y))
    test_set = torch.utils.data.TensorDataset(torch.tensor(test_x), torch.LongTensor(test_y))
    print("finish")
    return train_set, test_set


In [227]:
class Embedding(nn.Module):
    #Embedding module for attributes
    def __init__(self):
        nn.Module.__init__(self)
        layers = []
        layers += [nn.Linear(85, 1024), nn.ReLU()] 
        layers += [nn.Linear(1024, 2048), nn.ReLU()]
        
        self.layers = nn.Sequential(*layers)
        
    def forward(self, x):
        #x(bsize*85) -> embedded feature(bsize*2048)
        return self.layers(x)
    
class RelationNet(nn.Module):
    #Relation module for concatenated features
    def __init__(self):
        nn.Module.__init__(self)
        layers= []
        layers += [nn.Linear(4096, 400), nn.ReLU()]
        layers += [nn.Linear(400, 1), nn.Sigmoid()]
        
        self.layer = nn.Sequential(*layers)
    
    def forward(self, x):
        #x(bsize * cls_num * 4096) -> score(bsize * cls_num)
        return self.layer(x)
    

In [229]:
def train(attribute, train, test, lr=0.01, EPISODE_NUM=1, bsize=32, test_bsize=32):
    """
    parameters:
        attribute : attribute of the animals
        train : train datset
        test : test dataset
        lr :learning rate
        EPISODE_NUM : =~ epoch
        bsize : batch size
        test_bisze : batch size for test set
    """
    emb = Embedding()
    RN = RelationNet()
    MSE = nn.MSELoss()

    
    if torch.cuda.is_available():
        print("cuda is on")
        emb.cuda()
        RN.cuda()
        MSE=MSE.cuda()
    
    emb_optim = torch.optim.Adam(emb.parameters(), lr = lr)
    RN_optim = torch.optim.Adam(RN.parameters(), lr = lr)
    
    test_loader = torch.utils.data.DataLoader(test, batch_size=test_bsize, shuffle=True)
        
    for i in range(EPISODE_NUM):
        emb.train()
        RN.train()

        train_loader = torch.utils.data.DataLoader(train, batch_size=bsize, shuffle=True)
        train_feature, train_label = train_loader.__iter__().next()
        if len(train_feature) != bsize:
            bsize=len(train_feature)

        #relabeling the index of training set
        support_label={}
        cls_num=0
        new_label=[]
        for label in train_label:
            if int(label) not in support_label:
                support_label[int(label)] = cls_num
                cls_num += 1
            new_label.append(support_label[int(label)])
        train_feature = torch.autograd.Variable(train_feature)
        support_attribute = torch.autograd.Variable(attribute[list(support_label.keys())])

        if torch.cuda.is_available():
            support_attribute = support_attribute.cuda()
            train_feature = train_feature.cuda()

        att_emb = emb(support_attribute).unsqueeze(0).repeat(bsize, 1, 1)
        train_emb = train_feature.unsqueeze(0).repeat(cls_num, 1, 1).transpose(0,1)

        concat = torch.cat((att_emb, train_emb), 2)
        score = RN(concat).squeeze(2)
        
        one_hot = torch.zeros(bsize, cls_num)
        one_hot[torch.arange(bsize), new_label] = 1
        one_hot = torch.autograd.Variable(one_hot)
        if torch.cuda.is_available(): 
            one_hot=one_hot.cuda()
        
        loss = MSE(score, one_hot)
        
        emb.zero_grad()
        RN.zero_grad()
            
        loss.backward()
        
        emb_optim.step()
        RN_optim.step()


        if i%400 == 0 :
            emb.eval()
            RN.eval()
            correct=0
            for test_feature, test_label in test_loader:
                if len(test_feature) != test_bsize:
                    test_bsize=len(test_feature)
                
                support_label={}
                cls_num=0
                new_label=[]
                
                for label in test_label:
                    if int(label) not in support_label:
                        support_label[int(label)] = cls_num
                        cls_num += 1
                    new_label.append(support_label[int(label)])

                test_feature = torch.autograd.Variable(test_feature)
                support_attribute = torch.autograd.Variable(attribute[list(support_label.keys())])

                if torch.cuda.is_available():
                    support_attribute = support_attribute.cuda()
                    test_feature = test_feature.cuda()

                att_emb = emb(support_attribute).unsqueeze(0).repeat(test_bsize, 1, 1)
                test_emb = test_feature.unsqueeze(0).repeat(cls_num, 1, 1).transpose(0,1)
                
                concat = torch.cat((att_emb, test_emb), 2)
                score = RN(concat).squeeze(2)

                pred = score.max(1)[1]
                
                new_label = torch.tensor(new_label)
                if torch.cuda.is_available():
                    new_label = new_label.cuda()

                correct = correct + torch.sum(pred==new_label)
            print("EPISODE",i,":")
            print("Accuracy :", correct/float(len(test_set)))

    return emb, RN

In [None]:
attribute = Attribute("Animals_with_Attributes2/")
train_set, test_set = load_split("Animals_with_Attributes2-2/Features/ResNet101/")

In [230]:
trained_emb, trained_RN = train(attribute, train_set, test_set, lr=0.00001,EPISODE_NUM=10000)

cuda is on
EPISODE 0 :
Accuracy : tensor(0.1867, device='cuda:0')
EPISODE 400 :
Accuracy : tensor(0.1874, device='cuda:0')
EPISODE 800 :
Accuracy : tensor(0.3962, device='cuda:0')
EPISODE 1200 :
Accuracy : tensor(0.6139, device='cuda:0')
EPISODE 1600 :
Accuracy : tensor(0.7166, device='cuda:0')
EPISODE 2000 :
Accuracy : tensor(0.7386, device='cuda:0')
EPISODE 2400 :
Accuracy : tensor(0.7472, device='cuda:0')
EPISODE 2800 :
Accuracy : tensor(0.7443, device='cuda:0')
EPISODE 3200 :
Accuracy : tensor(0.7422, device='cuda:0')
EPISODE 3600 :
Accuracy : tensor(0.7692, device='cuda:0')
EPISODE 4000 :
Accuracy : tensor(0.7509, device='cuda:0')
EPISODE 4400 :
Accuracy : tensor(0.7579, device='cuda:0')
EPISODE 4800 :
Accuracy : tensor(0.7387, device='cuda:0')
EPISODE 5200 :
Accuracy : tensor(0.7675, device='cuda:0')
EPISODE 5600 :
Accuracy : tensor(0.7259, device='cuda:0')
EPISODE 6000 :
Accuracy : tensor(0.7141, device='cuda:0')
EPISODE 6400 :
Accuracy : tensor(0.7203, device='cuda:0')
EPISODE 