Installing Libraries

In [None]:
!pip install sentence-transformers -q   
print("sentence-transformers installed.")

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m29.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Importing necessary modules

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms
from torchvision import models
from torch.utils.data import Dataset,DataLoader
from PIL import Image

from sentence_transformers import SentenceTransformer

import numpy as np

import os
import random
import time

from collections import defaultdict

Device configuration

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
if torch.cuda.is_available():
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

Using GPU: Tesla T4


Dataset loading

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
zip_path = '/content/drive/MyDrive/datasets/RetinalOCT_Dataset.zip'
unzip_path = '/content/data'
!unzip -q -o "{zip_path}" -d "{unzip_path}"


Project Configuration: Class Definitions

In [None]:
CLASS_NAMES_OCT_C8 = ["NORMAL", "AMD", "CNV", "CSR", "DME", "DR", "DRUSEN", "MH"]
CLASS_MAPPING = {name: i for i, name in enumerate(CLASS_NAMES_OCT_C8)}
REVERSE_CLASS_MAPPING = {i: name for i, name in enumerate(CLASS_NAMES_OCT_C8)}
NUM_TOTAL_CLASSES = len(CLASS_NAMES_OCT_C8)

In [None]:
BASE_CLASS_NAMES = ["NORMAL", "DRUSEN", "DME", "AMD"]
VAL_CLASS_NAMES = ["CNV", "CSR"]
NOVEL_CLASS_NAMES = ["DR", "MH"]

In [None]:
try:
    BASE_CLASSES_IDX = sorted([CLASS_MAPPING[name] for name in BASE_CLASS_NAMES])
    VAL_CLASSES_IDX = sorted([CLASS_MAPPING[name] for name in VAL_CLASS_NAMES])
    NOVEL_CLASSES_IDX = sorted([CLASS_MAPPING[name] for name in NOVEL_CLASS_NAMES])

    all_fsl_classes_check = BASE_CLASSES_IDX + VAL_CLASSES_IDX + NOVEL_CLASSES_IDX
    if len(all_fsl_classes_check) != len(set(all_fsl_classes_check)):
        raise ValueError("Overlap detected in FSL class splits! Ensure Base, Val, and Novel class lists are disjoint.")
    for name_list in [BASE_CLASS_NAMES, VAL_CLASS_NAMES, NOVEL_CLASS_NAMES]:
        for name in name_list:
            if name not in CLASS_NAMES_OCT_C8:
                raise ValueError(f"Class name '{name}' in FSL splits is not in CLASS_NAMES_OCT_C8.")

except KeyError as e:
    print(f"ERROR: Class name {e} in your FSL splits is not defined in CLASS_NAMES_OCT_C8. Please check.")
    raise
except ValueError as e:
    print(f"ERROR: {e}")
    raise

NUM_BASE_CLASSES = len(BASE_CLASSES_IDX)
NUM_VAL_CLASSES = len(VAL_CLASSES_IDX)
NUM_NOVEL_CLASSES = len(NOVEL_CLASSES_IDX)

print(f"\nTotal Classes in OCT-C8: {NUM_TOTAL_CLASSES}")
print(f"Base Classes (Original Indices from CLASS_MAPPING): {BASE_CLASSES_IDX} -> Names: {[REVERSE_CLASS_MAPPING[i] for i in BASE_CLASSES_IDX]}")
print(f"Validation Classes (Original Indices): {VAL_CLASSES_IDX} -> Names: {[REVERSE_CLASS_MAPPING[i] for i in VAL_CLASSES_IDX]}")
print(f"Novel Classes (Original Indices): {NOVEL_CLASSES_IDX} -> Names: {[REVERSE_CLASS_MAPPING[i] for i in NOVEL_CLASSES_IDX]}")


Total Classes in OCT-C8: 8
Base Classes (Original Indices from CLASS_MAPPING): [0, 1, 4, 6] -> Names: ['NORMAL', 'AMD', 'DME', 'DRUSEN']
Validation Classes (Original Indices): [2, 3] -> Names: ['CNV', 'CSR']
Novel Classes (Original Indices): [5, 7] -> Names: ['DR', 'MH']


