In [None]:
import os  # Module for interacting with the operating system (e.g., file and directory manipulation)
import gdown  # Third-party module to download files from Google Drive
import shutil  # Module to perform high-level file operations like copying, archiving, and removing

# Define the download URLs in a dictionary.
# Each key in the dictionary represents a dataset, and the corresponding value is the Google Drive URL.
download_links = {
    "tokyo_xs": "https://drive.google.com/file/d/1XmDqZBdEURc9NdyL4WdgIQMth-v7h477/view?usp=share_link",
    "sf_xs": "https://drive.google.com/file/d/1uoex2BWD9pOyJmz5rtZez0kyQvgZP-B4/view?usp=share_link",
    "gsv_xs": "https://drive.google.com/file/d/1nz-QAYU6EOQiVnEnDyJ30QmMTWOxrMv_/view?usp=share_link"
}

# Ensure the "data" directory exists.
# If the directory does not exist, it will be created using os.makedirs().
data_directory = "data"
if not os.path.exists(data_directory):
    os.makedirs(data_directory)  # Creates the directory if it doesn't exist

# Iterate through each dataset in the download_links dictionary.
# For each entry, the name is the dataset identifier (e.g., "tokyo_xs"), and the link is the Google Drive URL.
for name, link in download_links.items():
    print(f"Starting download for {name}")  # Print a message indicating the start of the download process

    # Set the path where the downloaded zip file will be saved (e.g., "data/tokyo_xs.zip").
    zip_path = os.path.join(data_directory, f"{name}.zip")

    # Download the file from Google Drive using the gdown library.
    # The 'fuzzy=True' argument allows gdown to handle file links in various formats.
    gdown.download(link, zip_path, fuzzy=True)

    # Extract the contents of the downloaded zip file into the data_directory.
    # shutil.unpack_archive() automatically handles different archive formats (e.g., .zip, .tar).
    shutil.unpack_archive(zip_path, extract_dir=data_directory)

    # Once the zip file has been extracted, delete the original zip file to save space.
    os.remove(zip_path)


In [None]:
# change directory
%cd /content
# Clone repository
!git clone https://github.com/Bertone-Fabio/mldl2024.git
# change directory
%cd /content/mldl2024
# install requirements
!pip install -r requirements.txt

In [None]:
# Import utility functions from 'utils' (custom module).
import utils

# Import dataset classes for training and testing.
from dataset.test_dataset import TestDataset
from dataset.train_dataset import TrainDataset

# PyTorch Lightning for simplifying training loops and multi-GPU training.
import pytorch_lightning as pl

# Core PyTorch library for tensor operations and neural networks.
import torch

# Checkpoint callback for saving model checkpoints during training.
from pytorch_lightning.callbacks import ModelCheckpoint

# Image transformations for preprocessing (e.g., resizing, normalizing).
from torchvision import transforms as tfm

# Pre-trained models for transfer learning from torchvision.
import torchvision.models

# Learning rate scheduler to adjust the learning rate during training.
from torch.optim import lr_scheduler

# Neural network module for defining layers, loss functions, etc.
from torch import nn

# Logging for tracking events and debugging.
import logging

# NumPy for array operations and numerical computations.
import numpy as np

# Datetime module for working with dates and timestamps.
from datetime import datetime

# ImageFolder for loading datasets organized in class-labeled folders.
from torchvision.datasets import ImageFolder

# DataLoader for batching and multi-threaded dataset loading.
from torch.utils.data import DataLoader

# Model

