In [None]:
try:
    import google.colab
    colab = True
except:
    colab = False

In [None]:
if colab is True:

    !git clone https://github.com/sicara/easy-few-shot-learning
    %cd easy-few-shot-learning
    !pip install .
else:

    %cd ..

In [None]:
import gdown
import zipfile
import os


file_id = " "
destination = " "

download_url = f"https://drive.google.com/uc?id={file_id}"

gdown.download(download_url, destination, quiet=False)


unzip_dir = " "
os.makedirs(unzip_dir, exist_ok=True)
with zipfile.ZipFile(destination, 'r') as zip_ref:
    zip_ref.extractall(unzip_dir)

In [None]:
from pathlib import Path
import random
from statistics import mean

import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import torchvision
import torch.utils.data

In [None]:
random_seed = 30
np.random.seed(random_seed)
torch.manual_seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Episodic Training

In [None]:
n_way = 5
n_shot = 10
n_query = 10

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
n_workers = 0

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class BengaliCharactersDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.samples = []

        for class_name in self.classes:
            class_dir = os.path.join(root_dir, class_name)
            for img_file in os.listdir(class_dir):
                if img_file.endswith('.bmp'):
                    self.samples.append((os.path.join(class_dir, img_file), class_name))

        print("First few samples:")
        print(self.samples[:5])  

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label_str = self.samples[idx]
        image = Image.open(img_path).convert('RGB')

        label = self.class_to_idx[label_str]

        if self.transform:
            image = self.transform(image)

        return image, label

    def get_labels(self):
        return [label for _, label in self.samples]


In [None]:
image_size = 84  

train_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

test_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])

train_set = BengaliCharactersDataset(root_dir='/content/data/BasicFinalDatabase_FSL/Train', transform=test_transforms)
test_set = BengaliCharactersDataset(root_dir='/content/data/BasicFinalDatabase_FSL/Test', transform=test_transforms)


In [None]:
from easyfsl.samplers import TaskSampler
from torch.utils.data import DataLoader


n_tasks_per_epoch = 500
n_validation_tasks = 100

val_set = test_set

train_set.get_labels = lambda: [
    instance[1] for instance in train_set
]

val_set.get_labels = lambda: [
    instance[1] for instance in val_set
]

train_sampler = TaskSampler(
    train_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_tasks_per_epoch
)
val_sampler = TaskSampler(
    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks
)


train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)
val_loader = DataLoader(
    val_set,
    batch_sampler=val_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=val_sampler.episodic_collate_fn,
)

print(len(train_set), len(val_set))

In [None]:
import torchvision.models as models
from easyfsl.methods import MatchingNetworks, FewShotClassifier
import torch.nn as nn
import torch

pretrained_resnet18 = models.resnet18(pretrained=True)

num_classes = 50 
pretrained_resnet18.fc = nn.Linear(pretrained_resnet18.fc.in_features, num_classes)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_resnet18 = pretrained_resnet18.to(DEVICE)


few_shot_classifier = MatchingNetworks(pretrained_resnet18, feature_dimension = 512).to(DEVICE)


In [None]:
from torch.optim import SGD, Optimizer
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter


LOSS_FUNCTION = nn.CrossEntropyLoss()

n_epochs = 30
scheduler_milestones = [120, 160]
scheduler_gamma = 0.1
learning_rate = 1e-2
tb_logs_dir = Path(".")

train_optimizer = SGD(
    few_shot_classifier.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4
)
train_scheduler = MultiStepLR(
    train_optimizer,
    milestones=scheduler_milestones,
    gamma=scheduler_gamma,
)

tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))

In [None]:
def training_epoch(
    model: FewShotClassifier, data_loader: DataLoader, optimizer: Optimizer
):
    all_loss = []
    model.train()
    with tqdm(
        enumerate(data_loader), total=len(data_loader), desc="Training"
    ) as tqdm_train:
        for episode_index, (
            support_images,
            support_labels,
            query_images,
            query_labels,
            _,
        ) in tqdm_train:
            optimizer.zero_grad()
            model.process_support_set(
                support_images.to(DEVICE), support_labels.to(DEVICE)
            )
            classification_scores = model(query_images.to(DEVICE))

            loss = LOSS_FUNCTION(classification_scores, query_labels.to(DEVICE))

            loss.backward()

            optimizer.step()

            all_loss.append(loss.item())

            tqdm_train.set_postfix(loss=mean(all_loss))

    return mean(all_loss)

In [None]:
from easyfsl.utils import evaluate


best_state = few_shot_classifier.state_dict()
best_validation_accuracy = 0.0
for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    average_loss = training_epoch(few_shot_classifier, train_loader, train_optimizer)
    validation_accuracy = evaluate(
        few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix="Validation"
    )

    if validation_accuracy > best_validation_accuracy:
        best_validation_accuracy = validation_accuracy
        best_state = few_shot_classifier.state_dict()
        print("Ding ding ding! We found a new best model!")

    tb_writer.add_scalar("Train/loss", average_loss, epoch)
    tb_writer.add_scalar("Val/acc", validation_accuracy, epoch)

    train_scheduler.step()

    if epoch % 1 == 0 : torch.save(best_state, '/content/real_5shot_5way_Matching_BanglaLekha_Isolated.pth')

In [None]:
PATH = '/content/real_5shot_5way_Matching_BanglaLekha_Isolated.pth'
few_shot_classifier.load_state_dict(torch.load(PATH))

**Evaluation**

In [None]:
n_test_tasks = 100

test_set.get_labels = lambda: [
    instance[1] for instance in test_set
]

test_sampler = TaskSampler(
    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks
)
test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=n_workers,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

In [None]:
from easyfsl.utils import evaluate
accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)
print(f"Average accuracy : {(100 * accuracy):.2f} %")

In [None]:
from sklearn.metrics import precision_recall_fscore_support
import torch


def evaluate_with_metrics(model, loader, device):
    model.eval()
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for _, (support_images, support_labels, query_images, query_labels, _) in enumerate(loader):
            model.process_support_set(support_images.to(device), support_labels.to(device))
            outputs = model(query_images.to(device))
            _, preds = torch.max(outputs, 1)
            all_labels.extend(query_labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    precision, recall, f1_score, _ = precision_recall_fscore_support(all_labels, all_preds, average='macro')
    return precision, recall, f1_score


In [None]:
precision, recall, f1_score = evaluate_with_metrics(few_shot_classifier, test_loader, DEVICE)
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")