# Medical Image to Text Report Generation
## Experiment 4: Autoencoder + Transformer Decoder

This notebook implements an advanced approach to medical image captioning that draws inspiration from stable diffusion techniques. The method uses a custom autoencoder trained with GAN objectives to create a compact latent representation of X-ray images, which is then used by a transformer decoder to generate diagnostic reports.

## Import Libraries and Dependencies

In [5]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import math
import time
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.meteor_score import meteor_score
from rouge import Rouge

In [2]:
!pip install rouge

Defaulting to user installation because normal site-packages is not writeable
Collecting rouge
  Using cached rouge-1.0.1-py3-none-any.whl.metadata (4.1 kB)
Using cached rouge-1.0.1-py3-none-any.whl (13 kB)
Installing collected packages: rouge
Successfully installed rouge-1.0.1


## Config and Hyperparameters
I set the hyperparameters for both the autoencoder and transformer components. This approach requires two separate training phases, so I define different epoch counts for each. The latent dimension (256) represents the size of the compressed image representation.

In [8]:
torch.manual_seed(42)
np.random.seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define hyperparameters
BATCH_SIZE = 32
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
NHEAD = 8
NUM_DECODER_LAYERS = 6
LATENT_DIM = 256
DROPOUT = 0.1
LEARNING_RATE = 1e-4
NUM_EPOCHS_AE = 10
NUM_EPOCHS_TRANSFORMER = 30
MAX_LENGTH = 100

# Define transforms for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                         std=[0.229, 0.224, 0.225])
])



## Dataset Class
The dataset class handles loading images and captions, with special attention to path resolution for different environments. It also takes care of tokenizing and padding captions to a consistent length.

In [36]:
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, transform=None, max_length=100, word_to_idx=None, base_path=None):
        self.dataframe = dataframe
        self.transform = transform
        self.max_length = max_length
        self.word_to_idx = word_to_idx
        self.base_path = base_path
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx]['final_img_path']
        caption = self.dataframe.iloc[idx]['captions']
        
        # Adjust the image path if base_path is provided
        if self.base_path:
            # Extract the part of the path after 'data/'
            if 'data/' in img_path:
                relative_path = img_path[img_path.find('data/'):]
                img_path = os.path.join(self.base_path, relative_path)
            else:
                # If 'data/' is not in the path, just join with base_path
                img_path = os.path.join(self.base_path, img_path)
        
        # Load and transform image
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        # Convert caption to tensor
        caption_encoded = [self.word_to_idx.get(word, self.word_to_idx['<unk>']) 
                          for word in ['<start>'] + caption.split() + ['<end>']]
        
        # Pad caption if needed
        if len(caption_encoded) < self.max_length:
            caption_encoded = caption_encoded + [self.word_to_idx['<pad>']] * (self.max_length - len(caption_encoded))
        else:
            caption_encoded = caption_encoded[:self.max_length]
            
        caption_tensor = torch.tensor(caption_encoded, dtype=torch.long)
        
        return image, caption_tensor

## Autoencoder Components
The autoencoder consists of three components:

1. Encoder: A series of convolutional layers that compress the image to a latent representation
2. Decoder: Transpose convolutional layers that reconstruct the image from the latent space
3. Discriminator: A GAN component that helps improve the quality of generated images

This architecture is inspired by stable diffusion techniques but simplified for the medical imaging domain. The encoder progressively downsamples the image while increasing the feature channels, ending with a dense layer that maps to the latent space.

In [11]:
# Simple Autoencoder
class Encoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Encoder, self).__init__()
        # Use a simpler architecture
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        
        # For 224x224 input, feature map size will be 14x14x512
        self.fc = nn.Linear(14 * 14 * 512, latent_dim)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.relu(self.conv4(x))
        
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)  # Map to latent space
        return x

class Decoder(nn.Module):
    def __init__(self, latent_dim=256):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(latent_dim, 14 * 14 * 512)
        
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=1)
        
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), 512, 14, 14)  # Reshape to feature map
        
        x = self.relu(self.deconv1(x))
        x = self.relu(self.deconv2(x))
        x = self.relu(self.deconv3(x))
        x = self.tanh(self.deconv4(x))  # Output in range [-1, 1]
        return x


