In [19]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms  

import medmnist
from medmnist import INFO,Evaluator

from transformers import ViTConfig, ViTModel

In [None]:
datasets_paths = [
    "pathmnist",
    "bloodmnist",
    "breastmnist",
    "dermamnist",
    "octmnist",
    "organamnist",
    "organcmnist",
    "organsmnist",
    "pneumoniamnist",
    "retinamnist",
    "tissuemnist",
]

In [39]:
# data_flag = 'pathmnist'
DOWNLOAD = True
NUM_EPOCHS = 10
BATCH_SIZE = 64
LEARNING_RATE = 0.001
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [40]:
def train_network(
    model, num_epochs, optimizer, loss_function, trainloader, validloader, device
):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        progress_bar = tqdm(
            trainloader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False
        )
        for inputs, targets in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device).squeeze()

            optimizer.zero_grad()
            outputs = model(inputs)
            pooled_output = outputs.pooler_output
            logits = model.classifier(pooled_output)
            loss = loss_function(logits, targets)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = logits.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            progress_bar.set_postfix(
                {
                    "loss": f"{train_loss / len(trainloader):.4f}",
                    "acc": f"{100. * correct / total:.2f}%",
                }
            )

        train_accuracy = 100.0 * correct / total
        print(
            f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss/len(trainloader):.4f}, Train Accuracy: {train_accuracy:.2f}%"
        )

        # Validation phase
        model.eval()
        valid_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in validloader:
                inputs, targets = inputs.to(device), targets.to(device).squeeze()
                outputs = model(inputs)
                pooled_output = outputs.pooler_output
                logits = model.classifier(pooled_output)
                loss = loss_function(logits, targets)

                valid_loss += loss.item()
                _, predicted = logits.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        valid_accuracy = 100.0 * correct / total
        print(
            f"Validation Loss: {valid_loss/len(validloader):.4f}, Validation Accuracy: {valid_accuracy:.2f}%"
        )

In [41]:
def multiheaded_vit(data_flag):

    ##Load the dataset
    info = INFO[data_flag]
    task = info["task"]
    n_channels = info["n_channels"]
    n_classes = len(info["label"])

    DataClass = getattr(medmnist, info["python_class"])

    ## Data Augmentation
    train_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )
    
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )

    ##Data Loaders
    train_dataset = DataClass(
        split="train", transform=train_transform, download=DOWNLOAD
    )
    test_dataset = DataClass(split="test", transform=test_transform, download=DOWNLOAD)
    val_dataset = DataClass(split="val", transform=test_transform, download=DOWNLOAD)

    train_loader = DataLoader(
        dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    ## Load the model
    config = ViTConfig.from_pretrained(
        "google/vit-base-patch16-224-in21k",
        num_labels=n_classes,
        image_size=224,
        num_channels=n_channels,    
    )
    model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k", config=config)
    model.classifier = nn.Linear(model.config.hidden_size, n_classes)
    model.to(DEVICE)

    ## Change the id2label and label2id
    model.config.id2label = {i: label for i, label in enumerate(info["label"].values())}
    model.config.label2id = {label: i for i, label in enumerate(info["label"].values())}

    ## Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    ##Train Loop
    train_network(
        model, NUM_EPOCHS, optimizer, criterion, train_loader, val_loader, DEVICE
    )

    ##Save Model
    torch.save(model.state_dict(), f"{data_flag}_vit.pth")

In [None]:
if 'DOWNLOAD' not in globals():
    DOWNLOAD = True

for data_flag in datasets_paths:
    print(f"Training on {data_flag}")
    multiheaded_vit(data_flag)
    print(f"Training on {data_flag} completed")
    print("---------------------------------------------------")

Training on pathmnist


Using downloaded and verified file: /home/ubuntu/.medmnist/pathmnist.npz
Using downloaded and verified file: /home/ubuntu/.medmnist/pathmnist.npz
Using downloaded and verified file: /home/ubuntu/.medmnist/pathmnist.npz


                                                                                        

Epoch 1/10: Train Loss: 0.5568, Train Accuracy: 79.55%
Validation Loss: 0.2922, Validation Accuracy: 89.53%


                                                                                        

Epoch 2/10: Train Loss: 0.2604, Train Accuracy: 90.90%
Validation Loss: 0.1889, Validation Accuracy: 93.59%


                                                                                        

Epoch 3/10: Train Loss: 0.1991, Train Accuracy: 93.19%
Validation Loss: 0.1536, Validation Accuracy: 95.24%


                                                                                        

Epoch 4/10: Train Loss: 0.1621, Train Accuracy: 94.42%
Validation Loss: 0.1726, Validation Accuracy: 94.10%


                                                                                        

Epoch 5/10: Train Loss: 0.1378, Train Accuracy: 95.31%
Validation Loss: 0.1282, Validation Accuracy: 95.62%


                                                                                        

Epoch 6/10: Train Loss: 0.1173, Train Accuracy: 96.01%
Validation Loss: 0.1206, Validation Accuracy: 95.86%


                                                                                        

Epoch 7/10: Train Loss: 0.1058, Train Accuracy: 96.39%
Validation Loss: 0.1248, Validation Accuracy: 95.74%


                                                                                        

KeyboardInterrupt: 