In [1]:
import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import matplotlib.pyplot as plt
import numpy as np

# Define Dataset class
class FaceMapDataset(Dataset):
    def __init__(self, data_file="data/facemap_softlabels.pt", transform=None):
        super().__init__()
        self.transform = transform
        self.data, _, self.targets = torch.load(data_file)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, index):
        image, label = self.data[index].clone(), self.targets[index].clone()
        image = image.repeat(3, 1, 1)  # Converts (1, 224, 224) to (3, 224, 224)

        if self.transform:
            image = self.transform(image)
        
        return image, label

# Initialize the dataset with transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
])

dataset = FaceMapDataset(data_file="data/facemap_softlabels.pt", transform=transform)
train_loader = DataLoader(dataset, batch_size=8, shuffle=True)

# Define Modified WindowAttention to save attention weights
class ModifiedWindowAttention(nn.Module):
    def __init__(self, original_window_attention):
        super(ModifiedWindowAttention, self).__init__()
        # Copy parameters from original WindowAttention
        self.qkv = original_window_attention.qkv
        self.scale = original_window_attention.scale
        self.num_heads = original_window_attention.num_heads
        self.attn_drop = original_window_attention.attn_drop
        self.proj = original_window_attention.proj

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        self.attn_weights = attn  # Save attention weights for visualization
        
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

class SimpleSwinHeatmap(nn.Module):
    def __init__(self, pretrained=True):
        super(SimpleSwinHeatmap, self).__init__()
        # Load the full Swin Transformer model without `features_only=True`
        self.encoder = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained)
        
        # Modify WindowAttention layers to use ModifiedWindowAttention
        for stage in self.encoder.layers:
            for block in stage.blocks:
                block.attn = ModifiedWindowAttention(block.attn)

        # Final convolutional layer to reduce to single-channel heatmap output
        self.conv_out = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1)

    def forward(self, x):
        # Forward pass through the encoder to get feature maps
        encoder_outputs = []
        for i, stage in enumerate(self.encoder.layers):
            x = stage(x)
            encoder_outputs.append(x)
            print(f"Feature map at stage {i} has shape: {x.shape}")

        # Use the last stage output
        x = encoder_outputs[-1]
        x = self.conv_out(x)
        x = nn.functional.interpolate(x, size=(224, 224), mode='bilinear', align_corners=False)

        return x, encoder_outputs

    def get_attention_maps(self):
        attention_maps = []
        for stage in self.encoder.layers:
            for block in stage.blocks:
                if hasattr(block.attn, 'attn_weights'):
                    attention_maps.append(block.attn.attn_weights)
        return attention_maps

# Initialize model, loss, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleSwinHeatmap().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, masks) in enumerate(train_loader):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()

        outputs, encoder_outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 0:
            print(f"Epoch {epoch+1}/{num_epochs}, Batch {i}, Loss: {loss.item()}")

    torch.save(model.state_dict(), 'simple_swin_heatmap.pth')

# Visualization functions
test_loader = DataLoader(dataset, batch_size=1, shuffle=False)

def plot_overlay(image, mask, prediction, alpha=0.5):
    image_np = image.cpu().numpy().transpose(1, 2, 0)
    mask_np = mask.cpu().squeeze().numpy()
    prediction_np = prediction.cpu().squeeze().numpy()

    mask_np = (mask_np - mask_np.min()) / (mask_np.max() - mask_np.min())
    prediction_np = (prediction_np - prediction_np.min()) / (prediction_np.max() - prediction_np.min())

    plt.figure(figsize=(12, 4))
    plt.subplot(1, 3, 1)
    plt.imshow(image_np, cmap='gray')
    plt.title("Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(image_np, cmap='gray')
    plt.imshow(mask_np, cmap='jet', alpha=alpha)
    plt.title("True Mask Overlay")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(image_np, cmap='gray')
    plt.imshow(prediction_np, cmap='jet', alpha=alpha)
    plt.title("Predicted Heatmap Overlay")
    plt.axis('off')

    plt.show()

def visualize_feature_maps(encoder_outputs):
    for idx, feature_map in enumerate(encoder_outputs):
        channels = feature_map.shape[1]
        fig, axes = plt.subplots(1, min(channels, 8), figsize=(15, 15))
        for i in range(min(channels, 8)):
            ax = axes[i]
            ax.imshow(feature_map[0, i].detach().cpu().numpy(), cmap='viridis')
            ax.axis('off')
        plt.suptitle(f"Feature Maps at Encoder Stage {idx}")
        plt.show()

def visualize_attention_maps(attention_maps):
    for idx, attn_map in enumerate(attention_maps):
        attn_map = attn_map[0, 0].detach().cpu().numpy()
        plt.imshow(attn_map, cmap='viridis')
        plt.title(f"Attention Map at Block {idx}")
        plt.colorbar()
        plt.show()

# Evaluation and visualization
model.eval()
for images, masks in test_loader:
    images, masks = images.to(device), masks.to(device)
    with torch.no_grad():
        predictions, encoder_outputs = model(images)
        attention_maps = model.get_attention_maps()
    
    plot_overlay(images[0], masks[0], predictions[0])
    visualize_feature_maps(encoder_outputs)
    visualize_attention_maps(attention_maps)
    break


FileNotFoundError: [Errno 2] No such file or directory: 'data/facemap_softlabels.pt'