# Simple GAN Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        # Simple discriminator
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
        
        # For 224x224 input, feature map size after convs is 14x14x512
        self.fc = nn.Linear(14 * 14 * 512, 1)
        
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        x = self.leaky_relu(self.conv1(x))
        x = self.leaky_relu(self.conv2(x))
        x = self.leaky_relu(self.conv3(x))
        x = self.leaky_relu(self.conv4(x))
        
        x = x.view(x.size(0), -1)  # Flatten
        x = self.sigmoid(self.fc(x))  # Output probability
        return x

## Transformer Components

1. Positional encoding is critical for the transformer, as it provides sequential information to the otherwise position-agnostic architecture. This implementation uses the standard sinusoidal encodings proposed in the original transformer paper.
2. The transformer decoder receives the encoded image as a "memory" input and generates text tokens autoregressively. It uses multi-head attention mechanisms to attend to both the encoded image and previously generated tokens.

In [30]:
# Positional Encoding for Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [13]:
# Transformer Decoder
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, nhead, num_layers, dropout=0.1):
        super(DecoderTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_dim, 
            nhead=nhead,
            dim_feedforward=hidden_dim,
            dropout=dropout
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_dim, vocab_size)
        
    def forward(self, encoded_image, captions, tgt_mask=None):
        # Convert encoded image to the right shape for transformer
        # encoded_image shape: [batch_size, embed_dim]
        # We need: [1, batch_size, embed_dim] for memory in transformer
        memory = encoded_image.unsqueeze(0)
        
        # Embed captions
        # captions shape: [batch_size, seq_len]
        embedded = self.embedding(captions)  # [batch_size, seq_len, embed_dim]
        
        # Add positional encoding
        embedded = self.pos_encoder(embedded.permute(1, 0, 2))  # [seq_len, batch_size, embed_dim]
        
        # Create target mask to prevent attention to future tokens
        if tgt_mask is None:
            seq_len = captions.size(1)
            tgt_mask = generate_square_subsequent_mask(seq_len).to(device)
        
        # Decode
        output = self.transformer_decoder(
            tgt=embedded,
            memory=memory,
            tgt_mask=tgt_mask
        )  # [seq_len, batch_size, embed_dim]
        
        # Project to vocabulary
        output = self.fc_out(output)  # [seq_len, batch_size, vocab_size]
        
        # Reshape to [batch_size, seq_len, vocab_size]
        return output.permute(1, 0, 2)

## Complete Model
The full model combines the encoder from the autoencoder and the transformer decoder into a cohesive pipeline. The encoder creates a latent representation of the image, which is then passed to the transformer decoder to generate the textual report.

In [14]:
# Full image captioning model
class ImageCaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(ImageCaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, images, captions, tgt_mask=None):
        # Encode images to get latent representation
        latent_vectors = self.encoder(images)
        
        # Decode using transformer decoder
        outputs = self.decoder(latent_vectors, captions, tgt_mask)
        return outputs

## Utility Functions
These utility functions handle data loading, vocabulary creation, and the generation of attention masks for the transformer. The **load_data** function is particularly important as it prepares the data and constructs the vocabulary mappings.

In [28]:
# Utility function to generate square mask for transformer
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# Load and preprocess data
def load_data(csv_path):
    # Load CSV file
    df = pd.read_csv(csv_path)
    
    # Check for missing values
    print(f"Missing values in DataFrame:\n{df.isnull().sum()}")
    
    # Create vocabulary from captions
    all_captions = df['captions'].tolist()
    words = set()
    for caption in all_captions:
        words.update(caption.split())
    
    # Create word indices
    word_to_idx = {
        '<pad>': 0,
        '<start>': 1,
        '<end>': 2,
        '<unk>': 3
    }
    
    idx = 4
    for word in words:
        word_to_idx[word] = idx
        idx += 1
    
    idx_to_word = {idx: word for word, idx in word_to_idx.items()}
    vocab_size = len(word_to_idx)
    
    # Determine max caption length for padding
    max_length = max(len(caption.split()) for caption in all_captions) + 2  # +2 for <start> and <end>
    print(f"Max caption length: {max_length}")
    
    # Split data into train, validation, and test sets (70%, 15%, 15%)
    train_df, temp_df = train_test_split(df, test_size=0.3, random_state=42)
    val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)
    
    print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}, Test samples: {len(test_df)}")
    
    return train_df, val_df, test_df, word_to_idx, idx_to_word, vocab_size, max_length


