In [1]:
import os
import random
import torch
import torch.nn as nn
import timm
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms, models 
from PIL import Image
from sklearn.metrics import accuracy_score,classification_report
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time
from datetime import datetime
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [2]:
# Initial setup
data_dir = r"..\Datasets\kvasir-dataset-v2"
all_class_names = [os.path.basename(d) for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))] 
class_names = [class_name for class_name in all_class_names if class_name != "polyps"]

class CustomDataset(datasets.VisionDataset):
    def __init__(self, selected_images, class_to_idx, transform=None):
        super().__init__(root=None, transform=transform)
        self.selected_images = selected_images
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.selected_images)

    def __getitem__(self, idx):
        img_path, label = self.selected_images[idx]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label_idx = self.class_to_idx[label]
        return image, label_idx
        
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the full dataset
full_dataset = datasets.ImageFolder(data_dir, transform=transform)
classes = full_dataset.classes
class_to_idx = full_dataset.class_to_idx

# Split the dataset into training and testing sets
test_ratio = 0.2
test_size = int(len(full_dataset) * test_ratio)
train_size = len(full_dataset) - test_size

all_indices = list(range(len(full_dataset)))
random.shuffle(all_indices)

train_indices, test_indices = all_indices[:train_size], all_indices[train_size:]

train_subset = Subset(full_dataset, train_indices)
test_subset = Subset(full_dataset, test_indices)

# Ensure training dataset has no polyps
polyps_class = "polyps"
train_indices_without_polyps = [i for i in train_indices if full_dataset.imgs[i][1] != class_to_idx[polyps_class]]

train_subset_without_polyps = Subset(full_dataset, train_indices_without_polyps)

# Create a small polyps dataset with 20 images
def get_polyps_dataset(data_dir, transform, num_samples=20, target_class="polyps"):
    polyps_dir = os.path.join(data_dir, target_class)
    polyps_images = [os.path.join(polyps_dir, img) for img in os.listdir(polyps_dir) if img.endswith(('.jpg', '.png', '.jpeg'))]
    sampled_images = random.sample(polyps_images, min(num_samples, len(polyps_images)))

    selected_images = [(img_path, target_class) for img_path in sampled_images]
    return selected_images

polyps_images = get_polyps_dataset(data_dir, transform)
polyps_dataset = CustomDataset(polyps_images, class_to_idx, transform=transform)

# Create DataLoaders
train_loader = DataLoader(train_subset_without_polyps, batch_size=32, shuffle=True)
polyps_loader = DataLoader(polyps_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=32, shuffle=True)

In [3]:
# Feature Extractor (Base Network)
def build_base_network():
    base_model = timm.create_model('vit_small_patch16_224', pretrained=True)
    base_model.head = nn.Identity()
    return base_model

