In [None]:
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, random_split
from transformers import ViTFeatureExtractor, ViTMSNForImageClassification
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
import numpy as np

In [None]:
data_dir = r"..\Datasets\kvasir-dataset-v2"

train_classes = ['dyed-lifted-polyps', 'dyed-resection-margins', 'esophagitis', 'normal-cecum', 'normal-pylorus', 'normal-z-line', 'ulcerative-colitis']

dataset = datasets.ImageFolder(root=data_dir, transform=transforms.ToTensor())

train_indices = [i for i, (img, label) in enumerate(dataset) if dataset.classes[label] in train_classes]
train_subset = torch.utils.data.Subset(dataset, train_indices)


In [None]:
train_size = int(0.8 * len(train_subset))
test_size = len(train_subset) - train_size
train_dataset, test_dataset = random_split(train_subset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained('microsoft/vit-msn')
model = ViTMSNForImageClassification.from_pretrained('microsoft/vit-msn', num_labels=7)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

num_epochs = 10
patience = 3
best_loss = np.inf
early_stopping_counter = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    train_loader = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        train_loader.set_postfix({'Loss': loss.item()})

    epoch_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")

    if epoch_loss < best_loss:
        best_loss = epoch_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        if early_stopping_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    model.eval()
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images).logits
            predictions = outputs.argmax(dim=1).cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions)

    accuracy = accuracy_score(all_labels, all_predictions)
    print(f"Test Accuracy: {accuracy}")
    print(classification_report(all_labels, all_predictions, target_names=train_classes))


In [None]:
polyps_indices = [i for i, (img, label) in enumerate(dataset) if dataset.classes[label] == 'polyps']
polyps_subset = torch.utils.data.Subset(dataset, polyps_indices)
few_shot_size = 5
few_shot_loader = DataLoader(polyps_subset, batch_size=few_shot_size, shuffle=True)

model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    few_shot_loader = tqdm(few_shot_loader, desc=f"Fine-tuning Epoch {epoch+1}/{num_epochs}", leave=False)
    for images, _ in few_shot_loader:
        images = images.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, torch.zeros_like(outputs))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        few_shot_loader.set_postfix({'Loss': loss.item()})

    print(f"Fine-tuning Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(few_shot_loader)}")