In [None]:
class VPRModel(pl.LightningModule):
    """This is the main model for Visual Place Recognition.
    We use PyTorch Lightning for modularity purposes.
    """

    def __init__(self,
                #---- Aggregator
                agg_arch='avgpool',  # The architecture of the aggregator, default is 'avgpool'
                agg_config={},  # Configuration dictionary for the aggregator

                #---- Datasets
                val_dataset=None,  # Validation dataset, can be None
                test_dataset=None,  # Test dataset, can be None

                #---- Train hyperparameters
                lr=0.03,  # Learning rate for the optimizer
                optimizer='sgd',  # Optimizer type, default is SGD
                weight_decay=1e-3,  # Weight decay (L2 regularization)
                momentum=0.9,  # Momentum factor for SGD
                milestones=[5, 10, 15],  # Milestones for learning rate decay
                T_max=10,  # Max number of iterations for Cosine Annealing LR (if used)
                lr_mult=0.3,  # Learning rate multiplier for step LR scheduler

                #----- Loss
                loss_name='MultiSimilarityLoss',  # Name of the loss function
                miner_name='MultiSimilarityMiner',  # Name of the miner (for online mining strategies)
                miner_margin=0.1,  # Margin parameter for the miner
                faiss_gpu=False,  # Flag to enable GPU acceleration with FAISS (for nearest neighbor search)

                #--- Images to save
                num_preds_to_save=0  # Number of predictions to save during evaluation
                 ):
        super().__init__()
        self.agg_arch = agg_arch  # Store aggregator architecture
        self.agg_config = agg_config  # Store aggregator configuration

        self.val_dataset = val_dataset  # Store validation dataset
        self.test_dataset = test_dataset  # Store test dataset
        self.test_name = self.get_test_name()  # Extract test name from the test dataset's directory

        # Store training hyperparameters
        self.lr = lr
        self.optimizer = optimizer
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.milestones = milestones
        self.lr_mult = lr_mult
        self.T_max = T_max

        # Store loss and miner parameters
        self.loss_name = loss_name
        self.miner_name = miner_name
        self.miner_margin = miner_margin

        self.save_hyperparameters()  # Write all hyperparameters into a file

        # Initialize loss function and miner using utility functions
        self.loss_fn = utils.get_loss(loss_name)
        self.miner = utils.get_miner(miner_name, miner_margin)
        self.batch_acc = []  # List to track the percentage of trivial pairs/triplets in each batch

        self.faiss_gpu = faiss_gpu  # Set FAISS GPU flag

        # Initialize lists to store outputs during validation and testing steps
        self.validation_step_outputs = []
        self.test_step_outputs = []

        self.num_preds_to_save = num_preds_to_save  # Set the number of predictions to save
        self.total_steps = 0  # Initialize total steps counter

        # ----------------------------------
        # Load the backbone and aggregator models
        # Load the pretrained ResNet-18 model
        pretrained_model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

        # Modify the backbone (cut at the third convolutional layer)
        self.backbone = nn.Sequential(
            pretrained_model.conv1,
            pretrained_model.bn1,
            pretrained_model.relu,
            pretrained_model.maxpool,
            pretrained_model.layer1,
            pretrained_model.layer2,
            pretrained_model.layer3,
        )

        # Load the aggregator model using the utility function
        self.aggregator = utils.get_aggregator(agg_arch, agg_config)

    # Define the forward pass for the model
    def forward(self, x):
        x = self.backbone(x)  # Pass input through the backbone
        x = self.aggregator(x)  # Pass output through the aggregator
        return x  # Return the final output

    # Configure the optimizer and learning rate scheduler
    def configure_optimizers(self):
        # Select the optimizer based on the provided name
        if self.optimizer.lower() == 'sgd':
            optimizer = torch.optim.SGD(self.parameters(),
                                        lr=self.lr,
                                        weight_decay=self.weight_decay,
                                        momentum=self.momentum)
        elif self.optimizer.lower() == 'adamw':
            optimizer = torch.optim.AdamW(self.parameters(),
                                          lr=self.lr,
                                          weight_decay=self.weight_decay)
        elif self.optimizer.lower() == 'adam':
            optimizer = torch.optim.Adam(self.parameters(),
                                         lr=self.lr,
                                         weight_decay=self.weight_decay)
        else:
            raise ValueError(f'Optimizer {self.optimizer} has not been added to "configure_optimizers()"')

        # Define a multi-step learning rate scheduler
        scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=self.milestones, gamma=self.lr_mult)
        #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.T_max)
        return [optimizer], [scheduler]  # Return the optimizer and scheduler as a list

    # The loss function call (called at each training iteration)
    def loss_function(self, descriptors, labels):
        # If a miner is defined, use it to mine pairs/triplets
        if self.miner is not None:
            miner_outputs = self.miner(descriptors, labels)
            loss = self.loss_fn(descriptors, labels, miner_outputs)

            # Calculate the percentage of trivial pairs/triplets (those that do not contribute to the loss)
            nb_samples = descriptors.shape[0]
            nb_mined = len(set(miner_outputs[0].detach().cpu().numpy()))
            batch_acc = 1.0 - (nb_mined / nb_samples)
        else:
            # If no miner is defined, calculate loss directly
            loss = self.loss_fn(descriptors, labels)
            batch_acc = 0.0
            if isinstance(loss, tuple):
                # Some losses perform online mining internally and return a tuple (loss, batch_accuracy)
                loss, batch_acc = loss

        # Store batch accuracy and reset at the start of each epoch
        self.batch_acc.append(batch_acc)
        # Log the average batch accuracy
        self.log('b_acc', sum(self.batch_acc) / len(self.batch_acc), prog_bar=True, logger=True)
        return loss  # Return the computed loss

    # This is the training step executed at each iteration
    def training_step(self, batch, batch_idx):
        places, labels = batch  # Unpack the batch into places and labels

        # GSVCities yields places (each containing N images), so the dataloader returns a batch containing BS places
        BS, N, ch, h, w = places.shape

        # Reshape places and labels for processing
        images = places.view(BS * N, ch, h, w)
        labels = labels.view(-1)

        # Feed forward the batch to the model
        descriptors = self(images)  # Call the forward method defined above
        loss = self.loss_function(descriptors, labels)  # Calculate the loss

        # Increment the total steps counter
        self.total_steps += 1
        # Log the loss and total steps
        self.log('loss', loss.item(), logger=True)
        self.log('step', self.total_steps, logger=True)
        return {'loss': loss}  # Return the loss in a dictionary

    # This method is called at the end of each training epoch
    def on_training_epoch_end(self, training_step_outputs):
        # Reset the batch accuracy list for the next epoch
        self.batch_acc = []

    # The validation step, executed step by step over the validation set
    def validation_step(self, batch, batch_idx):
        places, _ = batch  # Unpack the batch (only places are needed)
        # Calculate descriptors
        descriptors = self(places)
        # Store descriptors in the validation output list
        self.validation_step_outputs.append(descriptors.detach().cpu())
        return descriptors.detach().cpu()  # Return descriptors (detached from the computation graph)

    # The test step, executed step by step over the test set
    def test_step(self, batch, batch_idx):
        places, _ = batch  # Unpack the batch (only places are needed)
        # Calculate descriptors
        descriptors = self(places)
        # Store descriptors in the test output list
        self.test_step_outputs.append(descriptors.detach().cpu())
        return descriptors.detach().cpu()  # Return descriptors (detached from the computation graph)

    # This method is called at the end of each validation epoch
    def on_validation_epoch_end(self):
        # Concatenate all validation descriptors
        all_descriptors = torch.cat(self.validation_step_outputs, dim=0)
        # Clear the validation output list
        self.validation_step_outputs.clear()
        # Perform inference and return results
        return self.inference_epoch_end(all_descriptors, self.val_dataset, 'val', self.num_preds_to_save)

    # This method is called at the end of each test epoch
    def on_test_epoch_end(self):
        # Concatenate all test descriptors
        all_descriptors = torch.cat(self.test_step_outputs, dim=0)
        # Clear the test output list
        self.test_step_outputs.clear()
        # Perform inference and return results
        return self.inference_epoch_end(all_descriptors, self.test_dataset, 'test', self.num_preds_to_save)

    # Save the total step count in the checkpoint
    def on_save_checkpoint(self, checkpoint):
        checkpoint['total_steps'] = self.total_steps  # Save the total steps count

    # Load the total step count from the checkpoint
    def on_load_checkpoint(self, checkpoint):
        self.total_steps = checkpoint.get('total_steps', 0)  # Load the total steps count (default to 0)

    # Method to change the number of predictions to save
    def change_pred_to_save(self, num_pred):
        self.num_preds_to_save = num_pred  # Update the number of predictions to save

    # Extract the test name from the test dataset's root directory
    def get_test_name(self):
        test_dir = self.test_dataset.root_dir
        return test_dir.split("/")[-2]

    # Method to change the test dataset
    def change_test_dataset(self, test_dataset):
        self.test_dataset = test_dataset  # Update the test dataset
        self.test_name = self.get_test_name()  # Update the test name

    # Method to change the validation dataset
    def change_val_dataset(self, val_dataset):
        self.val_dataset = val_dataset  # Update the validation dataset

    # Perform inference at the end of each epoch
    def inference_epoch_end(self, all_descriptors, inference_dataset, split, num_preds_to_save=0):
        """
        At the end of each validation epoch, descriptors are returned in their order
        depending on how the validation dataset is implemented.
        For this project, it is always references then queries.
        For example, if we have n references and m queries, we will get
        the descriptors for each val_dataset in a list as follows:
        [R1, R2, ..., Rn, Q1, Q2, ..., Qm]
        We then split it to references=[R1, R2, ..., Rn] and queries=[Q1, Q2, ..., Qm]
        to calculate recall@K using the ground truth provided.
        """

        # Split descriptors into queries and database descriptors
        queries_descriptors = all_descriptors[inference_dataset.num_db_images:]
        database_descriptors = all_descriptors[:inference_dataset.num_db_images]

        # Set the directory to save images based on the current split
        if split == 'val':
            save_dir = f"images/{aggregation_method}_{loss}_{optimizer}_{miner}/epoch{self.current_epoch + final_epoch - 5}/{split}"
        else:
            save_dir = f"images/{aggregation_method}_{loss}_{optimizer}_{miner}/epoch{self.current_epoch + final_epoch - 5}/{split}/{self.test_name}"

        # Calculate recall@K and save predictions
        recalls_dict, predictions = utils.get_validation_recalls(
                                                eval_dataset=inference_dataset,
                                                db_desc=database_descriptors,
                                                q_desc=queries_descriptors,
                                                k_values=[1, 5],  # Recall@1 and Recall@5
                                                save_dir=save_dir,
                                                print_results=True,
                                                faiss_gpu=self.faiss_gpu,
                                                num_queries_to_save=num_preds_to_save
                                                )

        # Format the recall results into a string
        recalls_str = "".join([f"R@{k}: {rec:.2f} " for k, rec in recalls_dict.items()])

        # Log the recall results
        logging.info(f"Epoch[{self.current_epoch:02d}]): " +
                     f"recalls: {recalls_str}")

        # Log Recall@1 and Recall@5 for the current split
        self.log(f'{split}/R@1', recalls_dict[1], prog_bar=False, logger=True)
        self.log(f'{split}/R@5', recalls_dict[5], prog_bar=False, logger=True)


