In [14]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import timm
import json
from monai.networks.nets import ViT

In [15]:
# Configurations
use_monai = True
_num_classes = 4
_batch_size = 16
_epochs = 5
_lr = 3e-5

In [27]:
import random

# Split the dataset into train, val, and test sets (4 directories, one for each class. Combine all directories after splitting)
train, val, test = [], [], []
path = "data/EDC/"


for dir in os.listdir(path):
    # Skip non directories
    if not os.path.isdir(os.path.join(path, dir)):
        continue
    dir_path = os.path.join(path, dir)
    dir_json_path = os.path.join(dir_path, "data.json")
    data = json.load(open(dir_json_path, "r"))

    # Randomly split the data into train, val, and test sets (70% train, 15% val, 15% test)
    shuffled_data = data.copy()
    random.shuffle(shuffled_data)


    train_size = int(0.7 * len(shuffled_data))
    val_size = int(0.15 * len(shuffled_data))
    test_size = len(shuffled_data) - train_size - val_size
    train_data = shuffled_data[:train_size]
    val_data = shuffled_data[train_size:train_size + val_size]
    test_data = shuffled_data[train_size + val_size:]
    train.append(train_data)
    val.append(val_data)
    test.append(test_data)
    # Save the split data into separate JSON files
    with open(os.path.join(path, "train.json"), "w") as f:
        json.dump(train_data, f)
    with open(os.path.join(path, "val.json"), "w") as f:
        json.dump(val_data, f)
    with open(os.path.join(path, "test.json"), "w") as f:
        json.dump(test_data, f)

# Print length of each set
print(f"Train set size: {len(train_data)}")
print(f"Validation set size: {len(val_data)}")
print(f"Test set size: {len(test_data)}")


Train set size: 751
Validation set size: 161
Test set size: 162


In [12]:
from PIL import Image

# Load the dataset
train_json = json.load(open("data/EDC/train.json", "r"))
val_json = json.load(open("data/EDC/val.json", "r"))
test_json = json.load(open("data/EDC/test.json", "r"))

# Use DataLoader to load the dataset (json includes the first value as the path to the image and the second value as the label)
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, json_data, transform=None):
        self.data = json_data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, label



# get image

train_loader = DataLoader(train_set_path, batch_size=_batch_size, shuffle=True)
val_loader = DataLoader(test_set_path, batch_size=_batch_size)

model_name = 'monai' if use_monai else 'timm'

# Select model
if use_monai:

    model = ViT(
        in_channels=3,
        img_size=(224, 224),
        patch_size=(16, 16),
        pos_embed='conv',
        classification=True,
        num_classes=_num_classes,
        # dim=768,
        # depth=12,
        # heads=12,
        # mlp_dim=3072,
        # dropout=0.1,
    )
else:
    model = timm.create_model('vit_base_patch16_224', pretrained=True)
    model.head = nn.Linear(model.head.in_features, _num_classes)

    # model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=_num_classes)

In [13]:
# ==== LOSS, OPTIMIZER, SCHEDULER ====
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=_lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=_epochs)

# ==== TRAINING LOOP ====
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
model.to(device)

for epoch in range(_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    scheduler.step()
    # Add validation step here
    print(f"Epoch {epoch+1}/{_epochs} completed.")

    # Print loss and accuracy
    print(f"Loss: {loss.item():.4f}")
    # Add accuracy calculation
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Accuracy: {accuracy:.2f}%")
    # Save the model if it is the best so far
    if epoch == 0 or accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), f'{model_name}_best_model_({accuracy:.2f}).pth')
        print(f"Model saved with accuracy: {best_accuracy:.2f}%")
    else:
        print(f"Model not improved. Current best accuracy: {best_accuracy:.2f}%")


# ==== SAVE MODEL ====
torch.save(model.state_dict(), 'best_vit_model.pth')

ValueError: too many values to unpack (expected 2)