In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import os

from Models import KeypointNet, KeypointDataset
from Generator import generate_synthetic_image
from Helper import save_checkpoint_generic, load_checkpoint_generic


In [2]:
# ============================================================
# HYPERPARAMETERS
# ============================================================

# Model parameters (SuperPoint: Adam, lr=0.001, beta=(0.9, 0.999))
learning_rate = 0.001
adam_betas = (0.9, 0.999)
weight_decay = 0.0

# Training parameters (iteration-based)
num_iterations = 200_000  # SuperPoint uses 200k iterations
batch_size = 32  # SuperPoint uses 32

# Image parameters
image_size = (240, 320)  # (Height, Width)

# Dataset parameters
num_train_samples = 5000  # Number of pregenerated training samples
num_test_samples = 500   # Number of pregenerated test samples

# Augmentation settings (applied during training, not during generation)
use_homography_augment = True    # Apply random homography to training data
use_photometric_augment = True   # Apply brightness/contrast to training data
use_geometric_augment = True     # Apply flips to training data

# Dataset file paths (.npz format - contains pregenerated images)
dataset_cache_dir = './dataset_cache'
load_datasets_if_exist = True    # Load from .npz files if available

# Checkpoint parameters
checkpoint_dir = './checkpoints'
save_checkpoint_every = 5000  # Save every N iterations
max_checkpoints = 4

# Logging parameters
print_every = 20   # Print loss every N iterations
eval_every = 100   # Evaluate on test set every N iterations

# Create directories
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(dataset_cache_dir, exist_ok=True)