In [None]:
# Parameters
train_path = "/content/data/gsv_xs/train"
img_per_place = 4
min_img_per_place = 4
val_path = "/content/data/sf_xs/val"
test_path = "/content/data/sf_xs/test"
batch_size = 64
aggregation_method = "MixVPR"
loss = 'ContrastiveLoss'
optimizer = 'adamw'
weight_decay = 0.0001
miner = "MultiSimilarityMiner"
T_max = 5

In [None]:
# Training data augmentation using Torchvision's transforms
train_transform = tfm.Compose([
    tfm.RandAugment(num_ops=3),  # Apply RandAugment with 3 augmentation operations
    tfm.ToTensor(),  # Convert images to tensor format
    tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize images with ImageNet mean and std
])

# Instantiate the training dataset with the specified transform
train_dataset = TrainDataset(
    root_dir=train_path,  # Path to the training data
    images_per_place=img_per_place,  # Number of images per place to use
    minimum_images_per_place=min_img_per_place,  # Minimum number of images per place required
    transform=train_transform  # Apply the data augmentation/transformations
)

# Instantiate the validation and test datasets without transformations
val_dataset = TestDataset(root_dir=val_path)  # Path to the validation data
test_dataset = TestDataset(root_dir=test_path)  # Path to the test data

# Create DataLoader for training, validation, and testing
train_loader = DataLoader(
    dataset=train_dataset,  # Use the training dataset
    batch_size=batch_size,  # Number of samples per batch
    num_workers=2,  # Number of subprocesses for data loading
    shuffle=True  # Shuffle the data at every epoch
)
val_loader = DataLoader(
    dataset=val_dataset,  # Use the validation dataset
    batch_size=batch_size,  # Number of samples per batch
    num_workers=2,  # Number of subprocesses for data loading
    shuffle=False  # Do not shuffle validation data
)
test_loader = DataLoader(
    dataset=test_dataset,  # Use the test dataset
    batch_size=batch_size,  # Number of samples per batch
    num_workers=2,  # Number of subprocesses for data loading
    shuffle=False  # Do not shuffle test data
)

