# Setting arguments

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import argparse


# setting hyperparameters
parser = argparse.ArgumentParser(description='RelationNet')
parser.add_argument('--gpu', '-g', 
                    type=int, 
                    default=0,
                    help='GPU ID (-1 indicates CPU)')#Set the initial matrix
parser.add_argument('--way', '-w',
                    type=int, 
                    default=5,
                    help='Number of way to train')
parser.add_argument('--shot', '-s',
                    type=int, 
                    default=20,
                    help='Number of shot to train')
parser.add_argument('--epoch',
                    type=int, 
                    default=5000,
                    help='Number of training time')
parser.add_argument('--episode',
                    type=int, 
                    default=100,
                    help='Number of test time (maximum-600)')
# create arg object
args = parser.parse_args([])

# Loading mini-Imagenet

In [None]:
from miniImage import miniImage
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

csv_file = './mini-imagenet/train.csv'
root_dir = './mini-imagenet/train'

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.92206, 0.92206, 0.92206], std=[0.08426, 0.08426, 0.08426])
])

trainset = miniImage(csv_file = csv_file,
                     root_dir = root_dir,
                     transform = transform,
                     way = args.way)
support_set = torch.utils.data.DataLoader(trainset,
                                          batch_size=args.shot,
                                          shuffle=True, 
                                          num_workers=6)
query_set = torch.utils.data.DataLoader(trainset,
                                        batch_size=1,
                                        shuffle=True, 
                                        num_workers=0)

# Using GPU

In [None]:
if(args.gpu == -1):
    device = torch.device('cpu')
else:
    device = torch.device('cuda:'+str(args.gpu) if torch.cuda.is_available() else "cpu")

# Assuming that we are on a CUDA machine, this should print a CUDA device:
torch.cuda.empty_cache()
print(device)

# Build Feature Net

In [None]:
class CNNEncoder(nn.Module):
    """docstring for ClassName"""
    def __init__(self):
        super(CNNEncoder, self).__init__()
        self.layer1 = nn.Sequential(
                        nn.Conv2d(3,64,kernel_size=3,padding=0),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=0),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU(),
                        nn.MaxPool2d(2))
        self.layer3 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=1),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU())
        self.layer4 = nn.Sequential(
                        nn.Conv2d(64,64,kernel_size=3,padding=1),
                        nn.BatchNorm2d(64, momentum=1, affine=True),
                        nn.ReLU())

    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        return out # 64
    
encoder = CNNEncoder().to(device)
encoder.eval()
encoder.load_state_dict(torch.load('miniimagenet_feature_encoder_5way_5shot.pkl', map_location=device))

# Build Relation Net

In [None]:
class RelationNetwork(nn.Module):
    """docstring for RelationNetwork"""
    def __init__(self):
        super(RelationNetwork, self).__init__()
        self.fc1 = nn.Linear(64*19*19*2, 512)
        self.pr1 = nn.PReLU()
        self.fc2 = nn.Linear(512, 64)
        self.pr2 = nn.PReLU()
        self.fc3 = nn.Linear(64, 1)

    def forward(self,x):
        out = x.view(x.size(0),-1)
        out = self.pr1(self.fc1(out))
        out = self.pr2(self.fc2(out))
        out = torch.sigmoid(self.fc3(out))
        return out
    
relation_net = RelationNetwork().to(device)
relation_net.load_state_dict(torch.load('miniimagenet_relation_net_5way_5shot.pkl', map_location=device))

In [None]:
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(relation_net.parameters(), lr=0.001)

# Meta-Training

In [None]:
"""
Epoch
"""
running_loss = 0.0
TP = 0
FP = 0
FN = 0
TN = 0
for epoch in range(args.epoch):
        sample = next(iter(support_set))
        sample_imgs = sample['data'].to(device) # [shot x 3 x 84 x 84]
        sample_labels = sample['label'] # [shot]        
        sample_features = encoder(sample_imgs)  # [shot, 64, 19, 19]
        
        """
        Episode
        """
        for episode, query in zip(range(args.episode), query_set):
            query_imgs = query['data'].to(device) # [1 x 3 x 84 x 84]
            query_labels = query['label'] # [1]
            query_features = encoder(query_imgs) #[1 x 64 x 19 x 19] 
            
            # copy arg.shot times
            query_features = query_features.repeat(args.shot, 1, 1, 1) #[shot x 64 x 19 x 19]
            query_labels = query_labels.repeat(args.shot) #[shot]

            # concatenate sample feautres and query features
            inputs = torch.cat([sample_features, query_features], dim=1) #[shot x 128 x 19 x 19]
            labels = torch.zeros((args.shot), dtype=torch.float).to(device) #[shot]
            labels[query_labels == sample_labels] = 1

            """
            Training
            """
            optimizer.zero_grad()                    
            outputs = relation_net(inputs).view(-1)
            loss = criterion(outputs, labels)
            loss.backward(retain_graph=True)
            optimizer.step()
            running_loss += loss.item()
    
            for i, class_prob in enumerate(outputs):
                if(class_prob > 0.9):
                    if(sample_labels[i] == query_labels[0]):
                        TP += 1
                    else:
                        FP += 1
                else:
                    if(sample_labels[i] == query_labels[0]):
                        FN += 1
                    else:
                        TN += 1
            
            if(episode % args.episode == args.episode-1):
                print('[{:d}, {:3d}] loss: {:.3f}'.format(
                      epoch + 1, episode+1, running_loss/args.episode))
                running_loss = 0.0
    
        if (epoch+1)*(episode+1) % (args.episode*100) == 0:
            print('------------------------------------------')
            print('Precision = {:5%}'.format(TP/(TP+FP) if (TP+FP)!=0 else 0))
            print('Sensitivity = {:5%}'.format(TP/(TP+FN) if (TP+FN)!=0 else 0))
            print('Accuracy = {:5%}'.format((TP+TN)/(TP+FP+FN+TN)))
            print('------------------------------------------')
            TP = 0
            FP = 0
            FN = 0
            TN = 0
            trainset.sample_class()

In [None]:
torch.save(relation_net.state_dict(), 'miniimagenet_relation_net_5way_5shot.pkl')