# Medical Image to Text Report Generation
## Experiment 3: Custom Autoencoder + Transformer Decoder

This experiment implements a two-stage approach for medical image captioning that combines a custom autoencoder with a transformer decoder. Unlike previous approaches, this method first trains a specialized autoencoder to create efficient latent representations of chest X-ray images, and then independently trains a transformer decoder that uses these encodings to generate diagnostic reports.

## Setup and Dependencies

In [3]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math
import time
import json
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid, save_image

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define hyperparameters
BATCH_SIZE = 32
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NHEAD = 8
NUM_DECODER_LAYERS = 6
LATENT_DIM = 256
DROPOUT = 0.1
LEARNING_RATE_AE = 1e-4
LEARNING_RATE_TRANSFORMER = 1e-4
NUM_EPOCHS_AE = 10
NUM_EPOCHS_TRANSFORMER = 30
MAX_LENGTH = 100  # Will be updated based on actual data

# Define transforms for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])




Using device: cuda


## Custom Dataset Implementation

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, transform=None, max_length=100, word_to_idx=None, base_path=None):
        self.dataframe = dataframe
        self.transform = transform
        self.max_length = max_length
        self.word_to_idx = word_to_idx
        self.base_path = base_path
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['final_img_path']
        caption = self.dataframe.iloc[idx]['captions']
        
        # Adjust the image path if base_path is provided
        if self.base_path:
            # Extract the part of the path after 'data/'
            if 'data/' in img_path:
                relative_path = img_path[img_path.find('data/'):]
                img_path = os.path.join(self.base_path, relative_path)
            else:
                # If 'data/' is not in the path, just join with base_path
                img_path = os.path.join(self.base_path, img_path)
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        # Convert caption to tensor (if word_to_idx is provided)
        if self.word_to_idx:
            caption_encoded = [self.word_to_idx.get(word, self.word_to_idx['<unk>']) 
                              for word in ['<start>'] + caption.split() + ['<end>']]
            
            # Pad caption if needed
            if len(caption_encoded) < self.max_length:
                caption_encoded = caption_encoded + [self.word_to_idx['<pad>']] * (self.max_length - len(caption_encoded))
            else:
                caption_encoded = caption_encoded[:self.max_length]
                
            caption_tensor = torch.tensor(caption_encoded, dtype=torch.long)
            return image, caption_tensor
        else:
            # For autoencoder training, we don't need captions
            return image, image  # Return image as both input and target

## Autoencoder Architecture

The autoencoder consists of an encoder and decoder. The encoder progressively compresses the image through convolutional layers with stride=2, halving the dimensions at each step while increasing the feature channels. After four such layers, the 224×224×3 input is transformed into a 14×14×512 feature map, which is then flattened and projected to a latent representation of dimension 256.

The decoder mirrors this process, starting with a linear projection from the latent space to a flattened feature map, followed by four transpose convolutional layers that progressively double the dimensions while reducing the channels. Notable architectural features include:

1. Batch normalization for more stable training
2. Dropout for regularization
3. Skip connections between corresponding encoder and decoder layers
4. ReLU activations for all intermediate layers and tanh for the final output

In [4]:
class Encoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Encoder, self).__init__()
        # Improved encoder with better gradient flow
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        
        # For 224x224 input, feature map size after convs is 14x14x512
        self.fc = nn.Linear(14 * 14 * 512, latent_dim)
        
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(256)
        self.bn4 = nn.BatchNorm2d(512)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout(x)
        x = self.fc(x)  # Map to latent space
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 14 * 14 * 512)
        self.dropout = nn.Dropout(0.1)
        
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
        
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(64)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.dropout(self.fc(x))
        x = x.view(x.size(0), 512, 14, 14)  # Reshape to feature map
        
        x = self.relu(self.bn1(self.deconv1(x)))
        x = self.relu(self.bn2(self.deconv2(x)))
        x = self.relu(self.bn3(self.deconv3(x)))
        x = self.tanh(self.deconv4(x))  # Output in range [-1, 1]
        return x

class Autoencoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        
    def forward(self, x):
        latent = self.encoder(x)
        reconstructed = self.decoder(latent)
        return reconstructed


## Transformer Architecture

The transformer architecture is designed for sequence-to-sequence tasks with a specific focus on conditioning the generation on the encoded image features. The implementation includes:

1. Positional Encoding: Adds information about token position using sinusoidal functions
2. Image Projection: Maps the autoencoder's latent representation to the transformer's embedding dimension
3. Transformer Decoder: Standard transformer decoder with multi-head self-attention and cross-attention to the encoded image
4. Output Layer: Projects the decoder outputs to vocabulary size with layer normalization

This approach differs from many standard image captioning models by treating the encoded image as a "memory" input to the transformer decoder rather than as an initial token or condition. This allows the model to attend to different parts of the image encoding throughout the generation process.

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, nhead, num_layers, dropout=0.1):
        super(DecoderTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        
        # Image feature projection to match embedding dimensions
        self.img_projection = nn.Linear(LATENT_DIM, embed_dim)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, 
            nhead=nhead,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            activation="gelu"  # Using GELU activation instead of ReLU
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        
        # Output projection with layer normalization
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, encoded_image, captions, tgt_mask=None):
        # Project image features to the embedding space
        memory = self.img_projection(encoded_image).unsqueeze(0)
        
        # Embed captions
        embedded = self.embedding(captions)  # [batch_size, seq_len, embed_dim]
        
        # Add positional encoding
        embedded = self.pos_encoder(embedded.permute(1, 0, 2))  # [seq_len, batch_size, embed_dim]
        
        # Create target mask to prevent attention to future tokens
        if tgt_mask is None:
            seq_len = captions.size(1)
            tgt_mask = generate_square_subsequent_mask(seq_len).to(encoded_image.device)
        
        # Decode
        output = self.transformer_decoder(
            tgt=embedded,
            memory=memory,
            tgt_mask=tgt_mask
        )  # [seq_len, batch_size, embed_dim]
        
        # Apply layer normalization and project to vocabulary
        output = self.layer_norm(output)
        output = self.fc_out(output)  # [seq_len, batch_size, vocab_size]
        
        # Reshape to [batch_size, seq_len, vocab_size]
        return output.permute(1, 0, 2)

class CaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(CaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, captions, tgt_mask=None):
        # Encode images
        latent_vectors = self.encoder(images)
        
        # Decode captions
        outputs = self.decoder(latent_vectors, captions, tgt_mask)
        return outputs


## Utility Functions and Bleu score Implementation
The utility functions handle various necessary operations, including:

- Generating causal masks for the transformer (preventing attention to future tokens)
- Loading and preprocessing the dataset
- Building the vocabulary from captions
- Splitting the data into train, validation, and test sets

Implemented a custom BLEU score calculation to avoid issues encountered with NLTK's implementation. The custom implementation includes:

- Precise handling of reference and hypothesis n-grams
- Proper calculation of brevity penalty
- Support for different n-gram weightings (BLEU-1 through BLEU-4)

In [6]:
def generate_square_subsequent_mask(sz):
    """Generate a square mask for the sequence. The masked positions are filled with float('-inf')."""
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def load_data(csv_path):
    """Load and preprocess data from CSV file."""
    # Load CSV file
    df = pd.read_csv(csv_path)
    
    # Check for missing values
    print(f"Missing values in DataFrame:\n{df.isnull().sum()}")
    
    # Create vocabulary from captions
    all_captions = df['captions'].tolist()
    words = set()
    for caption in all_captions:
        words.update(caption.split())
    
    # Create word indices
    word_to_idx = {
        '<pad>': 0,
        '<start>': 1,
        '<end>': 2,
        '<unk>': 3
    }
    
    idx = 4
    for word in words:
        word_to_idx[word] = idx
        idx += 1
    
    idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    vocab_size = len(word_to_idx)
    
    # Determine max caption length for padding
    max_length = max(len(caption.split()) for caption in all_captions) + 2  # +2 for <start> and <end>
    print(f"Max caption length: {max_length}")
    
    # Split data into train, validation, and test sets (70%, 15%, 15%)
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
    
    print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}, Test samples: {len(test_df)}")
    
    return train_df, val_df, test_df, word_to_idx, idx_to_word, vocab_size, max_length