# Instantiate the Visual Place Recognition (VPR) model
model = VPRModel(
    agg_arch=aggregation_method,  # Aggregation method passed as a variable
    agg_config={'in_channels': 256,  # Configuration for the selected aggregator
                'in_h': 14,
                'in_w': 14,
                'out_channels': 256,
                'mix_depth': 4,
                'mlp_ratio': 1,
                'out_rows': 4},

    # ---- Datasets -----
    val_dataset=val_dataset,  # Pass the validation dataset
    test_dataset=test_dataset,  # Pass the test dataset

    #-----------------------------------
    #---- Training hyperparameters -----
    lr=0.0002,  # Learning rate (e.g., 0.03 for SGD)
    optimizer=optimizer,  # Optimizer type (e.g., SGD, Adam, or AdamW)
    weight_decay=weight_decay,  # Weight decay (e.g., 0.001 for SGD or 0.0 for Adam)
    momentum=0.9,  # Momentum factor for SGD
    milestones=[5, 10, 15, 25],  # Milestones for learning rate scheduler
    T_max=T_max,  # Maximum number of iterations for Cosine Annealing LR
    lr_mult=0.3,  # Learning rate multiplier for scheduler

    #---------------------------------
    #---- Training loss function -----
    loss_name=loss,  # Loss function (e.g., ContrastiveLoss, TripletMarginLoss)
    miner_name=miner,  # Miner for online mining (e.g., PairMarginMiner, MultiSimilarityMiner)
    miner_margin=0.1,  # Margin parameter for the miner
    faiss_gpu=False,  # Use FAISS with GPU acceleration (False by default)
    num_preds_to_save=5  # Number of predictions to save during evaluation
)

