In [None]:
import torch

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

import os

In [None]:
def save_list_to_file(list, path):
    with open(path, "w") as f:
        for item in list:
            f.write(f"{item}\n")

def read_file_to_list(path):
    with open(path, "r") as f:
        list = f.read().splitlines()
        list = [float(item) for item in list]
    return list

# Configurations

In [None]:
data_dir = os.path.join(".", "data_cub")  # change as need

num_way_tr = 50
num_query_tr = 10

num_way_val = 50
num_query_val = 10

# training
num_iter = 100

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [None]:
# save results?

save = False

if save == True:
    output_dir = os.path.join(".", "output")

# Data

## Dataset

In [None]:
def get_attributes():
    with open(os.path.join(data_dir, "class_attribute_labels_continuous.txt")) as f:
        attributes = f.read().splitlines()
        attributes = [a.split(" ") for a in attributes]
        attributes = [[float(a) for a in aa] for aa in attributes]
        attributes = torch.Tensor(attributes)
    return attributes

attributes = get_attributes()

In [None]:
def get_classes(path):
    with open(path) as f:
        classes = f.read().splitlines()
    return classes

In [None]:
def get_data(mode, classes):
    x_list = []
    v_list = []
    y_list = []
    for c in tqdm(classes, desc="Class"):
        y = int(c.split(".")[0])
        v = attributes[y - 1]

        xs = torch.Tensor(torch.load(os.path.join(data_dir, "images", c)))
        xs = torch.swapaxes(xs, 1, 2)
        for i in range(xs.shape[0]):
            if mode == "train":
                for j in range(xs.shape[1]):
                    x = xs[i][j]
                    x_list.append(x)
                    v_list.append(v)
                    y_list.append(y)
            else:
                x = xs[i][0]
                x_list.append(x)
                v_list.append(v)
                y_list.append(y)
    
    return x_list, v_list, y_list

In [None]:
from torch.utils.data import Dataset

class CUBDataset(Dataset):
    def __init__(self, mode):
        self.classes = get_classes(os.path.join(data_dir, mode + "classes.txt"))

        self.x, self.v, self.y = get_data(mode, self.classes)

    def __getitem__(self, idx):
        return self.x[idx], self.v[idx], self.y[idx]

    def __len__(self):
        return len(self.y)

In [None]:
train_dataset = CUBDataset("train")
val_dataset = CUBDataset("val")
test_dataset = CUBDataset('test')
trainval_dataset = CUBDataset("trainval")

## Sampler

In [None]:
def get_class_indices(labels, classes):
    dic = {}
    for c in tqdm(classes, desc="Class"):
        dic[c] = np.where(labels == c)[0]
    return dic

In [None]:
class PrototypicalBatchSampler():
    def __init__(self, labels, num_way, num_samples, num_iter):
        super().__init__()
        self.num_way = num_way
        self.num_samples = num_samples
        self.num_iter = num_iter

        self.classes = np.unique(labels)
        self.class_indices = get_class_indices(labels, self.classes)

    
    def __iter__(self):
        for it in range(self.num_iter):
            batch = np.empty(self.num_way * self.num_samples, dtype=np.int64)
            
            # select classes
            c_idxs = torch.randperm(len(self.classes))[:self.num_way]

            # select samples
            for i, c in enumerate(self.classes[c_idxs]):
                s_idxs = torch.randperm(len(self.class_indices[c]))[:self.num_samples]

                sl = slice(i * self.num_samples, (i + 1) * self.num_samples)
                batch[sl] = self.class_indices[c][s_idxs]
                
            yield batch

    def __len__(self):
        return self.num_iter

## Dataloader

In [None]:
from torch.utils.data import DataLoader

def get_dataloader(dataset, num_way, num_query, num_iter):
    sampler = PrototypicalBatchSampler(dataset.y, num_way, num_query, num_iter)
    return DataLoader(dataset, batch_sampler=sampler)

In [None]:
train_dataloader = get_dataloader(train_dataset, num_way_tr, num_query_tr, num_iter)
val_dataloader = get_dataloader(val_dataset, num_way_val, num_query_val, num_iter)
test_dataloader = get_dataloader(test_dataset, num_way_val, num_query_val, num_iter)
trainval_dataloader = get_dataloader(trainval_dataset, num_way_val, num_query_val, num_iter)

# Model

In [None]:
import torch.nn as nn
from torch.nn.functional import normalize

class ProtoNet(nn.Module):
    def __init__(self, x_dim = 1024, v_dim = 312, emb_dim = 1024):
        super(ProtoNet, self).__init__()
        
        self.x_encoder = nn.Linear(x_dim, emb_dim)
        self.v_encoder = nn.Linear(v_dim, emb_dim)
    
    def forward(self, x, v):
        x = self.x_encoder(x)
        v = self.v_encoder(v)
        v = normalize(v)
        return x, v

In [None]:
model = ProtoNet().to(device)

# Loss

In [None]:
def compute_dist_matrix(x, y):
    # x: n x d
    # y: m x d
    n = x.shape[0]
    m = y.shape[0]
    d = x.shape[1]

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)  # n x m

In [None]:
from torch.nn.functional import log_softmax