## Training Functions


### Autoencoder Training
This function implements a GAN-based training approach for the autoencoder. It alternates between:

1. Training the encoder-decoder pair to minimize reconstruction error and fool the discriminator
2. Training the discriminator to distinguish between real and reconstructed images

The GAN objective helps the autoencoder generate more realistic and detailed reconstructions, which can lead to better latent representations.

In [16]:
# Function to train the autoencoder with GAN
def train_autoencoder_gan(encoder, decoder, discriminator, train_loader, val_loader, device, num_epochs=20, checkpoint_dir='checkpoints/autoencoder'):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Optimizers
    optimizer_G = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.0002, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    
    # Loss functions
    reconstruction_loss = nn.MSELoss()
    adversarial_loss = nn.BCELoss()
    
    # Training metrics
    train_recon_losses = []
    train_gen_losses = []
    train_disc_losses = []
    val_recon_losses = []
    
    # Best validation loss for model saving
    best_val_loss = float('inf')
    
    # Training loop
    for epoch in range(num_epochs):
        start_time = time.time()
        encoder.train()
        decoder.train()
        discriminator.train()
        
        total_recon_loss = 0
        total_gen_loss = 0
        total_disc_loss = 0
        
        for batch_idx, (real_images, _) in enumerate(train_loader):
            real_images = real_images.to(device)
            batch_size = real_images.size(0)
            
            # Ground truth labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)
            
            # -----------------
            # Train Generator (Encoder-Decoder)
            # -----------------
            optimizer_G.zero_grad()
            
            # Generate latent representation and reconstruct
            latent_vectors = encoder(real_images)
            reconstructed_images = decoder(latent_vectors)
            
            # Calculate reconstruction loss
            recon_loss = reconstruction_loss(reconstructed_images, real_images)
            
            # Calculate generator adversarial loss
            validity = discriminator(reconstructed_images)
            gen_loss = adversarial_loss(validity, real_labels)
            
            # Combined loss
            g_loss = recon_loss + 0.001 * gen_loss
            g_loss.backward()
            optimizer_G.step()
            
            # -----------------
            # Train Discriminator
            # -----------------
            optimizer_D.zero_grad()
            
            # Loss for real images
            validity_real = discriminator(real_images)
            d_real_loss = adversarial_loss(validity_real, real_labels)
            
            # Loss for fake images
            validity_fake = discriminator(reconstructed_images.detach())
            d_fake_loss = adversarial_loss(validity_fake, fake_labels)
            
            # Total discriminator loss
            d_loss = 0.5 * (d_real_loss + d_fake_loss)
            d_loss.backward()
            optimizer_D.step()
            
            # Save statistics
            total_recon_loss += recon_loss.item()
            total_gen_loss += gen_loss.item()
            total_disc_loss += d_loss.item()
            
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], "
                      f"Recon Loss: {recon_loss.item():.4f}, Gen Loss: {gen_loss.item():.4f}, Disc Loss: {d_loss.item():.4f}")
        
        # Validation
        encoder.eval()
        decoder.eval()
        val_recon_loss = 0
        
        with torch.no_grad():
            for val_images, _ in val_loader:
                val_images = val_images.to(device)
                
                # Encode and reconstruct
                val_latent = encoder(val_images)
                val_reconstructed = decoder(val_latent)
                
                # Calculate reconstruction loss
                val_loss = reconstruction_loss(val_reconstructed, val_images)
                val_recon_loss += val_loss.item()
        
        # Calculate average losses
        avg_train_recon = total_recon_loss / len(train_loader)
        avg_train_gen = total_gen_loss / len(train_loader)
        avg_train_disc = total_disc_loss / len(train_loader)
        avg_val_recon = val_recon_loss / len(val_loader)
        
        # Save metrics
        train_recon_losses.append(avg_train_recon)
        train_gen_losses.append(avg_train_gen)
        train_disc_losses.append(avg_train_disc)
        val_recon_losses.append(avg_val_recon)
        
        # Print epoch statistics
        time_elapsed = time.time() - start_time
        print(f"Epoch [{epoch+1}/{num_epochs}], Time: {time_elapsed:.2f}s, "
              f"Train Recon Loss: {avg_train_recon:.4f}, Train Gen Loss: {avg_train_gen:.4f}, "
              f"Train Disc Loss: {avg_train_disc:.4f}, Val Recon Loss: {avg_val_recon:.4f}")
        
        # Save best model
        if avg_val_recon < best_val_loss:
            best_val_loss = avg_val_recon
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'train_recon_loss': avg_train_recon,
                'val_recon_loss': avg_val_recon,
            }, checkpoint_path)
            print(f"Saved best model with validation loss: {avg_val_recon:.4f}")
        
        # Save regular checkpoint
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_G_state_dict': optimizer_G.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                'train_recon_loss': avg_train_recon,
                'val_recon_loss': avg_val_recon,
            }, checkpoint_path)
    
    # Plot training curves
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_recon_losses, label='Train Reconstruction Loss')
    plt.plot(val_recon_losses, label='Validation Reconstruction Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Reconstruction Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(train_gen_losses, label='Generator Loss')
    plt.plot(train_disc_losses, label='Discriminator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Adversarial Losses')
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, 'autoencoder_training_curves.png'))
    plt.show()
    
    # Save metrics
    metrics = {
        'train_recon_losses': train_recon_losses,
        'train_gen_losses': train_gen_losses,
        'train_disc_losses': train_disc_losses,
        'val_recon_losses': val_recon_losses
    }
    np.save(os.path.join(checkpoint_dir, 'training_metrics.npy'), metrics)
    
    return encoder, decoder, discriminator

