In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import T5Tokenizer, T5ForConditionalGeneration, ViTModel
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [None]:
# Data loading and preparation
df_projections = pd.read_csv('content/Projection.csv')
df_reports = pd.read_csv('content/report.csv')
df = df_projections.merge(df_reports, on="uid")

In [None]:
def clean_caption(row):
    findings = row["findings"]
    impression = row["impression"]
    if pd.notna(findings) and pd.notna(impression):
        return f"{findings.strip()} {impression.strip()}"
    elif pd.notna(findings):
        return findings.strip()
    elif pd.notna(impression):
        return impression.strip()
    else:
        return None

df["caption"] = df.apply(clean_caption, axis=1)
df = df.dropna(subset=["caption"])


In [None]:
# Image Encoder
class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
        for param in self.vit.parameters():
            param.requires_grad = True

    def forward(self, images):
        outputs = self.vit(pixel_values=images)
        return outputs.last_hidden_state  # shape: (B, N, D)

In [None]:
# QFormer blocks
class QFormerBlock(nn.Module):
    def __init__(self, dim, num_heads):
        super(QFormerBlock, self).__init__()
        self.self_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim),
        )
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, queries, image_features):
        # Self-Attention
        q = self.norm1(queries)
        q, _ = self.self_attn(q, q, q)

        # Cross-Attention with image
        q = self.norm2(q)
        q, _ = self.cross_attn(q, image_features, image_features)

        # Feedforward
        q = self.norm3(q)
        return self.feed_forward(q)

In [None]:
class QFormer(nn.Module):
    def __init__(self, num_queries=32, dim=768, num_blocks=6, num_heads=12):
        super(QFormer, self).__init__()
        self.learned_queries = nn.Parameter(torch.randn(1, num_queries, dim))
        self.blocks = nn.ModuleList([QFormerBlock(dim, num_heads) for _ in range(num_blocks)])

    def forward(self, image_features):
        B = image_features.size(0)
        queries = self.learned_queries.expand(B, -1, -1)  # (B, num_queries, dim)
        for block in self.blocks:
            queries = block(queries, image_features)
        return queries  # Output features for captioning or matching

In [None]:
# Fixed T5 Integration
class MedicalCaptioningModel(nn.Module):
    def __init__(self, t5_model_name="t5-base"):
        super().__init__()
        # Image encoder
        self.image_encoder = ImageEncoder()
        
        # QFormer for bridging vision and language
        self.qformer = QFormer(num_queries=32, dim=768)
        
        # T5 model
        self.t5 = T5ForConditionalGeneration.from_pretrained(t5_model_name)
        
        # Projection from vision to T5 hidden dimension
        self.vision_proj = nn.Linear(768, self.t5.config.d_model)
        
        # Freezing T5 encoder as we'll use our visual encoder instead
        for param in self.t5.encoder.parameters():
            param.requires_grad = False
            
    def forward(self, pixel_values, labels=None, decoder_attention_mask=None):
        batch_size = pixel_values.size(0)
        
        # Extract image features
        image_features = self.image_encoder(pixel_values)
        
        # Process through QFormer
        qformer_output = self.qformer(image_features)
        
        # Project to T5 dimension
        encoder_hidden_states = self.vision_proj(qformer_output)
        
        # Prepare hidden states in the format T5 expects
        # Create a BaseModelOutput object for encoder_outputs
        from transformers.modeling_outputs import BaseModelOutput
        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_hidden_states,
            hidden_states=None,
            attentions=None
        )
        
        # Prepare decoder input ids (right-shifted labels)
        if labels is not None:
            decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
        else:
            # For inference, start with the pad token
            decoder_input_ids = torch.full(
                (batch_size, 1), 
                self.t5.config.pad_token_id, 
                dtype=torch.long, 
                device=encoder_hidden_states.device
            )
            
        # Forward through T5 decoder - FIXED: removed encoder_attention_mask
        outputs = self.t5(
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            labels=labels,
            return_dict=True
        )
            
        return outputs
        
    def prepare_decoder_input_ids_from_labels(self, labels):
        """Shift labels to create decoder input ids"""
        decoder_input_ids = labels.clone()
        
        # Replace -100 with pad token id
        decoder_input_ids[decoder_input_ids == -100] = self.t5.config.pad_token_id
        
        # Shift right
        decoder_input_ids = torch.cat(
            [
                torch.full((decoder_input_ids.shape[0], 1), self.t5.config.decoder_start_token_id, 
                          device=decoder_input_ids.device),
                decoder_input_ids[:, :-1]
            ], 
            dim=-1
        )
        
        return decoder_input_ids
        
    def generate(self, pixel_values, max_length=128, num_beams=4):
        """Generate captions for images"""
        batch_size = pixel_values.size(0)
        
        # Extract image features
        image_features = self.image_encoder(pixel_values)
        
        # Process through QFormer
        qformer_output = self.qformer(image_features)
        
        # Project to T5 dimension
        encoder_hidden_states = self.vision_proj(qformer_output)
        
        # Create a BaseModelOutput object for encoder_outputs
        from transformers.modeling_outputs import BaseModelOutput
        encoder_outputs = BaseModelOutput(
            last_hidden_state=encoder_hidden_states,
            hidden_states=None,
            attentions=None
        )
        
        # Generate using T5's generation method
        output_ids = self.t5.generate(
            encoder_outputs=encoder_outputs,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True
        )
        
        return output_ids


