In [5]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from transformers import ViTForImageClassification, ViTFeatureExtractor

class RetinalDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data_frame.iloc[idx]['id_code'] + '.png')
        image = Image.open(img_name).convert('RGB')
        diagnosis = self.data_frame.iloc[idx]['diagnosis']

        # Let the feature_extractor handle all the image preprocessing
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return inputs['pixel_values'].squeeze(), torch.tensor(diagnosis, dtype=torch.long)

class RetinalClassifier:
    def __init__(self, num_classes=5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224',
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        ).to(self.device)
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def train(self, train_loader, num_epochs=10):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
        
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(inputs).logits
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            
            epoch_loss = running_loss/len(train_loader)
            epoch_accuracy = 100 * correct / total
            print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')

    def predict(self, image_path):
        self.model.eval()
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(image).logits
            _, predicted = torch.max(outputs, 1)
            
        return predicted.item()

# Usage example
if __name__ == "__main__":
    # Initialize dataset and dataloader
    dataset = RetinalDataset(
        csv_file='/Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/APTOS 2019 Blindness Detection/train.csv',
        img_dir='/Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/APTOS 2019 Blindness Detection/train_images',

    )
    
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    # Initialize and train the classifier
    classifier = RetinalClassifier(num_classes=5)
    classifier.train(train_loader)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1, Loss: 0.6623, Accuracy: 75.83%
Epoch 2, Loss: 0.4276, Accuracy: 84.43%
Epoch 3, Loss: 0.3179, Accuracy: 89.08%
Epoch 4, Loss: 0.2030, Accuracy: 93.23%
Epoch 5, Loss: 0.1262, Accuracy: 96.50%
Epoch 6, Loss: 0.0796, Accuracy: 97.79%
Epoch 7, Loss: 0.0619, Accuracy: 98.28%
Epoch 8, Loss: 0.0462, Accuracy: 98.50%
Epoch 9, Loss: 0.0378, Accuracy: 98.53%
Epoch 10, Loss: 0.0486, Accuracy: 97.98%


In [7]:
# Save the model
save_path = 'best_model.pth'
torch.save({
    'model_state_dict': classifier.model.state_dict(),
    'model_config': classifier.model.config
}, save_path)
print(f"Model saved to {save_path}")

Model saved to best_model.pth


In [8]:
# Initialize test dataset and dataloader
test_dataset = RetinalDataset(
    csv_file='/Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/APTOS 2019 Blindness Detection/test.csv',
    img_dir='/Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/APTOS 2019 Blindness Detection/test_images',
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Testing loop
classifier.model.eval()
correct = 0
total = 0
predictions = []
actual = []

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(classifier.device), labels.to(classifier.device)
        outputs = classifier.model(inputs).logits
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        predictions.extend(predicted.cpu().numpy())
        actual.extend(labels.cpu().numpy())

accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

# Optional: Detailed metrics
from sklearn.metrics import classification_report, confusion_matrix
print("\nDetailed Classification Report:")
print(classification_report(actual, predictions))

print("\nConfusion Matrix:")
print(confusion_matrix(actual, predictions))



KeyError: 'diagnosis'