# Image Baseline

In [None]:
import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torchvision.models as models
from tqdm.notebook import tqdm # Progress bars

from PIL import Image

from src.data import load_omnimed_dataset

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load the Dataset

### Load the Base Dataset

In [None]:
train_df, val_df, test_df = load_omnimed_dataset()

print("Train size:", len(train_df))
print("Validation size:", len(val_df))
print("Test size:", len(test_df))

# Check for image overlap
print("Overlap train-test:", len(set(train_df['image_path']) & set(test_df['image_path'])))
print("Overlap train-val:", len(set(train_df['image_path']) & set(val_df['image_path'])))


### Create Image-Only Dataframes

In [None]:
def create_image_df(df):
    # Extract the text of the correct option using gt_label
    df = df.copy()
    df['gt_text'] = df.apply(lambda row: row[row['gt_label']], axis=1)
    
    # Keep only image_path and gt_text for the image-only baseline
    return df[['image_path', 'gt_text']]

train_img_df = create_image_df(train_df)
val_img_df = create_image_df(val_df)
test_img_df = create_image_df(test_df)

### Build the Label Space

In [None]:
all_labels = pd.concat([train_img_df, val_img_df, test_img_df])['gt_text'].unique()
all_labels = sorted(all_labels)
label_to_idx = {label: i for i, label in enumerate(all_labels)}
idx_to_label = {i: label for label, i in label_to_idx.items()}

# Map text to integer labels
for df in [train_img_df, val_img_df, test_img_df]:
    df['label_idx'] = df['gt_text'].map(label_to_idx)

# Verify label distribution
print("Number of unique labels:", len(all_labels))

## Image Dataset Setup

### Define Image Transforms

In [None]:
# Training transforms (includes augmentation)
train_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),               # Resize image
    transforms.RandomHorizontalFlip(),                # Augment: flip
    transforms.RandomRotation(10),            # Augment: slight rotation
    models.ResNet18_Weights.DEFAULT.transforms(),     # Use normalization from pre-trained model
])

# Validation / Test transforms (no augmentation)
val_image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    models.ResNet18_Weights.DEFAULT.transforms()
])

### Create Image Dataset

In [None]:
class OmniMedImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        image_path = self.df.iloc[idx]['image_path']
        label = self.df.iloc[idx]['label_idx']

        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        return image, label

In [None]:
train_dataset = OmniMedImageDataset(train_img_df, transform=train_image_transform)
val_dataset = OmniMedImageDataset(val_img_df, transform=val_image_transform)
test_dataset = OmniMedImageDataset(test_img_df, transform=val_image_transform)

### Create Data Loaders

In [None]:
# Define batch_size
# TODO: Add to config.py as constant with optional override
batch_size = 32
num_workers = 0

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers
)

# Quick check
print("Train batches:", len(train_loader))
print("Validation batches:", len(val_loader))
print("Test batches:", len(test_loader))

In [None]:
for i, (images, labels) in enumerate(train_loader):
    print(i, images.shape, labels.shape)
    if i == 2:
        break

## Model Setup

### Define the Model

In [None]:
# Model
num_classes = len(label_to_idx)
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes) # Replace final layer
model = model.to(device)

# Loss, Optimizer, Scheduler
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

### Training Loop

In [None]:
# TODO: Add to config.py as constant with optional override
num_epochs = 1

best_val_acc = 0.0
best_model_path = "resnet18_best.pth"

for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")

    # Training
    model.train()
    train_loss = 0
    correct_train = 0
    total_train = 0

    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)
    
    train_loss /= total_train
    train_acc = correct_train / total_train
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    
    # Vaidation
    model.eval()
    val_loss = 0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)
    
    val_loss /= total_val
    val_acc = correct_val / total_val
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)
        print(f"Saved best model with val_acc={best_val_acc:.4f}")

    # Step the scheduler
    scheduler.step()

### Testing Loop

In [None]:
model.eval()
correct_test = 0
total_test = 0

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct_test += (preds == labels).sum().item()
        total_test += labels.size(0)

test_acc = correct_test / total_test
print(f"\nTest Accuracy: {test_acc:.4f}")