# ðŸŽ“ Model Training Walkthrough

This notebook demonstrates how to train a custom image classifier using the labeled dataset generated by the Image Labeler tool. 

We will fine-tune a pre-trained **ResNet18** model using PyTorch.

### ðŸš€ Prerequisites
Ensure you have run the Image Labeler and generated a dataset in `data/processed` (or a similar path) containing `train/` and `test/` folders.

## 1. Environment Setup
First, we install the necessary dependencies (`torch`, `torchvision`, etc.).

In [None]:
!pip install -r requirements.txt

## 2. Imports
Import necessary libraries for data manipulation, file handling, and PyTorch operations.

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm

## 3. Configuration
Set up your training parameters. Adjust `DATA_DIR` to point to your split dataset.

In [None]:
# Path to the directory containing 'train' and 'test' folders
DATA_DIR = "../data/processed"

# Hyperparameters
BATCH_SIZE = 4
NUM_EPOCHS = 5
LEARNING_RATE = 0.001

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 4. Custom Dataset Definition
We define a custom `Dataset` class that reads the `labels.json` file created by our tool. It maps images to their labels dynamically.

In [None]:
class LabelerDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.params_file = os.path.join(data_dir, "labels.json")
        
        if not os.path.exists(self.params_file):
            raise FileNotFoundError(f"labels.json not found in {data_dir}")

        with open(self.params_file, "r") as f:
            self.data = json.load(f)

        # Filter out invalid images
        self.valid_data = []
        for item in self.data:
            img_name = item.get("filename")
            # Check existence
            if img_name and os.path.exists(os.path.join(data_dir, img_name)):
                self.valid_data.append(item)
        
        self.labels = [item["label"] for item in self.valid_data]

    def set_class_map(self, class_to_idx):
        self.class_to_idx = class_to_idx

    def __len__(self):
        return len(self.valid_data)

    def __getitem__(self, idx):
        item = self.valid_data[idx]
        img_path = os.path.join(self.data_dir, item["filename"])
        image = Image.open(img_path).convert("RGB")
        label = item["label"]
        
        if self.transform:
            image = self.transform(image)
            
        label_idx = self.class_to_idx[label]
        return image, label_idx

## 5. Data Preparation
Define image transformations (resizing, normalization) and create DataLoaders.

In [None]:
# Data Augmentation and Normalization
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Initialize Datasets
image_datasets = {}
for phase in ['train', 'test']:
    dir_path = os.path.join(DATA_DIR, phase)
    if os.path.exists(dir_path):
        image_datasets[phase] = LabelerDataset(dir_path, transform=data_transforms[phase])
    else:
        print(f"Warning: {phase} directory not found at {dir_path}")

# Auto-detect classes from Training set
if 'train' in image_datasets:
    all_labels = sorted(list(set(image_datasets['train'].labels)))
    class_to_idx = {label: idx for idx, label in enumerate(all_labels)}
    print(f"Found {len(all_labels)} classes: {all_labels}")
    
    # Apply class map to all datasets
    for phase in image_datasets:
        image_datasets[phase].set_class_map(class_to_idx)
else:
    print("Error: No training data found. Cannot proceed.")

# Create DataLoaders
dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) 
               for x in image_datasets}
dataset_sizes = {x: len(image_datasets[x]) for x in image_datasets}

## 6. Model Setup
We use **ResNet18**, a lightweight CNN. We replace the final fully connected layer (`fc`) to output the number of classes present in our dataset.

In [None]:
# Load Pretrained ResNet18
try:
    from torchvision.models import ResNet18_Weights
    weights = ResNet18_Weights.IMAGENET1K_V1
except ImportError:
    weights = None # Fallback for older torch versions

model = models.resnet18(weights=weights)

# Modify the last layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(all_labels))

model = model.to(device)

# Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

## 7. Training Loop
The training process involves iterating through epochs. In each epoch, we: 
1.  **Train**: Update model weights using backpropagation.
2.  **Evaluate**: Check performance on the test set without updating weights.

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=5):
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'test']:
            if phase not in dataloaders:
                continue

            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Batch iteration
            for inputs, labels in tqdm(dataloaders[phase], desc=phase):
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward pass only in training
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
    print('Training complete')

## 8. Execute Training
Run the training loop.

In [None]:
if 'train' in dataloaders:
    train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS)

## 9. Save and Inference
Save the trained model and class map for future use.

In [None]:
# Save Model
save_path = os.path.join(DATA_DIR, "model.pth")
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Save Class Map
class_map_path = os.path.join(DATA_DIR, "class_map.json")
with open(class_map_path, "w") as f:
    json.dump(class_to_idx, f, indent=4)
print(f"Class map saved to {class_map_path}")