In [4]:
# Compute Class Prototypes
def compute_class_prototypes(train_loader, base_network, class_to_idx, save_path="class_prototypes.pth"):
    # Check if prototypes already exist
    if os.path.exists(save_path):
        print(f"Loading class prototypes from {save_path}...")
        return torch.load(save_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    base_network.eval()
    base_network.to(device)

    embeddings = {idx: [] for idx in class_to_idx.values()}

    with torch.no_grad():
        for images, labels in tqdm(train_loader, desc="Computing Class Prototypes"):
            images, labels = images.to(device), labels.to(device)
            features = base_network(images)
            for i, label in enumerate(labels):
                embeddings[label.item()].append(features[i].cpu())

    class_prototypes = {class_name: torch.mean(torch.stack(embeddings[class_to_idx[class_name]]), dim=0)
                        for class_name in class_to_idx if len(embeddings[class_to_idx[class_name]]) > 0}

    # Save computed prototypes
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    torch.save(class_prototypes, save_path)
    print(f"Class prototypes saved to {save_path}.")
    
    return class_prototypes

# Compute Polyps Prototype
def compute_polyps_prototype(polyps_loader, base_network, save_path="polyps_prototype.pth"):
    # Check if prototype already exists
    if os.path.exists(save_path):
        print(f"Loading polyps prototype from {save_path}...")
        return torch.load(save_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    base_network.eval()
    base_network.to(device)

    polyps_embeddings = []

    with torch.no_grad():
        for images, _ in tqdm(polyps_loader, desc="Computing Polyps Prototype"):
            images = images.to(device)
            features = base_network(images)
            polyps_embeddings.extend(features.cpu())

    if polyps_embeddings:
        polyps_prototype = torch.mean(torch.stack(polyps_embeddings), dim=0)
        
        # Save computed prototype
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        torch.save(polyps_prototype, save_path)
        print(f"Polyps prototype saved to {save_path}.")
        
        return polyps_prototype
    
    print("No polyps embeddings found. Prototype not computed.")
    return None



In [5]:
# Siamese Neural Network Definition
class SiameseNetwork(nn.Module):
    def __init__(self, base_network):
        super(SiameseNetwork, self).__init__()
        self.base_network = base_network
        self.fc = nn.Sequential(
            nn.Linear(base_network.embed_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img1, prototype = inputs
        feature1 = self.base_network(img1)
        distance = torch.abs(feature1 - prototype)
        similarity = self.fc(distance)
        return similarity

In [6]:
# Modified training function
def train_snn_classification(model, train_loader, class_prototypes, valid_class_names, 
                              epochs=10, save_path="snn_model.pth", log_interval=10, 
                              patience=3):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    # Create mapping for valid classes
    class_to_train_idx = {name: idx for idx, name in enumerate(valid_class_names)}
    
    print(f"\nTraining with classes: {valid_class_names}")
    print(f"Class to index mapping: {class_to_train_idx}")
    print(f"Training on device: {device}")
    
    # Initialize training metrics
    best_loss = float('inf')
    start_time = time.time()
    training_history = []
    patience_counter = 0  # Track epochs without improvement
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        batch_losses = []
        
        # Create progress bar for each epoch
        pbar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{epochs}')
        
        for batch_idx, (images, labels) in enumerate(pbar):
            images = images.to(device)
            batch_size = images.size(0)
            
            # Initialize scores tensor
            scores = torch.zeros(batch_size, len(valid_class_names), device=device)
            
            # Get original class names from labels
            original_class_names = [classes[label.item()] for label in labels]
            
            # Create new labels for valid classes only
            new_labels = torch.tensor([class_to_train_idx[class_name] 
                                       for class_name in original_class_names 
                                       if class_name in class_to_train_idx], 
                                      device=device)
            
            if len(new_labels) > 0:
                # Compute similarity scores
                for idx, class_name in enumerate(valid_class_names):
                    if class_name in class_prototypes:
                        prototype = class_prototypes[class_name].to(device)
                        prototype_expanded = prototype.expand(batch_size, -1)
                        similarities = model([images, prototype_expanded])
                        scores[:, idx] = similarities.squeeze()
                
                loss = criterion(scores, new_labels)
                epoch_loss += loss.item()
                batch_losses.append(loss.item())
                num_batches += 1
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update progress bar
                if batch_idx % log_interval == 0:
                    current_loss = sum(batch_losses[-log_interval:]) / min(log_interval, len(batch_losses))
                    pbar.set_postfix({
                        'loss': f'{current_loss:.4f}',
                        'avg_loss': f'{epoch_loss/num_batches:.4f}'
                    })
        
        # Compute epoch metrics
        if num_batches > 0:
            avg_epoch_loss = epoch_loss / num_batches
            elapsed_time = time.time() - start_time
            
            # Save the best model if the current epoch has the best loss
            if avg_epoch_loss < best_loss:
                best_loss = avg_epoch_loss
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': best_loss,
                }, save_path)
                print(f"\nNew best model saved! Loss: {best_loss:.4f}")
                patience_counter = 0  # Reset patience counter
            else:
                patience_counter += 1
                print(f"\nNo improvement in loss for {patience_counter} epoch(s). Best loss: {best_loss:.4f}")
            
            # Early stopping
            if patience_counter >= patience:
                print("\nEarly stopping triggered. Training halted.")
                break
            
            # Log training history
            training_history.append({
                'epoch': epoch + 1,
                'loss': avg_epoch_loss,
                'time': elapsed_time
            })
            
            print(f"\nEpoch {epoch + 1}/{epochs}")
            print(f"Average Loss: {avg_epoch_loss:.4f}")
            print(f"Time Elapsed: {elapsed_time:.2f}s")
            print("-" * 50)
    
    # Final training summary
    print("\nTraining Complete!")
    print(f"Best Loss: {best_loss:.4f}")
    print(f"Total Training Time: {time.time() - start_time:.2f}s")
    print(f"Final model saved to {save_path}")
    
    return model, training_history

def evaluate_model(test_loader, model, class_prototypes, class_names):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    model.to(device)
    
    all_preds = []
    all_targets = []
    total_samples = len(test_loader.dataset)
    
    print(f"\nEvaluating model on {total_samples} samples...")
    start_time = time.time()
    
    with torch.no_grad():
        # Create progress bar for evaluation
        pbar = tqdm(test_loader, desc='Evaluating')
        
        for images, labels in pbar:
            images = images.to(device)
            batch_size = images.size(0)
            scores = torch.zeros(batch_size, len(class_names), device=device)
            
            for i, class_name in enumerate(class_names):
                if class_name in class_prototypes:
                    prototype = class_prototypes[class_name].to(device)
                    prototype_expanded = prototype.expand(batch_size, -1)
                    similarities = model([images, prototype_expanded])
                    scores[:, i] = similarities.squeeze()
            
            _, predicted = torch.max(scores, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(labels.cpu().numpy())
            
            # Update progress bar with current batch stats
            pbar.set_postfix({
                'processed': f'{len(all_preds)}/{total_samples}'
            })
    
    # Calculate and print detailed metrics
    elapsed_time = time.time() - start_time
    accuracy = accuracy_score(all_targets, all_preds) * 100
    
    print("\nEvaluation Results:")
    print("-" * 50)
    print(f"Overall Accuracy: {accuracy:.2f}%")
    print(f"Evaluation Time: {elapsed_time:.2f}s")
    print("\nDetailed Classification Report:")
    print(classification_report(all_targets, all_preds, target_names=class_names))
    
    return {
        'accuracy': accuracy,
        'predictions': all_preds,
        'targets': all_targets,
        'evaluation_time': elapsed_time
    }

In [7]:
if __name__ == "__main__":
    # Compute class prototypes on training dataset without polyps
    print("Computing class prototypes...")
    base_network = build_base_network()
    class_prototypes = compute_class_prototypes(train_loader, base_network, class_to_idx)
    print("Class prototypes saved.")

    # Path to the pre-trained Siamese model
    snn_model_path = "./models/snn_model.pth"

    if os.path.exists(snn_model_path):
        print(f"Loading pre-trained Siamese Network from {snn_model_path}...")
        siamese_model = SiameseNetwork(base_network)
        siamese_model.load_state_dict(torch.load(snn_model_path))
    else:
        # Initialize and train the Siamese Network on non-polyps training dataset
        siamese_model = SiameseNetwork(base_network)
        print("Training the Siamese Network...")
        siamese_model = train_snn_classification(
            siamese_model, 
            train_loader, 
            class_prototypes, 
            class_names,  # Using only non-polyps classes
            epochs=10, 
            save_path=snn_model_path
        )

    # Evaluate the initial model
    print("\nEvaluating initial model...")
    evaluate_model(test_loader, siamese_model, class_prototypes, class_names)

    # Compute polyps prototype and add to class prototypes
    print("\nComputing polyps prototype...")
    polyps_prototype = compute_polyps_prototype(polyps_loader, base_network)
    if polyps_prototype is not None:
        class_prototypes['polyps'] = polyps_prototype

    # Path to the fine-tuned Siamese model
    fine_tuned_model_path = "./models/snn_finetuned_polyps.pth"

    if os.path.exists(fine_tuned_model_path):
        print(f"Loading fine-tuned Siamese Network from {fine_tuned_model_path}...")
        few_shot_model = SiameseNetwork(base_network)
        few_shot_model.load_state_dict(torch.load(fine_tuned_model_path))
    else:
        # Fine-tune on polyps
        print("\nFine-tuning on polyps...")
        few_shot_model = train_snn_classification(
            siamese_model, 
            polyps_loader, 
            class_prototypes,
            all_class_names,  # Now including polyps
            epochs=5,
            save_path=fine_tuned_model_path
        )

    # Evaluate the fine-tuned model
    print("\nEvaluating fine-tuned model...")
    evaluate_model(test_loader, few_shot_model, class_prototypes, all_class_names)


Computing class prototypes...
Loading class prototypes from class_prototypes.pth...
Class prototypes saved.
Loading pre-trained Siamese Network from ./models/snn_model.pth...


  return torch.load(save_path)
  siamese_model.load_state_dict(torch.load(snn_model_path))


RuntimeError: Error(s) in loading state_dict for SiameseNetwork:
	Missing key(s) in state_dict: "base_network.cls_token", "base_network.pos_embed", "base_network.patch_embed.proj.weight", "base_network.patch_embed.proj.bias", "base_network.blocks.0.norm1.weight", "base_network.blocks.0.norm1.bias", "base_network.blocks.0.attn.qkv.weight", "base_network.blocks.0.attn.qkv.bias", "base_network.blocks.0.attn.proj.weight", "base_network.blocks.0.attn.proj.bias", "base_network.blocks.0.norm2.weight", "base_network.blocks.0.norm2.bias", "base_network.blocks.0.mlp.fc1.weight", "base_network.blocks.0.mlp.fc1.bias", "base_network.blocks.0.mlp.fc2.weight", "base_network.blocks.0.mlp.fc2.bias", "base_network.blocks.1.norm1.weight", "base_network.blocks.1.norm1.bias", "base_network.blocks.1.attn.qkv.weight", "base_network.blocks.1.attn.qkv.bias", "base_network.blocks.1.attn.proj.weight", "base_network.blocks.1.attn.proj.bias", "base_network.blocks.1.norm2.weight", "base_network.blocks.1.norm2.bias", "base_network.blocks.1.mlp.fc1.weight", "base_network.blocks.1.mlp.fc1.bias", "base_network.blocks.1.mlp.fc2.weight", "base_network.blocks.1.mlp.fc2.bias", "base_network.blocks.2.norm1.weight", "base_network.blocks.2.norm1.bias", "base_network.blocks.2.attn.qkv.weight", "base_network.blocks.2.attn.qkv.bias", "base_network.blocks.2.attn.proj.weight", "base_network.blocks.2.attn.proj.bias", "base_network.blocks.2.norm2.weight", "base_network.blocks.2.norm2.bias", "base_network.blocks.2.mlp.fc1.weight", "base_network.blocks.2.mlp.fc1.bias", "base_network.blocks.2.mlp.fc2.weight", "base_network.blocks.2.mlp.fc2.bias", "base_network.blocks.3.norm1.weight", "base_network.blocks.3.norm1.bias", "base_network.blocks.3.attn.qkv.weight", "base_network.blocks.3.attn.qkv.bias", "base_network.blocks.3.attn.proj.weight", "base_network.blocks.3.attn.proj.bias", "base_network.blocks.3.norm2.weight", "base_network.blocks.3.norm2.bias", "base_network.blocks.3.mlp.fc1.weight", "base_network.blocks.3.mlp.fc1.bias", "base_network.blocks.3.mlp.fc2.weight", "base_network.blocks.3.mlp.fc2.bias", "base_network.blocks.4.norm1.weight", "base_network.blocks.4.norm1.bias", "base_network.blocks.4.attn.qkv.weight", "base_network.blocks.4.attn.qkv.bias", "base_network.blocks.4.attn.proj.weight", "base_network.blocks.4.attn.proj.bias", "base_network.blocks.4.norm2.weight", "base_network.blocks.4.norm2.bias", "base_network.blocks.4.mlp.fc1.weight", "base_network.blocks.4.mlp.fc1.bias", "base_network.blocks.4.mlp.fc2.weight", "base_network.blocks.4.mlp.fc2.bias", "base_network.blocks.5.norm1.weight", "base_network.blocks.5.norm1.bias", "base_network.blocks.5.attn.qkv.weight", "base_network.blocks.5.attn.qkv.bias", "base_network.blocks.5.attn.proj.weight", "base_network.blocks.5.attn.proj.bias", "base_network.blocks.5.norm2.weight", "base_network.blocks.5.norm2.bias", "base_network.blocks.5.mlp.fc1.weight", "base_network.blocks.5.mlp.fc1.bias", "base_network.blocks.5.mlp.fc2.weight", "base_network.blocks.5.mlp.fc2.bias", "base_network.blocks.6.norm1.weight", "base_network.blocks.6.norm1.bias", "base_network.blocks.6.attn.qkv.weight", "base_network.blocks.6.attn.qkv.bias", "base_network.blocks.6.attn.proj.weight", "base_network.blocks.6.attn.proj.bias", "base_network.blocks.6.norm2.weight", "base_network.blocks.6.norm2.bias", "base_network.blocks.6.mlp.fc1.weight", "base_network.blocks.6.mlp.fc1.bias", "base_network.blocks.6.mlp.fc2.weight", "base_network.blocks.6.mlp.fc2.bias", "base_network.blocks.7.norm1.weight", "base_network.blocks.7.norm1.bias", "base_network.blocks.7.attn.qkv.weight", "base_network.blocks.7.attn.qkv.bias", "base_network.blocks.7.attn.proj.weight", "base_network.blocks.7.attn.proj.bias", "base_network.blocks.7.norm2.weight", "base_network.blocks.7.norm2.bias", "base_network.blocks.7.mlp.fc1.weight", "base_network.blocks.7.mlp.fc1.bias", "base_network.blocks.7.mlp.fc2.weight", "base_network.blocks.7.mlp.fc2.bias", "base_network.blocks.8.norm1.weight", "base_network.blocks.8.norm1.bias", "base_network.blocks.8.attn.qkv.weight", "base_network.blocks.8.attn.qkv.bias", "base_network.blocks.8.attn.proj.weight", "base_network.blocks.8.attn.proj.bias", "base_network.blocks.8.norm2.weight", "base_network.blocks.8.norm2.bias", "base_network.blocks.8.mlp.fc1.weight", "base_network.blocks.8.mlp.fc1.bias", "base_network.blocks.8.mlp.fc2.weight", "base_network.blocks.8.mlp.fc2.bias", "base_network.blocks.9.norm1.weight", "base_network.blocks.9.norm1.bias", "base_network.blocks.9.attn.qkv.weight", "base_network.blocks.9.attn.qkv.bias", "base_network.blocks.9.attn.proj.weight", "base_network.blocks.9.attn.proj.bias", "base_network.blocks.9.norm2.weight", "base_network.blocks.9.norm2.bias", "base_network.blocks.9.mlp.fc1.weight", "base_network.blocks.9.mlp.fc1.bias", "base_network.blocks.9.mlp.fc2.weight", "base_network.blocks.9.mlp.fc2.bias", "base_network.blocks.10.norm1.weight", "base_network.blocks.10.norm1.bias", "base_network.blocks.10.attn.qkv.weight", "base_network.blocks.10.attn.qkv.bias", "base_network.blocks.10.attn.proj.weight", "base_network.blocks.10.attn.proj.bias", "base_network.blocks.10.norm2.weight", "base_network.blocks.10.norm2.bias", "base_network.blocks.10.mlp.fc1.weight", "base_network.blocks.10.mlp.fc1.bias", "base_network.blocks.10.mlp.fc2.weight", "base_network.blocks.10.mlp.fc2.bias", "base_network.blocks.11.norm1.weight", "base_network.blocks.11.norm1.bias", "base_network.blocks.11.attn.qkv.weight", "base_network.blocks.11.attn.qkv.bias", "base_network.blocks.11.attn.proj.weight", "base_network.blocks.11.attn.proj.bias", "base_network.blocks.11.norm2.weight", "base_network.blocks.11.norm2.bias", "base_network.blocks.11.mlp.fc1.weight", "base_network.blocks.11.mlp.fc1.bias", "base_network.blocks.11.mlp.fc2.weight", "base_network.blocks.11.mlp.fc2.bias", "base_network.norm.weight", "base_network.norm.bias", "fc.0.weight", "fc.0.bias", "fc.2.weight", "fc.2.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "loss". 