In [1]:
from models import SimpleShot  # make sure models.py is in the same directory or in PYTHONPATH
from data_collector import get_datasets  # same for data_collector.py
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import random
import copy

from train_eval import train, evaluate_few_shot

In [2]:
train_dataset, val_dataset, test_dataset = get_datasets()
print("Datasets loaded.")


Datasets loaded.


In [3]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=0)
print("DataLoaders ready.")


DataLoaders ready.


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleShot(
    input_dim=84, hidden_dim=64, num_classes=64, l2norm=False, support=None
)
model = model.to(device)
print("Model initialized.")


Model initialized.


In [None]:
model = train(model, train_loader, val_dataset, epochs=90, lr=0.1, device=device)


In [None]:
for transform in ["UN", "L2N", "CL2N"]:
    print(f"\nFeature transformation: {transform}")

    one_shot_acc = evaluate_few_shot(
        model,
        test_loader,
        n_way=5,
        k_shot=1,
        n_tasks=100,
        feature_transform=transform,
        device=device,
    )
    print(f"5-way 1-shot accuracy: {one_shot_acc:.2f}%")

    five_shot_acc = evaluate_few_shot(
        model,
        test_loader,
        n_way=5,
        k_shot=5,
        n_tasks=100,
        feature_transform=transform,
        device=device,
    )
    print(f"5-way 5-shot accuracy: {five_shot_acc:.2f}%")
