In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from Models import KeypointNet
from Models import KeypointDataset
from Generator import generate_synthetic_image


In [None]:
# ============================================================
# TRAINING SETUP AND LOOP
# ============================================================

import torch.nn.functional as F
from Helper import save_checkpoint_generic, load_checkpoint_generic
import os

# ============================================================
# TRAINING HYPERPARAMETERS
# ============================================================

# Model parameters (SuperPoint paper: Adam with lr=0.001, beta=(0.9, 0.999))
learning_rate = 0.001
adam_betas = (0.9, 0.999)
weight_decay = 0.0

# Training parameters (SuperPoint paper: 200k iterations, batch_size=32)
# Adjust based on your hardware - paper used larger batches and more iterations
num_epochs = 100
batch_size = 16  # Increase if you have enough GPU memory (paper used 32)

image_size = (240, 320)  # Height, Width

# Dataset parameters
num_train_samples = 5000  # Will result in ~312 iterations per epoch with batch_size=16
num_test_samples = 500
use_augmentation = True
pregenerate_data = True  # Set to False for on-the-fly generation

# Dataset saving/loading
dataset_cache_dir = './dataset_cache'
save_datasets = True  # Save datasets to disk for reuse
load_datasets_if_exist = True  # Load from disk if available

# Checkpoint parameters
checkpoint_dir = './checkpoints'
save_checkpoint_every = 5  # Save every N epochs
max_checkpoints = 4  # Keep only last N checkpoints

# Visualization parameters
print_every = 20  # Print loss every N batches
test_threshold = 0.015  # Threshold for keypoint detection during testing

# Create directories
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(dataset_cache_dir, exist_ok=True)

# ============================================================
# DATASET CREATION WITH SAVE/LOAD
# ============================================================

import pickle

train_dataset_path = os.path.join(dataset_cache_dir, f'train_dataset_{num_train_samples}.pkl')
test_dataset_path = os.path.join(dataset_cache_dir, f'test_dataset_{num_test_samples}.pkl')

# Try to load existing datasets
train_dataset = None
test_dataset = None

if load_datasets_if_exist and os.path.exists(train_dataset_path):
    print(f"Loading cached training dataset from {train_dataset_path}...")
    try:
        with open(train_dataset_path, 'rb') as f:
            train_dataset = pickle.load(f)
        print(f"✓ Training dataset loaded: {len(train_dataset)} samples\n")
    except Exception as e:
        print(f"⚠️  Failed to load training dataset: {e}")
        print("Will create new dataset...\n")
        train_dataset = None

if load_datasets_if_exist and os.path.exists(test_dataset_path):
    print(f"Loading cached test dataset from {test_dataset_path}...")
    try:
        with open(test_dataset_path, 'rb') as f:
            test_dataset = pickle.load(f)
        print(f"✓ Test dataset loaded: {len(test_dataset)} samples\n")
    except Exception as e:
        print(f"⚠️  Failed to load test dataset: {e}")
        print("Will create new dataset...\n")
        test_dataset = None

# Create datasets if not loaded
if train_dataset is None:
    print("Creating training dataset...")
    train_dataset = KeypointDataset(
        num_samples=num_train_samples,
        image_shape=image_size,
        generate_fn=generate_synthetic_image,
        generate_kwargs={
            'width': image_size[1],
            'height': image_size[0],
            'shape_type': 'random',
            'use_homography': True
        },
        augment=use_augmentation,
        pregenerate=pregenerate_data
    )
    print(f"✓ Training dataset created: {len(train_dataset)} samples\n")

    # Save dataset
    if save_datasets and pregenerate_data:
        print(f"Saving training dataset to {train_dataset_path}...")
        try:
            with open(train_dataset_path, 'wb') as f:
                pickle.dump(train_dataset, f)
            print(f"✓ Training dataset saved!\n")
        except Exception as e:
            print(f"⚠️  Failed to save training dataset: {e}\n")

if test_dataset is None:
    print("Creating test dataset...")
    test_dataset = KeypointDataset(
        num_samples=num_test_samples,
        image_shape=image_size,
        generate_fn=generate_synthetic_image,
        generate_kwargs={
            'width': image_size[1],
            'height': image_size[0],
            'shape_type': 'random',
            'use_homography': True
        },
        augment=False,  # No augmentation for test set
        pregenerate=True  # Always pregenerate test set
    )
    print(f"✓ Test dataset created: {len(test_dataset)} samples\n")

    # Save dataset
    if save_datasets:
        print(f"Saving test dataset to {test_dataset_path}...")
        try:
            with open(test_dataset_path, 'wb') as f:
                pickle.dump(test_dataset, f)
            print(f"✓ Test dataset saved!\n")
        except Exception as e:
            print(f"⚠️  Failed to save test dataset: {e}\n")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f"✓ DataLoaders created")