# Create directories for logging and checkpoints with a meaningful name
final_epoch = 5  # Set the number of final epochs
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # Generate a timestamp for directory names
log_dir = f"logs/{aggregation_method}_{loss}_{optimizer}_{miner}_weight_decay{weight_decay}/final_epoch{final_epoch}_{timestamp}"
checkpoint_dir = f"checkpoints/{aggregation_method}_{loss}_{optimizer}_{miner}_weight_decay{weight_decay}/final_epoch{final_epoch}_{timestamp}"

# Setup model checkpointing using PyTorch Lightning
# Save the best models according to Recall@1 on the validation set
checkpoint_cb = ModelCheckpoint(
    dirpath=checkpoint_dir,  # Directory to save checkpoints
    monitor='val/R@1',  # Metric to monitor for saving the best model
    filename='epoch({epoch:02d})_step({step:04d})_R1[{val/R@1:.4f}]_R5[{val/R@5:.4f}]',  # Checkpoint filename format
    auto_insert_metric_name=False,  # Prevent automatic insertion of metric name into the filename
    save_weights_only=True,  # Save only the model weights (not the entire model)
    save_top_k=1,  # Save only the best model (top-1)
    save_last=True,  # Save the last model checkpoint
    mode='max'  # Maximize the monitored metric (Recall@1)
)

# Instantiate a PyTorch Lightning Trainer
trainer = pl.Trainer(
    accelerator='gpu', devices=[0],  # Use GPU acceleration on device 0
    default_root_dir=log_dir,  # Set the root directory for logs
    num_sanity_val_steps=0,  # Number of sanity validation steps before training starts
    precision=16,  # Use 16-bit precision to reduce memory usage and increase speed
    max_epochs=5,  # Maximum number of training epochs
    check_val_every_n_epoch=1,  # Run validation every epoch
    callbacks=[checkpoint_cb],  # Add checkpointing callback (additional callbacks can be added)
    reload_dataloaders_every_n_epochs=1,  # Reload dataloaders to shuffle data every epoch
    log_every_n_steps=20,  # Log metrics every 20 steps
    # fast_dev_run=True  # Uncomment for a quick development run (skips full training)
)

# Start training the model
# Validate the model before training (optional)
# trainer.validate(model=model, dataloaders=val_loader)  # Uncomment if you want to validate before training

# Fit the model on the training data and validate it on the validation data
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Test the model after training (optional)
# trainer.test(model=model, dataloaders=test_loader)  # Uncomment to run testing after training


In [None]:
import os

# Set global Git configuration for the user email and name
!git config --global user.email "fabio.bertone1@gmail.com"  # Set your email for Git commits
!git config --global user.name "Bertone-Fabio"  # Set your Git username

# Pull the latest changes from the remote repository to ensure your local repo is up to date
!git pull

# Add all changes in the working directory to the staging area
!git add -A

# Commit the changes with a message that includes the checkpoint directory
!git commit -m f"add {checkpoint_dir}"

# Personal Access Token (PAT) for GitHub authentication
token = ""  # Insert your GitHub PAT here (avoid hardcoding tokens directly in the code)

# Configure the remote URL to include the token for authentication
repo_url = f"https://{token}@github.com/Bertone-Fabio/mldl2024.git"

# Set the remote URL for the repository to use the new token-authenticated URL
!git remote set-url origin {repo_url}

# Push the committed changes to the 'main' branch of the remote repository
!git push origin main

