# Image Baseline

In [None]:
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
import torchvision.models as models

from PIL import Image

from src.data import load_omnimed_dataset

## Load the Dataset

### Load the Base Dataset

In [None]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))

# Check for image overlap
print("Overlap train-test:", len(set(train_df['image_path']) & set(test_df['image_path'])))
print("Overlap train-val:", len(set(train_df['image_path']) & set(val_df['image_path'])))


### Create Image-Only Dataframes

In [None]:
def create_image_df(df):
    # Extract the text of the correct option using gt_label
    df = df.copy()
    df['gt_text'] = df.apply(lambda row: row[row['gt_label']], axis=1)
    
    # Keep only image_path and gt_text for the image-only baseline
    return df[['image_path', 'gt_text']]

train_img_df = create_image_df(train_df)
val_img_df = create_image_df(val_df)
test_img_df = create_image_df(test_df)

### Build the Label Space

In [None]:
all_labels = pd.concat([train_img_df, val_img_df, test_img_df])['gt_text'].unique()
all_labels = sorted(all_labels)
label_to_idx = {label: i for i, label in enumerate(all_labels)}
idx_to_label = {i: label for label, i in label_to_idx.items()}

# Map text to integer labels
for df in [train_img_df, val_img_df, test_img_df]:
    df['label_idx'] = df['gt_text'].map(label_to_idx)

# Verify label distribution
print("Number of unique labels:", len(all_labels))

## Image Dataset Setup

### Define Image Transforms

In [None]:
# Training transforms (includes augmentation)
train_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),               # Resize image
    transforms.RandomHorizontalFlip(),                # Augment: flip
    transforms.RandomRotation(10),            # Augment: slight rotation
    transforms.ToTensor(),                            # Convert PIL image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard ImageNet normalization
                         std=[0.229, 0.224, 0.225])
])

# Validation / Test transforms (no augmentation)
val_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

### Create Image Dataset

In [None]:
class OmniMedImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image_path']
        label = self.df.iloc[idx]['label_idx']

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
train_dataset = OmniMedImageDataset(train_img_df, transform=train_image_transform)
val_dataset = OmniMedImageDataset(val_img_df, transform=val_image_transform)
test_dataset = OmniMedImageDataset(test_img_df, transform=val_image_transform)

### Create Data Loaders

In [None]:
# Define batch_size
# TODO: Add to config.py as constant with optional override via environment variable
batch_size = 32
num_workers = 4

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

# Quick check
print("Train batches:", len(train_loader))
print("Validation batches:", len(val_loader))
print("Test batches:", len(test_loader))

## Model Setup