# Chest X-Ray (CSR) Report Generation

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from torchvision.models import resnet34, ResNet34_Weights
import pickle  # For saving the vocabulary

In [None]:
# -------------------------------
# Vocabulary Helper Class
# -------------------------------
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<start>": 1, "<end>": 2, "<unk>": 3}
        self.idx2word = {0: "<pad>", 1: "<start>", 2: "<end>", 3: "<unk>"}
        self.count = 4

    def add_sentence(self, sentence):
        for word in sentence.split():
            if word not in self.word2idx:
                self.word2idx[word] = self.count
                self.idx2word[self.count] = word
                self.count += 1

    def numericalize(self, sentence, max_len=50):
        tokens = sentence.split()
        tokens = ["<start>"] + tokens + ["<end>"]
        token_ids = [self.word2idx.get(token, self.word2idx["<unk>"]) for token in tokens]
        if len(token_ids) < max_len:
            token_ids += [self.word2idx["<pad>"]] * (max_len - len(token_ids))
        else:
            token_ids = token_ids[:max_len]
        return token_ids

In [None]:
# -------------------------------
# Dataset Class for CXR Report Generation
# -------------------------------
class CXRReportDataset(Dataset):
    def __init__(self, csv_file, vocab, image_root, transform=None, max_len=50, 
                 report_column='text', image_column='path'):
        """
        Args:
            csv_file (str): Path to CSV file containing report and image file path columns.
            vocab (Vocabulary): Vocabulary object to process reports.
            image_root (str): Root directory for image files.
            transform (callable, optional): Transformations for the images.
            max_len (int): Maximum token length for reports.
            report_column (str): The name of the column containing the report text.
            image_column (str): The name of the column containing the image file path.
        """
        self.data = pd.read_csv(csv_file)
        self.vocab = vocab
        self.image_root = image_root
        self.transform = transform
        self.max_len = max_len
        self.report_column = report_column
        self.image_column = image_column

        # Define the prefix to remove from CSV paths if present.
        self.csv_prefix = "../input/curated-cxr-report-generation-dataset/mimic_dset/re_512_3ch/"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path_raw = row[self.image_column]
        
        # If the image path starts with the known CSV prefix, remove it and join with image_root
        if isinstance(image_path_raw, str) and image_path_raw.startswith(self.csv_prefix):
            relative_path = image_path_raw.replace(self.csv_prefix, "")
            image_path = os.path.join(self.image_root, relative_path)
        else:
            image_path = image_path_raw

        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        report = row[self.report_column]
        report_ids = self.vocab.numericalize(report, self.max_len)
        report_ids = torch.tensor(report_ids, dtype=torch.long)
        return image, report_ids

In [None]:
# -------------------------------
# Encoder-Decoder Model for Report Generation
# -------------------------------
class CXRReportGenerator(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(CXRReportGenerator, self).__init__()
        # Encoder: Pretrained ResNet-50 (using resnet34 here for this code version as per earlier instructions)
        resnet = resnet34(weights=ResNet34_Weights.DEFAULT)
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])  # Output: (batch, 512, 1, 1)
        self.fc = nn.Linear(512, hidden_size)  # Map image features to hidden size

        # Decoder
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc_out = nn.Linear(hidden_size, vocab_size)
        self.hidden_size = hidden_size

    def forward(self, images, captions):
        # Encode images
        features = self.encoder(images)             # (batch, 512, 1, 1)
        features = features.view(features.size(0), -1)  # (batch, 512)
        features = self.fc(features)                # (batch, hidden_size)
        
        # Get embeddings for captions
        embeddings = self.embed(captions)           # (batch, seq_len, embed_size)
        
        # Use image features as the initial hidden state for the LSTM
        h0 = features.unsqueeze(0)                  # (1, batch, hidden_size)
        c0 = torch.zeros_like(h0)                   # (1, batch, hidden_size)
        
        outputs, _ = self.lstm(embeddings, (h0, c0))  # (batch, seq_len, hidden_size)
        outputs = self.fc_out(outputs)              # (batch, seq_len, vocab_size)
        return outputs

