In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
from medguard.medguard_image_anonymizer import anonymize_image  # optional preprocessing
from medguard.presidio_loader import anonymize_text
from biomedclip import BioMedCLIPModel, BioMedCLIPProcessor  # Hypothetical import
import json
import os
import random
import numpy as np

# ------------------------------
# 1️⃣ Hyperparameters
# ------------------------------
MODEL_NAME = "t5-base"
BIOMEDCLIP_MODEL_NAME = "microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
BATCH_SIZE = 8
EPOCHS = 3
LR = 5e-5
MAX_INPUT_LENGTH = 512
MAX_OUTPUT_LENGTH = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------------------------
# 2️⃣ Dataset Class
# ------------------------------
class CXRDataset(Dataset):
    def __init__(self, data_path, tokenizer, clip_processor):
        with open(data_path, "r") as f:
            self.data = json.load(f)
        self.tokenizer = tokenizer
        self.clip_processor = clip_processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        # 1. Anonymize image and text
        image_path = anonymize_image(item["image_path"])  # returns path to anonymized image
        text_input = anonymize_text(item["report"])

        # 2. Encode image with BioMedCLIP
        image_features = self.clip_processor.encode_image(image_path)

        # 3. Tokenize text input
        input_encoding = self.tokenizer(
            text_input,
            max_length=MAX_INPUT_LENGTH,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        # Optional: you can concatenate image_features with input embeddings if you want multimodal
        return {
            "input_ids": input_encoding["input_ids"].squeeze(0),
            "attention_mask": input_encoding["attention_mask"].squeeze(0),
            "image_features": image_features,  # optional
            "labels": input_encoding["input_ids"].squeeze(0)  # Teacher forcing
        }

# ------------------------------
# 3️⃣ Seed for Reproducibility
# ------------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if DEVICE == "cuda":
        torch.cuda.manual_seed_all(seed)

set_seed()

# ------------------------------
# 4️⃣ Load Tokenizer and Models
# ------------------------------
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
t5_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME).to(DEVICE)

clip_processor = BioMedCLIPProcessor.from_pretrained(BIOMEDCLIP_MODEL_NAME)
clip_model = BioMedCLIPModel.from_pretrained(BIOMEDCLIP_MODEL_NAME).to(DEVICE)

# ------------------------------
# 5️⃣ Load Dataset
# ------------------------------
train_dataset = CXRDataset("data/train.json", tokenizer, clip_processor)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

# ------------------------------
# 6️⃣ Optimizer & Scheduler
# ------------------------------
optimizer = AdamW(t5_model.parameters(), lr=LR)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# ------------------------------
# 7️⃣ Training Loop
# ------------------------------
t5_model.train()
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        # Optional: combine T5 embeddings with image_features if multimodal

        outputs = t5_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loop.set_description(f"Loss {loss.item():.4f}")

# ------------------------------
# 8️⃣ Save Fine-Tuned Model
# ------------------------------
output_dir = "outputs/finetuned_model"
os.makedirs(output_dir, exist_ok=True)
t5_model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print("Fine-tuning completed and model saved!")
