In [None]:
def nn_train_multi(model, num_epochs, batch_size, samples_per_epoch, model_file_name, images,
                   optimizer, criterion, checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    start_epoch = 0

    # Resume from checkpoint
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
    checkpoints = sorted(checkpoints, key=extract_epoch)
    if checkpoints:
        latest_ckpt = os.path.join(checkpoint_dir, checkpoints[-1])
        checkpoint = torch.load(latest_ckpt, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]
        print(f"‚úÖ Resuming from checkpoint: {latest_ckpt} (epoch {start_epoch})")
    else:
        print("üöÄ Starting training from scratch.")

    writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "runs"))

    dataset = HomographyPairDataset(images, samples_per_epoch)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                            num_workers=0, pin_memory=True)

    for epoch in range(start_epoch, num_epochs):
        model.train()
        epoch_loss = 0.0
        epoch_mae = 0.0

        # Progress bar for batches within epoch
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", ncols=120, leave=False)

        for batch_pairs, batch_offsets in progress_bar:
            batch_pairs = batch_pairs.to(device)
            batch_offsets = batch_offsets.to(device)

            # Forward
            preds = model(batch_pairs)
            loss = criterion(preds, -batch_offsets)

            # Backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Metrics
            batch_loss = loss.item()
            epoch_loss += batch_loss

            with torch.no_grad():
                mae = torch.mean(torch.abs(preds - (-batch_offsets))).item()
                epoch_mae += mae

            # Update progress bar
            progress_bar.set_postfix(loss=f"{batch_loss:.6f}", mae=f"{mae:.4f}")

        # Average metrics for the epoch
        avg_loss = epoch_loss / len(dataloader)
        avg_mae = epoch_mae / len(dataloader)
        avg_rmse = np.sqrt(avg_loss)

        # Log to TensorBoard
        writer.add_scalar("Loss/RMSE", avg_rmse, epoch)
        writer.add_scalar("Error/MAE", avg_mae, epoch)

        # Save checkpoint
        if (epoch + 1) % 100 == 0 or (epoch + 1) == num_epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)

            # Keep only last 4 checkpoints
            checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
            checkpoints = sorted(checkpoints, key=extract_epoch)
            while len(checkpoints) > 4:
                old_ckpt = os.path.join(checkpoint_dir, checkpoints[0])
                os.remove(old_ckpt)
                checkpoints.pop(0)

    writer.close()

    # Save final model
    torch.save(model.state_dict(), model_file_name)
    print(f"‚úÖ Final model saved: {model_file_name}")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