In [None]:
# -------------------------------
# Training Function with Accuracy Calculation
# -------------------------------
def train_model(model, dataloader, criterion, optimizer, device, epochs=10):
    model.train()
    loss_history = []
    accuracy_history = []
    pad_idx = 0  # <pad> token index
    for epoch in range(epochs):
        running_loss = 0.0
        running_correct = 0
        running_tokens = 0
        for images, captions in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = images.to(device)
            captions = captions.to(device)
            
            optimizer.zero_grad()
            # Input captions except the last token; target is captions shifted by one.
            outputs = model(images, captions[:, :-1])
            loss = criterion(outputs.reshape(-1, outputs.size(2)), captions[:, 1:].reshape(-1))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            
            # Calculate accuracy (token-level) for this batch
            target = captions[:, 1:]
            preds = outputs.argmax(dim=2)
            mask = target != pad_idx
            correct = (preds == target) & mask
            running_correct += correct.sum().item()
            running_tokens += mask.sum().item()
        
        epoch_loss = running_loss / len(dataloader)
        epoch_acc = running_correct / running_tokens if running_tokens > 0 else 0
        loss_history.append(epoch_loss)
        accuracy_history.append(epoch_acc)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc*100:.2f}%")
    return loss_history, accuracy_history

In [None]:
# -------------------------------
# Visualization Function for Loss and Accuracy
# -------------------------------
def plot_metrics(loss_history, accuracy_history):
    epochs = range(1, len(loss_history) + 1)
    
    plt.figure(figsize=(12, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss_history, marker='o', label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend()
    
    # Accuracy plot
    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracy_history, marker='o', label='Accuracy', color='green')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
# -------------------------------
# Main Function
# -------------------------------
def main():
    # Configurations and file paths
    csv_train = '/kaggle/input/curated-cxr-report-generation-dataset/NLP_aug_datasets/df_train_aug.csv'
    image_root = '/kaggle/input/curated-cxr-report-generation-dataset/mimic_dset/re_512_3ch'
    num_epochs = 30
    batch_size = 32
    learning_rate = 1e-3
    max_len = 50                         # Maximum length of report (in tokens)
    embed_size = 256
    hidden_size = 256
    num_layers = 1

    # Image transformations for CXR images
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    # Load the training CSV and print available columns for debugging
    df_train = pd.read_csv(csv_train)
    print("Columns in training CSV:", df_train.columns.tolist())

    # In this CSV, the report text is in the "text" column and the image path in the "path" column.
    report_column = 'text'
    image_column = 'path'

    # Build vocabulary from the training CSV using the "text" column
    vocab = Vocabulary()
    for report in df_train[report_column]:
        vocab.add_sentence(report)
    vocab_size = vocab.count
    print(f"Vocabulary size: {vocab_size}")

    # Save the vocabulary to a pickle file for later use during inference
    with open('vocab.pkl', 'wb') as f:
        pickle.dump(vocab, f)
    print("Vocabulary saved to vocab.pkl")

    # Create dataset and dataloader
    train_dataset = CXRReportDataset(csv_train, vocab, image_root, transform, max_len,
                                     report_column=report_column, image_column=image_column)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Device configuration: use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Initialize the model
    model = CXRReportGenerator(embed_size, hidden_size, vocab_size, num_layers).to(device)
    
    # Use DataParallel if multiple GPUs are available (e.g., dual T4 on Kaggle)
    if torch.cuda.device_count() > 1:
        print(f"Multiple GPUs detected: {torch.cuda.device_count()}. Using DataParallel.")
        model = nn.DataParallel(model)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model and obtain loss and accuracy histories
    loss_history, accuracy_history = train_model(model, train_loader, criterion, optimizer, device, epochs=num_epochs)

    # Visualize the training loss and accuracy
    plot_metrics(loss_history, accuracy_history)

    # Save the trained model
    torch.save(model.state_dict(), 'cxr_report_generator.pth')
    print("Training complete. Model saved as cxr_report_generator.pth")

In [None]:
if __name__ == '__main__':
    main()