# Few Shot Sampling of Blood Smear Images

## Custom Dataloader

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [None]:

def resize_and_pad(image, target_size=(224, 224), fill=0):
    """
    Resize an image to fit within the target size while preserving the aspect ratio,
    then pad the shorter sides with a constant value to reach the target size.

    Args:
        image (PIL.Image): Input image.
        target_size (tuple): Target size (width, height), default (224, 224).
        fill (int or tuple): Padding fill value (e.g., 0 for black, 255 for white).

    Returns:
        PIL.Image: Resized and padded image of size target_size.
    """
    # Get original dimensions
    width, height = image.size
    target_width, target_height = target_size

    # Calculate scaling factor to fit within target size
    scale = min(target_width / width, target_height / height)
    new_width = int(width * scale)
    new_height = int(height * scale)

    # Resize image with correct (height, width) order
    resized_image = transforms.functional.resize(image, (new_height, new_width))

    # Calculate padding to center the image
    padding_left = (target_width - new_width) // 2
    padding_top = (target_height - new_height) // 2
    # Ensure padding adds up exactly to the difference
    padding_right = target_width - new_width - padding_left
    padding_bottom = target_height - new_height - padding_top

    # Apply padding
    padded_image = transforms.functional.pad(
        resized_image,
        padding=(padding_left, padding_top, padding_right, padding_bottom),
        fill=fill
    )

    return padded_image

