In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install nltk rouge-score
import nltk
nltk.download('punkt')
nltk.download('wordnet')  # for METEOR

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score

import nltk

# Force download again (even if already there)
nltk.download('punkt', force=True)
nltk.download('wordnet', force=True)
nltk.download('omw-1.4', force=True)

# ✅ OPTIONAL: Check paths
print(nltk.data.path)


!pip install git+https://github.com/openai/CLIP.git

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pandas as pd
import clip
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import torch.nn as nn

def generate_caption(row):
    if row['abnormal'] == 0:
        return "Healthy knee"

    findings = []
    if row['acl'] == 1:
        findings.append("ACL tear")
    if row['meniscus'] == 1:
        findings.append("Meniscus tear")
    if findings:
        return " and a ".join(findings) + "."
    else:
        return "Unspecified abnormality."


def load_exam_mri(path):
    scan = np.load(path)  # shape: (slices, H, W)

    # Choose middle 3 slices and stack to simulate RGB
    mid = scan.shape[0] // 2
    slices = scan[mid - 1: mid + 2]

    # Normalize to [0, 255] and convert to uint8 for PIL
    slices = np.stack([((s - s.min()) / (s.max() - s.min()) * 255).astype(np.uint8) for s in slices], axis=-1)

    # Convert to PIL and preprocess
    img = Image.fromarray(slices)
    return clip_preprocess(img)


class ClipCaptionModel(nn.Module):
    def __init__(self, clip_dim=512, prefix_len=10):
        super().__init__()
        self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
        self.prefix_len = prefix_len
        self.clip_project = nn.Linear(clip_dim, self.gpt.config.n_embd * prefix_len)

    def forward(self, image_embedding, captions, attention_mask):
        batch_size = captions.shape[0]

        # 💡 Cast to float32 to match the Linear layer's weights
        image_embedding = image_embedding.float()

        prefix_embedding = self.clip_project(image_embedding).view(batch_size, self.prefix_len, -1)
        caption_embeddings = self.gpt.transformer.wte(captions)

        embeddings = torch.cat((prefix_embedding, caption_embeddings), dim=1)

        extended_attention = torch.cat((
            torch.ones((batch_size, self.prefix_len), device=attention_mask.device),
            attention_mask
        ), dim=1)

        labels = torch.cat((
            torch.full((batch_size, self.prefix_len), -100, device=captions.device),
            captions
        ), dim=1)

        outputs = self.gpt(inputs_embeds=embeddings, attention_mask=extended_attention, labels=labels)
        return outputs



def encode_image(tensor_image):
    tensor_image = tensor_image.unsqueeze(0).to(device)  # add batch dim
    with torch.no_grad():
        image_embedding = model.encode_image(tensor_image)
    return image_embedding  # shape: (1, 512)


from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class MRICaptionDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform, tokenizer, max_length=50, num_slices_to_use=5):
        self.data = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.num_slices_to_use = num_slices_to_use
        self.clip_dim = 512  # CLIP's embedding dimension

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        exam_id = str(row['exam']).zfill(4)
        caption = row['caption']
        img_path = os.path.join(self.image_dir, f"{exam_id}.npy")

        # Load all slices from the scan
        scan = np.load(img_path)  # shape: (slices, H, W)
        num_slices = scan.shape[0]
        
        # Select evenly spaced slices throughout the volume
        if num_slices <= self.num_slices_to_use:
            selected_slices = range(num_slices)
        else:
            selected_slices = np.linspace(0, num_slices-1, self.num_slices_to_use, dtype=int)
            
        # Take the middle slice for dataset output (representative slice)
        mid_slice = num_slices // 2
        slice_img = scan[mid_slice]
        slice_img = ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8)
        slice_rgb = np.stack([slice_img, slice_img, slice_img], axis=-1)
        pil_img = Image.fromarray(slice_rgb)
        img_tensor = self.transform(pil_img)

        # Tokenize caption with explicit attention mask
        tokens = self.tokenizer(
            caption, 
            padding="max_length", 
            truncation=True,
            max_length=self.max_length, 
            return_tensors="pt"
        )
        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)

        # Return representative slice tensor and metadata
        return img_tensor, input_ids, attention_mask, exam_id, selected_slices.tolist() if hasattr(selected_slices, 'tolist') else list(selected_slices)