In [None]:
# Dataset class for T5
class ChestXrayDataset(Dataset):
    def __init__(self, dataframe, image_root_dir, tokenizer, transform=None, image_size=224, max_length=128):
        self.df = dataframe
        self.image_root_dir = image_root_dir
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ]) if transform is None else transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image_path = os.path.join(self.image_root_dir, row['filename'])
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        caption = row['caption']
        
        # Tokenize caption for T5
        encoding = self.tokenizer(
            caption, 
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Extract and prepare labels for training
        labels = encoding.input_ids.clone().squeeze(0)
        
        return {
            "pixel_values": image,
            "labels": labels,
            "attention_mask": encoding.attention_mask.squeeze(0)
        }


In [None]:
# Main training and evaluation code
image_root = "C:/Users/Administrator//Desktop/125150051/Data/images/image_normalized/"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Split dataset
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

# Tokenizer and transform
tokenizer = T5Tokenizer.from_pretrained("t5-base")
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [None]:
# Dataset and Dataloader
train_dataset = ChestXrayDataset(train_df, image_root_dir=image_root, tokenizer=tokenizer, transform=image_transform)
test_dataset = ChestXrayDataset(test_df, image_root_dir=image_root, tokenizer=tokenizer, transform=image_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
# Initialize model
model = MedicalCaptioningModel(t5_model_name="t5-base").to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training Loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch in progress_bar:
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        
        # Forward pass
        outputs = model(
            pixel_values=pixel_values,
            labels=labels,
            decoder_attention_mask=attention_mask
        )
        
        loss = outputs.loss
        
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})
        
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Training Loss: {avg_loss:.4f}")

We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Epoch 1 - Training Loss: 2.2599
Epoch 2 - Training Loss: 1.2731
Epoch 3 - Training Loss: 1.0077
Epoch 4 - Training Loss: 0.8016
Epoch 5 - Training Loss: 0.6172
Epoch 6 - Training Loss: 0.4661
Epoch 7 - Training Loss: 0.3597
Epoch 8 - Training Loss: 0.2891
Epoch 9 - Training Loss: 0.2494
Epoch 10 - Training Loss: 0.2264
Epoch 11 - Training Loss: 0.2119
Epoch 12 - Training Loss: 0.2006
Epoch 13 - Training Loss: 0.1915
Epoch 14 - Training Loss: 0.1869
Epoch 15 - Training Loss: 0.1786
Epoch 16 - Training Loss: 0.1752
Epoch 17 - Training Loss: 0.1713
Epoch 18 - Training Loss: 0.1636
Epoch 19 - Training Loss: 0.1597
Epoch 20 - Training Loss: 0.1586


In [None]:
# Function to generate caption
def generate_caption(model, tokenizer, image):
    model.eval()
    with torch.no_grad():
        # Prepare image
        if not isinstance(image, torch.Tensor):
            image = image_transform(image).unsqueeze(0).to(device)
        else:
            image = image.unsqueeze(0).to(device) if image.dim() == 3 else image.to(device)
            
        # Generate tokens
        output_ids = model.generate(image)
        
        # Decode to text
        caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        return caption