def simple_bleu_score(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25)):
    """
    Simplified BLEU score implementation that avoids issues with NLTK's implementation
    
    Args:
        references: List of reference sentences (each a list of tokens)
        hypotheses: List of hypothesis sentences (each a list of tokens)
        weights: Weights for n-gram precisions
    """
    # Ensure weights sum to 1
    if sum(weights) != 1:
        weights = tuple(w/sum(weights) for w in weights)
    
    # Maximum n-gram order
    max_n = len(weights)
    
    # Calculate n-gram matches for each order
    precisions = []
    for n in range(1, max_n + 1):
        matches = 0
        total = 0
        
        for hyp, refs in zip(hypotheses, references):
            # Skip empty hypotheses
            if len(hyp) == 0:
                continue
                
            # Count n-grams in hypothesis
            hyp_ngrams = {}
            for i in range(len(hyp) - n + 1):
                ngram = tuple(hyp[i:i+n])
                hyp_ngrams[ngram] = hyp_ngrams.get(ngram, 0) + 1
            
            # Find maximum n-gram matches among references
            max_matches = 0
            for ref in refs:
                # Skip references shorter than n
                if len(ref) < n:
                    continue
                    
                # Count n-grams in reference
                ref_ngrams = {}
                for i in range(len(ref) - n + 1):
                    ngram = tuple(ref[i:i+n])
                    ref_ngrams[ngram] = ref_ngrams.get(ngram, 0) + 1
                
                # Count matches
                ref_matches = 0
                for ngram, count in hyp_ngrams.items():
                    ref_matches += min(count, ref_ngrams.get(ngram, 0))
                
                max_matches = max(max_matches, ref_matches)
            
            # Update counts
            matches += max_matches
            total += max(1, len(hyp) - n + 1)
        
        # Calculate precision for this n-gram order
        if total > 0:
            precisions.append(matches / total)
        else:
            precisions.append(0)
    
    # Calculate brevity penalty
    hyp_lengths = [len(hyp) for hyp in hypotheses]
    ref_lengths = []
    for hyp, refs in zip(hypotheses, references):
        ref_lens = [len(ref) for ref in refs]
        closest_ref_len = min(ref_lens, key=lambda x: abs(x - len(hyp))) if ref_lens else 0
        ref_lengths.append(closest_ref_len)
    
    if sum(hyp_lengths) == 0:
        bp = 0
    else:
        bp = min(1, math.exp(1 - sum(ref_lengths) / sum(hyp_lengths)))
    
    # Calculate final BLEU score
    if 0 in precisions:
        return 0
    
    log_precisions = [math.log(p) for p in precisions]
    weighted_log_precision = sum(w * lp for w, lp in zip(weights, log_precisions))
    bleu = bp * math.exp(weighted_log_precision)
    
    return bleu

def calculate_bleu_metrics(references, hypotheses):
    """Calculate BLEU-1 through BLEU-4 scores."""
    bleu1 = simple_bleu_score(references, hypotheses, weights=(1, 0, 0, 0))
    bleu2 = simple_bleu_score(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    bleu3 = simple_bleu_score(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = simple_bleu_score(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))
    
    return {
        'bleu1': bleu1,
        'bleu2': bleu2,
        'bleu3': bleu3,
        'bleu4': bleu4
    }

def save_checkpoint(model, optimizer, epoch, loss_dict, checkpoint_path):
    """Save model checkpoint with all training state."""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_dict': loss_dict
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")

def load_checkpoint(model, optimizer, checkpoint_path, device):
    """Load model checkpoint and return training state."""
    if not os.path.exists(checkpoint_path):
        print(f"No checkpoint found at {checkpoint_path}")
        return 0, {}
    
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f"Loaded checkpoint from epoch {checkpoint['epoch']+1}")
    return checkpoint['epoch'] + 1, checkpoint.get('loss_dict', {})

def setup_logger(log_dir):
    """Set up TensorBoard logger."""
    os.makedirs(log_dir, exist_ok=True)
    return SummaryWriter(log_dir)

def generate_caption_indices(model, image, word_to_idx, max_length=100):
    """Generate caption indices using beam search."""
    model.eval()
    
    with torch.no_grad():
        # Encode image
        latent_vector = model.encoder(image)
        
        # Start with <start> token
        caption = [word_to_idx['<start>']]
        
        # Generate caption word by word
        for i in range(max_length):
            # Convert caption to tensor
            caption_tensor = torch.LongTensor(caption).unsqueeze(0).to(device)
            
            # Generate next word
            output = model.decoder(latent_vector, caption_tensor)
            predicted_word_idx = output[0, -1].argmax().item()
            
            # Add predicted word to caption
            caption.append(predicted_word_idx)
            
            # Stop if <end> token is predicted
            if predicted_word_idx == word_to_idx['<end>']:
                break
        
        return caption

def plot_learning_curves(train_values, val_values, title, ylabel, save_path):
    """Plot and save learning curves."""
    plt.figure(figsize=(10, 5))
    plt.plot(train_values, label=f'Training {ylabel}')
    plt.plot(val_values, label=f'Validation {ylabel}')
    plt.xlabel('Epochs')
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

