In [None]:
from google.colab import drive
drive.mount('/content/drive')

!pip install nltk rouge-score
import nltk
nltk.download('punkt')
nltk.download('wordnet')  # for METEOR

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge_score import rouge_scorer
from sklearn.metrics import classification_report, precision_score, recall_score, f1_score, accuracy_score

import nltk

# Force download again (even if already there)
nltk.download('punkt', force=True)
nltk.download('wordnet', force=True)
nltk.download('omw-1.4', force=True)

# ✅ OPTIONAL: Check paths
print(nltk.data.path)


!pip install git+https://github.com/openai/CLIP.git

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import pandas as pd
import clip
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from tqdm import tqdm
import torch.nn as nn

base_dir = "/content/drive/MyDrive/biodata Project/MRNet-v1.0"
plane = "sagittal"  # can be 'axial', 'coronal', 'sagittal'
label_csv = os.path.join(base_dir, "train-abnormal.csv")
image_dir = os.path.join(base_dir, "train", plane)


# Load labels
labels_df = pd.read_csv(label_csv)
#print(f"Total scans: {len(labels_df)}")
# Choose a sample (change index to view other examples)
sample_index = 0
img_id = str(labels_df.loc[sample_index, '1']).zfill(4)
label = labels_df.loc[sample_index, '1']

# Load image
img_path = os.path.join(image_dir, f"{img_id}.npy")
scan = np.load(img_path)

print(f"Scan shape: {scan.shape} — Label: {label}")

# Plot a few slices
num_slices = scan.shape[0]
slices_to_plot = np.linspace(0, num_slices - 1, 6, dtype=int)

plt.figure(figsize=(12, 6))
for i, slice_idx in enumerate(slices_to_plot):
    plt.subplot(2, 3, i+1)
    plt.imshow(scan[slice_idx], cmap='gray')
    plt.title(f"Slice {slice_idx}")
    plt.axis('off')

plt.suptitle(f"Sample {img_id} — Abnormal: {label}")
plt.tight_layout()
plt.show()


abnormal_df = pd.read_csv(os.path.join(base_dir, "train-abnormal.csv"))
acl_df = pd.read_csv(os.path.join(base_dir, "train-acl.csv"))
meniscus_df = pd.read_csv(os.path.join(base_dir, "train-meniscus.csv"))

abnormal_df.columns = ['exam', 'abnormal']
acl_df.columns = ['exam', 'acl']
meniscus_df.columns = ['exam', 'meniscus']

merged_df = abnormal_df.merge(acl_df, on='exam').merge(meniscus_df, on='exam')


def generate_caption(row):
    if row['abnormal'] == 0:
        return "The depicted knee appears to be healthy."

    findings = []
    if row['acl'] == 1:
        findings.append("an ACL tear")
    if row['meniscus'] == 1:
        findings.append("a meniscus tear")

    if findings:
        return "The depicted knee has " + " and ".join(findings) + "."
    else:
        return "The depicted knee has an unspecified abnormality."


merged_df['caption'] = merged_df.apply(generate_caption, axis=1)
merged_df

merged_df["caption"].unique()

### Data Preprocessing

# Use CLIP preprocessing (from OpenAI or OpenCLIP)
clip_preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073),
                         std=(0.26862954, 0.26130258, 0.27577711))
])

def load_exam_mri(path):
    scan = np.load(path)  # shape: (slices, H, W)

    # Choose middle 3 slices and stack to simulate RGB
    mid = scan.shape[0] // 2
    slices = scan[mid - 1: mid + 2]

    # Normalize to [0, 255] and convert to uint8 for PIL
    slices = np.stack([((s - s.min()) / (s.max() - s.min()) * 255).astype(np.uint8) for s in slices], axis=-1)

    # Convert to PIL and preprocess
    img = Image.fromarray(slices)
    return clip_preprocess(img)


device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# We will only use model.visual (vision encoder)


def encode_image(tensor_image):
    tensor_image = tensor_image.unsqueeze(0).to(device)  # add batch dim
    with torch.no_grad():
        image_embedding = model.encode_image(tensor_image)
    return image_embedding  # shape: (1, 512)


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

gpt2 = GPT2LMHeadModel.from_pretrained("gpt2").to(device)

# Freeze CLIP if you want:
for p in model.visual.parameters():
    p.requires_grad = False


class ClipCaptionModel(nn.Module):
    def __init__(self, clip_dim=512, prefix_len=10):
        super().__init__()
        self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
        self.prefix_len = prefix_len
        self.clip_project = nn.Linear(clip_dim, self.gpt.config.n_embd * prefix_len)

    def forward(self, image_embedding, captions, attention_mask):
        batch_size = captions.shape[0]

        # 💡 Cast to float32 to match the Linear layer's weights
        image_embedding = image_embedding.float()

        prefix_embedding = self.clip_project(image_embedding).view(batch_size, self.prefix_len, -1)
        caption_embeddings = self.gpt.transformer.wte(captions)

        embeddings = torch.cat((prefix_embedding, caption_embeddings), dim=1)

        extended_attention = torch.cat((
            torch.ones((batch_size, self.prefix_len), device=attention_mask.device),
            attention_mask
        ), dim=1)

        labels = torch.cat((
            torch.full((batch_size, self.prefix_len), -100, device=captions.device),
            captions
        ), dim=1)

        outputs = self.gpt(inputs_embeds=embeddings, attention_mask=extended_attention, labels=labels)
        return outputs