# Save model function
def save_model(model, tokenizer, save_path):
    os.makedirs(save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(save_path, "model_weights.pth"))
    tokenizer.save_pretrained(save_path)

# Load model function
def load_model(model_path, tokenizer_path, t5_model_name="t5-base"):
    model = MedicalCaptioningModel(t5_model_name).to(device)
    model.load_state_dict(torch.load(os.path.join(model_path, "model_weights.pth"), map_location=device))
    tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
    model.eval()
    return model, tokenizer

# Inference on Test Set
print("\n--- Predictions on Test Set ---")
for i in range(5):
    image_path = os.path.join(image_root, test_df.iloc[i]['filename'])
    image = Image.open(image_path).convert("RGB")
    input_tensor = image_transform(image).to(device)

    prediction = generate_caption(model, tokenizer, input_tensor)
    print(f"Image: {test_df.iloc[i]['filename']}")
    print(f"GT     : {test_df.iloc[i]['caption']}")
    print(f"Predicted: {prediction}")
    print("-" * 50)

# Save the trained model
save_path = "t5_medical_image_caption_model"
save_model(model, tokenizer, save_path)



Sample Predictions after Epoch 20
Image: 632_IM-2213-1002.dcm.png
GT     : Cardiac and mediastinal silhouette are unremarkable. Lungs are clear. No focal consolidation, pneumothorax, or pleural effusion identified. XXXX and soft tissue are unremarkable. No acute cardiopulmonary abnormality.
Predicted: the heart size is normal. the mediastinal contour is within normal limits. the lungs are free of any focal infiltrates. there are no nodules or masses. no visible pneumothorax. no visible pleural fluid
--------------------------------------------------
Image: 1509_IM-0331-2001.dcm.png
GT     : Lungs are clear bilaterally.There is no focal consolidation, pleural effusion, or pneumothoraces. Cardiomediastinal silhouette is within normal limits. XXXX are unremarkable. No acute cardiopulmonary abnormality.
Predicted: the heart size and mediastinal silhouette are within normal limits for contour. the lungs are clear. no pneumothorax or pleural effusions. the xxxx are intact. stable left basil

In [None]:
# BLEU Evaluation
def evaluate_bleu(model, tokenizer, test_df, image_root_dir, transform, max_samples=None): 
    print("✅ Using BLEU evaluation function for T5 model")
    
    bleu1_scores, bleu2_scores, bleu3_scores, bleu4_scores = [], [], [], []
    model.eval()
    smooth = SmoothingFunction().method1

    total_samples = len(test_df) if max_samples is None else min(max_samples, len(test_df))

    for i in tqdm(range(total_samples)):
        row = test_df.iloc[i]
        image_path = os.path.join(image_root_dir, row['filename'])
        image = Image.open(image_path).convert("RGB")
        image = transform(image).to(device)

        generated = generate_caption(model, tokenizer, image)
        reference = row['caption']

        ref_tokens = reference.lower().split()
        gen_tokens = generated.lower().split()

        bleu1 = sentence_bleu([ref_tokens], gen_tokens, weights=(1, 0, 0, 0), smoothing_function=smooth)
        bleu2 = sentence_bleu([ref_tokens], gen_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smooth)
        bleu3 = sentence_bleu([ref_tokens], gen_tokens, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smooth)
        bleu4 = sentence_bleu([ref_tokens], gen_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)

        bleu1_scores.append(bleu1)
        bleu2_scores.append(bleu2)
        bleu3_scores.append(bleu3)
        bleu4_scores.append(bleu4)

    print(f"\nAverage BLEU-1: {sum(bleu1_scores)/len(bleu1_scores):.4f}")
    print(f"Average BLEU-2: {sum(bleu2_scores)/len(bleu2_scores):.4f}")
    print(f"Average BLEU-3: {sum(bleu3_scores)/len(bleu3_scores):.4f}")
    print(f"Average BLEU-4: {sum(bleu4_scores)/len(bleu4_scores):.4f}")

# Run BLEU evaluation
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

evaluate_bleu(model, tokenizer, test_df, image_root, transform=test_transform)

✅ Using updated evaluate_bleu() function


100%|█████████████████████████████████████████████████████████████████████████████| 1486/1486 [34:41<00:00,  1.40s/it]


Average BLEU-1: 0.5268
Average BLEU-2: 0.4236
Average BLEU-3: 0.3472
Average BLEU-4: 0.2143