In [None]:
# Data augmentation and preprocessing pipeline for training
train_transform = tfm.Compose([
    tfm.RandAugment(num_ops=3),  # Apply RandAugment with 3 operations for data augmentation
    tfm.ToTensor(),  # Convert images to PyTorch tensors
    tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize using ImageNet mean and std
])

# Initialize the training dataset with transformations
train_dataset = TrainDataset(
    root_dir=train_path,  # Path to the training data directory
    images_per_place=img_per_place,  # Number of images per place
    minimum_images_per_place=min_img_per_place,  # Minimum images per place required to include in the dataset
    transform=train_transform  # Apply the defined transformations
)

# Initialize the validation and test datasets without transformations
val_dataset = TestDataset(root_dir=val_path)  # Path to the validation data directory
test_dataset = TestDataset(root_dir=test_path)  # Path to the test data directory

# Create DataLoaders for training, validation, and testing
train_loader = DataLoader(
    dataset=train_dataset,  # Training dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=True  # Shuffle the training data each epoch
)
val_loader = DataLoader(
    dataset=val_dataset,  # Validation dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=False  # Do not shuffle validation data
)
test_loader = DataLoader(
    dataset=test_dataset,  # Test dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=False  # Do not shuffle test data
)

# Set the checkpoint path (update this path before every run)
checkpoint_path = '/content/mldl2024/checkpoints/MixVPR_MultiSimilarityLoss_adamw_MultiSimilarityMiner_Cosine_annealing_T_max5/final_epoch5_20240709-163808/last.ckpt'

# Load the model from the specified checkpoint
model = VPRModel.load_from_checkpoint(checkpoint_path)

# Create directories for logging and checkpoints using a meaningful name
final_epoch = 10  # Define the final epoch number
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # Generate a timestamp for unique directory names
log_dir = f"logs/{aggregation_method}_{loss}_{optimizer}_{miner}_Cosine_annealing_T_max{T_max}/final_epoch{final_epoch}_{timestamp}"
checkpoint_dir = f"checkpoints/{aggregation_method}_{loss}_{optimizer}_{miner}_Cosine_annealing_T_max{T_max}/final_epoch{final_epoch}_{timestamp}"

# Setup model checkpointing using PyTorch Lightning
# Save the best models according to Recall@1 on the validation set
checkpoint_cb = ModelCheckpoint(
    dirpath=checkpoint_dir,  # Directory to save checkpoints
    monitor='val/R@1',  # Metric to monitor for saving the best model
    filename='epoch({epoch:02d})_step({step:04d})_R1[{val/R@1:.4f}]_R5[{val/R@5:.4f}]',  # Checkpoint filename format
    auto_insert_metric_name=False,  # Prevent automatic insertion of metric name into the filename
    save_weights_only=True,  # Save only the model weights (not the entire model)
    save_top_k=1,  # Save only the best model (top-1)
    save_last=True,  # Save the last model checkpoint
    mode='max',  # Maximize the monitored metric (Recall@1)
)

# Instantiate the PyTorch Lightning Trainer
trainer = pl.Trainer(
    accelerator='gpu', devices=[0],  # Use GPU acceleration on device 0
    default_root_dir=log_dir,  # Set the root directory for logs
    num_sanity_val_steps=0,  # Number of sanity validation steps before training starts
    precision=16,  # Use 16-bit precision to reduce memory usage and increase speed
    max_epochs=5,  # Maximum number of training epochs
    check_val_every_n_epoch=1,  # Run validation every epoch
    callbacks=[checkpoint_cb],  # Add checkpointing callback (additional callbacks can be added)
    reload_dataloaders_every_n_epochs=1,  # Reload dataloaders to shuffle data every epoch
    log_every_n_steps=20,  # Log metrics every 20 steps
    # fast_dev_run=True  # Uncomment for a quick development run (skips full training)
)

# Optionally, validate the model before training (e.g., if loading from a checkpoint)
# trainer.validate(model=model, dataloaders=val_loader)  # Uncomment to run validation before training

# Train the model on the training data and validate on the validation data
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Optionally, test the model after training using the test data
# trainer.test(model=model, dataloaders=test_loader, ckpt_path='best')  # Uncomment to run testing after training


# Test on San Francisco

In [None]:
# Define dataset parameters
train_path = "/content/data/gsv_xs/train"  # Path to the training dataset
img_per_place = 4  # Number of images per place to be used in training
min_img_per_place = 4  # Minimum number of images per place required to include in the dataset
val_path = "/content/data/sf_xs/val"  # Path to the validation dataset
test_path = "/content/data/tokyo_xs/test"  # Path to the test dataset
batch_size = 64  # Number of samples per batch for training and evaluation

