In [1]:
from pathlib import Path
from torch.utils.data import Dataset
import torch
import numpy as np
from PIL import Image
from Helper import *
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import IO
from torchvision import transforms
from torch.utils.data import DataLoader

from Model import FlyingChairsOfficial


In [2]:
# Cell: Training setup
import torch
import torch.nn as nn
import torch.optim as optim

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

root = r"E:\datasets\FlyingChairs_release"
batch_size = 48


# Create datasets
train_dataset = FlyingChairsOfficial(root=root, split="train", transform=transform)
val_dataset = FlyingChairsOfficial(root=root, split="val", transform=transform)

# Data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True,
)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

Using device: cuda
Train samples: 22232
Val samples: 640


In [None]:
# Cell: Iteration-based Training Loop
from Model import FlowNetSimple, MultiScaleEPE
import time

# Initialize model
model = FlowNetSimple().to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss and optimizer
criterion = MultiScaleEPE(weights=(1.0, 0.5, 0.25))
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training configuration
max_iterations = 500000
val_interval = 100
save_interval = 5000
log_interval = 10

# Learning rate scheduler
def adjust_learning_rate(optimizer, iteration):
    if iteration >= 200000:
        num_halvings = (iteration - 200000) // 100000 + 1
        lr = 1e-4 / (2 ** num_halvings)
    else:
        lr = 1e-4

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

# Setup checkpoint directory
checkpoint_dir = Path('./checkpoints')
checkpoint_dir.mkdir(exist_ok=True)

# Load checkpoint if exists
checkpoint = load_checkpoint_generic(checkpoint_dir, device)

# Initialize tracking variables
train_losses = []
val_losses = []
val_epes = []
iterations_log = []
start_iteration = 1

if checkpoint:
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_iteration = checkpoint.get('iteration', 0) + 1
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    val_epes = checkpoint.get('val_epes', [])
    iterations_log = checkpoint.get('iterations_log', [])
    print(f"Resuming training from iteration {start_iteration}")
else:
    print("No checkpoint found, starting fresh training.")

best_val_loss = checkpoint.get('best_val_loss', float('inf')) if checkpoint else float('inf')
running_loss = 0.0

# Time tracking
start_time = time.time()

model.train()
print(f"\nStarting training from iteration {start_iteration} to {max_iterations:,}...")

iteration = start_iteration

# Main training loop with proper iterator
while iteration <= max_iterations:
    for img_pair, flow_gt in train_loader:
        if iteration > max_iterations:
            break

        # Move to device
        img_pair = img_pair.to(device, non_blocking=True)
        flow_gt = flow_gt.to(device, non_blocking=True)

        # Forward pass
        optimizer.zero_grad()
        flow_preds = model(img_pair)

        # Compute loss
        loss, epe1, epe2, epe3 = criterion(flow_preds, flow_gt)

        # Backward pass
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Adjust learning rate
        current_lr = adjust_learning_rate(optimizer, iteration)

        # Log progress
        if iteration % log_interval == 0:
            elapsed = time.time() - start_time
            avg_loss = running_loss / log_interval
            print(f"Iter {iteration:6d}/{max_iterations:6d} | Loss: {avg_loss:.4f} | Time: {elapsed:.2f}s")
            running_loss = 0.0
            start_time = time.time()  # Reset timer

        # Validation
        if iteration % val_interval == 0:
            model.eval()
            val_loss = 0.0
            val_epe = 0.0

            with torch.no_grad():
                for img_pair_v, flow_gt_v in val_loader:
                    img_pair_v = img_pair_v.to(device)
                    flow_gt_v = flow_gt_v.to(device)

                    flow_preds_v = model(img_pair_v)

                    loss_v, epe1_v, _, _ = criterion(flow_preds_v, flow_gt_v)
                    val_loss += loss_v.item()
                    val_epe += epe1_v.item()

            val_loss /= len(val_loader)
            val_epe /= len(val_loader)

            val_losses.append(val_loss)
            val_epes.append(val_epe)
            iterations_log.append(iteration)

            print(f">>> Validation at iter {iteration:6d} | Val Loss: {val_loss:.4f} | Val EPE: {val_epe:.4f} | LR: {current_lr:.6f}")

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save({
                    'iteration': iteration,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': val_loss,
                    'val_epe': val_epe,
                    'lr': current_lr
                }, 'best_flownet_model.pth')
                print(f"âœ“ Saved best model (EPE: {val_epe:.4f})")

            model.train()
            start_time = time.time()  # Reset timer after validation

        # Save checkpoint periodically
        if iteration % save_interval == 0:
            save_checkpoint_generic(
                checkpoint_dir,
                iteration,
                {
                    'iteration': iteration,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'lr': current_lr,
                    'best_val_loss': best_val_loss,
                    'val_losses': val_losses,
                    'val_epes': val_epes,
                    'iterations_log': iterations_log,
                    'train_losses': train_losses
                },
                max_checkpoints=5
            )

        iteration += 1

print("\nðŸŽ‰ Training completed!")

Model parameters: 387,518
ðŸš€ No checkpoint found, starting from scratch
No checkpoint found, starting fresh training.

