In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
from transformers import ViTForImageClassification, ViTFeatureExtractor

class RetinalDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.data_frame = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224')

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data_frame.iloc[idx]['id_code'] + '.png')
        image = Image.open(img_name).convert('RGB')
        diagnosis = self.data_frame.iloc[idx]['diagnosis']

        # Let the feature_extractor handle all the image preprocessing
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        return inputs['pixel_values'].squeeze(), torch.tensor(diagnosis, dtype=torch.long)

class RetinalClassifier:
    def __init__(self, num_classes=5):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = ViTForImageClassification.from_pretrained(
            'google/vit-base-patch16-224',
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        ).to(self.device)
        
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ])

    def train(self, train_loader, num_epochs=10):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5)
        
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            correct = 0
            total = 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = self.model(inputs).logits
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
                # Calculate accuracy
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            
            epoch_loss = running_loss/len(train_loader)
            epoch_accuracy = 100 * correct / total
            print(f'Epoch {epoch+1}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.2f}%')

    def predict(self, image_path):
        self.model.eval()
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            outputs = self.model(image).logits
            _, predicted = torch.max(outputs, 1)
            
        return predicted.item()

# Usage example
if __name__ == "__main__":
    # Initialize dataset and dataloader
    import torch
    from torch.utils.data import DataLoader, random_split
    from sklearn.model_selection import train_test_split

    # Load full dataset
    dataset = RetinalDataset(csv_file='APTOS 2019 Blindness Detection/train.csv', img_dir='APTOS 2019 Blindness Detection//train_images')

    # Split dataset (80% train, 20% test)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize and train the classifier
    classifier = RetinalClassifier(num_classes=5)
    classifier.train(train_loader)



In [None]:
# Save the model
save_path = 'best_model.pth'
torch.save({
    'model_state_dict': classifier.model.state_dict(),
    'model_config': classifier.model.config
}, save_path)
print(f"Model saved to {save_path}")

In [None]:
# Import necessary libraries
import torch
from sklearn.metrics import classification_report, confusion_matrix

classifier = RetinalClassifier(num_classes=5)

# Path to your saved model
model_path = '/Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/classifier_model/Best Model.pth'

# Load the saved model with map_location to handle CPU/GPU differences
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
classifier.model.load_state_dict(checkpoint['model_state_dict'])
print(f"Model loaded from {model_path}")

# Evaluation loop
classifier.model.eval()  # Set the model to evaluation mode
correct = 0
total = 0
predictions = []
actual = []

with torch.no_grad():  # No need to track gradients during evaluation
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(classifier.device), labels.to(classifier.device)
        outputs = classifier.model(inputs).logits
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        predictions.extend(predicted.cpu().numpy())
        actual.extend(labels.cpu().numpy())

# Calculate and print accuracy
accuracy = 100 * correct / total
print(f'Test Accuracy: {accuracy:.2f}%')

# Print detailed metrics
print("\nDetailed Classification Report:")
print(classification_report(actual, predictions))

print("\nConfusion Matrix:")
print(confusion_matrix(actual, predictions))

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded from /Users/devshah/Documents/WorkSpace/University/year 3/CSC490/Zero-Shot-Object-Tracking-FPS/classifier_model/Best Model.pth
Test Accuracy: 95.36%

Detailed Classification Report:
              precision    recall  f1-score   support

           0       1.00      0.99      1.00       353
           1       0.88      0.92      0.90        73
           2       0.93      0.94      0.94       205
           3       0.87      0.85      0.86        39
           4       0.92      0.87      0.89        63

    accuracy                           0.95       733
   macro avg       0.92      0.91      0.92       733
weighted avg       0.95      0.95      0.95       733


Confusion Matrix:
[[351   2   0   0   0]
 [  1  67   5   0   0]
 [  0   7 193   4   1]
 [  0   0   2  33   4]
 [  0   0   7   1  55]]


: 