def collate_fn(batch):
    images, input_ids, attention_masks, exam_ids, selected_slices = zip(*batch)
    
    # Create proper tensors
    images_tensor = torch.stack(images)
    input_ids_tensor = torch.stack(input_ids)
    attention_masks_tensor = torch.stack(attention_masks)
    
    return images_tensor, input_ids_tensor, attention_masks_tensor, exam_ids, selected_slices

# Memory-efficient slice encoding function
def encode_selected_slices(clip_model, exam_id, selected_slices, image_dir, transform, device):
    """Encodes selected slices of an MRI exam and aggregates them"""
    img_path = os.path.join(image_dir, f"{exam_id}.npy")
    scan = np.load(img_path)
    
    # Process each selected slice and get embedding
    slice_embeddings = []
    for slice_idx in selected_slices:
        # Convert slice to 3-channel image
        slice_img = scan[slice_idx]
        slice_img = ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8)
        slice_rgb = np.stack([slice_img, slice_img, slice_img], axis=-1)
        pil_img = Image.fromarray(slice_rgb)
        img_tensor = transform(pil_img).unsqueeze(0).to(device)
        
        # Get embedding
        with torch.no_grad():
            embedding = clip_model.encode_image(img_tensor)
        slice_embeddings.append(embedding)
    
    # Stack embeddings and take mean across slices
    all_embeddings = torch.cat(slice_embeddings, dim=0)
    mean_embedding = torch.mean(all_embeddings, dim=0, keepdim=True)
    
    return mean_embedding

# Improved training function with proper error handling
def train_epoch(caption_model, clip_model, dataloader, optimizer, device, image_dir, transform):
    caption_model.train()
    total_loss = 0
    batch_count = 0
    
    for batch in tqdm(dataloader, desc="Training"):
        try:
            images, input_ids, attention_mask, exam_ids, selected_slices = batch
            input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
            
            # Process selected slices for each exam in batch
            batch_embeddings = []
            for i in range(len(exam_ids)):
                exam_embedding = encode_selected_slices(
                    clip_model,
                    exam_ids[i],
                    selected_slices[i],
                    image_dir,
                    transform,
                    device
                )
                batch_embeddings.append(exam_embedding)
            
            # Stack embeddings into batch
            image_embeddings = torch.cat(batch_embeddings, dim=0)
            
            # Forward pass through caption model
            outputs = caption_model(image_embeddings, input_ids, attention_mask)
            loss = outputs.loss
            
            # Perform backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            batch_count += 1
            
        except Exception as e:
            print(f"Error in batch: {e}")
            continue
    
    return total_loss / max(1, batch_count)