print("‚úì Configuration loaded")
print(f"  Device: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
print(f"  Training samples: {num_train_samples}")
print(f"  Test samples: {num_test_samples}")
print(f"  Batch size: {batch_size}")
print(f"  Target iterations: {num_iterations:,}")
print(f"  Training augmentation: {'ENABLED' if use_homography_augment else 'DISABLED'}")
print()


‚úì Configuration loaded
  Device: CUDA
  Training samples: 5000
  Test samples: 500
  Batch size: 32
  Target iterations: 200,000
  Training augmentation: ENABLED



In [3]:
# ============================================================
# DATASET GENERATION AND SAVING (Run once to create datasets)
# ============================================================

train_samples_path = os.path.join(dataset_cache_dir, f'train_samples_{num_train_samples}.npz')
test_samples_path = os.path.join(dataset_cache_dir, f'test_samples_{num_test_samples}.npz')

print("=" * 60)
print("DATASET GENERATION")
print("=" * 60)
print()

# Generate and save training samples (raw, no augmentation)
print(f"Generating {num_train_samples} training samples...")
train_generator = KeypointDataset(
    num_samples=num_train_samples,
    image_shape=image_size,
    generate_fn=generate_synthetic_image,
    generate_kwargs={
        'width': image_size[1],
        'height': image_size[0],
        'shape_type': 'random',
    },
    use_homography_augment=False,  # No augmentation during generation
    use_photometric_augment=False,
    use_geometric_augment=False,
    pregenerate=True
)
print(f"‚úì Training samples generated: {len(train_generator)} samples")

# Save training samples
print(f"Saving to {train_samples_path}...")
train_generator.save_to_file(train_samples_path)
print(f"‚úì Training samples saved!")
print()

# Generate and save test samples (raw, no augmentation)
print(f"Generating {num_test_samples} test samples...")
test_generator = KeypointDataset(
    num_samples=num_test_samples,
    image_shape=image_size,
    generate_fn=generate_synthetic_image,
    generate_kwargs={
        'width': image_size[1],
        'height': image_size[0],
        'shape_type': 'random',
    },
    use_homography_augment=False,  # No augmentation during generation
    use_photometric_augment=False,
    use_geometric_augment=False,
    pregenerate=True
)
print(f"‚úì Test samples generated: {len(test_generator)} samples")

# Save test samples
print(f"Saving to {test_samples_path}...")
test_generator.save_to_file(test_samples_path)
print(f"‚úì Test samples saved!")
print()

print("=" * 60)
print("‚úì Dataset generation complete!")
print("=" * 60)
print(f"Training samples: {train_samples_path}")
print(f"Test samples: {test_samples_path}")
print()


DATASET GENERATION

Generating 5000 training samples...
Pre-generating 5000 base samples...
  500/5000 samples
  1000/5000 samples
  1500/5000 samples
  2000/5000 samples
  2500/5000 samples
  3000/5000 samples
  3500/5000 samples
  4000/5000 samples
  4500/5000 samples
  5000/5000 samples
‚úì Pre-generation complete!
‚úì Training samples generated: 1000000000 samples
Saving to ./dataset_cache\train_samples_5000.npz...
‚úì Saved 5000 samples to ./dataset_cache\train_samples_5000.npz
‚úì Training samples saved!

Generating 500 test samples...
Pre-generating 500 base samples...
  500/500 samples
‚úì Pre-generation complete!
‚úì Test samples generated: 1000000000 samples
Saving to ./dataset_cache\test_samples_500.npz...
‚úì Saved 500 samples to ./dataset_cache\test_samples_500.npz
‚úì Test samples saved!

‚úì Dataset generation complete!
Training samples: ./dataset_cache\train_samples_5000.npz
Test samples: ./dataset_cache\test_samples_500.npz



In [3]:
# ============================================================
# LOAD DATASETS AND INIT MODEL
# ============================================================

print("=" * 60)
print("TRAINING SETUP")
print("=" * 60)
print()

# ============================================================
# LOAD DATASETS FROM .NPZ FILES
# ============================================================

train_samples_path = os.path.join(dataset_cache_dir, f'train_samples_{num_train_samples}.npz')
test_samples_path = os.path.join(dataset_cache_dir, f'test_samples_{num_test_samples}.npz')

train_dataset = None
test_dataset = None

# Load training dataset WITH augmentation
if load_datasets_if_exist and os.path.exists(train_samples_path):
    print(f"Loading training samples from {train_samples_path}...")
    print(f"  Augmentation: {'ENABLED' if use_homography_augment else 'DISABLED'}")
    try:
        train_dataset = KeypointDataset(
            num_samples=num_train_samples,
            image_shape=image_size,
            use_homography_augment=use_homography_augment,
            use_photometric_augment=use_photometric_augment,
            use_geometric_augment=use_geometric_augment,
            pregenerate=False,  # Don't regenerate, just load
            load_from_file=train_samples_path
        )
        print(f"‚úì Training dataset loaded: {len(train_dataset)} samples")
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to load training dataset: {e}")
        print("Please run the dataset generation cell first!")
        train_dataset = None
else:
    print(f"‚ö†Ô∏è  Training samples not found at {train_samples_path}")
    print("Please run the dataset generation cell first!")

# Load test dataset WITHOUT augmentation
if load_datasets_if_exist and os.path.exists(test_samples_path):
    print(f"Loading test samples from {test_samples_path}...")
    print(f"  Augmentation: DISABLED (test set)")
    try:
        test_dataset = KeypointDataset(
            num_samples=num_test_samples,
            image_shape=image_size,
            use_homography_augment=False,  # No augmentation for test
            use_photometric_augment=False,
            use_geometric_augment=False,
            pregenerate=False,  # Don't regenerate, just load
            load_from_file=test_samples_path
        )
        print(f"‚úì Test dataset loaded: {len(test_dataset)} samples")
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to load test dataset: {e}")
        print("Please run the dataset generation cell first!")
        test_dataset = None
else:
    print(f"‚ö†Ô∏è  Test samples not found at {test_samples_path}")
    print("Please run the dataset generation cell first!")

# Check if datasets were loaded successfully
if train_dataset is None or test_dataset is None:
    print()
    print("=" * 60)
    print("‚ö†Ô∏è  ERROR: Datasets not loaded!")
    print("=" * 60)
    print("Please run Cell 3 (Dataset Generation) first to create the .npz files.")
    print()
    raise RuntimeError("Datasets not found. Run dataset generation cell first.")

print()

# ============================================================
# CREATE DATALOADERS
# ============================================================

print("Creating DataLoaders...")
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0
)

print(f"‚úì DataLoaders created")
print(f"  Training batches per epoch: {len(train_loader)}")
print(f"  Test batches: {len(test_loader)}")
print()

# ============================================================
# MODEL INITIALIZATION
# ============================================================

