In [None]:
from transformers import CLIPProcessor
from PIL import Image
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from pprint import pprint
import tqdm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
print("Loading the dataset...")
# Load the dataset
dataset = load_dataset("itsanmolgupta/mimic-cxr-dataset")
# dataset['train'], dataset['validation'], dataset['test'

In [None]:
# Check available splits
print("Available splits:", list(dataset.keys()))
print("Number of samples in each split:")
for split in dataset.keys():
    print(f"{split}: {len(dataset[split])}")



In [None]:
# Look at one sample
sample = dataset['train'][1]

print("\n--- Sample Data ---")
pprint(sample)


In [None]:
sample.__sizeof__()

In [None]:
sample_image = sample['image']
print("Image size:", sample_image.size)
print("Image mode:", sample_image.mode)

In [None]:
import matplotlib.pyplot as plt

# Assuming sample_data contains your data with the image
def display_sample(sample_data):
    # Extract components
    image = sample_data['image']
    findings = sample_data['findings']
    impression = sample_data['impression']
    
    # Create figure
    plt.figure(figsize=(4,6))
    
    # Display image
    plt.subplot(2, 1, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.title('Chest X-ray')
    
    # Display text
    plt.subplot(2, 1, 2)
    plt.axis('off')
    plt.text(0, 0.8, f"Findings: {findings}", wrap=True, fontsize=12)
    plt.text(0, 0.2, f"Impression: {impression}", wrap=True, fontsize=12, fontweight='bold')
    
    plt.tight_layout()
    plt.show()

# Call the function with your sample data
count = 0
for sample in dataset['train']:
    count += 1
    if count > 5:
        break
    display_sample(sample)  # Replace with your actual variable name
# Replace with your actual variable name

In [None]:
import random

# Define sampling fraction
fraction = 0.001
total_examples = len(dataset["train"])
print("total_examples:",total_examples)
# Define the number of examples to sample
subset_size = int(total_examples * fraction)
print("subset_size:",subset_size)

In [None]:
# Randomly sample subset indices from the full dataset
sampled_indices = random.sample(range(total_examples), subset_size)

# Shuffle the subset for randomness
random.shuffle(sampled_indices)

# Split into train and validation from the sampled subset
train_cutoff = int(subset_size * 0.8)
train_indices = sampled_indices[:train_cutoff]
val_indices = sampled_indices[train_cutoff:]

# Apply indices to the original full dataset
dataset["validation"] = dataset["train"].select(val_indices)
dataset["train"] = dataset["train"].select(train_indices)


In [None]:
print("Number of samples in the new train set:", len(dataset["train"]))
print("Number of samples in the new validation set:", len(dataset["validation"]))

In [None]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
from torch.utils.data import Dataset

In [None]:
# class MIMICCLIPDataset(Dataset):
#     def __init__(self, hf_dataset, processor):
#         self.dataset = hf_dataset
#         self.processor = processor

#     def __getitem__(self, idx):
#         item = self.dataset[idx]
#         image = item["image"]  # already a PIL.Image from dataset
#         text = item.get("impression") or item.get("findings")

#         inputs = self.processor(
#             text=text,
#             images=image,
#             return_tensors="pt",
#             padding=True
#         )
#         return {k: v.squeeze(0) for k, v in inputs.items()}, None

#     def __len__(self):
#         return len(self.dataset)

In [None]:
class MIMICCLIPDataset(Dataset):
    def __init__(self, hf_dataset, processor, max_length=77):
        self.dataset = hf_dataset
        self.processor = processor
        self.max_length = max_length  # CLIP default context length
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]  # already a PIL.Image from dataset
        
        # Preferentially use impression if available, otherwise findings
        text = item.get("impression") or item.get("findings")
        
        # Process with CLIP processor 
        inputs = self.processor(
            text=text,
            images=image,
            padding="max_length",
            max_length=self.max_length,
            truncation=True,
            return_tensors="pt"
        )
        
        # Remove batch dimension added by processor
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
    
    def __len__(self):
        return len(self.dataset)