print(f"  Training batches: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")
print()

# ============================================================
# MODEL INITIALIZATION
# ============================================================

print("Initializing model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")

model = KeypointNet().to(device)

# Optimizer (SuperPoint paper: Adam with lr=0.001, betas=(0.9, 0.999))
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    betas=adam_betas,
    weight_decay=weight_decay
)

# Loss tracking
train_losses = []
test_losses = []
start_epoch = 0

# Load checkpoint if exists
checkpoint = load_checkpoint_generic(checkpoint_dir, device)
if checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    train_losses = checkpoint.get('train_losses', [])
    test_losses = checkpoint.get('test_losses', [])
    print(f"Resuming from epoch {start_epoch}\n")

print("Model ready for training!")
print("=" * 60)
print()


In [None]:
# ============================================================
# TRAINING LOOP
# ============================================================

print("Starting training...")
print()

for epoch in range(start_epoch, num_epochs):
    # ========== TRAINING ==========
    model.train()
    epoch_train_loss = 0.0
    num_train_batches = 0

    for batch_idx, (images, targets) in enumerate(train_loader):
        images = images.to(device)
        targets = targets.to(device)

        # Forward pass
        logits = model(images, return_logits=True)  # (B, 65, H/8, W/8)

        # Convert targets to class indices
        targets_idx = targets.argmax(dim=1)  # (B, H/8, W/8)

        # Compute loss
        loss = F.cross_entropy(logits, targets_idx)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Track loss
        epoch_train_loss += loss.item()
        num_train_batches += 1

        # Print progress
        if (batch_idx + 1) % print_every == 0:
            avg_loss = epoch_train_loss / num_train_batches
            print(f"Epoch [{epoch + 1}/{num_epochs}] Batch [{batch_idx + 1}/{len(train_loader)}] "
                  f"Loss: {loss.item():.4f} (Avg: {avg_loss:.4f})")

    # Calculate average training loss
    avg_train_loss = epoch_train_loss / num_train_batches
    train_losses.append(avg_train_loss)

    # ========== TESTING ==========
    model.eval()
    epoch_test_loss = 0.0
    num_test_batches = 0

    with torch.no_grad():
        for images, targets in test_loader:
            images = images.to(device)
            targets = targets.to(device)

            # Forward pass
            logits = model(images, return_logits=True)

            # Convert targets to class indices
            targets_idx = targets.argmax(dim=1)

            # Compute loss
            loss = F.cross_entropy(logits, targets_idx)

            epoch_test_loss += loss.item()
            num_test_batches += 1

    avg_test_loss = epoch_test_loss / num_test_batches
    test_losses.append(avg_test_loss)

    # Print epoch summary
    print(f"\n{'=' * 60}")
    print(f"Epoch [{epoch + 1}/{num_epochs}] Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Test Loss:  {avg_test_loss:.4f}")
    print(f"{'=' * 60}\n")

    # Save checkpoint
    if (epoch + 1) % save_checkpoint_every == 0 or (epoch + 1) == num_epochs:
        save_checkpoint_generic(
            checkpoint_dir,
            epoch,
            {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'test_losses': test_losses,
                'config': {
                    'learning_rate': learning_rate,
                    'weight_decay': weight_decay,
                    'batch_size': batch_size,
                    'num_train_samples': num_train_samples,
                    'num_test_samples': num_test_samples,
                }
            },
            max_checkpoints=max_checkpoints
        )

print("\n✓ Training complete!")

# ============================================================
# PLOT TRAINING CURVES
# ============================================================

print("\nPlotting training curves...")
fig, ax = plt.subplots(1, 1, figsize=(10, 6))

epochs_range = range(1, len(train_losses) + 1)
ax.plot(epochs_range, train_losses, 'b-', label='Training Loss', linewidth=2)
ax.plot(epochs_range, test_losses, 'r-', label='Test Loss', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training and Test Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Training curves plotted!")
print(f"\nFinal Training Loss: {train_losses[-1]:.4f}")
print(f"Final Test Loss: {test_losses[-1]:.4f}")
