In [3]:
import os
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import classification_report
from tqdm import tqdm
from pathlib import Path

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Specify directories
#train_dir = "./Training"
#test_dir = "./Testing"

train_dir = "../data/Training"
test_dir = "../data/Testing"

# Transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

# Load datasets
train_dataset = datasets.ImageFolder(root=train_dir, transform=transform)
test_dataset = datasets.ImageFolder(root=test_dir, transform=transform)

# Split train dataset into training and validation
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load pre-trained Vision Transformer model
vit_model = models.vision_transformer.vit_b_16(pretrained=True)

# Freeze all layers except the classifier
for param in vit_model.parameters():
    param.requires_grad = False

# Replace the classifier for the correct number of classes (4 classes in this case)
num_features = vit_model.heads[0].in_features  # Access the in_features correctly
vit_model.heads = nn.Linear(num_features, 4)  # Replace with a new classifier for 4 classes
vit_model = vit_model.to(device)

# Only the classifier's parameters will be updated
optimizer = optim.Adam(vit_model.heads.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Fine-tuning loop
epochs = 20  # Fine-tune
for epoch in range(epochs):
    vit_model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, labels = images.to(device), labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = vit_model(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Track loss and accuracy
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")




Epoch 1/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:34<00:00,  3.18s/it]


Epoch [1/20], Loss: 0.5021, Accuracy: 82.97%


Epoch 2/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:33<00:00,  3.17s/it]


Epoch [2/20], Loss: 0.2767, Accuracy: 90.76%


Epoch 3/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:30<00:00,  3.15s/it]


Epoch [3/20], Loss: 0.2256, Accuracy: 92.41%


Epoch 4/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:47<00:00,  3.27s/it]


Epoch [4/20], Loss: 0.1937, Accuracy: 93.63%


Epoch 5/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:27<00:00,  3.13s/it]


Epoch [5/20], Loss: 0.1728, Accuracy: 94.07%


Epoch 6/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [08:11<00:00,  3.44s/it]


Epoch [6/20], Loss: 0.1568, Accuracy: 94.66%


Epoch 7/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [08:36<00:00,  3.61s/it]


Epoch [7/20], Loss: 0.1443, Accuracy: 95.27%


Epoch 8/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:59<00:00,  3.35s/it]


Epoch [8/20], Loss: 0.1336, Accuracy: 95.73%


Epoch 9/20: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [08:04<00:00,  3.39s/it]


Epoch [9/20], Loss: 0.1247, Accuracy: 96.24%


Epoch 10/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:41<00:00,  3.23s/it]


Epoch [10/20], Loss: 0.1165, Accuracy: 96.30%


Epoch 12/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:38<00:00,  3.20s/it]


Epoch [12/20], Loss: 0.1041, Accuracy: 96.67%


Epoch 13/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:38<00:00,  3.21s/it]


Epoch [13/20], Loss: 0.0991, Accuracy: 97.13%


Epoch 14/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [08:01<00:00,  3.37s/it]


Epoch [14/20], Loss: 0.0937, Accuracy: 97.29%


Epoch 15/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:38<00:00,  3.21s/it]


Epoch [15/20], Loss: 0.0902, Accuracy: 97.35%


Epoch 16/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:55<00:00,  3.33s/it]


Epoch [16/20], Loss: 0.0859, Accuracy: 97.61%


Epoch 17/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:35<00:00,  3.19s/it]


Epoch [17/20], Loss: 0.0826, Accuracy: 97.75%


Epoch 18/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:24<00:00,  3.11s/it]


Epoch [18/20], Loss: 0.0775, Accuracy: 97.94%


Epoch 19/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:29<00:00,  3.14s/it]


Epoch [19/20], Loss: 0.0745, Accuracy: 98.07%


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [07:26<00:00,  3.12s/it]

Epoch [20/20], Loss: 0.0711, Accuracy: 98.14%





In [5]:
# Save the fine-tuned model
import os
checkpoint_path = "/temp/tumor-detection/checkpoints/vit/vit_finetuned_full.pth"
ckpt_parent_path = "/temp/tumor-detection/checkpoints/vit/"
parent_dir = os.path.expanduser('~') + ckpt_parent_path
Path(parent_dir).mkdir(parents=True, exist_ok = True)


#checkpoint_dir = os.path.dirname(checkpoint_path)
save_dir = os.path.expanduser('~') + checkpoint_path
#torch.save(vit_model.state_dict(), save_dir)
torch.save(vit_model, save_dir)
print("Fine-tuned model saved successfully!")

Fine-tuned model saved successfully!


In [7]:
print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")

Epoch [20/20], Loss: 0.0715, Accuracy: 98.05%


In [7]:
vit_model.eval()
test_correct = 0
test_total = 0
all_labels = []
all_preds = []
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)
        outputs = vit_model(images)
        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

        # Collect labels and predictions for the classification report
        all_labels.extend(labels.cpu().numpy())
        all_preds.extend(predicted.cpu().numpy())

# Print test accuracy
print(f"Test Accuracy: {100 * test_correct / test_total:.2f}%")

# Generate classification report
report = classification_report(all_labels, all_preds, target_names=train_dataset.dataset.classes)
print("Classification Report:")
print(report)


Testing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 41/41 [15:26<00:00, 22.60s/it]

Test Accuracy: 93.36%
Classification Report:
              precision    recall  f1-score   support

      glioma       0.98      0.82      0.89       300
  meningioma       0.84      0.92      0.88       306
     notumor       0.98      1.00      0.99       405
   pituitary       0.94      0.98      0.96       300

    accuracy                           0.93      1311
   macro avg       0.93      0.93      0.93      1311
weighted avg       0.94      0.93      0.93      1311






In [8]:
report = classification_report(all_labels, all_preds, target_names=train_dataset.dataset.classes)
print("Classification Report:")
print(report)

Classification Report:
              precision    recall  f1-score   support

      glioma       0.98      0.82      0.89       300
  meningioma       0.84      0.92      0.88       306
     notumor       0.98      1.00      0.99       405
   pituitary       0.94      0.98      0.96       300

    accuracy                           0.93      1311
   macro avg       0.93      0.93      0.93      1311
weighted avg       0.94      0.93      0.93      1311

