In [12]:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch

from tqdm import tqdm
import os


In [13]:
from Models import HomographyRegressor, HomographyClassifier
from Models import save_checkpoint, load_latest_checkpoint
from Models import offsets_to_class_indices, classes_to_offsets
from Models import HomographyPairDataset, FixedSrcRandomDispDataset
from Models import classification_loss, rmse

In [14]:
def nn_train_multi(model, dataloader, num_epochs, model_file_name, optimizer, criterion,
                   checkpoint_dir="checkpoints", num_classes=21, disp_range=(-16, 16)):
    os.makedirs(checkpoint_dir, exist_ok=True)
    device = next(model.parameters()).device
    start_epoch = load_latest_checkpoint(checkpoint_dir, model, optimizer, device)
    writer = SummaryWriter(log_dir=os.path.join(checkpoint_dir, "runs"))

    epoch_pbar = tqdm(
        range(start_epoch, num_epochs),
        desc="Training",
        ncols=120,
        miniters=1,
        smoothing=0,
        dynamic_ncols=True,
        initial=start_epoch,
        total=num_epochs
    )

    try:
        for epoch in epoch_pbar:
            model.train()
            running_loss = 0.0
            running_rmse_hard = 0.0
            running_rmse_soft = 0.0
            count = 0

            for pairs, offsets in dataloader:
                pairs = pairs.to(device)
                offsets = offsets.to(device)
                B = pairs.shape[0]

                # --- Forward ---
                logits = model(pairs)  # (B, 8, 21)

                # --- Compute loss ---
                loss = classification_loss(criterion, logits, offsets,
                                           disp_range=disp_range, num_classes=num_classes)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # --- Metrics ---
                with torch.no_grad():
                    pred_hard = classes_to_offsets(logits, disp_range, soft=False)
                    pred_soft = classes_to_offsets(logits, disp_range, soft=True)

                    rmse_hard_val = rmse(pred_hard, offsets).mean().item()
                    rmse_soft_val = rmse(pred_soft, offsets).mean().item()

                # --- Logging ---
                running_loss += loss.item() * B
                running_rmse_hard += rmse_hard_val * B
                running_rmse_soft += rmse_soft_val * B
                count += B

            # Average metrics per epoch
            avg_loss = running_loss / count
            avg_rmse_hard = running_rmse_hard / count
            avg_rmse_soft = running_rmse_soft / count

            epoch_pbar.set_postfix({
                "loss": f"{avg_loss:.4f}",
                "rmse_hard": f"{avg_rmse_hard:.3f}px",
                "rmse_soft": f"{avg_rmse_soft:.3f}px"
            })

            # TensorBoard
            writer.add_scalar("Loss/train", avg_loss, epoch + 1)
            writer.add_scalar("RMSE/hard", avg_rmse_hard, epoch + 1)
            writer.add_scalar("RMSE/soft", avg_rmse_soft, epoch + 1)

            # --- Checkpoint every N epochs ---
            if (epoch + 1) % 1000 == 0 or (epoch + 1) == num_epochs:
                save_checkpoint(checkpoint_dir, epoch + 1, model, optimizer)

        # Save final model
        torch.save(model.state_dict(), model_file_name)
        print(f"‚úÖ Final model saved: {model_file_name}")

    except KeyboardInterrupt:
        epoch_pbar.close()
        print(f"\n‚ö†Ô∏è Interrupted at epoch {epoch + 1}")
        save_checkpoint(checkpoint_dir, epoch + 1, model, optimizer)
        print("‚úÖ Checkpoint saved")

    finally:
        epoch_pbar.close()
        writer.close()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()



In [None]:
# TRAIN

from Generator import get_images_from_names, get_random_images, get_all_images

PREPROCESSED_DIR = "datasets/val2017_preprocessed"
num_epochs = 30000
samples_per_epoch = 128
batch_size = 128
# learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# model = HomographyRegressor(dropout_rate=0.1).to(device)
# criterion = nn.MSELoss()
model = HomographyClassifier(num_classes=21, class_dim=8, dropout_rate=0.1).to(device)
criterion = nn.CrossEntropyLoss()

# optimizer = optim.Adam(model.parameters(), lr=learning_rate)
optimizer = optim.Adam(model.parameters())

# image_names = [
#     "000000002299.jpg",
#     #     # "000000000285.jpg",
#     #     # "000000000632.jpg",
# ]
# images = get_images_from_names(image_names, PREPROCESSED_DIR)
# images = get_random_images(image_dir=PREPROCESSED_DIR, num_images=16)
images = get_all_images(image_dir=PREPROCESSED_DIR)

print(f"üì∑ Loaded {len(images)} image(s) for training")

dataset = HomographyPairDataset(images, samples_per_epoch=samples_per_epoch)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=True)

# from Generator import get_corners
# Create dataset with fixed src
# dataset = FixedSrcRandomDispDataset(
#     image=images[0],
#     src_corners=get_corners(32, 32, 64),
#     samples_per_epoch=samples_per_epoch,
#     disp_range=(-16, 16)
# )
# dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0, pin_memory=True, shuffle=True)

nn_train_multi(
    model=model,
    dataloader=dataloader,
    num_epochs=num_epochs,
    model_file_name=f"h_classify.pth",
    optimizer=optimizer,
    criterion=criterion,
    checkpoint_dir="checkpoints_homography_classify"
)

Using device: cuda


In [16]:
# # visualization of classification results
#
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = HomographyClassifier(num_classes=21, class_dim=8).to(device)
# state = torch.load("checkpoints_homography_class_multi/checkpoint_epoch_5000.pth")["model_state_dict"]
# # state = torch.load("checkpoints_homography_clasify_oneImage/h_clasify_ep50000_I1.pth")
# model.load_state_dict(state)
# model.eval()
#
# img = cv2.imread("datasets/val2017_preprocessed/000000002299.jpg", cv2.IMREAD_GRAYSCALE)
# visualize_classification_result(model, img, soft_decode=False)
# # visualize_classification_result_dataloader(
# #     model=model,
# #     dataloader=dataloader,
# #     num_samples=3,
# #     soft_decode=True,
# #     device=device
# # )

In [17]:
# from Models import test_offset_class_conversion
#
# test_offset_class_conversion()