# Define the data augmentation and preprocessing pipeline for the training data
train_transform = tfm.Compose([
    tfm.RandAugment(num_ops=3),  # Apply RandAugment with 3 operations for data augmentation
    tfm.ToTensor(),  # Convert images to PyTorch tensors
    tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize using ImageNet mean and std
])

# Initialize the training dataset with transformations
train_dataset = TrainDataset(
    root_dir=train_path,  # Path to the training data directory
    images_per_place=img_per_place,  # Number of images per place
    minimum_images_per_place=min_img_per_place,  # Minimum images per place required to include in the dataset
    transform=train_transform  # Apply the defined transformations
)

# Initialize the validation and test datasets without transformations
val_dataset = TestDataset(root_dir=val_path)  # Path to the validation data directory
test_dataset = TestDataset(root_dir=test_path)  # Path to the test data directory

# Create DataLoaders for training, validation, and testing
train_loader = DataLoader(
    dataset=train_dataset,  # Training dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=True  # Shuffle the training data each epoch
)
val_loader = DataLoader(
    dataset=val_dataset,  # Validation dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=False  # Do not shuffle validation data
)
test_loader = DataLoader(
    dataset=test_dataset,  # Test dataset
    batch_size=batch_size,  # Batch size
    num_workers=2,  # Number of workers for data loading
    shuffle=False  # Do not shuffle test data
)

# Set the checkpoint path (update this path before every run)
checkpoint_path = '/content/mldl2024/checkpoints/gem_MultiSimilarityLoss_adam_None/final_epoch10_20240531-120547/last.ckpt'

# Load the model from the specified checkpoint
model = VPRModel.load_from_checkpoint(checkpoint_path)

# Set the test dataset in the model to the newly loaded test dataset
model.change_test_dataset(test_dataset)

# Create directories for logging and checkpoints using a meaningful name
final_epoch = 10  # Define the final epoch number
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # Generate a timestamp for unique directory names

# Extract the directory name from the checkpoint path
checkpoint_dir_name = os.path.basename(os.path.dirname(checkpoint_path))

# Create directories for logs and checkpoints based on the checkpoint name and timestamp
log_dir = f"logs/{checkpoint_dir_name}/test_{timestamp}"
checkpoint_dir = f"checkpoints/{checkpoint_dir_name}/test_{timestamp}"

# Setup model checkpointing using PyTorch Lightning
# Save the best models according to Recall@1 on the validation set
checkpoint_cb = ModelCheckpoint(
    dirpath=checkpoint_dir,  # Directory to save checkpoints
    monitor='val/R@1',  # Metric to monitor for saving the best model
    filename='epoch({epoch:02d})_step({step:04d})_R1[{val/R@1:.4f}]_R5[{val/R@5:.4f}]',  # Checkpoint filename format
    auto_insert_metric_name=False,  # Prevent automatic insertion of metric name into the filename
    save_weights_only=True,  # Save only the model weights (not the entire model)
    save_top_k=1,  # Save only the best model (top-1)
    save_last=True,  # Save the last model checkpoint
    mode='max',  # Maximize the monitored metric (Recall@1)
)

# Instantiate the PyTorch Lightning Trainer
trainer = pl.Trainer(
    accelerator='gpu', devices=[0],  # Use GPU acceleration on device 0
    default_root_dir=log_dir,  # Set the root directory for logs
    num_sanity_val_steps=0,  # Number of sanity validation steps before training starts
    precision=16,  # Use 16-bit precision to reduce memory usage and increase speed
    max_epochs=5,  # Maximum number of training epochs
    check_val_every_n_epoch=1,  # Run validation every epoch
    callbacks=[checkpoint_cb],  # Add checkpointing callback (additional callbacks can be added)
    reload_dataloaders_every_n_epochs=1,  # Reload dataloaders to shuffle data every epoch
    log_every_n_steps=20,  # Log metrics every 20 steps
    # fast_dev_run=True  # Uncomment for a quick development run (skips full training)
)

# If the model's current epoch is 0, set the final epoch to 15
if model.current_epoch == 0:
    final_epoch = 15

# Validate the model on the validation dataset
trainer.validate(model=model, dataloaders=val_loader)

# Test the model on the test dataset
trainer.test(model=model, dataloaders=test_loader)


# Test on Tokyo

