# Deep Learning Project: CNN

## Import

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset, TensorDataset
from torchvision.datasets import ImageFolder
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random
import os
import shutil
import kagglehub
import gc

## Configuration

In [None]:
seed = 1
batch_size = 16
learning_rate = 0.001

model_dir = "/content/ovr_models"
os.makedirs(model_dir, exist_ok=True)

## Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Use device: {device}")

## Seed

In [None]:
def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(seed)

## Download dataset

In [None]:
cinic_path = kagglehub.dataset_download("mengcius/cinic10")

## Fewshot dataset

In [None]:
class FewShotDataset(Dataset):
    def __init__(self, dataset, n_way, k_shot, n_query, limit_per_class):
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.class_to_indices = self.build_index(limit_per_class)

    def build_index(self, limit):
        class_to_indices = {}
        for idx, (_, label) in enumerate(self.dataset):
            class_to_indices.setdefault(label, []).append(idx)
        return {k: v[:limit] for k, v in class_to_indices.items() if len(v) >= self.k_shot + self.n_query}

    def __getitem__(self, index):
        classes = random.sample(list(self.class_to_indices.keys()), self.n_way)
        support_images, query_images, query_labels = [], [], []
        for i, cls in enumerate(classes):
            indices = random.sample(self.class_to_indices[cls], self.k_shot + self.n_query)
            support = [self.dataset[idx][0] for idx in indices[:self.k_shot]]
            query = [self.dataset[idx][0] for idx in indices[self.k_shot:]]
            support_images.append(torch.stack(support))  # [k_shot, C, H, W]
            query_images.append(torch.stack(query))      # [n_query, C, H, W]
            query_labels += [i] * self.n_query

        support_images = torch.stack(support_images)  # [n_way, k_shot, C, H, W]
        query_images = torch.stack(query_images)      # [n_way, n_query, C, H, W]
        return support_images, query_images, torch.tensor(query_labels)

    def __len__(self):
        return 1000

## Init sets

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

train_set = ImageFolder(f"{cinic_path}/train", transform=transform)
valid_set = ImageFolder(f"{cinic_path}/valid", transform=transform)
test_set = ImageFolder(f"{cinic_path}/test", transform=transform)
classes = train_set.classes

## Model definition

In [None]:
class FewShotEncoder(nn.Module):
    def __init__(self):
        super(FewShotEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Flatten()
        )

    def forward(self, x):
        return self.encoder(x)

class PrototypicalNetwork(nn.Module):
    def __init__(self, encoder, n_way, k_shot, n_query):
        super().__init__()
        self.encoder = encoder
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query

    def forward(self, support, query):
        batch_size = support.size(0)
        support = support.view(batch_size * self.n_way * self.k_shot, *support.shape[3:])
        query = query.view(batch_size * self.n_way * self.n_query, *query.shape[3:])

        z_support = self.encoder(support)
        z_query = self.encoder(query)

        z_support = z_support.view(batch_size, self.n_way, self.k_shot, -1)
        prototypes = z_support.mean(dim=2)  # (B, n_way, D)

        z_query = z_query.view(batch_size, self.n_way * self.n_query, -1)

        dists = torch.cdist(z_query, prototypes)  # (B, Q, n_way)
        return -dists



## Train

In [None]:
def train_protonet(model, loader, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss, correct, total = 0, 0, 0
        for support, query, labels in loader:
            support, query, labels = support.to(device), query.to(device), labels.to(device)
            logits = model(support, query)

            logits = logits.view(-1, model.n_way)
            labels = labels.view(-1)

            loss = F.cross_entropy(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        acc = correct / total
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}, Accuracy: {acc*100:.2f}%")

## Evaluation

In [None]:
def evaluate_protonet(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for support, query, labels in loader:
            support, query = support.to(device), query.to(device)
            logits = model(support, query)
            logits = logits.view(-1, model.n_way)
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.view(-1).cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    print(f"Test accuracy: {acc*100:.2f}%")

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()


## Parameters

In [None]:
n_way = 10                # liczba klas w epizodzie
k_shot = 5                # liczba przykładów na klasę (support)
n_query = 15              # liczba przykładów na klasę (query)
limit_per_class = 100     # maksymalna liczba przykładów na klasę w całym zbiorze
learning_rate = 0.001

## Load

In [None]:
train_loader = DataLoader(FewShotDataset(train_set, n_way=n_way, k_shot=k_shot, n_query=n_query, limit_per_class=100), batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(FewShotDataset(valid_set, n_way=n_way, k_shot=k_shot, n_query=n_query, limit_per_class=100), batch_size=batch_size)
test_loader  = DataLoader(FewShotDataset(test_set,  n_way=n_way, k_shot=k_shot, n_query=n_query, limit_per_class=100), batch_size=batch_size)

## Train

In [None]:
encoder = FewShotEncoder().to(device)
model = PrototypicalNetwork(encoder, n_way=n_way, k_shot=k_shot, n_query=n_query).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

train_protonet(model, train_loader, optimizer, device, epochs=5)

## Evaluate

In [None]:
evaluate_protonet(model, test_loader, device)