def plot_metrics(metrics_dict, title, save_path):
    """Plot and save multiple metrics."""
    plt.figure(figsize=(10, 5))
    for name, values in metrics_dict.items():
        plt.plot(values, label=name.upper())
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()


## Autoencoder Training Loop

In [7]:
def train_autoencoder(model, train_loader, val_loader, num_epochs, 
                       checkpoint_dir='checkpoints/autoencoder',
                       log_dir='logs/autoencoder'):
    """Train the autoencoder with efficient checkpointing and logging."""
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    
    # Set up logger
    writer = setup_logger(log_dir)
    
    # Loss and optimizer
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE_AE)
    
    # For tracking metrics
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    # Try to load checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, 'latest.pth')
    start_epoch, loss_dict = load_checkpoint(model, optimizer, checkpoint_path, device)
    
    if loss_dict:
        train_losses = loss_dict.get('train_losses', [])
        val_losses = loss_dict.get('val_losses', [])
        best_val_loss = loss_dict.get('best_val_loss', float('inf'))
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0
        
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, images)
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if (i+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
                
                # Log to TensorBoard
                global_step = epoch * len(train_loader) + i
                writer.add_scalar('autoencoder/train_loss_step', loss.item(), global_step)
        
        # Calculate average training loss
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for i, (images, _) in enumerate(val_loader):
                images = images.to(device)
                outputs = model(images)
                loss = criterion(outputs, images)
                val_loss += loss.item()
                
                # Log sample reconstructions (first batch only)
                if i == 0 and epoch % 5 == 0:
                    n = min(8, images.size(0))
                    comparison = torch.cat([
                        images[:n],
                        outputs[:n]
                    ])
                    grid = make_grid(comparison, nrow=n)
                    writer.add_image('autoencoder/reconstructions', grid, epoch)
                    
                    # Save reconstructions
                    os.makedirs(os.path.join(checkpoint_dir, 'samples'), exist_ok=True)
                    save_image(grid, os.path.join(checkpoint_dir, 'samples', f'reconstruction_epoch_{epoch+1}.png'))
        
        # Calculate average validation loss
        val_loss = val_loss / len(val_loader)
        val_losses.append(val_loss)
        
        # Log metrics
        writer.add_scalar('autoencoder/train_loss_epoch', train_loss, epoch)
        writer.add_scalar('autoencoder/val_loss', val_loss, epoch)
        
        time_elapsed = time.time() - start_time
        print(f'Epoch [{epoch+1}/{num_epochs}], Time: {time_elapsed:.2f}s, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        
        # Save loss dict for checkpoints
        loss_dict = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'best_val_loss': best_val_loss
        }
        
        # Save latest checkpoint
        save_checkpoint(model, optimizer, epoch, loss_dict, checkpoint_path)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            loss_dict['best_val_loss'] = best_val_loss
            best_checkpoint_path = os.path.join(checkpoint_dir, 'best.pth')
            save_checkpoint(model, optimizer, epoch, loss_dict, best_checkpoint_path)
            print(f"Saved best model with validation loss: {val_loss:.4f}")
        
        # Save epoch checkpoint every 5 epochs
        if (epoch+1) % 5 == 0 or epoch == num_epochs-1:
            epoch_checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch, loss_dict, epoch_checkpoint_path)
    
    # Plot and save training curves
    plot_learning_curves(
        train_losses, val_losses, 
        'Autoencoder Training and Validation Loss', 
        'Loss', 
        os.path.join(checkpoint_dir, 'learning_curve.png')
    )
    
    writer.close()
    return model

# Example usage:
# Initialize the autoencoder and train it
# autoencoder = Autoencoder(LATENT_DIM).to(device)
# autoencoder = train_autoencoder(autoencoder, train_loader_ae, val_loader_ae, NUM_EPOCHS_AE)


## Transformer Training

