In [None]:
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from torchvision.models import vit_b_16 # pretrained model
from torchsummary import summary

import brain_tumor_dataset as btd

In [None]:
transform = transforms.Compose([
	transforms.Grayscale(num_output_channels=3),	# convert to 3 channels since pretrained model expects 3 channels
	transforms.Resize((224, 224)), 					
	transforms.ToTensor(),
])

train_dataset = btd.BrainTumorDataset(btd.TRAIN_DATA_PATH, transform=transform)
test_dataset = btd.BrainTumorDataset(btd.TEST_DATA_PATH, transform=transform)

batch_size = 32

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
print(train_dataset.class_to_idx)
print(test_dataset.class_to_idx)

{'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3}
{'glioma': 0, 'meningioma': 1, 'notumor': 2, 'pituitary': 3}


In [None]:
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load pre-trained model
model = vit_b_16(weights='IMAGENET1K_V1')

# Modify for 4 classes
model.heads = torch.nn.Linear(model.hidden_dim, 4) 

model = model.to(device)
summary(model, (3, 224, 224))

Layer (type:depth-idx)                        Output Shape              Param #
├─Conv2d: 1-1                                 [-1, 768, 14, 14]         590,592
├─Encoder: 1-2                                [-1, 197, 768]            --
|    └─Dropout: 2-1                           [-1, 197, 768]            --
|    └─Sequential: 2-2                        [-1, 197, 768]            --
|    |    └─EncoderBlock: 3-1                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-2                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-3                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-4                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-5                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-6                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-7                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-8            

Layer (type:depth-idx)                        Output Shape              Param #
├─Conv2d: 1-1                                 [-1, 768, 14, 14]         590,592
├─Encoder: 1-2                                [-1, 197, 768]            --
|    └─Dropout: 2-1                           [-1, 197, 768]            --
|    └─Sequential: 2-2                        [-1, 197, 768]            --
|    |    └─EncoderBlock: 3-1                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-2                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-3                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-4                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-5                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-6                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-7                 [-1, 197, 768]            7,087,872
|    |    └─EncoderBlock: 3-8            

In [None]:
# Define loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 3

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0

    # Iterate over the training data
    for inputs, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        inputs, labels = inputs.to(device), labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Accumulate loss for current batch
        running_loss += loss.item() * inputs.size(0)

    # Calculate and print average loss for the epoch
    epoch_loss = running_loss / len(train_dataset)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}')

print("Training complete.")


Training Epoch 1/3:   1%|          | 2/179 [00:51<1:15:30, 25.60s/it]


KeyboardInterrupt: 

In [None]:
def hook_fn(module, input, output):
    print(module)
    print(f"Input shape: {input[0].shape}")
    print(f"Output shape: {output[0].shape}")
    print("===")

# Register the hook to the multihead attention module
for module in model.modules():
    if isinstance(module, torch.nn.MultiheadAttention):
        module.register_forward_hook(hook_fn)

In [None]:
model.load_state_dict(torch.load("model_pretrained.pth", weights_only=True, map_location=device))

# Evaluate the model on the test set
model.eval()
with torch.no_grad():
	correct = 0
	total = 0

	for inputs, labels in tqdm(test_loader, desc="Testing"):
		inputs, labels = inputs.to(device), labels.to(device)
		outputs = model(inputs)
		predictions = torch.argmax(outputs, dim=1)

		total += labels.size(0)
		correct += (predictions == labels).sum().item()

	accuracy = correct / total
	print(f'Accuracy on the test set: {accuracy:.2%}')
