# Breast Cancer Microcalcification Classification - Training Notebook

This notebook demonstrates how to use the shared `core` library for training the model. 
It is designed to run in Google Colab or locally.

In [None]:
# Setup for Colab (Uncomment if running in Colab)
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/path/to/project_root 

In [None]:
import sys
import os

# Add the project root to sys.path so we can import 'core'
# If running locally from 'notebooks/', the root is '../'
sys.path.append(os.path.abspath('..'))

from core.config import Config
from core.preprocessing import apply_clahe, get_transforms
from core.model import BreastCancerModel

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import pandas as pd
import numpy as np

## 1. Configuration
Check the configuration settings loaded from `core.config`.

In [None]:
print(f"Model Backbone: {Config.BACKBONE}")
print(f"Image Size: {Config.IMAGE_SIZE}")
print(f"Batch Size: {Config.BATCH_SIZE}")

## 2. Dataset Definition
Using `apply_clahe` from `core.preprocessing`.

In [None]:
class MammogramDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        # This is a placeholder. You should load your actual CSV.
        # self.df = pd.read_csv(csv_file)
        self.df = pd.DataFrame({'path': ['dummy.jpg'], 'label': [0]}) # Dummy data
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # img_path = self.df.iloc[idx, 0]
        # image = Image.open(img_path).convert('RGB')
        
        # Dummy Image for demonstration
        image_np = np.random.randint(0, 255, (224, 224), dtype=np.uint8)
        
        # Apply CLAHE (Shared Logic)
        image_clahe = apply_clahe(image_np)
        image_pil = Image.fromarray(image_clahe)
        
        if self.transform:
            image_pil = self.transform(image_pil)
            
        label = self.df.iloc[idx, 1]
        return image_pil, label

# Create Dataset
train_dataset = MammogramDataset(Config.TRAIN_CSV, transform=get_transforms('train'))
train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True)

## 3. Model Initialization
Using `BreastCancerModel` from `core.model`.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BreastCancerModel(backbone_name=Config.BACKBONE, num_classes=Config.NUM_CLASSES)
model = model.to(device)
print(model)

## 4. Training Loop (Skeleton)

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)

print("Starting training loop...")
# for epoch in range(Config.NUM_EPOCHS):
#     for images, labels in train_loader:
#         images, labels = images.to(device), labels.to(device)
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = criterion(outputs, labels)
#         loss.backward()
#         optimizer.step()
print("Training loop placeholder.")