In [None]:
!pip install transformers --quiet

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
data_dir = '/content/drive/MyDrive/Data'

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision.transforms import ToTensor
from sklearn.metrics import classification_report
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained('google/vit-base-patch16-224-in21k')
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=processor.image_mean, std=processor.image_std),
])
dataset = datasets.ImageFolder(root=data_dir, transform=transform)
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

class_names = dataset.classes
print("Classes:", class_names)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


preprocessor_config.json:   0%|          | 0.00/160 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/502 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


Classes: ['advanced_glaucoma', 'early_glaucoma', 'normal_control']


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ViTForImageClassification.from_pretrained(
    'google/vit-base-patch16-224-in21k',
    num_labels=len(class_names)
)

model.to(device)


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=2e-5)
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=inputs).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")


Epoch 1/50, Loss: 0.041843512652215
Epoch 2/50, Loss: 0.032964936747120165
Epoch 3/50, Loss: 0.02228363368731852
Epoch 4/50, Loss: 0.006536963288314068
Epoch 5/50, Loss: 0.00602274661949382
Epoch 6/50, Loss: 0.005752677412834974
Epoch 7/50, Loss: 0.005624261032491445
Epoch 8/50, Loss: 0.005618429189034475
Epoch 9/50, Loss: 0.005430572046838605
Epoch 10/50, Loss: 0.005420537596831146
Epoch 11/50, Loss: 0.0053441144737940375
Epoch 12/50, Loss: 0.005199632214871832
Epoch 13/50, Loss: 0.005140201987113613
Epoch 14/50, Loss: 0.0050395686486855345
Epoch 15/50, Loss: 0.005113957193083106
Epoch 16/50, Loss: 0.00500710349670277
Epoch 17/50, Loss: 0.004892883513373538
Epoch 18/50, Loss: 0.004914437013212591
Epoch 19/50, Loss: 0.004858872706207853
Epoch 20/50, Loss: 0.00471248709781764
Epoch 21/50, Loss: 0.004818456251180181
Epoch 22/50, Loss: 0.004702532278875319
Epoch 23/50, Loss: 0.004667930005757042
Epoch 24/50, Loss: 0.004608389485675173
Epoch 25/50, Loss: 0.004500812105386733
Epoch 26/50, L

In [27]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score


model.eval()

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        if isinstance(outputs, dict) and 'logits' in outputs:
            logits = outputs['logits']
        elif hasattr(outputs, 'logits'):
            logits = outputs.logits
        else:
            logits = outputs

        _, preds = torch.max(logits, 1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
accuracy = accuracy_score(all_labels, all_preds)
print(f"Validation Accuracy: {accuracy * 100:.2f}%")

Validation Accuracy: 77.98%


In [34]:
from PIL import Image
import torchvision.transforms as T
from google.colab import files

uploaded = files.upload()

for fname in uploaded.keys():
    image = Image.open(fname).convert('RGB')
    image_transformed = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(pixel_values=image_transformed).logits
        predicted_class = torch.argmax(outputs, dim=1).item()

    print(f"Predicted class for {fname}: {class_names[predicted_class]}")

Saving class_0_257.png to class_0_257.png
Predicted class for class_0_257.png: normal_control
