In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms
import lightning as L
from torchvision.models import vit_b_16 # pretrained model
from torchsummary import summary
import torchmetrics

import brain_tumor_dataset as btd

  from .autonotebook import tqdm as notebook_tqdm




# Pre trained model

In [2]:
# Define the transformation for the images
transform = transforms.Compose([
	transforms.Grayscale(num_output_channels=3),   # convert to 3 channels
	transforms.Resize((224, 224)),                 # resize to 224x224
	transforms.ToTensor(),
])

# Load your datasets with the defined transformations
train_dataset = btd.BrainTumorDataset(btd.TRAIN_DATA_PATH, transform=transform)
test_dataset = btd.BrainTumorDataset(btd.TEST_DATA_PATH, transform=transform)

val_size = len(test_dataset) // 2
test_size = len(test_dataset) - val_size
test_dataset, val_dataset = torch.utils.data.random_split(test_dataset, [test_size, val_size])

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

len(train_dataset), len(test_dataset), len(val_dataset)

(5712, 656, 655)

In [None]:
import cv2
# Define the PyTorch Lightning Module
class BrainTumorClassifier(L.LightningModule):
	def __init__(self, 
			  learning_rate=1e-4, 
			  pretrained_weights = "IMAGENET1K_V1",
			  weights_path=None):
		super().__init__()
		# Initialize the model with the pre-trained ViT
		self.model = vit_b_16(weights=pretrained_weights)
		self.model.heads = torch.nn.Linear(self.model.hidden_dim, 4)  # Modify for 4 classes

		# Load the weights if provided
		if weights_path:
			pass
			# self.load_from_checkpoint(weights_path, map_location=self.device)


		# Define loss function and learning rate
		self.criterion = torch.nn.CrossEntropyLoss()
		self.learning_rate = learning_rate

		# Initialize accuracy metric for logging
		self.train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=4)
		self.val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=4)

		self.attention_maps = {}  # To store attention maps

	def forward(self, x):
		return self.model(x)

	def training_step(self, batch, batch_idx):
		inputs, labels = batch
		outputs = self(inputs)
		loss = self.criterion(outputs, labels)

		# Log loss and accuracy
		self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
		self.train_accuracy(outputs, labels)
		self.log('train_acc', self.train_accuracy, on_step=True, on_epoch=True, prog_bar=True)

		return loss
	
	def validation_step(self, batch, batch_idx):
		inputs, labels = batch
		outputs = self(inputs)
		loss = self.criterion(outputs, labels)

		# Log loss and accuracy
		self.log('val_loss', loss, on_epoch=True, prog_bar=True)
		self.val_accuracy(outputs, labels)
		self.log('val_acc', self.val_accuracy, on_epoch=True, prog_bar=True)

		return loss

	def test_step(self, batch, batch_idx):
		inputs, labels = batch
		outputs = self(inputs)
		loss = self.criterion(outputs, labels)

		# Log loss and accuracy
		self.log('test_loss', loss, on_epoch=True, prog_bar=True)
		self.val_accuracy(outputs, labels)
		self.log('test_acc', self.val_accuracy, on_epoch=True, prog_bar=True)

		self.log_attention_maps(inputs, labels, outputs, batch_idx)  # Save attention maps

		return loss
	
	def configure_optimizers(self):
		return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
	
	def register_attention_hooks(self):
		"""Register hooks to capture attention maps."""
		self.attention_maps.clear()  # Reset attention maps

		def hook_fn(module, input, output, module_name):
			module_name = module_name.split(".")[2]  # Get the layer name

			q, k, _ = input
			self.attention_maps[module_name] = torch.nn.functional.softmax(q @ k.transpose(-2, -1))  # Save attention map

		# Register hooks on all MultiheadAttention layers
		for i, module in self.model.named_modules():
			if isinstance(module, torch.nn.MultiheadAttention):
				module.register_forward_hook(lambda module, input, output, module_name=i: hook_fn(module, input, output, module_name))
	
	def on_test_start(self):
		# Register the hook to each multi-head attention layer before testing
		self.register_attention_hooks()

	def log_attention_maps(self, inputs, labels, output, batch_idx):
			"""Log attention maps overlaid on the original image using Lightning's logger."""
			batch_size = inputs.size(0)
			for i in range(batch_size):
				average_attention_map = None
				num_layers = len(self.attention_maps)
				for _, attention in self.attention_maps.items():
					# Get the attention map for the first image in the batch
					attention_map = attention[i]  # Shape: [num_tokens, embedding_size]
					
					# Compute cosine similarity between class token and patches
					class_token_embedding = attention_map[0, :]  # Shape: (embedding_size)
					num_patches_side = int((attention_map.size(0) - 1) ** 0.5)
					attention_map = class_token_embedding[1:].view(num_patches_side, num_patches_side, -1).clone()
					
					# Accumulate attention maps
					if average_attention_map is None:
						average_attention_map = attention_map
					else:
						average_attention_map += attention_map

				average_attention_map = average_attention_map / num_layers
				average_attention_map = average_attention_map.cpu().detach().numpy()
				average_attention_map = cv2.resize(average_attention_map, (inputs.size(2), inputs.size(3)))
				heatmap = cv2.applyColorMap(np.uint8(255 * average_attention_map), cv2.COLORMAP_JET)

				# Overlay the heatmap on the original image
				image = inputs[i].cpu().numpy().transpose(1, 2, 0)
				image = (image*255).astype(np.uint8)
				overlayed_image = cv2.addWeighted(image, 0.8, heatmap, 0.4, 0)

				# add class label
				label = labels[i].item()
				label = train_dataset.idx_to_class[label]
				cv2.putText(overlayed_image, f"Class: {label}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)

				# convert to chw
				overlayed_image = overlayed_image.transpose(2, 0, 1)

				# log to tensorboard
				self.logger.experiment.add_image(f'attn_map/batch_{batch_idx}/img_{i}', overlayed_image, self.current_epoch)

# Create the model instance
weights_path = "./logs/vit_pretrained/version_0/checkpoints/epoch-epoch=04-val_loss-val_loss=0.01.ckpt"
model = BrainTumorClassifier()
model = BrainTumorClassifier.load_from_checkpoint(weights_path, map_location=model.device)

# Define callbacks
checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
	monitor="val_loss",                     # Monitor validation loss
	mode = "min",                           # mode for monitored metric
	dirpath="checkpoints/",                   # Directory to save checkpoints
	filename="epoch-{epoch:02d}-val_loss-{val_loss:.2f}",  # Naming pattern
	save_top_k=-1,                          # Save all checkpoints
	every_n_epochs=1,                       # Save at every epoch
)

early_stopping_callback = L.pytorch.callbacks.EarlyStopping(
	monitor="val_loss",                     # Metric to monitor
	patience=5,                             # Stop training if no improvement for 5 epochs
	mode="min",                             # Stop when `val_loss` stops decreasing
	verbose=True,
)

logger = L.pytorch.loggers.TensorBoardLogger("logs", name="vit_pretrained")


# Define the PyTorch Lightning Trainer
trainer = L.Trainer(max_epochs=10, 
					accelerator="auto", 
					logger=logger,
					callbacks=[checkpoint_callback, early_stopping_callback])


In [None]:
# Train the model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [8]:
# Test the model
trainer.test(model, dataloaders=test_loader)

c:\Users\Andreas\anaconda3\envs\deep_learning\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0:   0%|          | 0/21 [00:00<?, ?it/s]

  self.attention_maps[module_name] = torch.nn.functional.softmax(q @ k.transpose(-2, -1))  # Save attention map


NameError: name 'output' is not defined