def compute_loss_acc(x, v, y):
    classes = torch.unique(y)

    target_idxs = torch.Tensor([]).to(device)
    for i, c in enumerate(classes):
        c_idxs = torch.where(y == c)[0]
        
        c_prototype = v[c_idxs][0]
        c_query = x[c_idxs]
        c_target_idxs = torch.ones(len(c_query), dtype=torch.int64) * i
        c_target_idxs = c_target_idxs.to(device)

        if i == 0:
            prototypes = c_prototype
            query = c_query
            target_idxs = c_target_idxs
        else:
            prototypes = torch.vstack((prototypes, c_prototype))
            query = torch.vstack((query, c_query))
            target_idxs = torch.hstack((target_idxs, c_target_idxs))
        
    dists = compute_dist_matrix(prototypes, query)
    log_prob = log_softmax(-dists, dim=0)

    target_matrix = torch.zeros_like(log_prob)
    target_matrix[(target_idxs, torch.arange(len(query)))] = torch.ones(len(query)).to(device)

    loss = (-log_prob * target_matrix).mean()

    pred = torch.max(log_prob, dim=0).indices
    acc = (target_idxs == pred).float().mean()

    return loss, acc


# Optimiser

In [None]:
from torch.optim import Adam

lr = 1e-4
weight_decay = 1e-5

optimiser = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

# Train

## Train

In [None]:
history = {
    "total_epoch": 0,
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
    "best_acc": 0,
    "best_epoch": None,
    "best_model": None
    }

In [None]:
def train(epoch):
    global history

    for ep in range(epoch):
        history["total_epoch"] += 1

        # train
        model.train()
        sum_loss = 0
        sum_acc = 0
        for x, v, y in train_dataloader:
            x, v, y = x.to(device), v.to(device), y.to(device)
            x_emb, v_emb = model(x, v)
            loss, acc = compute_loss_acc(x_emb, v_emb, y)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            sum_loss += loss.item()
            sum_acc += acc.item()
        
        avg_loss = sum_loss / num_iter
        avg_acc = sum_acc / num_iter
        history["train_loss"].append(avg_loss)
        history["train_acc"].append(avg_acc)

        # validation
        model.eval()
        sum_loss = 0
        sum_acc = 0
        for x, v, y in val_dataloader:
            x, v, y = x.to(device), v.to(device), y.to(device)
            x_emb, v_emb = model(x, v)
            loss, acc = compute_loss_acc(x_emb, v_emb, y)

            sum_loss += loss.item()
            sum_acc += acc.item()
        
        avg_loss = sum_loss / num_iter
        avg_acc = sum_acc / num_iter
        history["val_loss"].append(avg_loss)
        history["val_acc"].append(avg_acc)

        if avg_acc > history["best_acc"]:
            history["best_acc"] = avg_acc
            history["best_epoch"] = history["total_epoch"]
            history["best_model"] = model.state_dict()
        
        print(f"Epoch {history['total_epoch']}: Train Loss {history['train_loss'][-1]}, Acc {history['train_acc'][-1]}; Val Loss {history['val_loss'][-1]}, Acc {history['val_acc'][-1]}; Best {history['best_acc']} (Epoch {history['best_epoch']})")

In [None]:
train(50)

In [None]:
# Plot training curve

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

ax1.plot(history["train_loss"], label='train')
ax1.plot(history["val_loss"], label='val')
ax1.grid()
ax1.legend()
ax1.set_xlabel('Epoch')
ax1.set_title('Loss')

ax2.plot(history["train_acc"], label='train')
ax2.plot(history["val_acc"], label='val')
ax2.grid()
ax2.legend()
ax2.set_xlabel('Epoch')
ax2.set_title('Acc')

## Retrain

In [None]:
model = ProtoNet().to(device)
model.load_state_dict(history['best_model'])

optimiser = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [None]:
retrain_history = {
    "total_epoch": 0,
    "train_loss": [],
    "train_acc": [],
    }

In [None]:
def retrain(epoch):
    global retrain_history

    for ep in range(epoch):
        retrain_history["total_epoch"] += 1

        # train
        model.train()
        sum_loss = 0
        sum_acc = 0
        for x, v, y in trainval_dataloader:
            x, v, y = x.to(device), v.to(device), y.to(device)
            x_emb, v_emb = model(x, v)
            loss, acc = compute_loss_acc(x_emb, v_emb, y)

            optimiser.zero_grad()
            loss.backward()
            optimiser.step()

            sum_loss += loss.item()
            sum_acc += acc.item()
        
        avg_loss = sum_loss / num_iter
        avg_acc = sum_acc / num_iter
        retrain_history["train_loss"].append(avg_loss)
        retrain_history["train_acc"].append(avg_acc)
        
        print(f"Epoch {retrain_history['total_epoch']}: Train Loss {retrain_history['train_loss'][-1]}, Acc {retrain_history['train_acc'][-1]}")

In [None]:
retrain(50)

In [None]:
# Plot retraining curve

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

ax1.plot(retrain_history["train_loss"], label='trainval')
ax1.grid()
ax1.legend()
ax1.set_xlabel('Epoch')
ax1.set_title('Loss')

ax2.plot(retrain_history["train_acc"], label='trainval')
ax2.grid()
ax2.legend()
ax2.set_xlabel('Epoch')
ax2.set_title('Acc')

# Test

In [None]:
def test(model, epoch):
    avg_acc = []
    model.eval()
    for ep in tqdm(range(epoch)):
        for x, v, y in test_dataloader:
            x, v, y = x.to(device), v.to(device), y.to(device)
            x_emb, v_emb = model(x, v)
            _, acc = compute_loss_acc(x_emb, v_emb, y)
            avg_acc.append(acc.item())
    avg_acc = np.mean(avg_acc)
    print(f"Test Acc: {avg_acc}")

In [None]:
test(model, 10)