def nn_train_single(model, num_epochs, model_file_name, img, optimizer, criterion, checkpoint_dir="checkpoints"):
    os.makedirs(checkpoint_dir, exist_ok=True)
    start_epoch = 0

    # üîÑ Resume if checkpoint exists
    checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
    checkpoints = sorted(checkpoints, key=extract_epoch)
    if checkpoints:
        latest_ckpt = os.path.join(checkpoint_dir, checkpoints[-1])
        checkpoint = torch.load(latest_ckpt, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        start_epoch = checkpoint["epoch"]
        print(f"‚úÖ Resuming from checkpoint: {latest_ckpt} (epoch {start_epoch})")
    else:
        print("üöÄ Starting training from scratch.")

    # ‚úÖ TensorBoard logger
    writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "runs"))

    # üîÅ Training loop with progress bar for epochs
    progress_bar = tqdm(range(start_epoch, num_epochs), desc="Training", ncols=100)

    # pair, offsets, *_ = generate_pair(
    #     img=random.choice(img) if isinstance(img, list) else img,
    #     window_size=64,
    #     margin=16,
    #     disp_range=(-16, 16)
    # )
    # pair = torch.from_numpy(pair).permute(2, 0, 1).unsqueeze(0).to(device).float()  # 1x2x64x64
    # offsets = torch.from_numpy(offsets.flatten()).unsqueeze(0).to(device).float()  # 1x8

    for epoch in progress_bar:
        model.train()

        pair, offsets, *_ = generate_pair(
            img=random.choice(img) if isinstance(img, list) else img,
            window_size=64,
            margin=16,
            disp_range=(-16, 16)
        )

        pair = torch.from_numpy(pair).permute(2, 0, 1).unsqueeze(0).to(device).float()  # 1x2x64x64
        offsets = torch.from_numpy(offsets.flatten()).unsqueeze(0).to(device).float()  # 1x8

        # Forward
        preds = model(pair)
        loss = criterion(preds, -offsets)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # üîπ Log loss to TensorBoard
        rmse = torch.sqrt(loss + 1e-8)
        writer.add_scalar("Loss/RMSE", rmse.item(), epoch)
        with torch.no_grad():
            mae = torch.mean(torch.abs(preds - -offsets))
        writer.add_scalar("Error/MAE", mae.item(), epoch)

        # Update progress bar with current epoch and loss
        progress_bar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        progress_bar.set_postfix(loss=f"{loss.item():.6f}")

        # üíæ Save checkpoint every 1000 epochs
        if (epoch + 1) % 1000 == 0 or (epoch + 1) == num_epochs:
            checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch + 1}.pth")
            torch.save({
                "epoch": epoch + 1,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            }, checkpoint_path)
            # print(f"\nüíæ Saved checkpoint: {checkpoint_path}")

            # üßπ Keep only last 4 checkpoints
            checkpoints = sorted([f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")])
            checkpoints = sorted(checkpoints, key=extract_epoch)
            while len(checkpoints) > 4:
                old_ckpt = os.path.join(checkpoint_dir, checkpoints[0])
                os.remove(old_ckpt)
                # print(f"üóëÔ∏è Removed old checkpoint: {old_ckpt}")
                checkpoints.pop(0)

    writer.close()
    progress_bar.close()

    # ‚úÖ Save final model
    torch.save(model.state_dict(), model_file_name)
    print(f"‚úÖ Final model saved: {model_file_name}")

    # üßπ Clear GPU cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"üßπ GPU memory cleared. Current allocated: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB")


def check_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated(device) / 1e9
        reserved = torch.cuda.memory_reserved(device) / 1e9
        total = torch.cuda.get_device_properties(device).total_memory / 1e9
        print(f"GPU Memory Status:")
        print(f"  Allocated: {allocated:.2f} GB")
        print(f"  Reserved:  {reserved:.2f} GB")
        print(f"  Total:     {total:.2f} GB")
        print(f"  Free:      {total - reserved:.2f} GB")
    else:
        print("CUDA not available")


def clear_gpu_memory():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        print("‚úÖ GPU cache cleared")
        check_gpu_memory()
    else:
        print("CUDA not available")

In [None]:
# # TRAIN REGRESSOR
#
# num_epochs = 30000
# batch_size = 32
# learning_rate = 1e-4
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")
#
# model = HomographyRegressor(dropout_rate=0.1).to(device)
# criterion = nn.MSELoss()
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#
# # image_names = [
# #     "000000002299.jpg",
# #     # "000000000285.jpg",
# #     # "000000000632.jpg",
# # ]
# # images = get_images_from_names(image_names, PREPROCESSED_DIR)
# images = get_random_images(image_dir=PREPROCESSED_DIR)
# print(f"üì∑ Loaded {len(images)} image(s) for training")
#
# # nn_train_single(
# #     model=model,
# #     num_epochs=num_epochs,
# #     model_file_name=f"h_regressor_ep{num_epochs}_I{len(images)}.pth",
# #     img=images[0] if len(images) == 1 else images,
# #     optimizer=optimizer,
# #     criterion=criterion,
# #     checkpoint_dir="checkpoints_homography_regressor_oneImage"
# # )
#
# nn_train_multi(
#     model=model,
#     num_epochs=num_epochs,
#     batch_size=batch_size,
#     samples_per_epoch=64,
#     model_file_name=f"h_regressor_multi.pth",
#     images=images,
#     optimizer=optimizer,
#     criterion=criterion,
#     checkpoint_dir="checkpoints_homography_regressor_multi"
# )