In [25]:
import torch as T  
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision.transforms as trans 
from torch.utils.data import Dataset, DataLoader 
from torchvision.models import resnet18
import numpy as np 
import matplotlib.pyplot as plt
import os
from PIL import Image 
import scipy as sp
import scipy.stats

# Global Variable

In [26]:
path_train = '/home/wwang/datasets/mini_imagenet/train/'
path_val = '/home/wwang/datasets/mini_imagenet/val/'
way = 2 # 不可改
shot = 2
image_size = [84, 84]
dim_embedding = 256
channel = 64
margin = 1 # 0.3
iterations = 1000000
batch_size = 32
path_work = "/home/wwang/wwfewshot/work/siamesenet_3loss/"
device = T.device("cuda:0")

# Data

In [27]:
class Dataset_siamesenet(Dataset):
    #
    def __init__(self, path):
        # print(0)
        category_names = os.listdir(path)
        self.image_list = []
        self.label_list = []
        for label, category_name in enumerate(category_names):
            file_names = os.listdir(path + category_name)
            for file_name in file_names:
                self.image_list.append(path + category_name + '/' + file_name)
                self.label_list.append(label) 
        self.category_num = label + 1
        self.transform = trans.Compose([trans.Resize(image_size, Image.BICUBIC),
                                        trans.RandomHorizontalFlip(),
                                        trans.ToTensor(),
                                        trans.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])
    #
    def __getitem__(self, idx):
        label = np.zeros([way])
        permutation = np.random.permutation(self.category_num)[:way]
        category_random = np.random.randint(way)
        label[category_random] = 1
        image_support = []
        for i, cls in enumerate(permutation):
            ids_per_cls = np.argwhere(np.array(self.label_list) == cls).reshape([-1])
            ids_per_cls = np.random.choice(ids_per_cls, shot+1)
            for j in range(shot):
                image_support.append(self.transform(Image.open(self.image_list[ids_per_cls[j]])))
            if i == category_random:
                image_query = self.transform(Image.open(self.image_list[ids_per_cls[shot]]))
        image_support = T.cat(image_support)
        return image_support, image_query, label
    #
    def __len__(self):
        return iterations

# Network Structure

In [28]:
class CNA(nn.Module):
    #
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    #
    def forward(self, x):
        return self.layers(x)

In [29]:
class BackBone(nn.Module):
    #
    def __init__(self, channel, dim_embedding):
        super().__init__()
        self.dim_embedding = dim_embedding
        self.layers = nn.Sequential(
            CNA(3, channel, 5, 1, 0),
            CNA(channel, channel*2, 2, 2, 0),
            CNA(channel*2, channel*2, 5, 1, 0),
            CNA(channel*2, channel*4, 2, 2, 0),
            CNA(channel*4, channel*4, 5, 1, 0),
            CNA(channel*4, channel*8, 2, 2, 0),
            nn.Conv2d(channel*8, dim_embedding, 5, 1, 0)
        )
    #
    def forward(self, x):
        return self.layers(x)

In [30]:
class SiameseNet(nn.Module):
    #
    def __init__(self, channel=channel, dim_embedding=dim_embedding):
        super().__init__()
        self.backbone = BackBone(channel, dim_embedding)
        # self.backbone = resnet18(pretrained=False, num_classes=dim_embedding)
    #
    def forward(self, image):
        return self.backbone(image)

# Show

In [31]:
def Forward(net, image_support, image_query):
    n, c, h, w = image_query.shape
    image_support = image_support.reshape([n*way*shot, c, h, w])
    image_query = image_query.reshape([n, c, h, w])
    #
    net.eval()
    embedding_support = net(image_support) # n*way*shot, dim_embedding
    embedding_query = net(image_query).reshape([-1, dim_embedding*9]) # n*1, dim_embedding
    #
    embedding_support = embedding_support.reshape([n, way, shot, dim_embedding*9])
    embedding_support_avg = T.mean(embedding_support, axis=2) # n, way, dim_embedding
    embedding_support_avg0 = embedding_support_avg[:, 0, :] # n, dim_embedding
    embedding_support_avg1 = embedding_support_avg[:, 1, :] # n, dim_embedding
    distance0 = T.sqrt( T.sum( (embedding_query - embedding_support_avg0)**2, axis=1) ) # n, 1
    distance1 = T.sqrt( T.sum( (embedding_query - embedding_support_avg1)**2, axis=1) ) # n, 1
    return distance0, distance1