## Transformer Training
After the autoencoder is trained, this function handles training the transformer decoder component. It:

1. Processes batches of images and captions
2. Calculates loss using cross-entropy (ignoring padding tokens)
3. Performs backpropagation and optimization
4. Evaluates on the validation set using BLEU scores
5. Saves checkpoints and training curves

In [17]:
# Function to train the transformer model
def train_transformer(model, train_loader, val_loader, word_to_idx, idx_to_word, criterion, optimizer, num_epochs=30, checkpoint_dir='checkpoints/transformer'):
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training metrics
    train_losses = []
    val_losses = []
    bleu_scores = []
    
    # Best validation loss for model saving
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        start_time = time.time()
        model.train()
        total_loss = 0
        
        for batch_idx, (images, captions) in enumerate(train_loader):
            images = images.to(device)
            captions = captions.to(device)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass (remove last token from input, first token from target)
            outputs = model(images, captions[:, :-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), 
                captions[:, 1:].reshape(-1)
            )
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            if (batch_idx + 1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
        
        # Validation
        model.eval()
        val_loss = 0
        all_refs = []
        all_hyps = []
        
        with torch.no_grad():
            for val_images, val_captions in val_loader:
                val_images = val_images.to(device)
                val_captions = val_captions.to(device)
                
                # Forward pass
                outputs = model(val_images, val_captions[:, :-1])
                
                # Calculate loss
                batch_loss = criterion(
                    outputs.reshape(-1, outputs.shape[2]), 
                    val_captions[:, 1:].reshape(-1)
                )
                val_loss += batch_loss.item()
                
                # Generate captions for BLEU score
                for i in range(min(5, val_images.size(0))):  # Only evaluate first 5 images per batch to save time
                    img = val_images[i].unsqueeze(0)
                    
                    # Generate caption
                    generated_idx = generate_caption_indices(model, img, word_to_idx)
                    
                    # Convert to words
                    generated_caption = [idx_to_word[idx] for idx in generated_idx 
                                       if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]
                    
                    # Get reference caption
                    reference_idx = val_captions[i].cpu().numpy()
                    reference_caption = [[idx_to_word[idx] for idx in reference_idx 
                                        if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]]
                    
                    all_refs.append(reference_caption)
                    all_hyps.append(generated_caption)
        
        # Calculate metrics
        avg_train_loss = total_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        # Calculate BLEU score
        bleu1 = corpus_bleu(all_refs, all_hyps, weights=(1, 0, 0, 0))
        bleu4 = corpus_bleu(all_refs, all_hyps, weights=(0.25, 0.25, 0.25, 0.25))
        
        # Save metrics
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        bleu_scores.append((bleu1, bleu4))
        
        # Print epoch statistics
        time_elapsed = time.time() - start_time
        print(f"Epoch [{epoch+1}/{num_epochs}], Time: {time_elapsed:.2f}s, "
              f"Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, "
              f"BLEU-1: {bleu1:.4f}, BLEU-4: {bleu4:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            checkpoint_path = os.path.join(checkpoint_dir, 'best_model.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'bleu1': bleu1,
                'bleu4': bleu4,
            }, checkpoint_path)
            print(f"Saved best model with validation loss: {avg_val_loss:.4f}")
        
        # Save regular checkpoint
        if (epoch + 1) % 5 == 0 or epoch == num_epochs - 1:
            checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'bleu1': bleu1,
                'bleu4': bleu4,
            }, checkpoint_path)
    
    # Plot training curves
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 1, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    
    plt.subplot(2, 1, 2)
    bleu1_scores = [b[0] for b in bleu_scores]
    bleu4_scores = [b[1] for b in bleu_scores]
    plt.plot(bleu1_scores, label='BLEU-1')
    plt.plot(bleu4_scores, label='BLEU-4')
    plt.xlabel('Epochs')
    plt.ylabel('Score')
    plt.legend()
    plt.title('BLEU Scores')
    
    plt.tight_layout()
    plt.savefig(os.path.join(checkpoint_dir, 'transformer_training_curves.png'))
    plt.show()
    
    # Save metrics
    metrics = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'bleu1_scores': bleu1_scores,
        'bleu4_scores': bleu4_scores
    }
    np.save(os.path.join(checkpoint_dir, 'training_metrics.npy'), metrics)
    
    return model