In [None]:
OCT_SEMANTICS = {
    CLASS_MAPPING["NORMAL"]: "A normal retina on OCT displays well-defined, continuous, and stratified retinal layers with a smooth foveal contour and a distinct foveal pit. The outer retina, including the ellipsoid zone and external limiting membrane, appears intact and uniformly hyperreflective. There is no evidence of subretinal or intraretinal fluid, pigment epithelial detachment (PED), or retinal thickening. The retinal pigment epithelium (RPE) forms a smooth, uninterrupted hyperreflective line without undulations, disruptions, or irregular reflectivity. The choroid appears uniform, and no hyper- or hyporeflective deposits, such as drusen, are present. Retinal thickness falls within normal population-specific reference ranges, and there are no signs of retinal atrophy, cystoid spaces, or structural distortion. This appearance confirms the absence of pathology and represents a healthy macular profile.",
    CLASS_MAPPING["AMD"]: "Age-related macular degeneration (AMD) is a progressive retinal disorder primarily impacting the macula in individuals over 55, driven by age-related, genetic, and environmental factors. Clinically, it is characterized by drusen (especially medium to large), pigmentary disturbances, and in advanced stages, geographic atrophy or choroidal neovascularization (CNV). On optical coherence tomography (OCT), hallmark features include hyperreflective drusen between the retinal pigment epithelium (RPE) and Bruch’s membrane, RPE irregularities, and disruption of outer retinal layers. In neovascular AMD, OCT reveals subretinal and/or intraretinal fluid, pigment epithelial detachment (PED), and hyperreflective material suggestive of fibrovascular tissue or hemorrhage. Geographic atrophy manifests as RPE and outer retinal thinning, choroidal hypertransmission, and loss of the ellipsoid zone. These OCT biomarkers provide critical insight into AMD severity and progression.",
    CLASS_MAPPING["CNV"]: "Choroidal Neovascularization (CNV) is characterized by the pathological growth of choroidal vessels breaching Bruch’s membrane, often extending into the sub-RPE or subretinal space, commonly seen in exudative (wet) age-related macular degeneration. On OCT, CNV manifests as a hyperreflective neovascular complex beneath or above the RPE, frequently accompanied by subretinal fluid (SRF), intraretinal fluid (IRF) or cystoid spaces, and pigment epithelial detachment (PED), which may be serous, fibrovascular, or mixed in nature. RPE irregularities, including elevation, thickening, or disruption, are often present along with outer retinal layer distortion or loss, especially in chronic or advanced cases. Additional findings may include subretinal hyperreflective material (SHRM), reflective of fibrovascular tissue, blood, or exudates, and increased choroidal thickness or shadowing beneath the lesion. These features help delineate CNV activity and guide treatment monitoring.",
    CLASS_MAPPING["CSR"]: "Central Serous Retinopathy (CSR) is a chorioretinal disorder marked by serous detachment of the neurosensory retina due to leakage from the choroidal circulation through a dysfunctional retinal pigment epithelium (RPE), typically affecting the macular region. On OCT, CSR is characterized by a well-demarcated dome-shaped accumulation of subretinal fluid (SRF) beneath the central retina, with preservation or mild distortion of the overlying retinal layers. Pigment epithelial detachment (PED), often shallow and serous, may be present and indicates focal RPE disruption. RPE irregularities such as granular hyperreflectivity, undulations, or atrophic patches may be observed, particularly in chronic cases. Additional OCT findings can include elongation of photoreceptor outer segments, thinning of outer retinal layers, and a thickened choroid (pachychoroid). These features are critical for diagnosing CSR and distinguishing it from other macular pathologies.",
    CLASS_MAPPING["DME"]: "Diabetic Macular Edema (DME) is a vision-threatening complication of diabetic retinopathy characterized by fluid accumulation within and beneath the macula due to breakdown of the blood-retinal barrier and increased vascular permeability. On OCT, DME presents as retinal thickening in the macular area with hyporeflective intraretinal cystoid spaces, primarily in the inner nuclear and outer plexiform layers. Subretinal fluid (SRF) may also be present, particularly in more advanced or inflammatory cases. The retinal layers may appear distorted or disorganized, and outer retinal disruption—including loss of the ellipsoid zone—can occur with chronic edema. Focal or diffuse thickening patterns are common, and hyperreflective foci representing lipid exudates or inflammatory cells may be seen. Although less frequent, small serous pigment epithelial detachments (PEDs) and RPE irregularities can be observed in chronic or severe DME. These OCT features are essential for diagnosing, staging, and monitoring response to treatment.",
    CLASS_MAPPING["DR"]: "Diabetic Retinopathy (DR) is a microvascular complication of diabetes that leads to progressive retinal damage, with Optical Coherence Tomography (OCT) revealing key structural changes beyond what is visible on fundus imaging. On OCT, DR is most commonly associated with diabetic macular edema (DME), characterized by retinal thickening, hyporeflective intraretinal cystoid spaces, and occasional subretinal fluid (SRF). Retinal layer morphology may show disorganization of inner retinal layers (DRIL), loss of the ellipsoid zone, or thinning due to chronic ischemia. Hyperreflective foci, representing lipid exudates or inflammatory cells, are frequently observed. Vitreomacular interface abnormalities, such as epiretinal membranes or vitreomacular traction, may contribute to edema or structural distortion. While pigment epithelial detachment (PED) and RPE irregularities are less common, they can occur in advanced or treated cases. OCT Angiography complements structural OCT by revealing capillary nonperfusion, microaneurysms, and enlargement of the foveal avascular zone, aiding in the assessment of disease severity and progression.",
    CLASS_MAPPING["DRUSEN"]: "Drusen are extracellular lipid- and protein-rich deposits that accumulate between the retinal pigment epithelium (RPE) and Bruch’s membrane, commonly associated with age-related macular degeneration (AMD). On OCT imaging, drusen appear as dome-shaped, irregular elevations or undulations of the RPE, with hyporeflective or variably reflective material interposed between the RPE and Bruch’s membrane. They may cause overlying retinal layer distortion, particularly of the outer retinal layers such as the ellipsoid zone. Larger, confluent soft drusen may mimic pigment epithelial detachment (PED), while hard drusen are smaller and sharply demarcated. RPE irregularities, thinning, or atrophy may be present, and chronic drusen can lead to geographic atrophy or neovascular complications. Subretinal fluid or hyperreflective foci may suggest progression toward exudative AMD.",
    CLASS_MAPPING["MH"]: "A macular hole (MH) is a full-thickness or partial-thickness defect in the neurosensory retina at the fovea, resulting in central vision loss and often caused by vitreomacular traction or age-related degeneration. OCT is the gold standard for diagnosis, revealing a well-demarcated interruption of the retinal layers at the foveal center, with or without an operculum. Surrounding the hole, cystoid spaces are commonly seen in the parafoveal retina due to intraretinal edema. The edges of the hole may appear elevated or curled, and the retinal layers may show disruption or thinning, especially in advanced stages. In early stages (impending holes), OCT may show foveal pseudocysts and vitreomacular adhesion. Vitreomacular traction is often visualized as a hyperreflective band exerting anterior-posterior force on the fovea. Subretinal fluid and RPE changes are generally minimal unless chronic or complicated by other pathology. These features help stage the hole and guide treatment planning",
}
# Verify all 8 classes have entries.
if len(OCT_SEMANTICS) != NUM_TOTAL_CLASSES or not all(i in OCT_SEMANTICS for i in range(NUM_TOTAL_CLASSES)):
    print("\nWARNING: OCT_SEMANTICS dictionary is incomplete or has incorrect keys. Please define descriptions for all 8 classes using their mapped indices (0-7) as keys.")
else:
    print("\nOCT_SEMANTICS dictionary structure seems okay (number of entries matches total classes).")


OCT_SEMANTICS dictionary structure seems okay (number of entries matches total classes).


Model Hyperparameters

In [None]:
# Model Hyperparameters
IMG_WIDTH, IMG_HEIGHT = 224, 224
VISION_ENCODER_NAME = 'resnet50'
TEXT_ENCODER_NAME = 'all-MiniLM-L6-v2'
VISUAL_FEATURE_DIM = 2048  # Output feature dimension of ResNet-50 before the final FC layer
TEXT_FEATURE_DIM = 384     # Output dimension of 'all-MiniLM-L6-v2'
SEMALIGN_HIDDEN_DIM = 1024 # Hidden dimension for the SemAlignNet

# Fine-tuning Vision Encoder Hyperparameters
FT_VE_LR = 1e-4
FT_VE_EPOCHS = 10
FT_VE_BATCH_SIZE = 32

# SemAlignNet Training Hyperparameters
SEMALIGN_LR = 1e-4
SEMALIGN_EPOCHS = 15
SEMALIGN_BATCH_SIZE = 64

# Few-Shot Evaluation Parameters

if NUM_VAL_CLASSES > 0 and VAL_CLASSES_IDX:
    N_WAY_VAL = NUM_VAL_CLASSES
    print(f"N_WAY_VAL set to: {N_WAY_VAL} (based on your VAL_CLASSES_IDX)")
else:
    N_WAY_VAL = 0
    print("N_WAY_VAL set to 0 as no/insufficient validation classes are defined.")

if NUM_NOVEL_CLASSES > 0 and NOVEL_CLASSES_IDX:
    N_WAY_NOVEL = NUM_NOVEL_CLASSES
    print(f"N_WAY_NOVEL set to: {N_WAY_NOVEL} (based on your NOVEL_CLASSES_IDX)")
else:
    N_WAY_NOVEL = 0
    print("N_WAY_NOVEL set to 0 as no/insufficient novel classes are defined.")


K_SHOT_VAL = 5
M_QUERY_VAL = 15

K_SHOT_NOVEL = 5
M_QUERY_NOVEL = 15

NUM_EVAL_TASKS = 200
FUSION_FACTOR_K = 0.7

N_WAY_VAL set to: 2 (based on your VAL_CLASSES_IDX)
N_WAY_NOVEL set to: 2 (based on your NOVEL_CLASSES_IDX)


In [None]:

import os
DATASET_PATH = "/content/data/RetinalOCT_Dataset/"

if os.path.exists(DATASET_PATH) and os.path.isdir(DATASET_PATH):
    print(f"Dataset is ready for use at (DATASET_PATH): {DATASET_PATH}")
    print(f"Contents of this path (should be train, val, test folders): {os.listdir(DATASET_PATH)}")
else:
    print(f"ERROR: Expected folder was NOT found")

Dataset is ready for use at (DATASET_PATH): /content/data/RetinalOCT_Dataset/
Contents of this path (should be train, val, test folders): ['.DS_Store', 'test', 'train', 'val']


In [None]:
# Define image transformations
transform_oct = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