def evaluate(caption_model, clip_model, dataloader, device, tokenizer, image_dir, transform):
    caption_model.eval()
    total_loss = 0
    
    # Tracking predicted and true captions
    true_caption_counts = {}  # Dictionary to count true captions
    pred_caption_counts = {}  # Dictionary to count predicted captions
    caption_pairs = []  # List to store (true, pred) caption pairs for analysis
    
    # Metrics
    all_bleu_scores = []
    all_rouge_scores = []
    
    smooth = SmoothingFunction().method4
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            try:
                images, input_ids, attention_mask, exam_ids, selected_slices = batch
                input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
                
                # Process selected slices for each exam in batch
                batch_embeddings = []
                for i in range(len(exam_ids)):
                    exam_embedding = encode_selected_slices(
                        clip_model,
                        exam_ids[i],
                        selected_slices[i],
                        image_dir,
                        transform,
                        device
                    )
                    batch_embeddings.append(exam_embedding)
                
                # Stack embeddings into batch
                image_embeddings = torch.cat(batch_embeddings, dim=0)
                
                # Forward pass for loss calculation
                outputs = caption_model(image_embeddings, input_ids, attention_mask)
                loss = outputs.loss
                total_loss += loss.item()
                
                # Generate captions for each image in batch
                batch_size = input_ids.size(0)
                for i in range(batch_size):
                    # Get ground truth caption
                    true_caption = tokenizer.decode(input_ids[i], skip_special_tokens=True)
                    
                    # Update true caption counts
                    if true_caption in true_caption_counts:
                        true_caption_counts[true_caption] += 1
                    else:
                        true_caption_counts[true_caption] = 1
                    
                    # Generate predicted caption
                    prefix_embed = caption_model.clip_project(image_embeddings[i].float().unsqueeze(0)) \
                        .view(1, caption_model.prefix_len, -1)
                    
                    # Generate with attention mask for beam search
                    attention_prefix = torch.ones(1, caption_model.prefix_len, device=device)
                    
                    generated = caption_model.gpt.generate(
                        inputs_embeds=prefix_embed,
                        max_length=50,
                        num_beams=5,
                        early_stopping=True,
                        pad_token_id=tokenizer.eos_token_id,
                        attention_mask=attention_prefix
                    )
                    pred_caption = tokenizer.decode(generated[0], skip_special_tokens=True)
                    
                    # Update predicted caption counts
                    if pred_caption in pred_caption_counts:
                        pred_caption_counts[pred_caption] += 1
                    else:
                        pred_caption_counts[pred_caption] = 1
                    
                    # Store the pair for later analysis
                    caption_pairs.append((true_caption, pred_caption))
                    
                    # Calculate BLEU score
                    reference = [nltk.word_tokenize(true_caption.lower())]
                    candidate = nltk.word_tokenize(pred_caption.lower())
                    if len(candidate) > 0:  # Check if candidate has words
                        bleu = sentence_bleu(reference, candidate, smoothing_function=smooth)
                        all_bleu_scores.append(bleu)
                    
                    # Calculate ROUGE score
                    rouge = scorer.score(true_caption, pred_caption)['rougeL'].fmeasure
                    all_rouge_scores.append(rouge)
                    
            except Exception as e:
                print(f"Error in validation batch: {e}")
                continue
    
    # Calculate standard metrics
    avg_loss = total_loss / max(1, len(dataloader))
    avg_bleu = sum(all_bleu_scores) / max(1, len(all_bleu_scores))
    avg_rouge = sum(all_rouge_scores) / max(1, len(all_rouge_scores))
    
    # Analyze caption distribution
    total_samples = len(caption_pairs)
    unique_pred_captions = len(pred_caption_counts)
    unique_true_captions = len(true_caption_counts)
    
    # Create a mapping of how often each predicted caption matched with each true caption
    pred_to_true_map = {}
    for true_cap, pred_cap in caption_pairs:
        if pred_cap not in pred_to_true_map:
            pred_to_true_map[pred_cap] = {}
        
        if true_cap not in pred_to_true_map[pred_cap]:
            pred_to_true_map[pred_cap][true_cap] = 1
        else:
            pred_to_true_map[pred_cap][true_cap] += 1
    
    # Sort captions by frequency
    sorted_pred_captions = sorted(pred_caption_counts.items(), key=lambda x: x[1], reverse=True)
    sorted_true_captions = sorted(true_caption_counts.items(), key=lambda x: x[1], reverse=True)
    
    # Report on caption distribution
    caption_analysis = {
        "total_samples": total_samples,
        "unique_predicted_captions": unique_pred_captions,
        "unique_true_captions": unique_true_captions,
        "top_predicted_captions": sorted_pred_captions[:10],  # Top 10 most frequent predictions
        "top_true_captions": sorted_true_captions[:10],       # Top 10 most frequent ground truths
        "prediction_to_true_map": pred_to_true_map,           # Mapping of predictions to ground truths
        "repetition_rate": 1 - (unique_pred_captions / total_samples)  # Higher means more repetition
    }
    for key, value in caption_analysis.items():
        print(f"{key}: {value}")
    return avg_loss, avg_bleu, avg_rouge