## Generation and Evaluation Functions
These functions handle inference with the trained model. **generate_caption_indices** performs autoregressive generation, predicting one token at a time until an end token is produced. **generate_caption** converts these indices to human-readable text.

In [18]:
# Function to generate caption indices for an image
def generate_caption_indices(model, image, word_to_idx, max_length=100):
    model.eval()
    
    with torch.no_grad():
        # Encode image
        latent_vector = model.encoder(image)
        
        # Initialize caption with START token
        caption = [word_to_idx['<start>']]
        
        for i in range(max_length):
            # Convert current caption to tensor
            caption_tensor = torch.LongTensor(caption).unsqueeze(0).to(device)
            
            # Generate next word prediction
            output = model(image, caption_tensor)
            
            # Get the predicted next word
            predicted_word_idx = output[0, -1].argmax().item()
            
            # Add predicted word to caption
            caption.append(predicted_word_idx)
            
            # Stop if END token is predicted
            if predicted_word_idx == word_to_idx['<end>']:
                break
        
        return caption

# Function to generate a readable caption for an image
def generate_caption(model, image, word_to_idx, idx_to_word, max_length=100):
    # Get caption indices
    caption_indices = generate_caption_indices(model, image, word_to_idx, max_length)
    
    # Convert indices to words, removing special tokens
    caption_words = [idx_to_word[idx] for idx in caption_indices 
                   if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]
    
    return ' '.join(caption_words)


## Evaluation Function
The evaluation function is comprehensive, calculating:

1. BLEU scores at multiple n-gram levels
2. METEOR scores for semantic similarity
3. ROUGE scores for recall-oriented metrics
4. Visual examples with both generated and reference captions

This allows for a thorough assessment of model performance beyond just BLEU scores.

