# Nocturna - MobileNetV3 Binary Classifier Training

This notebook trains a MobileNetV3 model to classify images as:
- **Label 0**: Photographs (not suitable for inverting)
- **Label 1**: Charts/Plots (suitable for inverting in dark mode)

## 1. Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import timm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from tqdm import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"timm version: {timm.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Explore Available MobileNetV3 Models

In [None]:
# List all available MobileNetV3 models in timm
mobilenet_models = timm.list_models('*mobilenetv3*', pretrained=True)
print("Available MobileNetV3 models:")
for model_name in mobilenet_models:
    print(f"  - {model_name}")

## 3. Load Pre-trained MobileNetV3 Model

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load pre-trained MobileNetV3-Large with binary classification head
model_name = 'mobilenetv3_large_100'
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=2  # Binary classification: 0 (photos) or 1 (plots)
)

model = model.to(device)
print(f"\nModel '{model_name}' loaded successfully!")
print(f"Number of parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Inspect Model Architecture

In [None]:
# Get model configuration
data_config = timm.data.resolve_model_data_config(model)
print("Model data configuration:")
print(f"  Input size: {data_config['input_size']}")
print(f"  Mean: {data_config['mean']}")
print(f"  Std: {data_config['std']}")
print(f"  Interpolation: {data_config['interpolation']}")
print(f"  Crop percentage: {data_config['crop_pct']}")

In [None]:
# Print model summary (classifier head)
print("\nClassifier head:")
print(model.classifier)

## 5. Create Data Transforms

In [None]:
# Create transforms for training and validation
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=data_config['mean'], std=data_config['std'])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=data_config['mean'], std=data_config['std'])
])

print("Data transforms created successfully!")

## 6. Test Model with Random Input

In [None]:
# Test the model with a random input
model.eval()
with torch.no_grad():
    dummy_input = torch.randn(1, 3, 224, 224).to(device)
    output = model(dummy_input)
    print(f"Output shape: {output.shape}")
    print(f"Raw output (logits): {output}")
    
    # Apply sigmoid for binary classification probabilities
    probs = torch.sigmoid(output)
    print(f"Probabilities: {probs}")
    print(f"Predicted class: {(probs > 0.5).long().item()}")

## 7. Next Steps

- [ ] Generate/download dataset (photos and plots)
- [ ] Create custom Dataset class
- [ ] Set up DataLoaders
- [ ] Define loss function and optimizer
- [ ] Implement training loop
- [ ] Evaluate model performance

## 5. Dataset Exploration

In [None]:
# Check dataset structure
import os

data_dir = Path('data')
label_0_dir = data_dir / 'label_0'  # Photos (to be filled)
label_1_dir = data_dir / 'label_1'  # Generated plots

print("Dataset Structure:")
print(f"  {label_0_dir}: {len(list(label_0_dir.glob('*.png')) + list(label_0_dir.glob('*.jpg')))} images")
print(f"  {label_1_dir}: {len(list(label_1_dir.glob('*.png')) + list(label_1_dir.glob('*.jpg')))} images")

# Show sample plots
plot_files = sorted(list(label_1_dir.glob('*.png')))[:12]
if plot_files:
    fig, axes = plt.subplots(3, 4, figsize=(15, 10))
    fig.suptitle('Sample Generated Plots (Label 1)', fontsize=16)
    for idx, (ax, img_path) in enumerate(zip(axes.flat, plot_files)):
        img = Image.open(img_path)
        ax.imshow(img)
        ax.set_title(img_path.stem, fontsize=8)
        ax.axis('off')
    plt.tight_layout()
    plt.show()
else:
    print("\nNo plots found in label_1 directory!")