print("Initializing model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = KeypointNet().to(device)

# Optimizer (SuperPoint paper: Adam with lr=0.001, betas=(0.9, 0.999))
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    betas=adam_betas,
    weight_decay=weight_decay
)


TRAINING SETUP

Loading training samples from ./dataset_cache\train_samples_5000.npz...
  Augmentation: ENABLED
Loading 5000 samples from ./dataset_cache\train_samples_5000.npz...
‚úì Loaded 5000 samples!
‚úì Training dataset loaded: 1000000000 samples
Loading test samples from ./dataset_cache\test_samples_500.npz...
  Augmentation: DISABLED (test set)
Loading 500 samples from ./dataset_cache\test_samples_500.npz...
‚úì Loaded 500 samples!
‚úì Test dataset loaded: 1000000000 samples

Creating DataLoaders...
‚úì DataLoaders created
  Training batches per epoch: 31250000
  Test batches: 31250000

Initializing model...
Using device: cuda


In [None]:
# ============================================================
# TRAINING LOOP
# ============================================================

# Loss tracking
train_losses = []
test_losses = []
start_iteration = 0

# Load checkpoint if exists
checkpoint = load_checkpoint_generic(checkpoint_dir, device)
if checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_iteration = checkpoint.get('iteration', 0)
    train_losses = checkpoint.get('train_losses', [])
    test_losses = checkpoint.get('test_losses', [])
    print(f"‚úì Resuming from iteration {start_iteration:,}")
else:
    print("‚úì Starting from scratch")

print()
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print()

# ============================================================
# ITERATION-BASED TRAINING LOOP
# ============================================================

model.train()
running_loss = 0.0
iteration = start_iteration

# Create iterator for infinite cycling through dataset
train_iterator = iter(train_loader)

while iteration < num_iterations:
    # Get next batch (infinite cycling)
    try:
        images, targets = next(train_iterator)
    except StopIteration:
        # Restart iterator when dataset is exhausted
        train_iterator = iter(train_loader)
        images, targets = next(train_iterator)

    images = images.to(device)
    targets = targets.to(device)

    # Forward pass
    logits = model(images, return_logits=True)  # (B, 65, H/8, W/8)
    targets_idx = targets.argmax(dim=1)  # (B, H/8, W/8)

    # Compute loss
    loss = F.cross_entropy(logits, targets_idx)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Track loss
    running_loss += loss.item()
    iteration += 1

    # Print progress
    if iteration % print_every == 0:
        avg_loss = running_loss / print_every
        train_losses.append(avg_loss)
        print(f"Iter [{iteration:>6}/{num_iterations}] Loss: {avg_loss:.4f}")
        running_loss = 0.0

    # Evaluate on test set
    if iteration % eval_every == 0:
        model.eval()
        test_loss = 0.0
        num_test_batches = 0

        with torch.no_grad():
            for test_images, test_targets in test_loader:
                test_images = test_images.to(device)
                test_targets = test_targets.to(device)

                logits = model(test_images, return_logits=True)
                targets_idx = test_targets.argmax(dim=1)
                loss = F.cross_entropy(logits, targets_idx)

                test_loss += loss.item()
                num_test_batches += 1

        avg_test_loss = test_loss / num_test_batches
        test_losses.append(avg_test_loss)
        print(f"  ‚îî‚îÄ Test Loss: {avg_test_loss:.4f}")
        model.train()

    # Save checkpoint
    if iteration % save_checkpoint_every == 0 or iteration == num_iterations:
        save_checkpoint_generic(
            checkpoint_dir,
            iteration,
            {
                'iteration': iteration,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_losses': train_losses,
                'test_losses': test_losses,
                'config': {
                    'learning_rate': learning_rate,
                    'batch_size': batch_size,
                    'num_iterations': num_iterations,
                }
            },
            max_checkpoints=max_checkpoints
        )
        print(f"  ‚îî‚îÄ Checkpoint saved")

print()
print("=" * 60)
print("‚úì TRAINING COMPLETE!")
print("=" * 60)
if len(train_losses) > 0:
    print(f"Final Training Loss: {train_losses[-1]:.4f}")
if len(test_losses) > 0:
    print(f"Final Test Loss: {test_losses[-1]:.4f}")