from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class MRICaptionDataset(Dataset):
    def __init__(self, dataframe, image_dir, transform, tokenizer, max_length=50):
        self.data = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        exam_id = str(row['exam']).zfill(4)
        caption = row['caption']
        img_path = os.path.join(self.image_dir, f"{exam_id}.npy")

        # Load and process image
        scan = np.load(img_path)
        mid = scan.shape[0] // 2
        slices = scan[mid - 1: mid + 2]
        slices = np.stack([((s - s.min()) / (s.max() - s.min()) * 255).astype(np.uint8) for s in slices], axis=-1)
        img = Image.fromarray(slices)
        img_tensor = self.transform(img)

        # Tokenize caption
        tokens = self.tokenizer(caption, padding="max_length", truncation=True,
                                max_length=self.max_length, return_tensors="pt")
        input_ids = tokens.input_ids.squeeze(0)
        attention_mask = tokens.attention_mask.squeeze(0)

        return img_tensor, input_ids, attention_mask


train_df, val_df = train_test_split(merged_df, test_size=0.2, random_state=42)

train_dataset = MRICaptionDataset(train_df, image_dir, clip_preprocess, tokenizer)
val_dataset = MRICaptionDataset(val_df, image_dir, clip_preprocess, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8)


clip_gpt_model = ClipCaptionModel().to(device)
optimizer = torch.optim.AdamW(clip_gpt_model.parameters(), lr=1e-4)

def caption_to_vector(caption):
    vector = [0, 0, 0]  # [ACL, Meniscus, Abnormality]

    if 'ACL tear' in caption:
        vector[0] = 1
    if 'meniscus tear' in caption:
        vector[1] = 1
    if 'unspecified abnormality' in caption:
        vector[2] = 1

    return vector


def encode_clip_images(clip_model, images):
    """Encodes a batch of images using CLIP's visual encoder."""
    with torch.no_grad():
        return clip_model.encode_image(images.to(device))

def train_epoch(caption_model, clip_model, dataloader, optimizer, device):
    caption_model.train()
    total_loss = 0

    for images, input_ids, attention_mask in tqdm(dataloader, desc="Training"):
        images, input_ids, attention_mask = images.to(device), input_ids.to(device), attention_mask.to(device)

        image_embeddings = encode_clip_images(clip_model, images)
        outputs = caption_model(image_embeddings, input_ids, attention_mask)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    return total_loss / len(dataloader)


def evaluate(caption_model, clip_model, dataloader, device, tokenizer):
    caption_model.eval()
    total_loss = 0
    total_bleu = 0
    total_rouge_l = 0

    smooth = SmoothingFunction().method4
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    true_vectors = []
    pred_vectors = []

    with torch.no_grad():
        for images, input_ids, attention_mask in tqdm(dataloader, desc="Validation"):
            images = images.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)

            image_embeddings = encode_clip_images(clip_model, images)
            outputs = caption_model(image_embeddings, input_ids, attention_mask)
            loss = outputs.loss
            total_loss += loss.item()

            batch_size = input_ids.size(0)
            for i in range(batch_size):
                # Ground truth
                true_caption = tokenizer.decode(input_ids[i], skip_special_tokens=True)

                # Generate predicted caption
                prefix_embed = caption_model.clip_project(image_embeddings[i].float().unsqueeze(0)) \
                    .view(1, caption_model.prefix_len, -1)

                generated = caption_model.gpt.generate(
                    inputs_embeds=prefix_embed,
                    max_length=50,
                    num_beams=5,
                    early_stopping=True,
                    pad_token_id=tokenizer.eos_token_id
                )
                pred_caption = tokenizer.decode(generated[0], skip_special_tokens=True)

                # BLEU
                reference = nltk.word_tokenize(true_caption)
                candidate = nltk.word_tokenize(pred_caption)
                bleu = sentence_bleu([reference], candidate, smoothing_function=smooth)
                total_bleu += bleu

                # ROUGE-L
                rouge_l = scorer.score(true_caption, pred_caption)['rougeL'].fmeasure
                total_rouge_l += rouge_l

                # Soft label vectors
                true_vec = caption_to_vector(true_caption)
                pred_vec = caption_to_vector(pred_caption)
                true_vectors.append(true_vec)
                pred_vectors.append(pred_vec)

    avg_loss = total_loss / len(dataloader)
    avg_bleu = total_bleu / len(dataloader.dataset)
    avg_rouge_l = total_rouge_l / len(dataloader.dataset)

    # Classification metrics
    true_vectors = np.array(true_vectors)
    pred_vectors = np.array(pred_vectors)

    accuracy = accuracy_score(true_vectors, pred_vectors)
    precision = precision_score(true_vectors, pred_vectors, average='micro')
    recall = recall_score(true_vectors, pred_vectors, average='micro')
    f1 = f1_score(true_vectors, pred_vectors, average='micro')

    return avg_loss, avg_bleu, avg_rouge_l, accuracy, precision, recall, f1


import nltk
nltk.download('punkt_tab')

for epoch in range(25):
    print(f"\n🌟 Epoch {epoch + 1}/{25}")

    avg_train_loss = train_epoch(clip_gpt_model, model, train_loader, optimizer, device)
    avg_val_loss, avg_bleu, avg_rouge_l, accuracy, precision, recall, f1 = evaluate(clip_gpt_model, model, val_loader, device, tokenizer)
    
    print(f"✅ Train Loss: {avg_train_loss: .4f} | 🔍 Val Loss: {avg_val_loss:.4f}")
    print(f"📝 BLEU: {avg_bleu:.4f}| 📊 ROUGE-L: {avg_rouge_l:.4f}")
    print(f"🟢 Accuracy:  {accuracy:.4f}")
    print(f"🟡 Precision:       {precision:.4f}")
    print(f"🔵 Recall:          {recall:.4f}")
    print(f"🟣 F1 Score:        {f1:.4f}")