In [32]:
@ T.no_grad()
def show(net, path):
    #
    dataloader = DataLoader(Dataset_siamesenet(path=path), batch_size=1)
    for image_support, image_query, label in dataloader:
        break
    image_support = image_support.reshape([-1, way*shot, 3, image_size[0], image_size[1]])
    #
    plt.figure(figsize=(10, 15))
    for i in range(way*shot+1):
        plt.subplot(way+1, shot, i+1)
        if i < way*shot:
            image = image_support[0, i, :, :, :].cpu().numpy().transpose([1,2,0])
            plt.title('Support. Label: {}'.format(int(label[0, i//shot])))
        else:
            image = image_query[0, :, :, :].cpu().numpy().transpose([1,2,0])
            plt.title('Query')
        plt.imshow(image*0.5+0.5)
    plt.show()
    #
    if net is not None:
        net.eval()
        image_support, image_query = image_support.to(device), image_query.to(device)
        distance0, distance1 = Forward(net, image_support, image_query)
        print('distance0: {}, distance1: {}'.format(distance0.cpu().numpy().item(), distance1.cpu().numpy().item()))

In [None]:
net = SiameseNet().to(device)
show(net, path_val)

In [34]:
def mean_confidence_interval(data, confidence=0.95):
    a = 1.0*np.array(data)
    n = len(a)
    m, se = np.mean(a), scipy.stats.sem(a)
    h = se * sp.stats.t._ppf((1+confidence)/2., n-1)
    return m,h


@ T.no_grad()
def test(net, path, episode, batch_size=15):
    net.eval()
    dataloader = DataLoader(Dataset_siamesenet(path=path), batch_size=batch_size)
    accuracies = []
    for i, (image_support, image_query, label_relative) in enumerate(dataloader):
        if i == episode:
            break
        image_support, image_query = image_support.to(device), image_query.to(device)
        distance0, distance1 = Forward(net, image_support, image_query)
        accuracy = np.mean( (distance0.cpu().numpy()>distance1.cpu().numpy()) == (label_relative[:,0].numpy()<label_relative[:,1].numpy()) )
        accuracies.append(accuracy)
    test_accuracy, h = mean_confidence_interval(accuracies)
    return test_accuracy, h


# net = SiameseNet().to(device)
# test_accuracy, h = test(net, path_val, episode=600, batch_size=15)
# print("val accuracy:", test_accuracy, "h:", h)

# Training

In [35]:
def trainer(image_support, image_query, label_relative, net, optimizer):
    #
    net.train()
    distance0, distance1 = Forward(net, image_support, image_query)
    label_relative = label_relative * 2 - 1 ######### 
    label_relative0 = label_relative[:, 0]
    label_relative1 = label_relative[:, 1]
    #
    loss = T.mean( T.maximum(T.zeros(distance0.shape, device=device), distance0*label_relative0 + distance1*label_relative1 + margin) )
    #
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [36]:
def train(iterations, load_model):
    #
    iteration = np.load(path_work + 'iteration.npy', allow_pickle=True).item() if load_model else 0
    print('Start training from iteration ', str(iteration))
    net = PrototypicalNet().to(device)
    optimizer = T.optim.Adam(net.parameters(), 1e-4)
    if load_model:
        net.load_state_dict(T.load(path_work + 'net.pt'))
        optimizer.load_state_dict(T.load(path_work + 'optimizer.pt'))
    #
    dataloader = DataLoader(Dataset_siamesenet(path=path_train), batch_size=batch_size)
    time_start = time.time()
    for image_support, image_query, label_relative in dataloader:
        iteration += 1
        if iteration == iterations:
            break
        image_support, image_query, label_relative = image_support.to(device), image_query.to(device), label_relative.to(device)
        loss = trainer(image_support, image_query, label_relative, net, optimizer)
    #
        if(iteration % 100 == 0):
            print('Iteration: {}, loss: {}'.format(iteration, loss.detach().cpu().numpy().item()))  
            T.save(net.state_dict(), path_work + 'net.pt')
            T.save(net.state_dict(), path_work + 'net_backup.pt')
            T.save(optimizer.state_dict(), path_work + 'optimizer.pt')
            T.save(optimizer.state_dict(), path_work + 'optimizer_backup.pt')
            np.save(path_work + 'iteration.npy', iteration) 
            show(net, path_val)
            print('These iterations cost {} seconds'.format(time.time() - time_start))
            time_start = time.time()
        if(iteration % 100000 == 0):
            test_accuracy, h = test(net, path_val, episode=600, batchsize=15)
            print("val accuracy:", test_accuracy, "h:", h)

In [None]:
train(iterations=1000, load_model=False)

In [38]:
# train(iterations=iterations, load_model=True)

# Test

In [None]:
net = SiameseNet().to(device)
net.load_state_dict(T.load(path_work + 'net.pt'))
show(net, path_val)

In [40]:
total_episode = 10
total_accuracy = 0.0
for _ in range(total_episode):
    test_accuracy, h = test(net, path_val, episode=600, batch_size=15)
    print("test accuracy:", test_accuracy, "h:", h)
    total_accuracy += test_accuracy
print("aver_accuracy:",total_accuracy/total_episode)

test accuracy: 0.6643333333333333 h: 0.010032627879366884
test accuracy: 0.6625555555555556 h: 0.009734515833499434
test accuracy: 0.6686666666666666 h: 0.009563334854158856