# Example of how to create and use the DataLoader with your dataset
'''
import torch
from torch.utils.data import DataLoader

# Create dataset
train_dataset = MIMICCLIPDataset(train_hf_dataset, processor)

# Create DataLoader
train_loader = DataLoader(
    train_dataset,
    batch_size=16, 
    shuffle=True,
    num_workers=4
)

# Example training loop
for batch, _ in train_loader:
    # batch is now a dict with keys like 'input_ids', 'attention_mask', 'pixel_values'
    # that can be passed directly to your CLIP model
    outputs = model(**batch)
    loss = outputs.loss
    # ... rest of training code
'''

In [None]:
# models/clip/train.py

import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader

import time
from datetime import timedelta

def do_train(model, train_dl, optimizer, lr_scheduler, device):
    train_loss = 0
    model.train()
    start_time = time.time()
    
    for bid, batch in enumerate(train_dl):
        batch_start = time.time()
        if bid % 100 == 0:
            print("...{:d} training steps complete".format(bid))

        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch, return_loss=True)
        loss = outputs.loss

        train_loss += loss.detach().cpu().numpy()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
        if bid % 100 == 0 and bid > 0:
            batch_time = time.time() - batch_start
            eta = batch_time * (len(train_dl) - bid)
            print(f"    Batch time: {batch_time:.2f}s | ETA: {str(timedelta(seconds=int(eta)))}")

    total_time = time.time() - start_time
    avg_time_per_batch = total_time / len(train_dl)
    print(f"...{bid} training steps COMPLETE in {str(timedelta(seconds=int(total_time)))}")
    print(f"Average time per batch: {avg_time_per_batch:.2f}s")
    
    return train_loss, total_time

#     return val_loss, val_acc, total_time
def do_eval(model, eval_dl, device):
    model.eval()
    val_loss, val_acc, num_examples = 0, 0, 0
    start_time = time.time()
    
    # Add debugging to check if dataloader is empty
    print(f"Validation dataloader contains {len(eval_dl)} batches")
    
    for bid, batch in enumerate(eval_dl):
        # Print every batch during validation for debugging
        print(f"Validating batch {bid+1}/{len(eval_dl)}")
        
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch, return_loss=True)

        loss = outputs.loss
        val_loss += loss.detach().cpu().numpy()

        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)
        predictions = torch.argmax(probs, dim=-1)
        labels = torch.arange(len(predictions)).to(device)

        accuracy = torch.sum(predictions == labels)
        num_examples += len(predictions)
        val_acc += accuracy

    total_time = time.time() - start_time
    
    # Avoid division by zero if no examples were processed
    if num_examples > 0:
        val_acc = val_acc.detach().cpu().numpy() / num_examples
    else:
        val_acc = 0.0
        print("WARNING: No examples were processed during validation!")
    
    print(f"Validation complete: Processed {num_examples} examples in {len(eval_dl)} batches")
    
    return val_loss, val_acc, total_time


In [None]:
# Example of setting up training
import torch
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup

# Setup dataloaders
train_ds = MIMICCLIPDataset(dataset["train"], processor)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
val_ds = MIMICCLIPDataset(dataset["validation"], processor)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)
# Check validation dataset size
print(f"Validation dataset size: {len(val_ds)} examples")
print(f"Validation batch size: {val_loader.batch_size}")
print(f"Expected number of batches: {len(val_ds) // val_loader.batch_size + (1 if len(val_ds) % val_loader.batch_size > 0 else 0)}")

# Setup model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
model.to(device)

# Setup optimizer and scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training_steps = len(train_loader) * 3  # Assuming 3 epochs
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

# Training loop
EPOCHS = 3
for epoch in range(EPOCHS):
    print(f"\n🚀 Epoch {epoch+1}/{EPOCHS}")
    train_loss,train_time = do_train(model, train_loader, optimizer, lr_scheduler, device)
    val_loss, val_acc ,val_time= do_eval(model, val_loader, device)
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f} | Train Time: {str(timedelta(seconds=int(train_time)))}")
    print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val Time: {str(timedelta(seconds=int(val_time)))}")