augment_transform_oct = transforms.Compose([
    transforms.Resize((IMG_HEIGHT + 30, IMG_WIDTH + 30)),
    transforms.RandomResizedCrop((IMG_HEIGHT, IMG_WIDTH), scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("\nImage transforms defined.")
if DATASET_PATH:
    print(f"Script will use DATASET_PATH = '{DATASET_PATH}' for subsequent steps.")
else:
    print("CRITICAL ERROR: DATASET_PATH is not correctly set. Please fix the paths before proceeding")


Image transforms defined.
Script will use DATASET_PATH = '/content/data/RetinalOCT_Dataset/' for subsequent steps.


Data Preparation

In [None]:
# Function to Load All Image Paths & Labels

def load_all_image_paths_and_labels(dataset_root, all_class_names_list, global_class_mapping):
    image_paths = []
    image_original_labels = []

    dataset_splits_to_scan = ['train', 'val', 'test']

    if not os.path.exists(dataset_root) or not os.path.isdir(dataset_root):
        print(f"ERROR: Root dataset directory '{dataset_root}' does not exist or is not a directory.")
        return image_paths, image_original_labels

    print(f"Scanning for images in: {dataset_root}")
    found_classes_in_dataset_structure = set()

    for split_folder_name in dataset_splits_to_scan:
        split_path = os.path.join(dataset_root, split_folder_name)
        if not os.path.isdir(split_path):
            print(f"Warning: Expected split folder not found: {split_path}")
            continue

        for class_name_str in os.listdir(split_path):
            if class_name_str not in global_class_mapping:
                continue

            found_classes_in_dataset_structure.add(class_name_str)
            original_class_idx = global_class_mapping[class_name_str]
            class_folder_full_path = os.path.join(split_path, class_name_str)

            if os.path.isdir(class_folder_full_path):
                for img_file_name in os.listdir(class_folder_full_path):
                    if img_file_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        image_paths.append(os.path.join(class_folder_full_path, img_file_name))
                        image_original_labels.append(original_class_idx)


    # Check if all expected classes were found in the directory structure
    for expected_class_name in all_class_names_list:
        if expected_class_name not in found_classes_in_dataset_structure:
            print(f"Warning: Expected class folder for '{expected_class_name}' was not found in any of the train/val/test splits under '{dataset_root}'.")

    if not image_paths:
        print(f"WARNING: No image paths were found. Check your DATASET_PATH ('{dataset_root}') and ensure it contains train/val/test subfolders, each with class name subfolders (e.g., NORMAL, AMD).")

    return image_paths, image_original_labels


#  Call the function to create your master lists

print(f"\nLoading all image paths from the unzipped dataset at '{DATASET_PATH}'...")
if DATASET_PATH and os.path.exists(DATASET_PATH):
    all_image_paths_master, all_original_labels_master = load_all_image_paths_and_labels(
        DATASET_PATH,
        CLASS_NAMES_OCT_C8,
        CLASS_MAPPING
    )

    print(f"\nTotal images found across all splits in OCT-C8 structure: {len(all_image_paths_master)}")
    if all_image_paths_master:
        # Print a few examples to verify
        print("\nExample loaded paths and labels (original 0-7 mapping):")
        for i in range(min(5, len(all_image_paths_master))):
            print(f"Path: {all_image_paths_master[i]}, Label: {all_original_labels_master[i]} ({REVERSE_CLASS_MAPPING[all_original_labels_master[i]]})")

        # Verify counts per class
        unique_labels, counts = np.unique(all_original_labels_master, return_counts=True)
        print("\nTotal image counts per original class (0-7) found in dataset:")
        for label_idx_found in unique_labels:
            class_name_found = REVERSE_CLASS_MAPPING.get(label_idx_found, f"Unknown_Label_{label_idx_found}")
            count_for_label = counts[np.where(unique_labels == label_idx_found)[0][0]]
            print(f"Class '{class_name_found}' (Index {label_idx_found}): {count_for_label} images")

        # OCT-C8 has 24k images total (8 classes * 3k each) if all are found.
        if sum(counts) != 24000 and len(CLASS_NAMES_OCT_C8) == 8:
             print(f"WARNING: Expected around 24000 images for OCT-C8, but found {sum(counts)}. This might be okay if some images were missing or not loadable, but double-check if the number is significantly off.")
        elif len(CLASS_NAMES_OCT_C8) == 8:
             print("Total image count and per-class counts appear consistent with the OCT-C8 dataset (approx. 3000 per class).")

else:
    print("ERROR: DATASET_PATH is not set or does not exist. Cannot load image paths.")
    all_image_paths_master, all_original_labels_master = [], []



Loading all image paths from the unzipped dataset at '/content/data/RetinalOCT_Dataset/'...
Scanning for images in: /content/data/RetinalOCT_Dataset/

Total images found across all splits in OCT-C8 structure: 24000

Example loaded paths and labels (original 0-7 mapping):
Path: /content/data/RetinalOCT_Dataset/train/NORMAL/normal_train_2547.jpg, Label: 0 (NORMAL)
Path: /content/data/RetinalOCT_Dataset/train/NORMAL/normal_train_1754.jpg, Label: 0 (NORMAL)
Path: /content/data/RetinalOCT_Dataset/train/NORMAL/normal_train_2405.jpg, Label: 0 (NORMAL)
Path: /content/data/RetinalOCT_Dataset/train/NORMAL/normal_train_2275.jpg, Label: 0 (NORMAL)
Path: /content/data/RetinalOCT_Dataset/train/NORMAL/normal_train_2716.jpg, Label: 0 (NORMAL)

Total image counts per original class (0-7) found in dataset:
Class 'NORMAL' (Index 0): 3000 images
Class 'AMD' (Index 1): 3000 images
Class 'CNV' (Index 2): 3000 images
Class 'CSR' (Index 3): 3000 images
Class 'DME' (Index 4): 3000 images
Class 'DR' (Index 5):

Implementing the Custom PyTorch Dataset Classes

In [None]:
class OCTFSLDataset(Dataset):
    def __init__(self,
                 all_image_paths_master,
                 all_original_labels_master,
                 target_fsl_class_original_indices,
                 global_class_mapping,
                 semantics_dict,
                 transform,
                 is_for_semalign_training=False):

        self.image_paths = []
        self.labels_original_for_this_split = []
        self.labels_mapped_for_this_split = []

        # Create a mapping from original labels (0-7) to continuous 0-based labels for this specific FSL split
        self.split_label_map = {orig_idx: new_idx for new_idx, orig_idx in enumerate(sorted(list(set(target_fsl_class_original_indices))))}

        # Filter the master lists to include only images belonging to the target_fsl_class_original_indices
        for path, orig_label in zip(all_image_paths_master, all_original_labels_master):
            if orig_label in target_fsl_class_original_indices:
                self.image_paths.append(path)
                self.labels_original_for_this_split.append(orig_label)
                self.labels_mapped_for_this_split.append(self.split_label_map[orig_label])

        if not self.image_paths:
            print(f"WARNING: OCTFSLDataset initialized with target classes {target_fsl_class_original_indices}, but no images were found for these classes. Check your FSL class splits and data loading.")

        self.semantics_dict = semantics_dict
        self.transform = transform
        self.is_for_semalign_training = is_for_semalign_training

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        original_label_for_sample = self.labels_original_for_this_split[idx]
        mapped_label_for_sample = self.labels_mapped_for_this_split[idx]
        try:
            image = Image.open(img_path).convert('L')
        except FileNotFoundError:
            print(f"ERROR: Image file not found at {img_path}")
            dummy_tensor = torch.zeros((3, IMG_HEIGHT, IMG_WIDTH))
            if self.is_for_semalign_training:
                return dummy_tensor, -1, "Error: Image not found"
            else:
                return dummy_tensor, -1


        if self.transform:
            image = self.transform(image)

        if self.is_for_semalign_training:

            description = self.semantics_dict[original_label_for_sample]
            return image, mapped_label_for_sample, description
        else:
            return image, mapped_label_for_sample

print("OCTFSLDataset class defined.")

OCTFSLDataset class defined.


In [None]:
# Instantiate Datasets for different purposes

# Dataset for Fine-Tuning the Vision Encoder (uses Base Classes)(applies data augmentation)

if NUM_BASE_CLASSES > 0:
    ft_ve_dataset = OCTFSLDataset(
        all_image_paths_master=all_image_paths_master,
        all_original_labels_master=all_original_labels_master,
        target_fsl_class_original_indices=BASE_CLASSES_IDX,
        global_class_mapping=CLASS_MAPPING,
        semantics_dict=OCT_SEMANTICS,
        transform=augment_transform_oct,
        is_for_semalign_training=False
    )
    print(f"\nDataset for Vision Encoder Fine-Tuning (ft_ve_dataset) created with {len(ft_ve_dataset)} samples from {NUM_BASE_CLASSES} base classes.")
    if len(ft_ve_dataset) > 0:
        img, lbl = ft_ve_dataset[0]
        print(f"  Example sample - Image shape: {img.shape}, Label: {lbl} (this is a mapped base label, 0 to {NUM_BASE_CLASSES-1})")
else:
    print("\nWARNING: No base classes defined. Skipping creation of ft_ve_dataset.")
    ft_ve_dataset = None





Dataset for Vision Encoder Fine-Tuning (ft_ve_dataset) created with 12000 samples from 4 base classes.
  Example sample - Image shape: torch.Size([3, 224, 224]), Label: 0 (this is a mapped base label, 0 to 3)


In [None]:
# Dataset for Training SemAlignNet (uses Base Classes and their Semantics) (Applies data augmentation)
if NUM_BASE_CLASSES > 0:
    base_oct_dataset_for_semalign = OCTFSLDataset(
        all_image_paths_master=all_image_paths_master,
        all_original_labels_master=all_original_labels_master,
        target_fsl_class_original_indices=BASE_CLASSES_IDX,
        global_class_mapping=CLASS_MAPPING,
        semantics_dict=OCT_SEMANTICS,
        transform=augment_transform_oct,
        is_for_semalign_training=True
    )
    print(f"\nDataset for SemAlignNet Training (base_oct_dataset_for_semalign) created with {len(base_oct_dataset_for_semalign)} samples from {NUM_BASE_CLASSES} base classes.")
    if len(base_oct_dataset_for_semalign) > 0:
        img, lbl, desc = base_oct_dataset_for_semalign[0]
        print(f"  Example sample - Image shape: {img.shape}, Label: {lbl}, Desc starts with: '{desc[:50]}...'")
else:
    print("\nWARNING: No base classes defined. Skipping creation of base_oct_dataset_for_semalign.")
    base_oct_dataset_for_semalign = None


Dataset for SemAlignNet Training (base_oct_dataset_for_semalign) created with 12000 samples from 4 base classes.
  Example sample - Image shape: torch.Size([3, 224, 224]), Label: 0, Desc starts with: 'A normal retina on OCT displays well-defined, cont...'


In [None]:
# Datasets for Validation

if NUM_VAL_CLASSES > 0:
    val_oct_dataset_for_analysis = OCTFSLDataset(
        all_image_paths_master=all_image_paths_master,
        all_original_labels_master=all_original_labels_master,
        target_fsl_class_original_indices=VAL_CLASSES_IDX,
        global_class_mapping=CLASS_MAPPING,
        semantics_dict=OCT_SEMANTICS,
        transform=transform_oct,
        is_for_semalign_training=False
    )
    print(f"\nDataset for Validation Class Analysis (val_oct_dataset_for_analysis) created with {len(val_oct_dataset_for_analysis)} samples from {NUM_VAL_CLASSES} validation classes.")
else:
    print("\nNo validation classes defined. Skipping creation of val_oct_dataset_for_analysis.")
    val_oct_dataset_for_analysis = None


Dataset for Validation Class Analysis (val_oct_dataset_for_analysis) created with 6000 samples from 2 validation classes.


Creating Data Loaders

In [None]:
# DataLoader for Vision Encoder Fine-tuning
if ft_ve_dataset and len(ft_ve_dataset) > 0 :
    ft_ve_loader = DataLoader(
        ft_ve_dataset,
        batch_size=FT_VE_BATCH_SIZE,
        shuffle=True,                # Shuffle data for training
        num_workers=2,
        pin_memory=True              # If using GPU, helps speed up data transfer
    )
    print(f"DataLoader for VE Fine-Tuning (ft_ve_loader) created with {len(ft_ve_loader)} batches.")


    # Test fetching one batch
    try:
        img_batch, lbl_batch = next(iter(ft_ve_loader))
        print(f"  Example batch - Image shape: {img_batch.shape}, Label shape: {lbl_batch.shape}")
    except Exception as e:
        print(f"  Error fetching batch from ft_ve_loader: {e}")
else:
    print("WARNING: ft_ve_dataset is None or empty. Skipping creation of ft_ve_loader.")
    ft_ve_loader = None


DataLoader for VE Fine-Tuning (ft_ve_loader) created with 375 batches.
  Example batch - Image shape: torch.Size([32, 3, 224, 224]), Label shape: torch.Size([32])


In [None]:
# DataLoader for SemAlignNet training
if base_oct_dataset_for_semalign and len(base_oct_dataset_for_semalign) > 0:
    semalign_train_loader = DataLoader(
        base_oct_dataset_for_semalign,
        batch_size=SEMALIGN_BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    print(f"DataLoader for SemAlignNet Training (semalign_train_loader) created with {len(semalign_train_loader)} batches.")


    # Test fetching one batch
    try:
        img_batch, lbl_batch, desc_batch = next(iter(semalign_train_loader))
        print(f"  Example batch - Image shape: {img_batch.shape}, Label shape: {lbl_batch.shape}, Num descriptions: {len(desc_batch)}")
    except Exception as e:
        print(f"  Error fetching batch from semalign_train_loader: {e}")
else:
    print("WARNING: base_oct_dataset_for_semalign is None or empty. Skipping creation of semalign_train_loader.")
    semalign_train_loader = None


DataLoader for SemAlignNet Training (semalign_train_loader) created with 188 batches.
  Example batch - Image shape: torch.Size([64, 3, 224, 224]), Label shape: torch.Size([64]), Num descriptions: 64


Loading Pretrained encoders

text encoder - all-MiniLM-L6-v2, vision encoder - ResNet50


In [None]:
print("\n Loading Pre-trained Encoders")
start_time_step = time.time()

# Load Text Encoder

print(f"Loading Text Encoder: {TEXT_ENCODER_NAME}...")
try:
    text_encoder = SentenceTransformer(TEXT_ENCODER_NAME, device=str(DEVICE))
    # Test encoding a dummy sentence
    dummy_description_list = ["This is a test sentence for the text encoder."]
    with torch.no_grad():
        dummy_embedding = text_encoder.encode(dummy_description_list, convert_to_tensor=True, device=DEVICE)
    print(f"Text Encoder loaded successfully. Example embedding shape: {dummy_embedding.shape}")

    if dummy_embedding.shape[1] != TEXT_FEATURE_DIM:
        print(f"WARNING: Text encoder output dimension {dummy_embedding.shape[1]} does not match TEXT_FEATURE_DIM {TEXT_FEATURE_DIM}. Please check TEXT_ENCODER_NAME and TEXT_FEATURE_DIM.")
except Exception as e:
    print(f"ERROR loading Text Encoder: {e}")
    text_encoder = None

# Load Vision Encoder

print(f"\nLoading Vision Encoder base: {VISION_ENCODER_NAME}...")
try:
    if VISION_ENCODER_NAME == 'resnet50':
        vision_encoder_base = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        num_ftrs_original = vision_encoder_base.fc.in_features
        if num_ftrs_original != VISUAL_FEATURE_DIM:
             print(f"WARNING: Original ResNet50 fc input dim {num_ftrs_original} does not match VISUAL_FEATURE_DIM {VISUAL_FEATURE_DIM}. This might be an issue if VISUAL_FEATURE_DIM is used incorrectly elsewhere.")
        print(f"ResNet50 base model loaded. Original output features (before fc): {num_ftrs_original}")
    else:
        raise ValueError(f"Vision encoder {VISION_ENCODER_NAME} not currently supported in this script. Please add it or choose resnet50.")

    vision_encoder_base = vision_encoder_base.to(DEVICE)
    print(f"Vision Encoder base ({VISION_ENCODER_NAME}) moved to {DEVICE}.")

except Exception as e:
    print(f"ERROR loading Vision Encoder base: {e}")
    vision_encoder_base = None

print(f"completed in {time.time() - start_time_step:.2f} seconds.")


 Loading Pre-trained Encoders
Loading Text Encoder: all-MiniLM-L6-v2...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Text Encoder loaded successfully. Example embedding shape: torch.Size([1, 384])

Loading Vision Encoder base: resnet50...


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 149MB/s]


ResNet50 base model loaded. Original output features (before fc): 2048
Vision Encoder base (resnet50) moved to cuda.
completed in 12.79 seconds.


Fine-tuning the Vision Encoder on Base Set

In [None]:
print("Fine-tuning the Vision Encoder...")
start_time_step = time.time()

# Modify Classifier Head for Fine-tuning

if vision_encoder_base is not None and NUM_BASE_CLASSES > 0 :

    num_ftrs = vision_encoder_base.fc.in_features

    vision_encoder_base.fc = nn.Linear(num_ftrs, NUM_BASE_CLASSES)

    vision_encoder_for_finetuning = vision_encoder_base.to(DEVICE)

    print(f"Vision Encoder's final layer replaced to output {NUM_BASE_CLASSES} features for fine-tuning.")
else:
    print("ERROR: vision_encoder_base is None or NUM_BASE_CLASSES is 0. Cannot proceed with fine-tuning.")
    vision_encoder_for_finetuning = None


#Define Loss and Optimizer for Fine-tuning

if vision_encoder_for_finetuning is not None:
    criterion_ve = nn.CrossEntropyLoss()
    optimizer_ve = optim.Adam(vision_encoder_for_finetuning.parameters(), lr=FT_VE_LR)
    print("Loss function and optimizer for VE fine-tuning defined.")


#Training Loop for Fine-tuning

if vision_encoder_for_finetuning and ft_ve_loader and len(ft_ve_loader.dataset) > 0:
    print(f"Fine-tuning Vision Encoder for {FT_VE_EPOCHS} epochs on {len(ft_ve_loader.dataset)} base images...")

    vision_encoder_for_finetuning.train()

    for epoch in range(FT_VE_EPOCHS):
        epoch_loss_ve = 0.0
        correct_ve = 0
        total_ve = 0

        for images, mapped_labels in ft_ve_loader:
            images, mapped_labels = images.to(DEVICE), mapped_labels.to(DEVICE)

            optimizer_ve.zero_grad()

            outputs = vision_encoder_for_finetuning(images)
            loss = criterion_ve(outputs, mapped_labels)

            loss.backward()
            optimizer_ve.step()

            epoch_loss_ve += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_ve += mapped_labels.size(0)
            correct_ve += (predicted == mapped_labels).sum().item()

        avg_epoch_loss_ve = epoch_loss_ve / len(ft_ve_loader) if len(ft_ve_loader) > 0 else 0
        accuracy_ve = 100 * correct_ve / total_ve if total_ve > 0 else 0
        print(f"VE Fine-Tuning Epoch [{epoch+1}/{FT_VE_EPOCHS}], Loss: {avg_epoch_loss_ve:.4f}, Accuracy: {accuracy_ve:.2f}%")
else:
    if not vision_encoder_for_finetuning:
         print("Skipping Vision Encoder fine-tuning as vision_encoder_for_finetuning is not defined.")
    elif not ft_ve_loader or len(ft_ve_loader.dataset) == 0:
         print("Skipping Vision Encoder fine-tuning as ft_ve_loader is None or its dataset is empty.")


# Convert to Feature Extractor and Freeze
if vision_encoder_for_finetuning is not None:
    vision_encoder_for_finetuning.fc = nn.Identity()
    vision_encoder_for_finetuning.eval()

    # Freeze all parameters of the vision encoder
    for param in vision_encoder_for_finetuning.parameters():
        param.requires_grad = False
    print("Vision Encoder fine-tuning complete. Final layer converted to Identity, and all weights frozen.")
    vision_encoder_f = vision_encoder_for_finetuning
else:
    print("Vision encoder was not fine-tuned. If you intended to use a pre-trained frozen encoder without fine-tuning, ensure 'vision_encoder_base' is appropriately configured as a feature extractor.")
    if vision_encoder_base is not None:
        vision_encoder_base.fc = nn.Identity()
        vision_encoder_base.eval()
        for param in vision_encoder_base.parameters():
            param.requires_grad = False
        vision_encoder_f = vision_encoder_base
        print("Using the original pre-trained vision_encoder_base as a frozen feature extractor (no fine-tuning was performed).")
    else:
        vision_encoder_f = None
        print("ERROR: No vision encoder available.")


print(f"completed in {time.time() - start_time_step:.2f} seconds.")


Fine-tuning the Vision Encoder...
Vision Encoder's final layer replaced to output 4 features for fine-tuning.
Loss function and optimizer for VE fine-tuning defined.
Fine-tuning Vision Encoder for 10 epochs on 12000 base images...
VE Fine-Tuning Epoch [1/10], Loss: 0.1810, Accuracy: 94.17%
VE Fine-Tuning Epoch [2/10], Loss: 0.1001, Accuracy: 96.81%
VE Fine-Tuning Epoch [3/10], Loss: 0.0842, Accuracy: 97.37%
VE Fine-Tuning Epoch [4/10], Loss: 0.0747, Accuracy: 97.59%
VE Fine-Tuning Epoch [5/10], Loss: 0.0662, Accuracy: 97.80%
VE Fine-Tuning Epoch [6/10], Loss: 0.0655, Accuracy: 98.03%
VE Fine-Tuning Epoch [7/10], Loss: 0.0551, Accuracy: 98.15%
VE Fine-Tuning Epoch [8/10], Loss: 0.0515, Accuracy: 98.22%
VE Fine-Tuning Epoch [9/10], Loss: 0.0439, Accuracy: 98.65%
VE Fine-Tuning Epoch [10/10], Loss: 0.0497, Accuracy: 98.38%
Vision Encoder fine-tuning complete. Final layer converted to Identity, and all weights frozen.
completed in 1265.01 seconds.


 Calculating the Base Set Prototypes

In [None]:
print("Calculating OCT Base Set Prototypes...")
start_time_step = time.time()

base_prototypes_mapped = torch.zeros(NUM_BASE_CLASSES, VISUAL_FEATURE_DIM).to(DEVICE)
counts_mapped = torch.zeros(NUM_BASE_CLASSES).to(DEVICE)


if 'base_proto_calc_dataset' not in locals() or base_proto_calc_dataset is None:
    print("Re-creating base_proto_calc_dataset as it was not found in locals...")

    base_proto_calc_dataset = OCTFSLDataset(
        all_image_paths_master=all_image_paths_master,
        all_original_labels_master=all_original_labels_master,
        target_fsl_class_original_indices=BASE_CLASSES_IDX,
        global_class_mapping=CLASS_MAPPING,
        semantics_dict=OCT_SEMANTICS,
        transform=transform_oct,
        is_for_semalign_training=False
    )

if base_proto_calc_dataset and len(base_proto_calc_dataset) > 0:
    base_proto_calc_loader = DataLoader(
        base_proto_calc_dataset,
        batch_size=SEMALIGN_BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    print(f"DataLoader for prototype calculation created with {len(base_proto_calc_loader)} batches.")

    with torch.no_grad():
        vision_encoder_f.eval()

        for images, mapped_labels in base_proto_calc_loader:
            images = images.to(DEVICE)


            features = vision_encoder_f(images)

            for i_feature_in_batch, mapped_label_val in enumerate(mapped_labels):

                label_idx = mapped_label_val.item() if torch.is_tensor(mapped_label_val) else mapped_label_val

                if 0 <= label_idx < NUM_BASE_CLASSES:
                    base_prototypes_mapped[label_idx] += features[i_feature_in_batch]
                    counts_mapped[label_idx] += 1
                else:
                    print(f"Warning: Invalid mapped_label_val encountered: {label_idx}")

    # Average the features for each class
    for i in range(NUM_BASE_CLASSES):
        if counts_mapped[i] > 0:
            base_prototypes_mapped[i] /= counts_mapped[i]
        else:
            original_class_name_for_warning = "Unknown (mapping issue)"
            for name, idx_map in base_oct_dataset_for_semalign.split_label_map.items():
                    original_class_name_for_warning = REVERSE_CLASS_MAPPING.get(name, "Unknown Original Index")
                    break
            print(f"WARNING: No samples found for mapped base label index {i} (Corresponds to Base Class: '{original_class_name_for_warning}') during prototype calculation. Its prototype will be zeros.")

    print("OCT Base Set Prototypes calculated.")
    print(f"Shape of base_prototypes_mapped: {base_prototypes_mapped.shape}")
else:
    print("ERROR: base_proto_calc_dataset is None or empty. Cannot calculate base prototypes.")
    base_prototypes_mapped = None

print(f"\ncompleted in {time.time() - start_time_step:.2f} seconds.")

Calculating OCT Base Set Prototypes...
Re-creating base_proto_calc_dataset as it was not found in locals...
DataLoader for prototype calculation created with 188 batches.
OCT Base Set Prototypes calculated.
Shape of base_prototypes_mapped: torch.Size([4, 2048])

completed in 60.95 seconds.


Defining the Semantic Alignment Network

In [None]:
print("\n: Defining the Semantic Alignment Network (SemAlignNet)...")
start_time_step = time.time()

# Define SemAlign Network

class SemAlignNet(nn.Module):
    def __init__(self, visual_dim, text_dim, hidden_dim, output_dim):
        super().__init__()
        # First linear layer takes the concatenated visual and text features
        self.fc1 = nn.Linear(visual_dim + text_dim, hidden_dim)
        self.relu = nn.LeakyReLU(0.1)
        # Second linear layer maps the hidden representation to the visual feature space dimension
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, visual_features, semantic_features):
        # Concatenate visual and semantic features along the feature dimension
        combined_features = torch.cat((visual_features, semantic_features), dim=1)

        hidden_representation = self.relu(self.fc1(combined_features))
        output_prototype = self.fc2(hidden_representation)
        return output_prototype


print("SemAlignNet class defined.")
print(f"It will take input of dim: {VISUAL_FEATURE_DIM} (visual) + {TEXT_FEATURE_DIM} (text) = {VISUAL_FEATURE_DIM + TEXT_FEATURE_DIM}")
print(f"Hidden layer dimension: {SEMALIGN_HIDDEN_DIM}")
print(f"Output dimension (to match visual features): {VISUAL_FEATURE_DIM}")

print(f"\n completed in {time.time() - start_time_step:.2f} seconds.")


: Defining the Semantic Alignment Network (SemAlignNet)...
SemAlignNet class defined.
It will take input of dim: 2048 (visual) + 384 (text) = 2432
Hidden layer dimension: 1024
Output dimension (to match visual features): 2048

 completed in 0.00 seconds.


Training the SemAlignNet

In [None]:
print("\n Training the SemAlignNet...")
start_time_step = time.time()

# Instantiate Network, Define Loss and Optimizer
if 'vision_encoder_f' not in globals() or vision_encoder_f is None:
    print("ERROR: vision_encoder_f (fine-tuned feature extractor) is not defined. Please run previous steps.")
    semalign_net_trained_instance = None
elif base_prototypes_mapped is None:
    print("ERROR: base_prototypes_mapped is None. Cannot train SemAlignNet.")
    semalign_net_trained_instance = None
elif semalign_train_loader is None or len(semalign_train_loader.dataset) == 0:
    print("ERROR: semalign_train_loader is None or its dataset is empty. Cannot train SemAlignNet.")
    semalign_net_trained_instance = None
else:
    semalign_net_trained_instance = SemAlignNet(
        visual_dim=VISUAL_FEATURE_DIM,
        text_dim=TEXT_FEATURE_DIM,
        hidden_dim=SEMALIGN_HIDDEN_DIM,
        output_dim=VISUAL_FEATURE_DIM
    ).to(DEVICE)

    criterion_semalign = nn.L1Loss()
    optimizer_semalign = optim.Adam(semalign_net_trained_instance.parameters(), lr=SEMALIGN_LR)
    print("SemAlignNet instantiated and optimizer/loss defined.")

    # Training Loop
    print(f"Training SemAlignNet for {SEMALIGN_EPOCHS} epochs on {len(semalign_train_loader.dataset)} base images...")

    vision_encoder_f.eval()

    for epoch in range(SEMALIGN_EPOCHS):
        semalign_net_trained_instance.train()
        epoch_loss_sa = 0.0

        for images, mapped_labels, descriptions in semalign_train_loader:
            images = images.to(DEVICE)
            mapped_labels = mapped_labels.to(DEVICE)

            optimizer_semalign.zero_grad()


            with torch.no_grad():
                visual_features = vision_encoder_f(images)

                semantic_features = torch.tensor(text_encoder.encode(list(descriptions)), device=DEVICE).float()

            # Forward pass through SemAlignNet
            reconstructed_prototypes = semalign_net_trained_instance(visual_features, semantic_features)

            # Get target base prototypes for this batch
            target_prototypes = base_prototypes_mapped[mapped_labels]

            # Calculate loss
            loss = criterion_semalign(reconstructed_prototypes, target_prototypes)

            # Backward pass and optimization (only for SemAlignNet parameters)
            loss.backward()
            optimizer_semalign.step()

            epoch_loss_sa += loss.item()

        avg_epoch_loss_sa = epoch_loss_sa / len(semalign_train_loader) if len(semalign_train_loader) > 0 else 0
        print(f"SemAlignNet Training Epoch [{epoch+1}/{SEMALIGN_EPOCHS}], Average Loss: {avg_epoch_loss_sa:.4f}")

    semalign_net_trained_instance.eval()
    print("SemAlignNet training complete.")

print(f"\n completed in {time.time() - start_time_step:.2f} seconds.")


 Training the SemAlignNet...
SemAlignNet instantiated and optimizer/loss defined.
Training SemAlignNet for 15 epochs on 12000 base images...
SemAlignNet Training Epoch [1/15], Average Loss: 0.0941
SemAlignNet Training Epoch [2/15], Average Loss: 0.0316
SemAlignNet Training Epoch [3/15], Average Loss: 0.0237
SemAlignNet Training Epoch [4/15], Average Loss: 0.0190
SemAlignNet Training Epoch [5/15], Average Loss: 0.0163
SemAlignNet Training Epoch [6/15], Average Loss: 0.0133
SemAlignNet Training Epoch [7/15], Average Loss: 0.0126
SemAlignNet Training Epoch [8/15], Average Loss: 0.0118
SemAlignNet Training Epoch [9/15], Average Loss: 0.0102
SemAlignNet Training Epoch [10/15], Average Loss: 0.0095
SemAlignNet Training Epoch [11/15], Average Loss: 0.0090
SemAlignNet Training Epoch [12/15], Average Loss: 0.0087
SemAlignNet Training Epoch [13/15], Average Loss: 0.0086
SemAlignNet Training Epoch [14/15], Average Loss: 0.0084
SemAlignNet Training Epoch [15/15], Average Loss: 0.0082
SemAlignNet 

Implementing Few Shot Evaluation

In [None]:
print("\n Implementing and Running Few-Shot Evaluation...")


# Define FSL Task Sampler
class OCTFSLTaskSampler:
    def __init__(self, all_img_paths_master, all_orig_labels_master,
                 fsl_class_original_indices,
                 global_class_mapping,
                 transform_eval):

        self.transform = transform_eval
        self.global_class_mapping = global_class_mapping

        self.fsl_class_original_indices = sorted(list(set(fsl_class_original_indices)))

        self.class_to_image_paths = defaultdict(list)
        for path, orig_label in zip(all_img_paths_master, all_orig_labels_master):
            if orig_label in self.fsl_class_original_indices:
                self.class_to_image_paths[orig_label].append(path)

        self.usable_fsl_class_original_indices = []
        min_samples_needed_check = K_SHOT_NOVEL + M_QUERY_NOVEL

        for cls_idx in self.fsl_class_original_indices:
            if len(self.class_to_image_paths[cls_idx]) >= min_samples_needed_check:
                self.usable_fsl_class_original_indices.append(cls_idx)
            else:
                print(f"Sampler Info: Class {REVERSE_CLASS_MAPPING.get(cls_idx, cls_idx)} has {len(self.class_to_image_paths[cls_idx])} images, less than the typical K+M={min_samples_needed_check} needed for robust sampling. It might not be usable for all N-way K-M configurations.")

        if not self.usable_fsl_class_original_indices:
            print("Sampler WARNING: No usable classes found with enough samples for typical K+M shots. FSL evaluation may fail.")
        else:
            print(f"Sampler created for original class indices: {self.fsl_class_original_indices}. Usable original class indices (>= K_novel+M_novel samples): {self.usable_fsl_class_original_indices}")


    def get_task(self, N, K, M):
        if len(self.usable_fsl_class_original_indices) < N:
            # print(f"Sampler: Cannot sample {N}-way task, only {len(self.usable_fsl_class_original_indices)} usable classes with enough samples.")
            return None

        # Sample N distinct original class indices for the current task
        task_original_labels_for_sampling = random.sample(self.usable_fsl_class_original_indices, N)

        support_images_task, query_images_task, query_labels_task_based = [], [], []

        task_based_label_map = {orig_label: task_label for task_label, orig_label in enumerate(task_original_labels_for_sampling)}

        for original_label_in_task in task_original_labels_for_sampling:
            task_based_label = task_based_label_map[original_label_in_task]
            class_image_paths = self.class_to_image_paths[original_label_in_task]

            if len(class_image_paths) < K + M:
               return None

            selected_paths = random.sample(class_image_paths, K + M)

            for i, img_path in enumerate(selected_paths):
                try:
                    image = Image.open(img_path).convert('L')
                    if self.transform:
                        image = self.transform(image)
                except Exception as e:
                    print(f"Error loading/transforming image {img_path}: {e}")
                    # Skip this image or task if an image is problematic
                    return None


                if i < K: # First K images are support
                    support_images_task.append(image)
                else: # Remaining M images are query
                    query_images_task.append(image)
                    query_labels_task_based.append(task_based_label)


        if len(support_images_task) != N*K or len(query_images_task) != N*M or len(query_labels_task_based) != N*M :

            return None

        return (torch.stack(support_images_task).to(DEVICE),
                torch.stack(query_images_task).to(DEVICE),
                torch.tensor(query_labels_task_based).long().to(DEVICE),
                task_original_labels_for_sampling)

print("OCTFSLTaskSampler class defined.")


#Instantiate Samplers
val_sampler = None
if NUM_VAL_CLASSES > 0 and VAL_CLASSES_IDX:
    val_sampler = OCTFSLTaskSampler(all_image_paths_master, all_original_labels_master,
                                    VAL_CLASSES_IDX, CLASS_MAPPING, transform_oct)
    if not val_sampler.usable_fsl_class_original_indices or len(val_sampler.usable_fsl_class_original_indices) < N_WAY_VAL:
        print(f"Warning: Validation sampler has < {N_WAY_VAL} usable classes for {N_WAY_VAL}-way tasks. Check VAL_CLASSES_IDX and image counts.")
else:
    print("No validation classes defined or VAL_CLASSES_IDX is empty. Skipping validation sampler instantiation.")

novel_sampler = None
if NUM_NOVEL_CLASSES > 0 and NOVEL_CLASSES_IDX:
    novel_sampler = OCTFSLTaskSampler(all_image_paths_master, all_original_labels_master,
                                      NOVEL_CLASSES_IDX, CLASS_MAPPING, transform_oct)
    if not novel_sampler.usable_fsl_class_original_indices or len(novel_sampler.usable_fsl_class_original_indices) < N_WAY_NOVEL:
        print(f"Warning: Novel sampler has < {N_WAY_NOVEL} usable classes for {N_WAY_NOVEL}-way tasks. Check NOVEL_CLASSES_IDX and image counts.")
else:
    print("No novel classes defined or NOVEL_CLASSES_IDX is empty. Skipping novel sampler instantiation.")

#Define FSL Evaluation Function


def run_fsl_evaluation(sampler, N_way, K_shot, M_query, num_tasks,
                       description_dict, fusion_k_val,
                       vision_enc, text_enc, semalign_module,
                       set_name="Novel"):
    if sampler is None:
        print(f"Skipping evaluation for {set_name} set as sampler is not available.")
        return 0.0, 0.0
    if len(sampler.usable_fsl_class_original_indices) < N_way:
        print(f"Skipping {N_way}-way evaluation for {set_name} set: Not enough usable classes in sampler ({len(sampler.usable_fsl_class_original_indices)} found).")
        return 0.0, 0.0

    vision_enc.eval()
    semalign_module.eval()


    accuracies = []
    print(f"\nRunning Few-Shot Evaluation on {set_name} Set ({N_way}-way {K_shot}-shot)...")
    start_eval_time_local = time.time()
    tasks_formed_count = 0
    valid_tasks_processed = 0

    for task_idx in range(num_tasks):
        task_data = sampler.get_task(N_way, K_shot, M_query)

        if task_data is None:

            continue
        tasks_formed_count +=1

        support_images, query_images, query_labels_task_based, task_original_class_indices = task_data


        if support_images.shape[0] != N_way * K_shot or query_images.shape[0] != N_way * M_query:

            continue

        task_prototypes = torch.zeros(N_way, VISUAL_FEATURE_DIM).to(DEVICE)

        with torch.no_grad():

            all_support_features = vision_enc(support_images)

            for i_class_in_task in range(N_way):
                original_class_label = task_original_class_indices[i_class_in_task]


                class_support_features = all_support_features[i_class_in_task * K_shot : (i_class_in_task + 1) * K_shot]


                u_t = torch.mean(class_support_features, dim=0)

                # Get semantic description for this original class label

                semantic_desc_str = description_dict.get(original_class_label)
                if semantic_desc_str is None:
                    print(f"ERROR: Semantic description not found for original class index {original_class_label}. Skipping task.")
                    break

                # Encode semantic description to get g(s_t)

                g_s_t = torch.tensor(text_enc.encode([semantic_desc_str]), device=DEVICE).float()[0]

                g_s_t_expanded = g_s_t.unsqueeze(0).expand(K_shot, -1)

                # Calculate r_t: reconstructed prototype using SemAlignNet
                # Pass each of the K_shot support features along with the (same) class semantic feature
                reconstructed_support_protos_t = semalign_module(class_support_features, g_s_t_expanded)
                r_t = torch.mean(reconstructed_support_protos_t, dim=0) # Shape: (VISUAL_FEATURE_DIM)

                # Fuse prototypes: p_t = k*r_t + (1-k)*u_t (Equation 6)
                p_t = fusion_k_val * r_t + (1 - fusion_k_val) * u_t
                task_prototypes[i_class_in_task] = p_t
            else:
                # Proceed to classify query images only if all class prototypes were generated
                all_query_features = vision_enc(query_images)

                # Cosine similarity: normalize features and prototypes, then dot product
                all_query_features_norm = torch.nn.functional.normalize(all_query_features, p=2, dim=1)
                task_prototypes_norm = torch.nn.functional.normalize(task_prototypes, p=2, dim=1)

                similarities = torch.mm(all_query_features_norm, task_prototypes_norm.t())
                predictions_task_based = torch.argmax(similarities, dim=1)

                correct_predictions = (predictions_task_based == query_labels_task_based).sum().item()
                task_accuracy = correct_predictions / len(query_labels_task_based) if len(query_labels_task_based) > 0 else 0
                accuracies.append(task_accuracy)
                valid_tasks_processed +=1


        if (task_idx + 1) % (num_tasks // 10 if num_tasks >=10 else 1) == 0:
             print(f"  {set_name} Set: Evaluated task {task_idx + 1}/{num_tasks} (Actual tasks processed so far: {valid_tasks_processed})")

    if not accuracies:
        print(f"No tasks could be successfully processed for {set_name} set with N={N_way}, K={K_shot}, M={M_query}.")
        print(f"  Total tasks attempted: {tasks_formed_count}. Valid tasks processed: {valid_tasks_processed}.")
        print(f"  Check usable class counts in sampler ({len(sampler.usable_fsl_class_original_indices)} found) and if descriptions exist for all sampled classes.")
        return 0.0, 0.0

    mean_accuracy = np.mean(accuracies)
    std_dev = np.std(accuracies)

    conf_interval = 1.96 * std_dev / np.sqrt(len(accuracies)) if len(accuracies) > 1 else 0.0

    print(f"\n{set_name} Set Evaluation (processed {valid_tasks_processed}/{tasks_formed_count} formed tasks from {num_tasks} attempts) finished in {time.time() - start_eval_time_local:.2f} seconds.")
    print(f"--- {set_name} Set Results ({N_way}-way {K_shot}-shot) ---")
    print(f"Mean Accuracy: {mean_accuracy:.4f}")
    print(f"95% Confidence Interval: +/- {conf_interval:.4f}")
    print(f"Accuracy Range: [{mean_accuracy - conf_interval:.4f}, {mean_accuracy + conf_interval:.4f}]")
    return mean_accuracy, conf_interval





 Implementing and Running Few-Shot Evaluation...
OCTFSLTaskSampler class defined.
Sampler created for original class indices: [2, 3]. Usable original class indices (>= K_novel+M_novel samples): [2, 3]
Sampler created for original class indices: [5, 7]. Usable original class indices (>= K_novel+M_novel samples): [5, 7]


In [None]:
print("\n--- Tuning FUSION_FACTOR_K on Validation Set ---")
best_k_for_validation = -1
best_accuracy_on_validation = -1.0

k_values_to_try = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

if NUM_VAL_CLASSES > 0 and val_sampler is not None:
    N_WAY_VAL_ACTUAL = len(val_sampler.usable_fsl_class_original_indices) if len(val_sampler.usable_fsl_class_original_indices) < N_WAY_VAL else N_WAY_VAL
    if N_WAY_VAL_ACTUAL >=2 :
        for k_try in k_values_to_try:
            print(f"\nEvaluating Validation Set with FUSION_FACTOR_K = {k_try}...")
            current_val_accuracy, _ = run_fsl_evaluation(
                val_sampler,
                N_way=N_WAY_VAL_ACTUAL,
                K_shot=K_SHOT_VAL,
                M_query=M_QUERY_VAL,
                num_tasks=NUM_EVAL_TASKS // 2 if NUM_EVAL_TASKS > 1 else 1, # Fewer tasks for faster tuning
                description_dict=OCT_SEMANTICS,
                fusion_k_val=k_try,
                vision_enc=vision_encoder_f,
                text_enc=text_encoder,
                semalign_module=semalign_net_trained_instance,
                set_name=f"Validation (k={k_try})"
            )
            if current_val_accuracy > best_accuracy_on_validation:
                best_accuracy_on_validation = current_val_accuracy
                best_k_for_validation = k_try

        print(f"\n--- K-Tuning Complete ---")
        print(f"Best FUSION_FACTOR_K found on Validation Set: {best_k_for_validation}")
        print(f"Corresponding Mean Accuracy on Validation Set: {best_accuracy_on_validation:.4f}")
        FUSION_FACTOR_K = best_k_for_validation # Update the global K for Novel set evaluation
    else:
        print("Not enough usable validation classes for N-way evaluation during k-tuning.")
        # FUSION_FACTOR_K remains the default if tuning cannot be performed
else:
    print("Skipping K-tuning as no validation classes defined or sampler not created.")
    # FUSION_FACTOR_K remains the default


--- Tuning FUSION_FACTOR_K on Validation Set ---

Evaluating Validation Set with FUSION_FACTOR_K = 0.0...

Running Few-Shot Evaluation on Validation (k=0.0) Set (2-way 5-shot)...
  Validation (k=0.0) Set: Evaluated task 10/100 (Actual tasks processed so far: 10)
  Validation (k=0.0) Set: Evaluated task 20/100 (Actual tasks processed so far: 20)
  Validation (k=0.0) Set: Evaluated task 30/100 (Actual tasks processed so far: 30)
  Validation (k=0.0) Set: Evaluated task 40/100 (Actual tasks processed so far: 40)
  Validation (k=0.0) Set: Evaluated task 50/100 (Actual tasks processed so far: 50)
  Validation (k=0.0) Set: Evaluated task 60/100 (Actual tasks processed so far: 60)
  Validation (k=0.0) Set: Evaluated task 70/100 (Actual tasks processed so far: 70)
  Validation (k=0.0) Set: Evaluated task 80/100 (Actual tasks processed so far: 80)
  Validation (k=0.0) Set: Evaluated task 90/100 (Actual tasks processed so far: 90)
  Validation (k=0.0) Set: Evaluated task 100/100 (Actual tasks p

In [None]:
print(f"\nProceeding to evaluate Novel Set with the tuned FUSION_FACTOR_K = {FUSION_FACTOR_K}")

if NUM_NOVEL_CLASSES > 0 and novel_sampler is not None:
    # Determine actual N_WAY for novel set based on usable classes in sampler
    N_WAY_NOVEL_ACTUAL = len(novel_sampler.usable_fsl_class_original_indices) if novel_sampler.usable_fsl_class_original_indices and len(novel_sampler.usable_fsl_class_original_indices) < N_WAY_NOVEL else N_WAY_NOVEL

    if N_WAY_NOVEL_ACTUAL >= 2:
        print(f"\nEvaluating on Novel Set with FUSION_FACTOR_K = {FUSION_FACTOR_K}...")
        run_fsl_evaluation(
            novel_sampler,
            N_way=N_WAY_NOVEL_ACTUAL,
            K_shot=K_SHOT_NOVEL,
            M_query=M_QUERY_NOVEL,
            num_tasks=NUM_EVAL_TASKS,
            description_dict=OCT_SEMANTICS,
            fusion_k_val=FUSION_FACTOR_K, # Use the tuned k
            vision_enc=vision_encoder_f,
            text_enc=text_encoder,
            semalign_module=semalign_net_trained_instance,
            set_name="Novel"
        )
    else:
        print(f"Not enough usable novel classes ({len(novel_sampler.usable_fsl_class_original_indices)} found) for {N_WAY_NOVEL}-way evaluation. Skipping Novel Set.")
else:
    print("Skipping Novel Set evaluation (no novel classes defined or novel_sampler not created).")


Proceeding to evaluate Novel Set with the tuned FUSION_FACTOR_K = 0.5

Evaluating on Novel Set with FUSION_FACTOR_K = 0.5...

Running Few-Shot Evaluation on Novel Set (2-way 5-shot)...
  Novel Set: Evaluated task 20/200 (Actual tasks processed so far: 20)
  Novel Set: Evaluated task 40/200 (Actual tasks processed so far: 40)
  Novel Set: Evaluated task 60/200 (Actual tasks processed so far: 60)
  Novel Set: Evaluated task 80/200 (Actual tasks processed so far: 80)
  Novel Set: Evaluated task 100/200 (Actual tasks processed so far: 100)
  Novel Set: Evaluated task 120/200 (Actual tasks processed so far: 120)
  Novel Set: Evaluated task 140/200 (Actual tasks processed so far: 140)
  Novel Set: Evaluated task 160/200 (Actual tasks processed so far: 160)
  Novel Set: Evaluated task 180/200 (Actual tasks processed so far: 180)
  Novel Set: Evaluated task 200/200 (Actual tasks processed so far: 200)

Novel Set Evaluation (processed 200/200 formed tasks from 200 attempts) finished in 74.87 s