In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification, ViTImageProcessor
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import logging
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
print(torch.cuda.is_available())

True


In [3]:
# Logging and device configuration
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 10
LEARNING_RATE = 2e-4
DATA_DIR =  r"..\Datasets\kvasir-dataset-v2"
SAVE_DIR = "vit_models"
MODEL_PATH = os.path.join(SAVE_DIR, 'best_model.pth')

# Ensure save directory exists
os.makedirs(SAVE_DIR, exist_ok=True)

In [4]:
class KvasirDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform
        self.processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        
        inputs = self.processor(images=img, return_tensors="pt",do_rescale=False)
        return inputs['pixel_values'].squeeze(), self.labels[idx]

def prepare_dataset(data_dir):
    # Collect image paths and labels
    classes = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
    
    image_paths = []
    labels = []
    
    for idx, class_name in enumerate(classes):
        class_path = os.path.join(data_dir, class_name)
        class_images = [os.path.join(class_path, img) for img in os.listdir(class_path) 
                        if img.lower().endswith(('.png', '.jpg', '.jpeg'))]
        image_paths.extend(class_images)
        labels.extend([idx] * len(class_images))
    
    # Split the dataset
    train_paths, test_paths, train_labels, test_labels = train_test_split(
        image_paths, labels, test_size=0.2, stratify=labels, random_state=42
    )
    
    train_paths, val_paths, train_labels, val_labels = train_test_split(
        train_paths, train_labels, test_size=0.2, stratify=train_labels, random_state=42
    )
    print(f"Dataset Prepared!")
    return {
        'train_paths': train_paths,
        'train_labels': train_labels,
        'val_paths': val_paths,
        'val_labels': val_labels,
        'test_paths': test_paths,
        'test_labels': test_labels,
        'classes': classes
    }


In [5]:
class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def train_model(model, train_loader, val_loader, classes, patience=3):
    if os.path.exists(MODEL_PATH):
        logger.info(f"Existing model found at {MODEL_PATH}. Skipping training.")
        model.load_state_dict(torch.load(MODEL_PATH))
        return model

    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    best_val_loss = float('inf')
    early_stopping = EarlyStopping(patience=patience)
    
    for epoch in range(NUM_EPOCHS):
        model.train()
        train_loss = 0.0
        
        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Train]')
        for images, labels in train_pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images).logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f'{train_loss/len(train_loader):.4f}'})
        
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{NUM_EPOCHS} [Val]')
        with torch.no_grad():
            for images, labels in val_pbar:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                val_pbar.set_postfix({
                    'loss': f'{val_loss/len(val_loader):.4f}',
                    'acc': f'{100*correct/total:.2f}%'
                })
        
        avg_val_loss = val_loss/len(val_loader)
        early_stopping(avg_val_loss)
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), MODEL_PATH)
        
        if early_stopping.early_stop:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break
    
    return model

In [6]:
def test_model(model, test_loader, classes):
    model.eval()
    all_preds = []
    all_labels = []
    
    test_pbar = tqdm(test_loader, desc='Testing')
    with torch.no_grad():
        for images, labels in test_pbar:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images).logits
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Classification Report
    report = classification_report(all_labels, all_preds, target_names=classes, digits=4)
    print("Classification Report:\n", report)
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=classes, yticklabels=classes)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'confusion_matrix.png'))
    plt.close()
    
    # Performance Visualization
    class_accuracy = cm.diagonal() / cm.sum(axis=1)
    plt.figure(figsize=(10, 6))
    plt.bar(classes, class_accuracy * 100)
    plt.title('Accuracy per Class')
    plt.xlabel('Classes')
    plt.ylabel('Accuracy (%)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(os.path.join(SAVE_DIR, 'class_performance.png'))
    plt.close()
    
    return all_preds, all_labels

In [7]:
def main():
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    
    # Prepare dataset
    dataset = prepare_dataset(DATA_DIR)
    
    # Transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
    ])
    
    # Create datasets
    train_dataset = KvasirDataset(dataset['train_paths'], dataset['train_labels'], transform)
    val_dataset = KvasirDataset(dataset['val_paths'], dataset['val_labels'], transform)
    test_dataset = KvasirDataset(dataset['test_paths'], dataset['test_labels'], transform)
    
    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
    
    # Initialize model
    model = ViTForImageClassification.from_pretrained(
        "google/vit-base-patch16-224", 
        num_labels=len(dataset['classes']),
        attn_implementation="sdpa",
        torch_dtype=torch.float32,
        ignore_mismatched_sizes=True
    ).to(DEVICE)
    
    # Train or load the model
    trained_model = train_model(model, train_loader, val_loader, dataset['classes'])
    
    # Test the model
    test_model(trained_model, test_loader, dataset['classes'])

if __name__ == "__main__":
    main()

Dataset Prepared!


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([8]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([8, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/10 [Train]: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 640/640 [02:50<00:00,  3.75it/s, loss=0.4585]
Epoch 1/10 [Val]: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 160/160 [00:34<00:00,  4.59it/s, loss=0.2676, acc=90.62%]
Epoch 2/10 [Train]: 100%|██████████████████

Early stopping triggered after 4 epochs


Testing: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:12<00:00,  2.74it/s]


Classification Report:
                         precision    recall  f1-score   support

    dyed-lifted-polyps     0.9038    0.9400    0.9216       200
dyed-resection-margins     0.9442    0.9300    0.9370       200
           esophagitis     0.8077    0.8400    0.8235       200
          normal-cecum     0.8622    0.9700    0.9129       200
        normal-pylorus     0.9948    0.9600    0.9771       200
         normal-z-line     0.8283    0.8200    0.8241       200
                polyps     0.8788    0.8700    0.8744       200
    ulcerative-colitis     0.9364    0.8100    0.8686       200

              accuracy                         0.8925      1600
             macro avg     0.8945    0.8925    0.8924      1600
          weighted avg     0.8945    0.8925    0.8924      1600