In [None]:
# Basic Image transformation
def get_base_transforms(target_size=(320, 320), use_grayscale=False):
    """
    Returns a composed set of basic image transformations for preprocessing input images.

    Parameters:
    - target_size (tuple): The desired output size (height, width) of the image after resizing and padding.
    - use_grayscale (bool): If True, converts the image to grayscale with 3 channels before applying other transformations.

    Returns:
    - torchvision.transforms.Compose: A sequence of transformations including:
        - Optional grayscale conversion with 3 output channels,
        - Resizing and padding the image to match the target size,
        - Conversion to tensor,
        - Normalization using ImageNet mean and standard deviation.
    """
    base_transforms = [
        transforms.Lambda(lambda img: resize_and_pad(img, target_size=target_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]

    if use_grayscale:
        base_transforms.insert(0, transforms.Grayscale(num_output_channels=3))  # Keep 3 channels for compatibility
    return transforms.Compose(base_transforms)


# Data augmentation transforms
def get_augmentation_transforms():
    """
    Returns a composed set of data augmentation transformations to artificially expand the training dataset.

    This function applies a series of random transformations to simulate variations in brightness, contrast, orientation,
    and color mode, helping the model generalize better.

    Returns:
    - torchvision.transforms.Compose: A sequence of transformations including:
        - Random brightness and contrast adjustment (ColorJitter),
        - Random horizontal and vertical flipping,
        - Random rotation by up to ±10 degrees,
        - Random conversion to grayscale with a 20% probability.
    """
    return transforms.Compose([
        transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Simulate lighting/stain variations
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.RandomGrayscale(p=0.2)  # Optional: Randomly apply grayscale as part of augmentation
    ])


In [None]:
class FewShotDataset(Dataset):
    def __init__(self, data_dir, split, num_ways=5, num_support=5,
                 num_query=10, num_episodes=100, target_size=(224, 224),
                 use_grayscale=False,
                 augment=False,
                 ):
        """
        Args:
            data_dir (str): Path to dataset directory
            split (str): One of 'train', 'validation', or 'test'
            num_ways (int): Number of classes per episode
            num_support (int): Number of support samples per class (i.e. number of shots)
            num_query (int): Number of query samples per class
            num_episodes (int): Number of episodes per epoch
            use_grayscale(bool),  Use grayscale or not
            augment(bool),        For data augmentation technique
        """
        self.split_dir = os.path.join(data_dir, split)
        self.num_ways = num_ways
        self.num_support = num_support
        self.num_query = num_query
        self.num_episodes = num_episodes
        base_transform = get_base_transforms(target_size, use_grayscale)
        if augment:
            augmentation_transform = get_augmentation_transforms()
            self.transform = transforms.Compose([augmentation_transform, base_transform])
        else:
            self.transform = base_transform

        # Load class directories and their images
        self.classes = [c for c in os.listdir(self.split_dir)
                       if os.path.isdir(os.path.join(self.split_dir, c))]
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}  # Map class names to indices
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}  # Map indices to class names
        self.class_images = {
            c: [os.path.join(self.split_dir, c, img)
                for img in os.listdir(os.path.join(self.split_dir, c))]
            for c in self.classes
        }

    def __len__(self):
        return self.num_episodes

    def __getitem__(self, index):
        # Randomly select N classes for this episode
        selected_classes = np.random.choice(self.classes, self.num_ways, replace=False)

        support_images = []
        support_labels = []
        query_images = []
        query_labels = []

        for label_idx, class_name in enumerate(selected_classes):
            all_images = self.class_images[class_name]
            if len(all_images) < self.num_support + self.num_query:
                raise ValueError(
                    f"Class {class_name} has only {len(all_images)} images. "
                    f"Need at least {self.num_support + self.num_query}."
                )

            # Randomly select support and query images
            selected_indices = np.random.choice(
                len(all_images),
                self.num_support + self.num_query,
                replace=False #True#########################################################################
            )
            support_paths = [all_images[i] for i in selected_indices[:self.num_support]]
            query_paths = [all_images[i] for i in selected_indices[self.num_support:]]

            # Load and transform support images
            for path in support_paths:
                img = Image.open(path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                support_images.append(img)
                support_labels.append(label_idx)

            # Load and transform query images
            for path in query_paths:
                img = Image.open(path).convert('RGB')
                if self.transform:
                    img = self.transform(img)
                query_images.append(img)
                query_labels.append(label_idx)

        # Shuffle the support and query sets
        support_indices = np.arange(len(support_images))
        query_indices = np.arange(len(query_images))
        np.random.shuffle(support_indices)
        np.random.shuffle(query_indices)

        support_images = [support_images[i] for i in support_indices]
        support_labels = [support_labels[i] for i in support_indices]
        query_images = [query_images[i] for i in query_indices]
        query_labels = [query_labels[i] for i in query_indices]

        # Convert lists to tensors
        support_set = (
            torch.stack(support_images),
            torch.tensor(support_labels, dtype=torch.long)
        )
        query_set = (
            torch.stack(query_images),
            torch.tensor(query_labels, dtype=torch.long)
        )
        # Store the selected class names for this episode
        selected_classes = [str(cls) for cls in selected_classes]
        episode_classes = selected_classes

        return support_set, query_set, episode_classes

In [None]:
def get_data_loader(dataset, batch_size=1, shuffle=True):
    """
    Returns DataLoader for the dataset.
    Note: Batch size should typically be 1 for few-shot learning,
    as each episode is a separate task.
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=2,
        pin_memory=True
    )

## Modelling

### Augmentation Based

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_ways = 5
num_shots_eval = [1, 5, 10]
data_dir = '/content/drive/MyDrive/Computer vision with few shot sampling focus group/data_set'

#### SimCLR and CutMix

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

class SimCLRDataset(Dataset):
    """
    Dataset for SimCLR pretraining.

    This dataset loads images from a specified directory structure and returns two independently augmented
    views of the same image, as required by SimCLR for contrastive learning.

    Directory structure is expected as:
    data_dir/
        split/ (e.g., train/)
            class_1/
                img1.jpg
                img2.jpg
            class_2/
                ...

    Attributes:
    - data_dir (str): Root directory containing image data.
    - split (str): Subdirectory (e.g., 'train', 'val') to load data from.
    - simclr_transform (callable): Transformations to apply to the images (should include data augmentation).
    - image_paths (list): Full paths to all images in the specified split.
    """
    def __init__(self, data_dir, split, simclr_transform):
        self.split_dir = os.path.join(data_dir, split)
        self.simclr_transform = simclr_transform
        self.image_paths = []
        for cls in os.listdir(self.split_dir):
            cls_dir = os.path.join(self.split_dir, cls)
            if os.path.isdir(cls_dir):
                self.image_paths.extend([os.path.join(cls_dir, img) for img in os.listdir(cls_dir)])

    def __len__(self):
        """
        Returns:
        - int: Total number of images in the dataset.
        """
        return len(self.image_paths)

    def __getitem__(self, index):
        """
        Loads an image and applies SimCLR augmentations to generate two distinct views.

        Parameters:
        - index (int): Index of the image to retrieve.

        Returns:
        - tuple: A tuple containing two augmented views of the same image (img1, img2).
        """
        img_path = self.image_paths[index]
        img = Image.open(img_path).convert('RGB')
        img1 = self.simclr_transform(img)
        img2 = self.simclr_transform(img)
        return img1, img2


class SimCLR(nn.Module):
    """
    SimCLR model using a ResNet backbone and a projection head.

    The backbone (e.g., ResNet50) is used to extract image features,
    and the projection head maps those features into a space suitable for contrastive loss.

    Attributes:
    - backbone (nn.Module): Feature extractor network with the final classification layer removed.
    - projection_head (nn.Sequential): MLP head that projects backbone outputs into a contrastive embedding space.
    """
    def __init__(self, backbone, projection_dim=128):
        """
        Initializes the SimCLR model.

        Parameters:
        - backbone (nn.Module): Pretrained ResNet model to use as feature extractor.
        - projection_dim (int): Dimensionality of the output projection space.
        """
        super().__init__()
        self.backbone = backbone
        self.backbone.fc = nn.Identity()  # Remove original classifier
        self.projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        """
        Forward pass of the SimCLR model.

        Parameters:
        - x (Tensor): Input image batch of shape (batch_size, 3, H, W).

        Returns:
        - Tensor: Projected features of shape (batch_size, projection_dim).
        """
        h = self.backbone(x)
        z = self.projection_head(h)
        return z

def nt_xent_loss(z, tau=0.5):
    """
    Computes the Normalized Temperature-scaled Cross Entropy (NT-Xent) loss used in SimCLR.

    This loss encourages positive pairs (two augmented views of the same image) to have similar representations
    while pushing apart representations of all other (negative) pairs within the batch.

    Assumes the input tensor `z` contains 2N feature vectors, where the first N and second N
    are corresponding positive pairs (i.e., for each i in [0, N), z[i] and z[i+N] are a positive pair).

    Parameters:
    - z (Tensor): A tensor of shape (2N, D), where D is the embedding dimension. Contains projections for all views.
    - tau (float): Temperature scaling factor used to soften the distribution in the softmax.

    Returns:
    - Tensor: The scalar NT-Xent loss value.
    """
    z = F.normalize(z, dim=1)  # Normalize embeddings to unit vectors
    sim_matrix = torch.mm(z, z.t()) / tau  # Compute pairwise cosine similarities

    batch_size = z.size(0) // 2  # N
    pos_indices = torch.arange(batch_size, device=z.device)
    pos_indices = torch.cat([pos_indices + batch_size, pos_indices])  # Indices of positive pairs

    log_softmax = F.log_softmax(sim_matrix, dim=1)
    pos_sim = log_softmax[torch.arange(2 * batch_size), pos_indices]  # Log-prob of positives

    loss = -pos_sim.mean()  # Mean of negative log-likelihoods for positives
    return loss


# SimCLR transform
simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



In [None]:
# --- CutMix and Fine-tuning Components ---

class CustomClassificationDataset(Dataset):
    """
    Custom dataset for image classification tasks using provided image paths and labels.

    Each sample in the dataset is an image-label pair, optionally transformed using a provided transform.

    Attributes:
    - image_paths (list): List of file paths to image files.
    - labels (list): List of corresponding class labels for each image.
    - transform (callable, optional): Optional transformation to apply to each image.
    """
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        """
        Returns:
        - int: Number of samples in the dataset.
        """
        return len(self.image_paths)

    def __getitem__(self, index):
        """
        Loads an image and its corresponding label by index.

        Parameters:
        - index (int): Index of the sample to retrieve.

        Returns:
        - tuple: A tuple (image, label), where the image may be transformed.
        """
        img_path = self.image_paths[index]
        label = self.labels[index]
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label


def cutmix(data, targets, alpha=1.0):
    """
    Applies CutMix augmentation on a batch of images and their labels.

    CutMix replaces a random region of each image with a patch from another image
    and mixes the corresponding labels proportionally.

    Parameters:
    - data (Tensor): A batch of images of shape (B, C, H, W).
    - targets (Tensor): Corresponding labels of shape (B,).
    - alpha (float): Hyperparameter for the Beta distribution used to sample the mixing ratio.

    Returns:
    - tuple: Augmented images and a tuple of (original_targets, mixed_targets, lambda).
    """
    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]
    lam = np.random.beta(alpha, alpha)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bby1:bby2, bbx1:bbx2] = shuffled_data[:, :, bby1:bby2, bbx1:bbx2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size(-1) * data.size(-2)))
    return data, (targets, shuffled_targets, lam)


def rand_bbox(size, lam):
    """
    Generates a random bounding box for CutMix based on the lambda value.

    Parameters:
    - size (tuple): Size of the input tensor, expected to be (B, C, H, W).
    - lam (float): Lambda value sampled from a Beta distribution for determining cutout area.

    Returns:
    - tuple: Coordinates of the bounding box (bbx1, bby1, bbx2, bby2).
    """
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2


class Classifier(nn.Module):
    """
    Image classification model using a ResNet50 backbone with a custom classification head.

    The ResNet50 backbone is used for feature extraction (with the final classification layer removed),
    and a new linear layer is added to map the extracted features to the desired number of classes.

    Attributes:
    - backbone (nn.Module): ResNet50 feature extractor with the final layer removed.
    - classifier (nn.Linear): Linear classification layer mapping features to class logits.
    """
    def __init__(self, backbone, num_classes):
        """
        Initializes the classifier model.

        Parameters:
        - backbone (nn.Module): Pretrained ResNet50 model to be used as the backbone.
        - num_classes (int): Number of output classes.
        """
        super().__init__()
        self.backbone = backbone
        self.backbone.fc = nn.Identity()
        self.classifier = nn.Linear(2048, num_classes)

    def forward(self, x):
        """
        Forward pass of the classifier.

        Parameters:
        - x (Tensor): Input image batch of shape (B, 3, H, W).

        Returns:
        - Tensor: Logits for each class of shape (B, num_classes).
        """
        features = self.backbone(x)
        logits = self.classifier(features)
        return logits

# Base transform for fine-tuning and evaluation
base_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])



In [None]:
# --- Dataset Preparation ---

data_dir = "/content/drive/MyDrive/Computer vision with few shot sampling focus group/data_set"  # Your dataset path

# SimCLR Pretraining Dataset
simclr_dataset = SimCLRDataset(data_dir, split='train', simclr_transform=simclr_transform)
simclr_loader = DataLoader(simclr_dataset, batch_size=64, shuffle=True, num_workers=2)

# Split 'test' split for fine-tuning and evaluation
test_dir = os.path.join(data_dir, 'test')
test_classes = [c for c in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, c))]
class_to_idx = {cls: idx for idx, cls in enumerate(test_classes)}

fine_tune_image_paths = []
fine_tune_labels = []
eval_image_paths = []
eval_labels = []

for cls in test_classes:
    cls_dir = os.path.join(test_dir, cls)
    images = [os.path.join(cls_dir, img) for img in os.listdir(cls_dir)]
    np.random.shuffle(images)
    split_idx = int(0.8 * len(images))  # 80% for fine-tuning, 20% for evaluation
    fine_tune_image_paths.extend(images[:split_idx])
    fine_tune_labels.extend([class_to_idx[cls]] * split_idx)
    eval_image_paths.extend(images[split_idx:])
    eval_labels.extend([class_to_idx[cls]] * (len(images) - split_idx))

fine_tune_dataset = CustomClassificationDataset(fine_tune_image_paths, fine_tune_labels, transform=base_transform)
eval_dataset = CustomClassificationDataset(eval_image_paths, eval_labels, transform=base_transform)

fine_tune_loader = DataLoader(fine_tune_dataset, batch_size=32, shuffle=True, num_workers=4)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False, num_workers=4)



In [None]:
# --- SimCLR Pretraining ---

print("Starting SimCLR Pre-training...")
backbone = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
simclr_model = SimCLR(backbone).to(device)
optimizer = torch.optim.Adam(simclr_model.parameters(), lr=0.001)

num_epochs = 10
for epoch in range(num_epochs):
    simclr_model.train()
    total_loss = 0
    for img1, img2 in simclr_loader:
        img1, img2 = img1.to(device), img2.to(device)
        z1 = simclr_model(img1)
        z2 = simclr_model(img2)
        z = torch.cat([z1, z2], dim=0)
        loss = nt_xent_loss(z)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(simclr_loader)
    print(f"SimCLR Pretraining Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

torch.save(backbone.state_dict(), "simclr_backbone.pth")

Starting SimCLR Pre-training...
SimCLR Pretraining Epoch 1/10, Loss: 4.0404
SimCLR Pretraining Epoch 2/10, Loss: 3.8645
SimCLR Pretraining Epoch 3/10, Loss: 3.8318
SimCLR Pretraining Epoch 4/10, Loss: 3.7105
SimCLR Pretraining Epoch 5/10, Loss: 3.6864
SimCLR Pretraining Epoch 6/10, Loss: 3.6007
SimCLR Pretraining Epoch 7/10, Loss: 3.5808
SimCLR Pretraining Epoch 8/10, Loss: 3.5730
SimCLR Pretraining Epoch 9/10, Loss: 3.5850
SimCLR Pretraining Epoch 10/10, Loss: 3.5598


In [None]:
# --- Fine-tuning with CutMix ---

print("Starting Fine-tuning with CutMix...")
backbone = models.resnet50(weights=None) # Initialize without pre-trained weights
backbone.load_state_dict(torch.load("simclr_backbone.pth"), strict=False) # Load with strict=False
classifier = Classifier(backbone, num_classes=len(test_classes)).to(device)
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.0001)  # Lower LR for fine-tuning
criterion = nn.CrossEntropyLoss()

num_epochs = 15
for epoch in range(num_epochs):
    classifier.train()
    total_loss = 0
    total_acc = 0
    for data, targets in fine_tune_loader:
        data, targets = data.to(device), targets.to(device)
        if np.random.rand() < 0.5:  # Apply CutMix 50% of the time
            data, (targets_a, targets_b, lam) = cutmix(data, targets)
            logits = classifier(data)
            loss = lam * criterion(logits, targets_a) + (1 - lam) * criterion(logits, targets_b)
        else:
            logits = classifier(data)
            loss = criterion(logits, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        if isinstance(targets, tuple):  # Handle CutMix case
            targets = targets_a  # Use primary targets for accuracy
        total_acc += (preds == targets).float().mean().item()
    avg_loss = total_loss / len(fine_tune_loader)
    avg_acc = total_acc / len(fine_tune_loader)
    print(f"Fine-tuning Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.4f}")

torch.save(classifier.state_dict(), "fine_tuned_model.pth")



Starting Fine-tuning with CutMix...
Fine-tuning Epoch 1/15, Loss: 1.3961, Acc: 0.2844
Fine-tuning Epoch 2/15, Loss: 1.2086, Acc: 0.4672
Fine-tuning Epoch 3/15, Loss: 1.0619, Acc: 0.5547
Fine-tuning Epoch 4/15, Loss: 0.9623, Acc: 0.6750
Fine-tuning Epoch 5/15, Loss: 0.9100, Acc: 0.6125
Fine-tuning Epoch 6/15, Loss: 0.8542, Acc: 0.7203
Fine-tuning Epoch 7/15, Loss: 0.8035, Acc: 0.6938
Fine-tuning Epoch 8/15, Loss: 0.9759, Acc: 0.6500
Fine-tuning Epoch 9/15, Loss: 0.9932, Acc: 0.6484
Fine-tuning Epoch 10/15, Loss: 0.5999, Acc: 0.7734
Fine-tuning Epoch 11/15, Loss: 0.8872, Acc: 0.6875
Fine-tuning Epoch 12/15, Loss: 0.5618, Acc: 0.8953
Fine-tuning Epoch 13/15, Loss: 0.6807, Acc: 0.8484
Fine-tuning Epoch 14/15, Loss: 0.4753, Acc: 0.8125
Fine-tuning Epoch 15/15, Loss: 0.5359, Acc: 0.8625


In [None]:
# --- Evaluation ---

print("Starting Evaluation...")
classifier.load_state_dict(torch.load("fine_tuned_model.pth"))
classifier.eval()

correct = 0
total = 0
with torch.no_grad():
    for data, targets in eval_loader:
        data, targets = data.to(device), targets.to(device)
        logits = classifier(data)
        preds = torch.argmax(logits, dim=1)
        correct += (preds == targets).sum().item()
        total += targets.size(0)
accuracy = correct / total
print(f"Evaluation Accuracy on Test Split: {accuracy:.4f}")

Starting Evaluation...
Evaluation Accuracy on Test Split: 0.4444
