# Prep

In [None]:
import torch

import numpy as np
import matplotlib.pyplot as plt

import os

from tqdm import tqdm

In [None]:
# for saving and reading results

def save_list_to_file(list, path):
    with open(path, "w") as f:
        for item in list:
            f.write(f"{item}\n")

def read_file_to_list(path):
    with open(path, "r") as f:
        list = f.read().splitlines()
        list = [float(item) for item in list]
    return list

# Configurations

In [None]:
data_dir = "./data_omniglot"  # change as need

num_way_tr = 60
num_shot_tr = 5
num_query_tr = 5

num_way_val = 20
num_shot_val = 5
num_query_val = 15

# training
num_epoch = 100
num_iter = 100

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
# save results?

save = False

if save == True:
    output_dir = os.path.join(".", "output")

# Data

## Dataset

In [None]:
def get_classes(path):
    with open(path) as f:
        classes = f.read().replace('/', os.sep).splitlines()
    return classes

def get_class_label(classes):
    dic = {}
    for idx, c in enumerate(classes):
        dic[c] = idx
    return dic

In [None]:
from PIL import Image
from torchvision.transforms.functional import to_tensor

def get_data(classes, class_label):
    x_list = []
    y_list = []
    for c in tqdm(classes, desc="Class"):
        l = c.split(os.sep)
        path = os.path.join(data_dir, "data", l[0], l[1])
        rot = l[2][3:]
        for image in os.listdir(path):
            x = Image.open(os.path.join(path, image))
            x = x.rotate(float(rot))
            x = x.resize((28, 28))
            x = to_tensor(x)
            x_list.append(x)
            y_list.append(class_label[c])
    return x_list, y_list

In [None]:
from torch.utils.data import Dataset

class OmniglotDataset(Dataset):
    def __init__(self, mode):
        super().__init__()
        self.classes = get_classes(os.path.join(data_dir, "splits", "vinyals", mode + ".txt"))
        self.class_label = get_class_label(self.classes)
        self.x, self.y = get_data(self.classes, self.class_label)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.y)

In [None]:
train_dataset = OmniglotDataset('train')
val_dataset = OmniglotDataset('val')
test_dataset = OmniglotDataset('test')

## Sampler

In [None]:
def get_class_indices(labels, classes):
    dic = {}
    for c in tqdm(classes, desc="Class"):
        dic[c] = np.where(labels == c)[0]
    return dic

In [None]:
class PrototypicalBatchSampler():
    def __init__(self, labels, num_way, num_samples, num_iter):
        super().__init__()
        self.num_way = num_way
        self.num_samples = num_samples
        self.num_iter = num_iter

        self.classes = np.unique(labels)
        self.class_indices = get_class_indices(labels, self.classes)

    
    def __iter__(self):
        for it in range(self.num_iter):
            batch = np.empty(self.num_way * self.num_samples, dtype=np.int64)
            
            # select classes
            c_idxs = torch.randperm(len(self.classes))[:self.num_way]

            # select samples
            for i, c in enumerate(self.classes[c_idxs]):
                s_idxs = torch.randperm(len(self.class_indices[c]))[:self.num_samples]

                sl = slice(i * self.num_samples, (i + 1) * self.num_samples)
                batch[sl] = self.class_indices[c][s_idxs]
                
            yield batch

    def __len__(self):
        return self.num_iter

## Dataloader

In [None]:
from torch.utils.data import DataLoader

def get_dataloader(dataset, num_way, num_shot, num_query, num_iter):
    sampler = PrototypicalBatchSampler(dataset.y, num_way, num_shot + num_query, num_iter)
    return DataLoader(dataset, batch_sampler=sampler)

In [None]:
train_dataloader = get_dataloader(train_dataset, num_way_tr, num_shot_tr, num_query_tr, num_iter)
val_dataloader = get_dataloader(val_dataset, num_way_val, num_shot_val, num_query_val, num_iter)
test_dataloader = get_dataloader(test_dataset, num_way_val, num_shot_val, num_query_val, num_iter)

# time taken for getting image indices for classes when constructing sampler

# Model

In [None]:
import torch.nn as nn

def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

class ProtoNet(nn.Module):
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


In [None]:
# instantiate
model = ProtoNet().to(device)

# Loss

In [None]:
def compute_dist_matrix(x, y):
    # x: n x d
    # y: m x d
    n = x.shape[0]
    m = y.shape[0]
    d = x.shape[1]

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)  # n x m