In [None]:
# Define the path to the test dataset
test_path = "/content/data/tokyo_xs/test"

# Set the test dataset in the model
# This updates the model's internal state to use the specified test dataset
model.change_test_dataset(test_dataset)

# Define the final epoch number and generate a timestamp for directory naming
final_epoch = 10  # Define the final epoch number for naming directories
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")  # Generate a timestamp for unique directory names

# Extract the directory name from the checkpoint path
checkpoint_dir_name = os.path.basename(os.path.dirname(checkpoint_path))

# Create directories for logs and checkpoints based on the checkpoint name and timestamp
log_dir = f"logs/{checkpoint_dir_name}/test_{timestamp}"  # Directory for logs
checkpoint_dir = f"checkpoints/{checkpoint_dir_name}/test_{timestamp}"  # Directory for checkpoints

# Set up model checkpointing using PyTorch Lightning
# The model checkpoints will be saved based on the best Recall@1 score on the validation set
checkpoint_cb = ModelCheckpoint(
    dirpath=checkpoint_dir,  # Directory to save checkpoints
    monitor='val/R@1',  # Metric to monitor for saving the best model
    filename='epoch({epoch:02d})_step({step:04d})_R1[{val/R@1:.4f}]_R5[{val/R@5:.4f}]',  # Checkpoint filename format
    auto_insert_metric_name=False,  # Prevent automatic insertion of metric name into the filename
    save_weights_only=True,  # Save only the model weights (not the entire model)
    save_top_k=1,  # Save only the best model (top-1)
    save_last=True,  # Save the last model checkpoint
    mode='max',  # Maximize the monitored metric (Recall@1)
)

# Instantiate the PyTorch Lightning Trainer
trainer = pl.Trainer(
    accelerator='gpu', devices=[0],  # Use GPU acceleration on device 0
    default_root_dir=log_dir,  # Set the root directory for logs
    num_sanity_val_steps=0,  # Number of sanity validation steps before training starts
    precision=16,  # Use 16-bit precision to reduce memory usage and increase speed
    max_epochs=5,  # Maximum number of training epochs
    check_val_every_n_epoch=1,  # Run validation every epoch
    callbacks=[checkpoint_cb],  # Add checkpointing callback (additional callbacks can be added)
    reload_dataloaders_every_n_epochs=1,  # Reload dataloaders to shuffle data every epoch
    log_every_n_steps=20,  # Log metrics every 20 steps
    # fast_dev_run=True  # Uncomment for a quick development run (skips full training)
)

# Adjust final epoch if the model is at the beginning of training
if model.current_epoch == 0:
    final_epoch = 15

# Validate the model on the validation dataset
trainer.validate(model=model, dataloaders=val_loader)

# Test the model on the test dataset
trainer.test(model=model, dataloaders=test_loader)


# Image visualization

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

def visualize_images(folder_path, images_per_block):
    """
    Visualizes images from a given folder and its subfolders in blocks of a specified number.

    Args:
        folder_path (str): The path to the folder containing the images.
        images_per_block (int): The number of images to display in each block.
    """

    # Get a list of all image paths in the folder and its subfolders
    image_paths = []
    for root, _, files in os.walk(folder_path):
        for file in files:
            if file.endswith(('.jpg', '.jpeg', '.png')):
                image_path = os.path.join(root, file)
                image_paths.append(image_path)

    # Sort the image paths alphabetically
    image_paths.sort()

    # Visualize images in blocks
    num_blocks = len(image_paths) // images_per_block
    if len(image_paths) % images_per_block != 0:
        num_blocks += 1

    for block_index in range(num_blocks):
        start_index = block_index * images_per_block
        end_index = min(start_index + images_per_block, len(image_paths))

        block_images = []
        for image_index in range(start_index, end_index):
            image_path = image_paths[image_index]
            image = plt.imread(image_path)
            block_images.append(image)

        # Display images in the block
        fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(15, 15))
        for row_index in range(4):
            for col_index in range(5):
                image_index = row_index * 5 + col_index
                if image_index < len(block_images):
                    axes[row_index, col_index].imshow(block_images[image_index])
                    axes[row_index, col_index].axis('off')
                else:
                    axes[row_index, col_index].axis('off')

        plt.tight_layout()
        plt.show()

# Example usage
image_folder = "/content/mldl2024/images/avgpool_MultiSimilarityLoss_adam_None/epoch10/test/tokyo_xs"  # Replace with your folder path
images_per_block = 20

visualize_images(image_folder, images_per_block)