In [8]:
def train_transformer(model, train_loader, val_loader, word_to_idx, idx_to_word, num_epochs,
                      checkpoint_dir='checkpoints/transformer', 
                      log_dir='logs/transformer'):
    """Train the transformer model with efficient checkpointing and logging."""
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(os.path.join(checkpoint_dir, 'samples'), exist_ok=True)
    
    # Set up logger
    writer = setup_logger(log_dir)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
    # Only optimize decoder parameters (encoder is frozen)
    decoder_params = list(model.decoder.parameters())
    optimizer = optim.Adam(decoder_params, lr=LEARNING_RATE_TRANSFORMER)
    
    # For tracking metrics
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    bleu_scores = {f'bleu{i}': [] for i in range(1, 5)}
    
    # Try to load checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, 'latest.pth')
    start_epoch, loss_dict = load_checkpoint(model, optimizer, checkpoint_path, device)
    
    if loss_dict:
        train_losses = loss_dict.get('train_losses', [])
        val_losses = loss_dict.get('val_losses', [])
        best_val_loss = loss_dict.get('best_val_loss', float('inf'))
        
        # Load BLEU scores if available
        for i in range(1, 5):
            key = f'bleu{i}'
            if key in loss_dict:
                bleu_scores[key] = loss_dict[key]
    
    # Training loop
    for epoch in range(start_epoch, num_epochs):
        start_time = time.time()
        model.train()
        running_loss = 0.0
        
        for i, (images, captions) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            
            # Forward pass (remove last token from input, first token from target)
            outputs = model(images, captions[:, :-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), 
                captions[:, 1:].reshape(-1)
            )
            
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            if (i+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
                
                # Log to TensorBoard
                global_step = epoch * len(train_loader) + i
                writer.add_scalar('transformer/train_loss_step', loss.item(), global_step)
        
        # Calculate average training loss
        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)
        
        # Validation
        model.eval()
        val_loss = 0.0
        references = []
        hypotheses = []
        
        with torch.no_grad():
            for i, (images, captions) in enumerate(val_loader):
                images = images.to(device)
                captions = captions.to(device)
                
                # Calculate validation loss
                outputs = model(images, captions[:, :-1])
                loss = criterion(
                    outputs.reshape(-1, outputs.shape[2]), 
                    captions[:, 1:].reshape(-1)
                )
                val_loss += loss.item()
                
                # Generate captions for BLEU score calculation (for a subset)
                if len(hypotheses) < 100:  # Limit to 100 examples for speed
                    for j in range(min(images.size(0), 5)):
                        img = images[j:j+1]
                        
                        # Generate caption
                        generated_idx = generate_caption_indices(model, img, word_to_idx)
                        generated_words = [idx_to_word[idx] for idx in generated_idx 
                                          if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]
                        
                        # Get reference caption
                        reference_idx = captions[j].cpu().numpy()
                        reference_words = [[idx_to_word[idx] for idx in reference_idx 
                                          if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]]
                        
                        references.append(reference_words)
                        hypotheses.append(generated_words)
                        
                        # Save example images with captions every 5 epochs
                        if epoch % 5 == 0 and len(hypotheses) <= 5:
                            # Convert image for display
                            img_np = img.squeeze(0).cpu().permute(1, 2, 0).numpy()
                            img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                            img_np = np.clip(img_np, 0, 1)
                            
                            plt.figure(figsize=(8, 6))
                            plt.imshow(img_np)
                            plt.title(f"Generated: {' '.join(generated_words)}\nReference: {' '.join(reference_words[0])}")
                            plt.axis('off')
                            plt.savefig(os.path.join(checkpoint_dir, 'samples', f'sample_{len(hypotheses)}_epoch_{epoch+1}.png'))
                            plt.close()
        
        # Calculate average validation loss
        val_loss = val_loss / len(val_loader)
        val_losses.append(val_loss)
        
        # Calculate BLEU scores
        bleu_metrics = calculate_bleu_metrics(references, hypotheses)
        for key, value in bleu_metrics.items():
            bleu_scores[key].append(value)
        
        # Log metrics
        writer.add_scalar('transformer/train_loss_epoch', train_loss, epoch)
        writer.add_scalar('transformer/val_loss', val_loss, epoch)
        for key, value in bleu_metrics.items():
            writer.add_scalar(f'transformer/{key}', value, epoch)
        
        time_elapsed = time.time() - start_time
        print(f'Epoch [{epoch+1}/{num_epochs}], Time: {time_elapsed:.2f}s, '
              f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
              f'BLEU-1: {bleu_metrics["bleu1"]:.4f}, BLEU-4: {bleu_metrics["bleu4"]:.4f}')
        
        # Save loss dict for checkpoints
        loss_dict = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'best_val_loss': best_val_loss,
            **{k: v for k, v in bleu_scores.items()}
        }
        
        # Save latest checkpoint
        save_checkpoint(model, optimizer, epoch, loss_dict, checkpoint_path)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            loss_dict['best_val_loss'] = best_val_loss
            best_checkpoint_path = os.path.join(checkpoint_dir, 'best.pth')
            save_checkpoint(model, optimizer, epoch, loss_dict, best_checkpoint_path)
            print(f"Saved best model with validation loss: {val_loss:.4f}")
        
        # Save epoch checkpoint every 5 epochs
        if (epoch+1) % 5 == 0 or epoch == num_epochs-1:
            epoch_checkpoint_path = os.path.join(checkpoint_dir, f'epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch, loss_dict, epoch_checkpoint_path)
    
    # Plot and save training curves
    plot_learning_curves(
        train_losses, val_losses, 
        'Transformer Training and Validation Loss', 
        'Loss', 
        os.path.join(checkpoint_dir, 'learning_curve.png')
    )
    
    # Plot BLEU scores
    plot_metrics(
        bleu_scores,
        'BLEU Scores',
        os.path.join(checkpoint_dir, 'bleu_scores.png')
    )
    
    writer.close()
    return model

# Example usage:
# Initialize the captioning model and train it
# captioning_model = CaptioningModel(encoder, transformer_decoder).to(device)
# captioning_model = train_transformer(captioning_model, train_loader, val_loader, word_to_idx, idx_to_word, NUM_EPOCHS_TRANSFORMER)


In [9]:
def evaluate_model(model, test_loader, word_to_idx, idx_to_word, results_dir='results'):
    """Evaluate the model on the test set with visualizations."""
    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(os.path.join(results_dir, 'samples'), exist_ok=True)
    
    model.eval()
    references = []
    hypotheses = []
    
    # For visualization
    results_data = []
    
    with torch.no_grad():
        for i, (images, captions) in enumerate(test_loader):
            images = images.to(device)
            
            # Generate captions for all images in batch
            for j in range(len(images)):
                img = images[j:j+1]
                
                # Generate caption
                generated_idx = generate_caption_indices(model, img, word_to_idx)
                generated_words = [idx_to_word[idx] for idx in generated_idx 
                                  if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]
                
                # Get reference caption
                reference_idx = captions[j].cpu().numpy()
                reference_words = [[idx_to_word[idx] for idx in reference_idx 
                                  if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]]
                
                references.append(reference_words)
                hypotheses.append(generated_words)
                
                # Store result data
                results_data.append({
                    'image_idx': i * len(images) + j,
                    'generated_caption': ' '.join(generated_words),
                    'reference_caption': ' '.join(reference_words[0])
                })
                
                # Save some examples for visualization
                if len(results_data) <= 20:
                    # Convert image for display
                    img_np = img.squeeze(0).cpu().permute(1, 2, 0).numpy()
                    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                    img_np = np.clip(img_np, 0, 1)
                    
                    plt.figure(figsize=(8, 6))
                    plt.imshow(img_np)
                    plt.title(f"Generated: {' '.join(generated_words)}\nReference: {' '.join(reference_words[0])}")
                    plt.axis('off')
                    plt.savefig(os.path.join(results_dir, 'samples', f'sample_{len(results_data)}.png'))
                    plt.close()
            
            # Print progress
            if (i+1) % 10 == 0:
                print(f'Evaluated {(i+1)*len(images)}/{len(test_loader)*len(images)} images')
    
    # Calculate BLEU scores
    bleu_metrics = calculate_bleu_metrics(references, hypotheses)
    
    # Print metrics
    print("\nEvaluation Metrics:")
    for key, value in bleu_metrics.items():
        print(f"{key.upper()}: {value:.4f}")
    
    # Save metrics to file
    with open(os.path.join(results_dir, 'metrics.json'), 'w') as f:
        json.dump(bleu_metrics, f, indent=4)
    
    # Save results to CSV
    results_df = pd.DataFrame(results_data)
    results_df.to_csv(os.path.join(results_dir, 'captioning_results.csv'), index=False)
    
    # Create a grid of examples
    plt.figure(figsize=(15, 20))
    for idx in range(min(10, len(results_data))):
        sample = results_data[idx]
        
        # Load image
        img_path = os.path.join(results_dir, 'samples', f'sample_{idx+1}.png')
        img = plt.imread(img_path)
        
        plt.subplot(5, 2, idx+1)
        plt.imshow(img)
        plt.axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'results_grid.png'))
    plt.close()
    
    return bleu_metrics