In [None]:
from torch.nn.functional import log_softmax

def compute_loss_acc(output, target, num_shot):
    classes = torch.unique(target)
    
    for i, c in enumerate(classes):
        c_idxs = torch.where(target == c)[0]
        c_support_idxs = c_idxs[:num_shot]
        c_query_idxs = c_idxs[num_shot:]

        c_prototype = output[c_support_idxs].mean(0)
        c_query = output[c_query_idxs]
        c_target_idxs = torch.ones(len(c_query), dtype=torch.int64) * i
        c_target_idxs = c_target_idxs.to(device)

        if i == 0:
            prototypes = c_prototype
            query = c_query
            target_idxs = c_target_idxs
        else:
            prototypes = torch.vstack((prototypes, c_prototype))
            query = torch.vstack((query, c_query))
            target_idxs = torch.hstack((target_idxs, c_target_idxs))
        
    dists = compute_dist_matrix(prototypes, query)
    log_prob = log_softmax(-dists, dim=0)

    target_matrix = torch.zeros_like(log_prob)
    target_matrix[(target_idxs, torch.arange(len(query)))] = torch.ones(len(query)).to(device)

    loss = (-log_prob * target_matrix).mean()

    pred = torch.max(log_prob, dim=0).indices
    acc = (target_idxs == pred).float().mean()

    return loss, acc

# Optimiser

In [None]:
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR

lr = 0.001
scheduler_step = 20
scheduler_gamma = 0.5

optimiser = Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer=optimiser, gamma=scheduler_gamma, step_size=scheduler_step)

# Train

In [None]:
# Initialise training hisotry container

history = {
    "total_epoch": 0,
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
    "best_acc": 0,
    "best_epoch": None,
    "best_model": None
    }

In [None]:
def train(epoch):
    global history

    for ep in range(epoch):
        history["total_epoch"] += 1

        # train
        model.train()
        sum_loss = 0
        sum_acc = 0
        for x, y in train_dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss, acc = compute_loss_acc(output, y, num_shot_tr)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            sum_loss += loss.item()
            sum_acc += acc.item()
        
        scheduler.step()
    
        avg_loss = sum_loss / num_iter
        avg_acc = sum_acc / num_iter
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(avg_acc)

        # validation
        model.eval()
        sum_loss = 0
        sum_acc = 0
        for x,y in val_dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss, acc = compute_loss_acc(output, y, num_shot_val)

            sum_loss += loss.item()
            sum_acc += acc.item()
        
        avg_loss = sum_loss / num_iter
        avg_acc = sum_acc / num_iter
        history["val_loss"].append(avg_loss)
        history["val_acc"].append(avg_acc)

        if avg_acc > history["best_acc"]:
            history["best_acc"] = avg_acc
            history["best_epoch"] = history["total_epoch"]
            history["best_model"] = model.state_dict()
        
        print(f"Epoch {history['total_epoch']}: Train Loss {history['train_loss'][-1]}, Acc {history['train_acc'][-1]}; Val Loss {history['val_loss'][-1]}, Acc {history['val_acc'][-1]}; Best {history['best_acc']} (Epoch {history['best_epoch']})")

In [None]:
train(num_epoch)

# num_episode = num_epoch * num_iter

In [None]:
# Save results

if save:
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # last model
    torch.save(
        model.state_dict(),
        os.path.join(output_dir, 'last_model.pth')
        )

    # best model
    torch.save(
        history["best_model"],
        os.path.join(output_dir, 'best_model.pth')
        )

    for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
        save_list_to_file(
            history[name],
            os.path.join(output_dir, name + '.txt')
            )

In [None]:
# Plot training curve

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

ax1.plot(history["train_loss"], label='train')
ax1.plot(history["val_loss"], label='val')
ax1.grid()
ax1.legend()
ax1.set_xlabel('Epoch')
ax1.set_title('Loss')

ax2.plot(history["train_acc"], label='train')
ax2.plot(history["val_acc"], label='val')
ax2.grid()
ax2.legend()
ax2.set_xlabel('Epoch')
ax2.set_title('Acc')

# Test

In [None]:
def test(model, epoch):
    avg_acc = []
    model.eval()
    for ep in tqdm(range(epoch)):
        for x, y in test_dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            _, acc = compute_loss_acc(output, y, num_shot_val)
            avg_acc.append(acc.item())
    avg_acc = np.mean(avg_acc)
    print(f"Test Acc: {avg_acc}")

In [None]:
test(model, 10)