In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification


In [2]:
# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Match ViT input size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
])

In [4]:
# Load dataset
train_data = datasets.ImageFolder("C:/Users/karti/Desktop/Projects/DiagnoSphere/Dataset/SkinDisease/SkinDisease/train", transform=transform)
test_data = datasets.ImageFolder("C:/Users/karti/Desktop/Projects/DiagnoSphere/Dataset/SkinDisease/SkinDisease/test", transform=transform)


In [5]:
# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8, shuffle=False)

In [6]:
# Load ViT model
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=23)
model.to(device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermed

In [7]:
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-5)

In [8]:
from tqdm import tqdm  # Import tqdm for the progress bar

# Training loop with tqdm progress bar
epochs = 10
for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=True)  # Add progress bar

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Compute training accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update tqdm description dynamically
        progress_bar.set_postfix(loss=loss.item())

    train_accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}, Accuracy: {train_accuracy:.2f}%")


Epoch 1/10: 100%|██████████| 1738/1738 [22:05<00:00,  1.31it/s, loss=2.19] 


Epoch 1, Loss: 1.7974, Accuracy: 50.36%


Epoch 2/10: 100%|██████████| 1738/1738 [19:15<00:00,  1.50it/s, loss=0.59]  


Epoch 2, Loss: 0.8700, Accuracy: 75.68%


Epoch 3/10: 100%|██████████| 1738/1738 [18:47<00:00,  1.54it/s, loss=0.107] 


Epoch 3, Loss: 0.3990, Accuracy: 88.96%


Epoch 4/10: 100%|██████████| 1738/1738 [19:02<00:00,  1.52it/s, loss=0.0111] 


Epoch 4, Loss: 0.1905, Accuracy: 94.83%


Epoch 5/10: 100%|██████████| 1738/1738 [18:49<00:00,  1.54it/s, loss=0.251]  


Epoch 5, Loss: 0.1405, Accuracy: 96.01%


Epoch 6/10: 100%|██████████| 1738/1738 [18:45<00:00,  1.54it/s, loss=0.00584]


Epoch 6, Loss: 0.1091, Accuracy: 96.90%


Epoch 7/10: 100%|██████████| 1738/1738 [18:43<00:00,  1.55it/s, loss=0.0177] 


Epoch 7, Loss: 0.1003, Accuracy: 96.96%


Epoch 8/10: 100%|██████████| 1738/1738 [18:41<00:00,  1.55it/s, loss=0.068]  


Epoch 8, Loss: 0.0842, Accuracy: 97.30%


Epoch 9/10: 100%|██████████| 1738/1738 [18:38<00:00,  1.55it/s, loss=0.00234]


Epoch 9, Loss: 0.0759, Accuracy: 97.76%


Epoch 10/10: 100%|██████████| 1738/1738 [18:31<00:00,  1.56it/s, loss=0.00317] 

Epoch 10, Loss: 0.0679, Accuracy: 97.91%





In [9]:
# Save model
torch.save(model.state_dict(), "vit_skin_disease.pth")

In [10]:
# Evaluation
model.eval()
correct = 0
total = 0

In [11]:
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images).logits
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 76.07%
