In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
import cv2
import albumentations as A
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tqdm.auto import tqdm
import logging

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Logging and device configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 10
LEARNING_RATE = 2e-4
DATA_DIR = r"G:\USER\Documents\GitHub\Masters-Project\Datasets\kvasir-dataset-v2"
SAVE_DIR = "medical_models"
MODEL_PATH = os.path.join(SAVE_DIR, 'best_model.pth')

# Ensure save directory exists
os.makedirs(SAVE_DIR, exist_ok=True)

In [3]:
class ROIDetector:
    def __init__(self, min_area=1000):
        """
        Initialize ROI detector with OpenCV contour-based approach
        
        Args:
            min_area (int): Minimum area of contour to be considered a valid region
        """
        self.min_area = min_area
    
    def detect_roi(self, image):
        """
        Detect regions of interest using contour detection
        
        Args:
            image (np.ndarray): Input image
        
        Returns:
            list of tuples: Bounding boxes of ROIs [(x, y, w, h), ...]
        """
        # Convert to grayscale
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        
        # Apply Gaussian blur
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        
        # Threshold the image
        _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        
        # Find contours
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        # Filter contours by area
        valid_contours = [
            cv2.boundingRect(cnt) for cnt in contours 
            if cv2.contourArea(cnt) >= self.min_area
        ]
        
        return valid_contours
    
    def crop_to_roi(self, image, roi):
        """
        Crop image to specified region of interest
        
        Args:
            image (np.ndarray): Original image
            roi (tuple): Bounding box (x, y, w, h)
        
        Returns:
            np.ndarray: Cropped image
        """
        x, y, w, h = roi
        return image[y:y+h, x:x+w]


In [4]:
class MedicalImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.roi_detector = ROIDetector()
        
        # Augmentation techniques
        self.augmentation = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.GaussNoise(p=0.2),
            A.Blur(blur_limit=3, p=0.2),
        ])
        
        self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Read image
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Detect ROIs
        rois = self.roi_detector.detect_roi(img)
        
        # If ROIs found, crop to the largest region
        if rois:
            # Sort ROIs by area and select the largest
            largest_roi = max(rois, key=lambda r: r[2] * r[3])
            img = self.roi_detector.crop_to_roi(img, largest_roi)
        
        # Apply augmentation
        if self.transform:
            augmented = self.augmentation(image=img)
            img = augmented['image']
        
        # Process for ViT
        inputs = self.processor(images=img, return_tensors="pt", do_rescale=False)
        
        return inputs['pixel_values'].squeeze(), self.labels[idx]

def prepare_dataset(data_dir):
    # Collect image paths and labels
    classes = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    image_paths = []
    labels = []
    
    for idx, class_name in enumerate(classes):
        class_path = os.path.join(data_dir, class_name)
        class_images = [os.path.join(class_path, img) for img in os.listdir(class_path) 
                        if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        image_paths.extend(class_images)
        labels.extend([idx] * len(class_images))
    
    # Split the dataset
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        image_paths, labels, test_size=0.2, stratify=labels, random_state=42
    )
    
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_paths, train_labels, test_size=0.2, stratify=train_labels, random_state=42
    )
    
    return {
        'train_paths': train_paths,
        'train_labels': train_labels,
        'val_paths': val_paths,
        'val_labels': val_labels,
        'test_paths': test_paths,
        'test_labels': test_labels,
        'classes': classes
    }

In [5]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

In [6]:
def train_vit_model(model, train_loader, val_loader, classes, patience=3):
    if os.path.exists(MODEL_PATH):
        logger.info(f"Existing model found at {MODEL_PATH}. Skipping training.")
        model.load_state_dict(torch.load(MODEL_PATH))
        return model

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    early_stopping = EarlyStopping(patience=patience)
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Train]')
        for images, labels in train_pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f'{train_loss/len(train_loader):.4f}'})
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Val]')
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_pbar.set_postfix({
                    'loss': f'{val_loss/len(val_loader):.4f}',
                    'acc': f'{100*correct/total:.2f}%'
                })
        
        avg_val_loss = val_loss/len(val_loader)
        early_stopping(avg_val_loss)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_PATH)
        
        if early_stopping.early_stop:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    return model


In [7]:
def test_model(model, test_loader, classes):
    model.eval()
    all_preds = []
    all_labels = []
    
    test_pbar = tqdm(test_loader, desc='Testing')
    with torch.no_grad():
        for images, labels in test_pbar:
            images_for_vit = images.to(DEVICE)
            outputs = model(images_for_vit).logits
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    # Classification Report   
    report = classification_report(all_labels, all_preds, target_names=classes, digits=4)
    print("Classification Report:\n", report)
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'confusion_matrix.png'))
    plt.close()
    
    return all_preds, all_labels

In [8]:
def main():
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    # Prepare dataset
    dataset = prepare_dataset(DATA_DIR)
    
    # Transformations
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # Create datasets
    train_dataset = MedicalImageDataset(dataset['train_paths'], dataset['train_labels'], transform)
    val_dataset = MedicalImageDataset(dataset['val_paths'], dataset['val_labels'], transform)
    test_dataset = MedicalImageDataset(dataset['test_paths'], dataset['test_labels'], transform)
    
    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    
    # Initialize ViT model
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224", 
        num_labels=len(dataset['classes']),
        attn_implementation="sdpa", 
        torch_dtype=torch.float32,
        ignore_mismatched_sizes=True
    ).to(DEVICE)
    
    # Train or load the ViT model
    trained_model = train_vit_model(model, train_loader, val_loader, dataset['classes'])
    
    # Test the model
    test_model(trained_model, test_loader, dataset['classes'])

if __name__ == "__main__":
    main()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/10 [Train]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 640/640 [04:46<00:00,  2.24it/s, loss=1.0379]
Epoch 1/10 [Val]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160/160 [01:02<00:00,  2.55it/s, loss=0.7983, acc=59.06%]
Epoch 2/10 [Train]: 100%|██████████████████

Early stopping triggered after 5 epochs


Testing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:19<00:00,  2.52it/s]


Classification Report:
                         precision    recall  f1-score   support

    dyed-lifted-polyps     0.8182    0.3600    0.5000       200
dyed-resection-margins     0.6080    0.9150    0.7305       200
           esophagitis     0.5686    0.7250    0.6374       200
          normal-cecum     0.9296    0.3300    0.4871       200
        normal-pylorus     0.9401    0.7850    0.8556       200
         normal-z-line     0.6541    0.6050    0.6286       200
                polyps     0.4136    0.8850    0.5637       200
    ulcerative-colitis     0.7238    0.3800    0.4984       200

              accuracy                         0.6231      1600
             macro avg     0.7070    0.6231    0.6126      1600
          weighted avg     0.7070    0.6231    0.6126      1600