def init_model(device):
    clip_model, preprocess = clip.load("ViT-B/32", device=device)
    
    # Freeze CLIP parameters 
    for param in clip_model.parameters():
        param.requires_grad = False
    
    # Configure tokenizer with proper padding
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    # Create caption model
    caption_model = ClipCaptionModel(clip_dim=512, prefix_len=10).to(device)
    
    # Use proper optimizer with weight decay
    optimizer = torch.optim.AdamW(
        caption_model.parameters(),
        lr=2e-5,
        weight_decay=0.01
    )
    
    return clip_model, preprocess, tokenizer, caption_model, optimizer

base_dir = "/content/drive/MyDrive/biodata Project/MRNet-v1.0"
plane = "sagittal"  # can be 'axial', 'coronal', 'sagittal'
label_csv = os.path.join(base_dir, "train-abnormal.csv")
image_dir = os.path.join(base_dir, "train", plane)


abnormal_df = pd.read_csv(os.path.join(base_dir, "train-abnormal.csv"))
acl_df = pd.read_csv(os.path.join(base_dir, "train-acl.csv"))
meniscus_df = pd.read_csv(os.path.join(base_dir, "train-meniscus.csv"))

abnormal_df.columns = ['exam', 'abnormal']
acl_df.columns = ['exam', 'acl']
meniscus_df.columns = ['exam', 'meniscus']

df = abnormal_df.merge(acl_df, on='exam').merge(meniscus_df, on='exam')


df['caption'] = df.apply(generate_caption, axis=1)
df

df["caption"].value_counts()

acl_meniscus_mask = (df['acl'] == 1) & (df['meniscus'] == 1)
acl_only_mask     = (df['acl'] == 1) & (df['meniscus'] == 0)
meniscus_only_mask = (df['acl'] == 0) & (df['meniscus'] == 1)
healthy_mask      = (df['abnormal'] == 0)
unspecified_mask  = (df['abnormal'] == 1) & (df['acl'] == 0) & (df['meniscus'] == 0)

# Sample 83 from each group
df_acl_meniscus = df[acl_meniscus_mask].sample(n=83, random_state=42)
df_acl_only     = df[acl_only_mask].sample(n=83, random_state=42)
df_meniscus     = df[meniscus_only_mask].sample(n=83, random_state=42)
df_healthy      = df[healthy_mask].sample(n=83, random_state=42)
df_unspecified  = df[unspecified_mask].sample(n=83, random_state=42)

# Concatenate them
df_balanced = pd.concat([
    df_acl_meniscus,
    df_acl_only,
    df_meniscus,
    df_healthy,
    df_unspecified
], ignore_index=True).sample(frac=1, random_state=42)

df_balanced["caption"].value_counts()

### Data Preprocessing

# Use CLIP preprocessing (from OpenAI or OpenCLIP)
clip_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])



import nltk
nltk.download('punkt_tab')

# Main training loop
def train_model(num_epochs=25, batch_size=8):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Initialize models and data
    clip_model, preprocess, tokenizer, caption_model, optimizer = init_model(device)
    
    # Create datasets
    train_df, val_df = train_test_split(df_balanced, test_size=0.1, random_state=42)
    
    # Select a balanced number of slices from each MRI
    train_dataset = MRICaptionDataset(train_df, image_dir, preprocess, tokenizer, num_slices_to_use=10)
    val_dataset = MRICaptionDataset(val_df, image_dir, preprocess, tokenizer, num_slices_to_use=10)
    
    # Create data loaders with custom collate
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)
    
    # Initialize best metrics for model checkpointing
    best_rouge = 0
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"\n🌟 Epoch {epoch + 1}/{num_epochs}")
        
        # Train and evaluate
        avg_train_loss = train_epoch(caption_model, clip_model, train_loader, optimizer, device, image_dir, preprocess)
        avg_val_loss, avg_bleu, avg_rouge = evaluate(
            caption_model, clip_model, val_loader, device, tokenizer, image_dir, preprocess
        )
        
        # Print metrics
        print(f"Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"BLEU: {avg_bleu:.4f} | ROUGE-L: {avg_rouge:.4f}")
        
        # Save best model
        if avg_rouge > best_rouge:
            best_rouge = avg_rouge
            torch.save(caption_model.state_dict(), f"/content/drive/MyDrive/biodata Project/best_model.pt")

train_model(num_epochs=25, batch_size=8)