In [19]:
# Function to evaluate model on test set
def evaluate_model(model, test_loader, word_to_idx, idx_to_word, results_dir='results'):
    os.makedirs(results_dir, exist_ok=True)
    
    model.eval()
    
    all_refs = []
    all_hyps = []
    all_image_ids = []
    
    with torch.no_grad():
        for i, (images, captions) in enumerate(test_loader):
            images = images.to(device)
            
            for j in range(images.size(0)):
                img = images[j].unsqueeze(0)
                
                # Generate caption
                generated_idx = generate_caption_indices(model, img, word_to_idx)
                generated_caption = [idx_to_word[idx] for idx in generated_idx 
                                   if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]
                
                # Get reference caption
                reference_idx = captions[j].cpu().numpy()
                reference_caption = [[idx_to_word[idx] for idx in reference_idx 
                                    if idx not in [word_to_idx['<pad>'], word_to_idx['<start>'], word_to_idx['<end>']]]]
                
                all_refs.append(reference_caption)
                all_hyps.append(generated_caption)
                all_image_ids.append(f"img_{i}_{j}")
                
                # Save some examples (first 10 images)
                if len(all_image_ids) <= 10:
                    plt.figure(figsize=(10, 6))
                    # Denormalize image
                    img_np = img.cpu().squeeze().permute(1, 2, 0).numpy()
                    img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
                    img_np = np.clip(img_np, 0, 1)
                    
                    plt.imshow(img_np)
                    plt.title(f"Generated: {' '.join(generated_caption)}\nReference: {' '.join(reference_caption[0])}")
                    plt.axis('off')
                    plt.savefig(os.path.join(results_dir, f'example_{len(all_image_ids)}.png'))
                    plt.close()
    
    # Calculate metrics
    bleu1 = corpus_bleu(all_refs, all_hyps, weights=(1, 0, 0, 0))
    bleu2 = corpus_bleu(all_refs, all_hyps, weights=(0.5, 0.5, 0, 0))
    bleu3 = corpus_bleu(all_refs, all_hyps, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(all_refs, all_hyps, weights=(0.25, 0.25, 0.25, 0.25))
    
    # Calculate METEOR score
    meteor_scores = []
    for i in range(len(all_hyps)):
        meteor_scores.append(meteor_score(all_refs[i], all_hyps[i]))
    avg_meteor = np.mean(meteor_scores)
    
    # Calculate ROUGE score
    rouge = Rouge()
    rouge_scores = rouge.get_scores([' '.join(h) for h in all_hyps], [' '.join(r[0]) for r in all_refs], avg=True)
    
    # Print and save metrics
    print(f"BLEU-1: {bleu1:.4f}")
    print(f"BLEU-2: {bleu2:.4f}")
    print(f"BLEU-3: {bleu3:.4f}")
    print(f"BLEU-4: {bleu4:.4f}")
    print(f"METEOR: {avg_meteor:.4f}")
    print(f"ROUGE-1 F1: {rouge_scores['rouge-1']['f']:.4f}")
    print(f"ROUGE-2 F1: {rouge_scores['rouge-2']['f']:.4f}")
    print(f"ROUGE-L F1: {rouge_scores['rouge-l']['f']:.4f}")
    
    # Save metrics to file
    metrics = {
        'bleu1': bleu1,
        'bleu2': bleu2,
        'bleu3': bleu3,
        'bleu4': bleu4,
        'meteor': avg_meteor,
        'rouge1_f1': rouge_scores['rouge-1']['f'],
        'rouge2_f1': rouge_scores['rouge-2']['f'],
        'rougeL_f1': rouge_scores['rouge-l']['f']
    }
    
    with open(os.path.join(results_dir, 'metrics.txt'), 'w') as f:
        for k, v in metrics.items():
            f.write(f"{k}: {v:.4f}\n")
    
    # Save generated captions
    captions_df = pd.DataFrame({
        'image_id': all_image_ids,
        'reference_caption': [' '.join(r[0]) for r in all_refs],
        'generated_caption': [' '.join(h) for h in all_hyps]
    })
    captions_df.to_csv(os.path.join(results_dir, 'generated_captions.csv'), index=False)
    
    return metrics

## Main Function
The main function orchestrates the entire training pipeline:

1. Loads and preprocesses the data
2. Initializes all model components
3. Trains the autoencoder with GAN objectives
4. Uses the best encoder from the autoencoder training (freezing its weights)
5. Trains the transformer decoder with the encoded image features
6. Evaluates the final model on the test set

This two-stage training approach allows each component to be optimized for its specific task before being combined.


In [34]:
# Main function to run the entire pipeline
def main(csv_path, resume_training=False, resume_from=None):
    # Set up checkpoint and results directories
    checkpoint_dir_ae = 'checkpoints/autoencoder'
    checkpoint_dir_transformer = 'checkpoints/transformer'
    results_dir = 'results'
    
    for dir_path in [checkpoint_dir_ae, checkpoint_dir_transformer, results_dir]:
        os.makedirs(dir_path, exist_ok=True)
    
    # Load and preprocess data
    train_df, val_df, test_df, word_to_idx, idx_to_word, vocab_size, max_length = load_data(csv_path)
    
    # Update global MAX_LENGTH
    global MAX_LENGTH
    MAX_LENGTH = max_length
    # In your main function
    base_path = "../../"  # This will go up two directories from your notebook location
    
    # Create datasets with the corrected base path
    train_dataset = ChestXrayDataset(train_df, transform=transform, max_length=MAX_LENGTH, 
                                    word_to_idx=word_to_idx, base_path=base_path)
    val_dataset = ChestXrayDataset(val_df, transform=transform, max_length=MAX_LENGTH, 
                                  word_to_idx=word_to_idx, base_path=base_path)
    test_dataset = ChestXrayDataset(test_df, transform=transform, max_length=MAX_LENGTH, 
                                   word_to_idx=word_to_idx, base_path=base_path)
    
    # # Create datasets
    # train_dataset = ChestXrayDataset(train_df, transform=transform, max_length=MAX_LENGTH, word_to_idx=word_to_idx)
    # val_dataset = ChestXrayDataset(val_df, transform=transform, max_length=MAX_LENGTH, word_to_idx=word_to_idx)
    # test_dataset = ChestXrayDataset(test_df, transform=transform, max_length=MAX_LENGTH, word_to_idx=word_to_idx)
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    
    # Initialize models
    encoder = Encoder(LATENT_DIM).to(device)
    decoder = Decoder(LATENT_DIM).to(device)
    discriminator = Discriminator().to(device)
    
    transformer_decoder = DecoderTransformer(
        vocab_size=vocab_size,
        embed_dim=EMBEDDING_DIM,
        hidden_dim=HIDDEN_DIM,
        nhead=NHEAD,
        num_layers=NUM_DECODER_LAYERS,
        dropout=DROPOUT
    ).to(device)
    
    full_model = ImageCaptioningModel(encoder, transformer_decoder).to(device)
    
    # Define loss and optimizer for transformer
    criterion = nn.CrossEntropyLoss(ignore_index=word_to_idx['<pad>'])
    optimizer = torch.optim.Adam(full_model.parameters(), lr=LEARNING_RATE)
    
    # Resume training if specified
    if resume_training and resume_from is not None:
        print(f"Resuming training from checkpoint: {resume_from}")
        checkpoint = torch.load(resume_from)
        
        if 'encoder_state_dict' in checkpoint:
            # Resuming autoencoder training
            encoder.load_state_dict(checkpoint['encoder_state_dict'])
            decoder.load_state_dict(checkpoint['decoder_state_dict'])
            discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
            print(f"Loaded autoencoder checkpoint from epoch {checkpoint['epoch'] + 1}")
        elif 'model_state_dict' in checkpoint:
            # Resuming transformer training
            full_model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print(f"Loaded transformer checkpoint from epoch {checkpoint['epoch'] + 1}")
    
    # Step 1: Train autoencoder with GAN
    print("\n=== Training Autoencoder-GAN ===")
    encoder, decoder, discriminator = train_autoencoder_gan(
        encoder, decoder, discriminator, train_loader, val_loader, device, 
        num_epochs=NUM_EPOCHS_AE, checkpoint_dir=checkpoint_dir_ae
    )
    
    # Step 2: Train transformer with frozen encoder
    print("\n=== Training Transformer with encoded features ===")
    # Load best autoencoder checkpoint
    best_ae_checkpoint = torch.load(os.path.join(checkpoint_dir_ae, 'best_model.pth'))
    encoder.load_state_dict(best_ae_checkpoint['encoder_state_dict'])
    
    # Update encoder in full model
    full_model.encoder = encoder
    
    # Freeze encoder weights
    for param in full_model.encoder.parameters():
        param.requires_grad = False
    
    # Train transformer
    full_model = train_transformer(
        full_model, train_loader, val_loader, word_to_idx, idx_to_word,
        criterion, optimizer, num_epochs=NUM_EPOCHS_TRANSFORMER, 
        checkpoint_dir=checkpoint_dir_transformer
    )
    
    # Step 3: Evaluate on test set
    print("\n=== Evaluating on Test Set ===")
    # Load best transformer checkpoint
    best_transformer_checkpoint = torch.load(os.path.join(checkpoint_dir_transformer, 'best_model.pth'))
    full_model.load_state_dict(best_transformer_checkpoint['model_state_dict'])
    
    metrics = evaluate_model(full_model, test_loader, word_to_idx, idx_to_word, results_dir=results_dir)
    
    # Save final model
    torch.save(full_model.state_dict(), os.path.join(results_dir, 'final_model.pth'))
    
    print("\n=== Training and Evaluation Complete ===")
    print(f"Final BLEU-1: {metrics['bleu1']:.4f}")
    print(f"Final BLEU-4: {metrics['bleu4']:.4f}")
    print(f"Final METEOR: {metrics['meteor']:.4f}")
    print(f"Final ROUGE-L F1: {metrics['rougeL_f1']:.4f}")

In [24]:
!pwd

/scratch/joshi.tanm/CSYE-7374/final-project/experiments/stable diffusion


In [None]:
if __name__ == "__main__":
    csv_path = "final_dataset.csv"  # Replace with your actual CSV path
    main(csv_path)

Missing values in DataFrame:
Unnamed: 0        0
final_img_path    0
captions          0
dtype: int64
Max caption length: 125
Training samples: 4519, Validation samples: 969, Test samples: 969

=== Training Autoencoder-GAN ===
Epoch [1/10], Batch [100/142], Recon Loss: 0.5257, Gen Loss: 6.2657, Disc Loss: 0.0027
Epoch [1/10], Time: 983.48s, Train Recon Loss: 0.6925, Train Gen Loss: 4.6510, Train Disc Loss: 0.1141, Val Recon Loss: 0.4859
Saved best model with validation loss: 0.4859
Epoch [2/10], Batch [100/142], Recon Loss: 0.4755, Gen Loss: 7.9654, Disc Loss: 0.0002
Epoch [2/10], Time: 678.72s, Train Recon Loss: 0.4632, Train Gen Loss: 7.3685, Train Disc Loss: 0.0006, Val Recon Loss: 0.4485
Saved best model with validation loss: 0.4485


# Preliminary Results
Based on the partial training output visible in the notebook, the autoencoder GAN training was progressing well, with reconstruction loss decreasing from 0.6925 to 0.4485 over just two epochs. This suggests that the model was successfully learning to compress and reconstruct the X-ray images.

## Challenges and Future Work
This approach is computationally demanding due to:

1. The GAN-based autoencoder training
2. The subsequent transformer training on the encoded features

Future optimizations could include:

- Using a more efficient encoder architecture
- Implementing progressive training strategies
- Exploring different latent space dimensions
- Incorporating more medical domain knowledge

Despite these challenges, the stable diffusion-inspired approach offers a promising direction for medical image captioning by potentially capturing more nuanced visual features in the latent space compared to standard CNN encoders.

# Conclusion
This experiment combines ideas from stable diffusion models with transformer-based text generation, creating a powerful pipeline for medical image-to-text report generation. While computationally intensive, the approach has the potential to create better latent representations of medical images, leading to more accurate and detailed report generation.