print()


üöÄ No checkpoint found, starting from scratch
‚úì Starting from scratch

STARTING TRAINING



In [None]:
# ============================================================
# PLOT TRAINING CURVES
# ============================================================

print("=" * 60)
print("TRAINING VISUALIZATION")
print("=" * 60)
print()

if len(train_losses) == 0 and len(test_losses) == 0:
    print("‚ö†Ô∏è  No training data to plot. Run the training loop first.")
else:
    fig, ax = plt.subplots(1, 1, figsize=(12, 6))

    # Plot training losses
    if len(train_losses) > 0:
        iterations_range = [(i + 1) * print_every for i in range(len(train_losses))]
        ax.plot(iterations_range, train_losses, 'b-', label='Training Loss', linewidth=2, alpha=0.7)
        print(f"‚úì Training losses plotted ({len(train_losses)} points)")

    # Plot test losses
    if len(test_losses) > 0:
        test_iterations_range = [(i + 1) * eval_every for i in range(len(test_losses))]
        ax.plot(test_iterations_range, test_losses, 'r-', label='Test Loss', linewidth=2, alpha=0.7)
        print(f"‚úì Test losses plotted ({len(test_losses)} points)")

    ax.set_xlabel('Iteration', fontsize=12)
    ax.set_ylabel('Loss', fontsize=12)
    ax.set_title('Training Progress - Interest Point Detection', fontsize=14, fontweight='bold')
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)

    # Add minor gridlines for better readability
    ax.minorticks_on()
    ax.grid(which='minor', alpha=0.1)

    plt.tight_layout()
    plt.show()

    print()
    if len(train_losses) > 0:
        print(f"Final Training Loss: {train_losses[-1]:.4f}")
        print(f"Best Training Loss: {min(train_losses):.4f}")
    if len(test_losses) > 0:
        print(f"Final Test Loss: {test_losses[-1]:.4f}")
        print(f"Best Test Loss: {min(test_losses):.4f}")
    print()
    print("=" * 60)


In [None]:
# ============================================================
# MODEL TESTING AND VISUALIZATION
# ============================================================

print("=" * 60)
print("MODEL TESTING")
print("=" * 60)
print()

# Set model to evaluation mode
model.eval()

# Get a few test samples
num_vis_samples = 4
test_samples = []

with torch.no_grad():
    for i, (images, targets) in enumerate(test_loader):
        if i >= num_vis_samples:
            break
        test_samples.append((images[0], targets[0]))

print(f"Loaded {len(test_samples)} test samples for visualization")
print()

# Visualize predictions
fig, axes = plt.subplots(len(test_samples), 3, figsize=(15, 5 * len(test_samples)))
if len(test_samples) == 1:
    axes = axes[np.newaxis, :]

for idx, (image, target) in enumerate(test_samples):
    # Prepare input
    image_input = image.unsqueeze(0).to(device)

    # Get prediction
    with torch.no_grad():
        pred_heatmap = model(image_input)  # (1, 65, H/8, W/8)

    # Convert to numpy for visualization
    image_np = image.squeeze().cpu().numpy()
    target_np = target.cpu().numpy()
    pred_np = pred_heatmap.squeeze().cpu().numpy()

    # Get max prob across channels for visualization
    target_max = target_np.max(axis=0)
    pred_max = pred_np.max(axis=0)

    # Plot original image
    axes[idx, 0].imshow(image_np, cmap='gray')
    axes[idx, 0].set_title(f'Sample {idx + 1}: Input Image', fontsize=12, fontweight='bold')
    axes[idx, 0].axis('off')

    # Plot ground truth heatmap
    axes[idx, 1].imshow(target_max, cmap='hot')
    axes[idx, 1].set_title('Ground Truth Heatmap', fontsize=12, fontweight='bold')
    axes[idx, 1].axis('off')

    # Plot predicted heatmap
    axes[idx, 2].imshow(pred_max, cmap='hot')
    axes[idx, 2].set_title('Predicted Heatmap', fontsize=12, fontweight='bold')
    axes[idx, 2].axis('off')

plt.tight_layout()
plt.show()

print()
print("=" * 60)
print("‚úì Visualization complete!")
print("=" * 60)