Starting training from iteration 1 to 500,000...
Iter     10/500000 | Loss: 8.3284 | Time: 7.46s
Iter     20/500000 | Loss: 8.7984 | Time: 6.82s
Iter     30/500000 | Loss: 9.5983 | Time: 7.19s
Iter     40/500000 | Loss: 8.0201 | Time: 7.19s
Iter     50/500000 | Loss: 9.4466 | Time: 7.47s
Iter     60/500000 | Loss: 8.9577 | Time: 7.48s
Iter     70/500000 | Loss: 8.5180 | Time: 7.18s
Iter     80/500000 | Loss: 7.9357 | Time: 7.83s
Iter     90/500000 | Loss: 7.9339 | Time: 7.61s
Iter    100/500000 | Loss: 8.8379 | Time: 7.28s
>>> Validation at iter    100 | Val Loss: 8.6647 | Val EPE: 6.6214 | LR: 0.000100
âœ“ Saved best model (EPE: 6.6214)
Iter    110/500000 | Loss: 8.5722 | Time: 7.74s
Iter    120/500000 | Loss: 8.0903 | Time: 7.98s
Iter    130/500000 | Loss: 8.9677 | Time: 7.98s
Iter    140/500000 | Loss: 7.5922 | Time: 7.88s
Iter    150/500000 | Loss: 8.2386 | Time:

In [None]:
# Visualization function for predicted flow
def visualize_flow_prediction(model, dataset, idx=0, device='cpu'):
    """
    Visualize model's flow prediction
    """
    model.eval()
    model = model.to(device)

    # Get data (dataset returns img_pair, flow_gt)
    img_pair, flow_gt = dataset[idx]

    # Split concatenated images
    img1 = img_pair[:3]  # First 3 channels
    img2 = img_pair[3:]  # Last 3 channels

    # Add batch dimension and move to device
    img_pair_batch = img_pair.unsqueeze(0).to(device)

    # Predict flow
    with torch.no_grad():
        flow_preds = model(img_pair_batch)
        flow_pred = flow_preds[0]  # Get full resolution prediction

    # Convert to numpy
    img1_np = img1.permute(1, 2, 0).cpu().numpy()
    img2_np = img2.permute(1, 2, 0).cpu().numpy()
    flow_gt_np = flow_gt.permute(1, 2, 0).cpu().numpy()
    flow_pred_np = flow_pred[0].permute(1, 2, 0).cpu().numpy()

    # Visualize using HSV color coding
    def flow_to_hsv(flow):
        u = flow[:, :, 0]
        v = flow[:, :, 1]

        mag = np.sqrt(u ** 2 + v ** 2)
        ang = np.arctan2(v, u)

        h = (ang + np.pi) / (2 * np.pi)
        s = np.clip(mag / (mag.max() + 1e-6), 0, 1)
        v_ = np.ones_like(s)

        hsv = np.stack([h, s, v_], axis=-1)
        return mcolors.hsv_to_rgb(hsv)

    flow_gt_hsv = flow_to_hsv(flow_gt_np)
    flow_pred_hsv = flow_to_hsv(flow_pred_np)

    # Compute endpoint error
    epe = np.linalg.norm(flow_pred_np - flow_gt_np, axis=2).mean()

    # Plot
    plt.figure(figsize=(20, 8))

    plt.subplot(2, 3, 1)
    plt.title("Image 1")
    plt.imshow(img1_np)
    plt.axis("off")

    plt.subplot(2, 3, 2)
    plt.title("Image 2")
    plt.imshow(img2_np)
    plt.axis("off")

    plt.subplot(2, 3, 3)
    plt.title("Ground Truth Flow")
    plt.imshow(flow_gt_hsv)
    plt.axis("off")

    plt.subplot(2, 3, 5)
    plt.title(f"Predicted Flow (EPE: {epe:.2f})")
    plt.imshow(flow_pred_hsv)
    plt.axis("off")

    plt.subplot(2, 3, 6)
    plt.title("Error Map")
    error_map = np.linalg.norm(flow_pred_np - flow_gt_np, axis=2)
    plt.imshow(error_map, cmap='hot')
    plt.colorbar()
    plt.axis("off")

    plt.tight_layout()
    plt.show()

    print(f"Average Endpoint Error: {epe:.4f} pixels")
    return epe


In [None]:
# Test visualization with untrained model
# visualize_flow_prediction(model, val_dataset, idx=0, device=device)


In [None]:
# Plot training curves
def plot_training_curves(iterations_log, val_losses):
    """Plot validation losses over iterations"""
    plt.figure(figsize=(12, 5))
    plt.plot(iterations_log, val_losses, label='Validation Loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Validation Loss over Iterations')
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot after training
plot_training_curves(iterations_log, val_losses)

# Plot EPE over iterations
plt.figure(figsize=(12, 5))
plt.plot(iterations_log, val_epes, label='Validation EPE', color='orange')
plt.xlabel('Iteration')
plt.ylabel('EPE (pixels)')
plt.title('Validation Endpoint Error over Iterations')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# Cell: Load best model and visualize
best_model_path = checkpoint_dir / 'best_flownet_model.pth'
checkpoint = torch.load(best_model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from iteration {checkpoint['iteration']}")
print(f"Val Loss: {checkpoint['val_loss']:.4f}, Val EPE: {checkpoint['val_epe']:.4f}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Visualize predictions
visualize_flow_prediction(model, val_dataset, idx=0, device=device)
visualize_flow_prediction(model, val_dataset, idx=5, device=device)