# Example usage:
# metrics = evaluate_model(captioning_model, test_loader, word_to_idx, idx_to_word)


In [10]:
csv_path = "final_dataset.csv"
base_path = "../../"

# Load data
train_df, val_df, test_df, word_to_idx, idx_to_word, vocab_size, max_length = load_data(csv_path)
MAX_LENGTH = max_length

# Create datasets for autoencoder training (without captions)
train_dataset_ae = ChestXrayDataset(train_df, transform=transform, base_path=base_path)
val_dataset_ae = ChestXrayDataset(val_df, transform=transform, base_path=base_path)

# Create datasets for transformer training (with captions)
train_dataset_transformer = ChestXrayDataset(
    train_df, transform=transform, max_length=MAX_LENGTH, 
    word_to_idx=word_to_idx, base_path=base_path
)
val_dataset_transformer = ChestXrayDataset(
    val_df, transform=transform, max_length=MAX_LENGTH, 
    word_to_idx=word_to_idx, base_path=base_path
)
test_dataset = ChestXrayDataset(
    test_df, transform=transform, max_length=MAX_LENGTH, 
    word_to_idx=word_to_idx, base_path=base_path
)

# Create data loaders
train_loader_ae = DataLoader(train_dataset_ae, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader_ae = DataLoader(val_dataset_ae, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

train_loader_transformer = DataLoader(train_dataset_transformer, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader_transformer = DataLoader(val_dataset_transformer, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Data preparation complete. Vocabulary size: {vocab_size}")


Missing values in DataFrame:
Unnamed: 0        0
final_img_path    0
captions          0
dtype: int64
Max caption length: 125
Training samples: 4519, Validation samples: 969, Test samples: 969
Data preparation complete. Vocabulary size: 1988


In [9]:
!nvidia-smi

Mon Apr 14 14:32:40 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          On  | 00000000:55:00.0 Off |                    0 |
| N/A   35C    P0              71W / 700W |      8MiB / 81559MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [9]:

print("\n=== Training Autoencoder ===")

# Initialize and train autoencoder
autoencoder = Autoencoder(LATENT_DIM).to(device)
autoencoder = train_autoencoder(
    autoencoder, 
    train_loader_ae, 
    val_loader_ae, 
    NUM_EPOCHS_AE
)

print("Autoencoder training complete!")



=== Training Autoencoder ===
Loaded checkpoint from epoch 9
Epoch [10/10], Step [100/142], Loss: 0.3849
Epoch [10/10], Time: 240.77s, Train Loss: 0.4017, Val Loss: 0.3951
Checkpoint saved to checkpoints/autoencoder/latest.pth
Checkpoint saved to checkpoints/autoencoder/best.pth
Saved best model with validation loss: 0.3951
Checkpoint saved to checkpoints/autoencoder/epoch_10.pth
Autoencoder training complete!


In [None]:

print("\n=== Training Transformer Decoder ===")

# Load the best autoencoder model
auto_checkpoint_path = 'checkpoints/autoencoder/best.pth'
autoencoder = Autoencoder(LATENT_DIM).to(device)
_, _ = load_checkpoint(autoencoder, None, auto_checkpoint_path, device)

# Extract the encoder part
encoder = autoencoder.encoder

# Create transformer decoder
transformer_decoder = DecoderTransformer(
    vocab_size=vocab_size,
    embed_dim=EMBEDDING_DIM,
    hidden_dim=HIDDEN_DIM,
    nhead=NHEAD,
    num_layers=NUM_DECODER_LAYERS,
    dropout=DROPOUT
).to(device)

# Create captioning model
captioning_model = CaptioningModel(encoder, transformer_decoder).to(device)

# Freeze encoder weights (only train the decoder)
for param in captioning_model.encoder.parameters():
    param.requires_grad = False

# Train the transformer
captioning_model = train_transformer(
    captioning_model, 
    train_loader_transformer, 
    val_loader_transformer, 
    word_to_idx, 
    idx_to_word, 
    NUM_EPOCHS_TRANSFORMER
)

print("Transformer training complete!")



=== Training Transformer Decoder ===
Loaded checkpoint from epoch 10
Loaded checkpoint from epoch 10
Epoch [11/30], Step [100/142], Loss: 4.8798
Epoch [11/30], Time: 260.40s, Train Loss: 4.9578, Val Loss: 5.7685, BLEU-1: 0.0000, BLEU-4: 0.0000
Checkpoint saved to checkpoints/transformer/latest.pth
Epoch [12/30], Step [100/142], Loss: 4.4543
Epoch [12/30], Time: 259.86s, Train Loss: 4.0646, Val Loss: 5.1578, BLEU-1: 0.0000, BLEU-4: 0.0000
Checkpoint saved to checkpoints/transformer/latest.pth
Epoch [13/30], Step [100/142], Loss: 3.9906
Epoch [13/30], Time: 256.56s, Train Loss: 3.9376, Val Loss: 6.6015, BLEU-1: 0.0000, BLEU-4: 0.0000
Checkpoint saved to checkpoints/transformer/latest.pth
Epoch [14/30], Step [100/142], Loss: 2.8358
Epoch [14/30], Time: 258.19s, Train Loss: 3.8187, Val Loss: 7.6144, BLEU-1: 0.0000, BLEU-4: 0.0000
Checkpoint saved to checkpoints/transformer/latest.pth
Epoch [15/30], Step [100/142], Loss: 3.6798
Epoch [15/30], Time: 262.51s, Train Loss: 3.7696, Val Loss: 8.

In [11]:

print("\n=== Evaluating on Test Set ===")

# Load best transformer model
best_model_path = 'checkpoints/transformer/best.pth'
if os.path.exists(best_model_path):
    captioning_model = CaptioningModel(encoder, transformer_decoder).to(device)
    _, _ = load_checkpoint(captioning_model, None, best_model_path, device)

# Evaluate
metrics = evaluate_model(captioning_model, test_loader, word_to_idx, idx_to_word)

print("\n=== Training and Evaluation Complete ===")
print(f"Final BLEU-1: {metrics['bleu1']:.4f}")
print(f"Final BLEU-4: {metrics['bleu4']:.4f}")

# Save final model
torch.save(captioning_model.state_dict(), 'results/final_model.pth')



=== Evaluating on Test Set ===
Loaded checkpoint from epoch 9
Evaluated 320/992 images
Evaluated 640/992 images
Evaluated 960/992 images

Evaluation Metrics:
BLEU1: 0.0000
BLEU2: 0.0000
BLEU3: 0.0000
BLEU4: 0.0000

=== Training and Evaluation Complete ===
Final BLEU-1: 0.0000
Final BLEU-4: 0.0000


# Results and Analysis
From the training outputs included in the notebook, we can observe several aspects of the model's performance:

1. Autoencoder Training: The autoencoder training completed successfully over 10 epochs, with the validation loss decreasing to 0.3951. This indicates that the model achieved reasonable reconstruction quality for the medical images.
2. Transformer Training: The transformer training showed mixed results. While the training loss decreased steadily (from 4.9578 to 3.4380 over the epochs shown), the validation loss exhibited high variability, ranging from 5.1578 to 9.9787. This suggests potential overfitting or instability in the caption generation task.
3. BLEU Scores: Unfortunately, the BLEU scores during the training process were consistently at 0.0000, indicating that the model was not generating captions that matched the reference captions. This could be due to several factors:

    - The transformer might need more training time to learn the complex mapping from image features to text
    - The cross-modal gap between visual features and text might be too challenging for the current architecture
    - The dataset size might be insufficient for the model to learn meaningful patterns


4. Final Evaluation: The evaluation on the test set confirmed the issues observed during training, with all BLEU metrics at 0.0000.

The zero BLEU scores suggest that the generated captions had no n-gram overlap with the reference captions. This could happen if:

1. The model is generating very short captions (e.g., single tokens)
2. The model is repeating the same generic phrases for all images
3. The model might be focusing on completely different aspects of the images than those described in the reference captions

## Conclusions and Future Work

This experiment demonstrated the implementation of a two-stage approach for medical image captioning using a custom autoencoder and transformer decoder. While the autoencoder phase was successful, the caption generation phase faced significant challenges.
For future work, several improvements could be explored:

1. Improved Architectures: Using more sophisticated architectures for both the encoder and decoder components
2. Pretraining: Leveraging medical domain-specific pretraining for the encoder
3. Transformer Variants: Experimenting with other transformer variants like BART or T5
4. Learning Rate Scheduling: Implementing a more sophisticated learning rate schedule to stabilize training
5. Data Augmentation: Increasing the effective dataset size through data augmentation techniques
6. Attention Visualization: Adding attention visualization to understand what parts of the images the model focuses on
7. Teacher Forcing Reduction: Gradually reducing teacher forcing during training to improve generalization

Despite the current limitations, this approach provides a solid foundation for future experimentation with medical image captioning systems. The separate training of the autoencoder and transformer components allows for modular improvement of each part independently.