In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTMSNForImageClassification, ViTImageProcessor
from PIL import Image
import os
import numpy as np
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import logging
import random

In [None]:
# Logging and device configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Environment setup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_DIR = r"..\Datasets\kvasir-dataset-v2"
SAVE_DIR = "models"
BATCH_SIZE = 16
NUM_EPOCHS = 10
LEARNING_RATE = 2e-5


In [None]:
# Saving and loading utilities
def save_model(model, path, optimizer=None, epoch=None, loss=None):
    save_dict = {
        'model_state_dict': model.state_dict(),
    }
    if optimizer:
        save_dict['optimizer_state_dict'] = optimizer.state_dict()
    if epoch is not None:
        save_dict['epoch'] = epoch
    if loss is not None:
        save_dict['loss'] = loss
    
    torch.save(save_dict, path)
    logger.info(f"Model saved to {path}")

def load_model(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        logger.info(f"Model loaded from {path}")
        return True

In [None]:
class KvasirDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.processor = ViTImageProcessor.from_pretrained('facebook/vit-msn-small', do_rescale=False)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            img = Image.open(self.image_paths[idx]).convert('RGB')
            if self.transform:
                img = self.transform(img)
                img = img.numpy()
            inputs = self.processor(images=img, return_tensors="pt")
            return inputs['pixel_values'].squeeze(), self.labels[idx]
        except Exception as e:
            logger.error(f"Error loading image {self.image_paths[idx]}: {str(e)}")
            raise
            
def prepare_data(data_dir):
    try:
        classes = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
        if 'polyps' in classes:
            classes.remove('polyps')
        print(f"Number of classes: {len(classes)}")
        image_paths = []
        labels = []
        
        for idx, class_name in enumerate(classes):
            class_path = os.path.join(data_dir, class_name)
            class_images = [os.path.join(class_path, img) for img in os.listdir(class_path) 
                          if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
            image_paths.extend(class_images)
            labels.extend([idx] * len(class_images))
        
        logger.info(f"Found {len(image_paths)} images across {len(classes)} classes")
        return train_test_split(image_paths, labels, test_size=0.2, random_state=42)
    
    except Exception as e:
        logger.error(f"Error preparing data: {str(e)}")
        raise


In [None]:
class DatasetManager:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.classes = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
        
    def stratified_split(self, test_size=80):
        # Collect all image paths and labels
        all_image_paths = {}
        for class_name in self.classes:
            class_path = os.path.join(self.data_dir, class_name)
            class_images = [os.path.join(class_path, img) for img in os.listdir(class_path) 
                            if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
            all_image_paths[class_name] = class_images
        
        # Stratified test set selection
        test_paths = []
        test_labels = []
        remaining_paths = {}
        
        for idx, (class_name, images) in enumerate(all_image_paths.items()):
            # Calculate test images per class (proportional to 100 total)
            class_test_size = max(1, test_size // len(self.classes))
            
            # Randomly select test images
            test_class_images = random.sample(images, min(class_test_size, len(images)))
            
            # Add to test set
            test_paths.extend(test_class_images)
            test_labels.extend([idx] * len(test_class_images))
            
            # Remove test images from original set
            remaining_images = [img for img in images if img not in test_class_images]
            remaining_paths[class_name] = remaining_images
        
        # Separate polyps class
        polyps_paths = all_image_paths.get('polyps', [])
        
        # Prepare remaining paths for train/val split
        all_remaining_paths = []
        all_remaining_labels = []
        for class_name, paths in remaining_paths.items():
            if class_name != 'polyps':
                all_remaining_paths.extend(paths)
                all_remaining_labels.extend([self.classes.index(class_name)] * len(paths))
        
        # Split remaining images into train and validation
        train_paths, val_paths, train_labels, val_labels = train_test_split(
            all_remaining_paths, all_remaining_labels, test_size=0.2, random_state=42
        )
        
        return {
            'test_paths': test_paths,
            'test_labels': test_labels,
            'train_paths': train_paths,
            'train_labels': train_labels,
            'val_paths': val_paths,
            'val_labels': val_labels,
            'polyps_paths': polyps_paths
        }


In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

In [None]:
class SiameseNetwork(nn.Module):
    def __init__(self, num_classes=None):
        super(SiameseNetwork, self).__init__()
        self.vit = ViTMSNForImageClassification.from_pretrained('facebook/vit-msn-small')
        
        # Dynamically determine number of classes if not provided
        if num_classes is None:
            num_classes = len(set(split_data['train_labels']))
        
        self.fc = nn.Linear(self.vit.config.hidden_size, num_classes)
        
    def forward_one(self, x):
        outputs = self.vit(x, output_hidden_states=True)
        return self.fc(outputs.hidden_states[-1][:, 0])
        
    def forward(self, x1, x2=None):
        output1 = self.forward_one(x1)
        if x2 is not None:
            output2 = self.forward_one(x2)
            return output1, output2
        return output1

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


In [None]:
def train_model(model, train_loader, val_loader, save_path, num_epochs=10):
    # Check if model already exists
    if load_model(model, save_path):
        return model
    
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        for images, labels in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        avg_val_loss = val_loss / len(val_loader)
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            save_model(model, save_path, optimizer, epoch, best_val_loss)
    
    return model

def few_shot_fine_tuning(model, polyps_paths, num_shots=5):
    fine_tune_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # Prepare save path
    save_path = f"models/fine_tuned_{num_shots}.pth"
    
    # Check if fine-tuned model exists
    if load_model(model, save_path):
        return model
    
    # Select a small number of polyps images
    selected_polyps = random.sample(polyps_paths, min(num_shots, len(polyps_paths)))
    
    # Create a custom dataset for fine-tuning
    fine_tune_dataset = KvasirDataset(selected_polyps, [0]*len(selected_polyps), fine_tune_transform)
    fine_tune_loader = DataLoader(fine_tune_dataset, batch_size=num_shots, shuffle=True)
    
    # Prepare for fine-tuning
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
    criterion = nn.CrossEntropyLoss()
    
    # Fine-tuning loop
    for _ in range(10):  # Few iterations for few-shot learning
        for images, _ in fine_tune_loader:
            images = images.to(DEVICE)
            labels = torch.zeros(images.size(0), dtype=torch.long).to(DEVICE)  # Polyps class
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
    
    # Save fine-tuned model
    save_model(model, save_path)
    
    return model

In [None]:
def main():
    # Dataset directory
    DATA_DIR = r"..\Datasets\kvasir-dataset-v2"
    
    # Initialize dataset manager
    dataset_manager = DatasetManager(DATA_DIR)
    
    # Perform stratified splitting
    split_data = dataset_manager.stratified_split()
    
    # Create datasets
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # Prepare datasets
    train_dataset = KvasirDataset(split_data['train_paths'], split_data['train_labels'], transform)
    val_dataset = KvasirDataset(split_data['val_paths'], split_data['val_labels'], transform)
    test_dataset = KvasirDataset(split_data['test_paths'], split_data['test_labels'], transform)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16)
    test_loader = DataLoader(test_dataset, batch_size=16)
    
    # Model initialization and training
    model = SiameseNetwork().to(DEVICE)
    early_stopping = EarlyStopping(patience=5)
    os.makedirs(SAVE_DIR, exist_ok=True)

    # Training loop
    train_model(model, train_loader, val_loader, "models/main_model.pth")   
    # Testing function
    def test_model(model, test_loader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        
        accuracy = 100. * correct / total
        logger.info(f"Test Accuracy: {accuracy:.2f}%")
        return accuracy
    
    # Perform testing
    test_accuracy = test_model(model, test_loader)
    
    # Few-shot fine-tuning with polyps
    fine_tuned_model = few_shot_fine_tuning(model, split_data['polyps_paths'], num_shots=5)
    
    # Test fine-tuned model
    fine_tuned_accuracy = test_model(fine_tuned_model, test_loader)

if __name__ == "__main__":
    main()