# Required Imports

In [1]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models

import numpy as np
from numpy import dot
from numpy.linalg import norm

from tqdm import tqdm

# Imports PIL module
from PIL import Image

## Parameters

In [3]:
# Set a fixed random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)  # Seed for NumPy random number generator
torch.manual_seed(RANDOM_SEED)  # Seed for PyTorch random number generator

# Set the depth of the network
depth = 32  # Initial depth of the Siamese network
img_channels = 3  # Number of color channels (3 for RGB)
img_w = 172  # Image width
img_h = 128  # Image height

# Determine the computing device (GPU if available, otherwise CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'device being used: {device}')

# Configuration settings
n_classes = 81 # Number of identities in the dataset - range [0, 80]
n_epochs = 50  # Number of training epochs
train_batch_size = 32  # Batch size for training
val_batch_size = 1  # Batch size for validation
test_batch_size = 1  # Batch size for testing
learning_rate = 0.0005  # Learning rate for the optimizer
backbone = 'siameseNet'  # Backbone network architecture ['siameseNet', 'VGG16', 'VGG19', 'denseNet', 'MobileNetV3', 'efficientNetB0', 'ViT16']

device being used: cpu


## Load Probe and Gallery sets

In [4]:
# Define ID
def extract_id_from_imagename(filename):
    """ Extract the ID from the filename, including the number and the trailing 'A' or 'B'. """
    #match = re.search(r'_(\d+)[mf]?([AB])\_', filename) # e.g., 00001A
    #return f"{match.group(1)}{match.group(2)}" if match else None
    match = re.search(r'(\d+[mf][AB])\_', filename) # e.g., 00001mA
    return match.group(1) if match else None

# Define a custom dataset class for training, specific to handling probe and gallery images
class ProbeGalleryTrain(Dataset):
    def __init__(self, data_root, transform=None):
        # Initialize the dataset with the root directory of the data and any transformations to be applied
        self.data_root = data_root  # Store the directory where the dataset is located
        self.data = self.read_folder()  # Read image paths and labels from the directory
        self.samples = self.create_samples()  # Create matched and unmatched pairs of images (samples)
        self.transform = transform  # Store the transformations to be applied to each image

    def __len__(self):
        # Return the total number of image pairs in the dataset
        return len(self.samples)

    def __getitem__(self, idx):
        # Fetch a specific sample by its index
        im1_pth, im2_pth, is_match, probe_target, gallery_target = self.samples[idx]  # Extract details of the indexed pair
        im1 = Image.open(im1_pth)  # Load the first image
        im2 = Image.open(im2_pth)  # Load the second image

        if self.transform:
            # Apply specified transformations to both images, if any
            im1 = self.transform(im1)
            im2 = self.transform(im2)

        # Return the pair of images along with their match status and associated targets
        return [im1, im2, is_match, probe_target, gallery_target]

    def read_folder(self):
        # Read and store the file paths and labels of images from the dataset directory
        paths = []  # List to hold paths of images
        labels = []  # List to hold corresponding labels ('Probe' or 'Gallery')

        # Traverse the dataset directory
        for dirpath, dirnames, filenames in os.walk(self.data_root):
            # Select only image files (.jpg, .jpeg, .png)
            files = [f for f in filenames if f.split('.')[-1] in ['jpg', 'jpeg', 'png']]
            for item in files:
                # Extract the label (directory name) and append both path and label to the lists
                label = dirpath.split('/')[-1]
                paths.append([os.path.join(dirpath, item), label])
                if label not in labels:
                    labels.append(label)

        return paths

    def create_samples(self):
        # Generate pairs of images (probes paired with each gallery image) with matching status
        probes = [x[0] for x in self.data if x[1] == 'Probe']  # Extract paths of 'Probe' images
        gallery = [x[0] for x in self.data if x[1] == 'Gallery']  # Extract paths of 'Gallery' images

        samples = []
        # Create pairs of probe and gallery images, and determine if they match
        for probe in probes:
            probe_id = extract_id_from_imagename(probe)
            probe_id = int(probe_id[:-2])-1 # IDs range [1, 81] -> [0, 80]
            for item in gallery:
                item_id = extract_id_from_imagename(item)
                item_id = int(item_id[:-2])-1 # IDs range [1, 81] -> [0, 80]
                # Match status: 1 if IDs are similar (ignoring the last two characters), else -1
                samples.append([probe, item, 1 if probe_id == item_id else -1, probe_id, item_id])

        return samples


# create the val/test dataset
class ProbeGalleryValTest(Dataset):
    def __init__(self, data_root, transform=None):
        # Initialize the dataset object with the root directory of the data and optional transformations
        self.data_root = data_root  # Store the root directory where the data is located
        self.data = self.read_folder()  # Read the data from the folder
        self.samples = self.create_samples()  # Create samples from the data
        self.transform = transform  # Store any transformations to be applied to the images

    def __len__(self):
        return len(self.samples)  # Return the number of samples in the dataset

    def __getitem__(self, idx):
      # Retrieve the sample data at the given index 'idx'
      probe_pth, gallery_pths, is_match, probe_target, gallery_targets = self.samples[idx]  # Get paths and target labels

      # Load the probe image from its path
      probe = Image.open(probe_pth)  # Open the probe image file

      # Initialize an empty list to store gallery images
      gallery = []
      # Iterate over each path in the gallery paths list
      for item in gallery_pths:
          # Open each gallery image and append it to the gallery list
          gallery.append(Image.open(item))

      # Check if a transform is set to be applied to the images
      if self.transform:
          # Apply the transform to the probe image
          probe = self.transform(probe)

          # Apply the transform to each image in the gallery list
          for i in range(len(gallery)):
              gallery[i] = self.transform(gallery[i])

      # Return a list containing the probe image, gallery images, target labels, and their respective paths
      return [probe, gallery, is_match, probe_pth, gallery_pths, probe_target, gallery_targets]

    def read_folder(self):
        paths = []  # Initialize a list to store image paths
        labels = []  # Initialize a list to store labels

        # Retrieve all items in the given root directory
        for dirpath, dirnames, filenames in os.walk(self.data_root):
            # Filter for images in .jpg, .jpeg, and .png format
            files = [f for f in filenames if f.split('.')[-1] in ['jpg', 'jpeg', 'png']]

            # Build file paths and labels
            for item in files:
                label = dirpath.split('/')[-1]  # Extract the label from the directory name (e.g., probe or gallery)
                if label not in labels:
                    labels.append(label)

                paths.append([f'{dirpath}/{item}', label])  # Append the file path and label (e.g., gallery or probe)
        return paths

    def create_samples(self):
        # Create a list of samples for the dataset
        probes = [x[0] for x in self.data if x[1] == 'Probe']  # Get paths of images labeled as 'Probe'
        gallery = [x[0] for x in self.data if x[1] == 'Gallery']  # Get paths of images labeled as 'Gallery'

        samples = []
        # Pair each 'Probe' image with every 'Gallery' image and assign a target
        for probe in probes:
          is_match = []
          gallery_targets = []
          probe_id = extract_id_from_imagename(probe)  # Extract ID from probe image name
          probe_id = int(probe_id[:-2])-1 # IDs range [1, 81] -> [0, 80]

          for item in gallery:
            item_id = extract_id_from_imagename(item)  # Extract ID from gallery image name
            item_id = int(item_id[:-2])-1 # IDs range [1, 81] -> [0, 80]
            # Assign a target of 1 if IDs match (ignoring the last two characters), otherwise -1
            is_match.append(1 if probe_id == item_id else -1)
            gallery_targets.append(item_id)

          samples.append([probe, gallery, is_match, probe_id, gallery_targets])
        return samples

In [5]:
# Define the transformation to be applied to the images (converting them to PyTorch tensors)
if backbone == 'ViT16':
    transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
else:
    transform = transforms.Compose([
        transforms.Resize(128), 
        transforms.ToTensor(),
    ])

# Load the training, validation, and test datasets from specified directories and apply the defined transformation
train_dataset = ProbeGalleryTrain('./output/ImageData/D1/', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True)

val_dataset = ProbeGalleryValTest('./output/ImageData/D2v/', transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=False)

test_dataset = ProbeGalleryValTest('./output/ImageData/D2t/', transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)



# Neural Network Models

In [6]:
# Define a Siamese network class inheriting from nn.Module
class BSN(nn.Module):
    def __init__(self, backbone, img_channels, depth, img_w, img_h, n_classes):
        super(BSN, self).__init__()

        self.backbone = backbone
        
        if self.backbone == 'siameseNet':
            # Initialize convolutional and pooling layers
            # First convolutional layer with specified input channels and depth
            self.conv1 = nn.Conv2d(img_channels, depth, kernel_size=3, padding=1)
            # First pooling layer, reduces spatial dimensions by half
            self.pool1 = nn.MaxPool2d(2, 2)
            # Second convolutional layer, doubling the depth
            self.conv2 = nn.Conv2d(depth, depth*2, kernel_size=3, padding=1)
            # Second pooling layer, further reducing spatial dimensions
            self.pool2 = nn.MaxPool2d(2, 2)

            # Calculate the output size after the convolutional and pooling layers
            output_size = depth*2 * img_h//4 * img_w//4
            
            

        if self.backbone == 'VGG19':
            self.model = models.vgg19_bn().features # VGG19 with batch normalization

            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, img_h, img_w)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)
        
        if self.backbone == 'VGG16':
            self.model = models.vgg16_bn().features # VGG16 with batch normalization

            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, img_h, img_w)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)

        if self.backbone == 'denseNet':
            self.model = models.densenet121().features # DenseNet121 with batch normalization
            
            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, img_h, img_w)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)

        if self.backbone == "MobileNetV3":
            self.model = models.mobilenet_v3_small().features # MobileNetV3 with batch normalization

            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, img_h, img_w)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)
            
        if self.backbone == 'efficientNetB0':
            self.model = models.efficientnet_b0().features # VGG16 with batch normalization

            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, img_h, img_w)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)

        if self.backbone == "ViT16":
            self.model = models.vit_b_16() # ViT16
            self.model.heads = nn.Identity() # remove the classification head

            # Create a dummy input tensor of the correct size
            dummy_input = torch.randn(1, img_channels, 224, 224)

            # Run a forward pass through the convolutional layers
            output = self.model(dummy_input)

            # Calculate the output size
            output_size = output.view(output.size(0), -1).size(1)
        
        self.fc1 = nn.Linear(output_size, 256)
        self.fc2 = nn.Linear(256, 256)
        self.classifier = nn.Linear(256,n_classes)
        self.dropout = nn.Dropout()

    # Define the forward pass for one branch of the Siamese network
    def forward_one(self, x):
        if self.backbone == 'siameseNet':
            # Apply the first convolutional layer followed by ReLU activation and pooling
            x = self.pool1(F.relu(self.conv1(x)))
            # Apply the second convolutional layer, ReLU, and pooling
            x = self.pool2(F.relu(self.conv2(x)))
            # Flatten the output for the fully connected layer
            x = x.view(x.size(0), -1)
            # Apply the first fully connected layer with ReLU activation
            x = self.dropout(F.relu(self.fc1(x)))
            # Apply the second fully connected layer
            x = self.dropout(F.relu(self.fc2(x)))
            y = self.classifier(x)
        else:
            # Apply the selected model to the input
            x = self.model(x)
            # flatten output 
            x = x.view(x.size(0), -1)
            # Apply classifier component
            x = self.dropout(F.relu(self.fc1(x)))
            x = self.dropout(F.relu(self.fc2(x)))
            y = self.classifier(x)
        
        return x, y

    # Define the forward pass for the whole Siamese network
    def forward(self, input1, input2):
        # Process each input through the network
        output1, output_class1 = self.forward_one(input1)
        output2, output_class2= self.forward_one(input2)
        return output1, output_class1, output2, output_class2

# Training and Testing functions

In [7]:
def compute_ap(matches):
    """Compute the average precision (AP) given ranks of positive images and the number of positive images.
    Args:
        matches (list): The ranks of positive images sorted in ascending order.
    Returns:
        float: The average precision for the given data.
    """
    if len(matches) == 0:
        return 0.0
    
    matched = 0
    ap = 0.0
    
    for rank, match in enumerate(matches, start=1):
        if match:
            matched += 1
            ap += matched / rank

    return ap / matched

def testing_step(model, test_dataloader):
    # Switch the model to evaluation mode. This turns off specific layers/features like dropout.
    model.eval()

    # Initialize counters for correct predictions at different ranks
    correct1 = 0  # Correct predictions at rank 1
    correct2 = 0
    correct3 = 0
    correct4 = 0
    correct5 = 0  # Correct predictions within the top 5
    correct6 = 0
    correct7 = 0
    correct8 = 0
    correct9 = 0
    correct10 = 0  # Correct predictions within the top 10
    samples = 0  # Total number of samples processed

    # Set the rank limit for evaluation
    K = 10
    AP_scores = []  # To store Average Precision scores for each query
    # Initialize an empty list to store the ranking results
    rank_matrix =[]

    # Disable gradient calculation for efficiency and to prevent changes to the model
    with torch.no_grad():
        # Iterate over the test dataloader
        for data in tqdm(test_dataloader, leave=False, total=len(test_dataloader)):
            # Unpack the data
            probe, gallery, is_match, probe_pth, gallery_pths, probe_target, gallery_targets = data

            # Initialize a list to store similarity scores
            scores = []
            # Concatenate gallery images into a batch
            gallery = torch.cat(gallery, axis=0)

            # Repeat the probe image to match the number of gallery images
            probe_batch = probe.repeat(gallery.shape[0], 1, 1, 1)

            # Get model outputs for the probe and gallery batches
            output1, output_class1, output2, output_class2 = model(probe_batch.to(device), gallery.to(device))

            # Calculate similarity scores for each pair of probe and gallery images
            for i in range(len(probe_batch)):
                a = output1[i].detach().cpu().numpy()
                b = output2[i].detach().cpu().numpy()
                # Compute cosine similarity
                cos_sim = dot(a, b)/(norm(a)*norm(b))

                # Append the similarity score and corresponding gallery path
                scores.append([cos_sim, gallery_pths[i][0]])
            
            samples += 1

            # Sort the scores in descending order to rank the gallery images
            topk = sorted(scores, key=lambda element: (element[0]), reverse=True)
            # Store the probe path and its corresponding ranked gallery paths
            rank_matrix.append([probe_pth[0], topk])

            # Compute the Average Precision (AP) for the query
            pos_ranks = []
            probe_id = extract_id_from_imagename(probe_pth[0])
            probe_id = int(probe_id[:-2]) - 1
            for i, score in enumerate(topk):
                gallery_id = extract_id_from_imagename(score[1])
                gallery_id = int(gallery_id[:-2]) - 1
                if probe_id == gallery_id:
                    pos_ranks.append(i)
            if pos_ranks:
                AP = compute_ap(pos_ranks)  # Assuming each probe_id should match with a unique gallery_id
                AP_scores.append(AP)  # Assuming each probe is supposed to match with exactly one gallery

        # Compute correct predictions for rank 1, 5, and 10
        for idx, probe in enumerate(rank_matrix):
            # Extract the probe ID
            probe_id=extract_id_from_imagename(probe[0])
            probe_id = int(probe_id[:-2])-1 # IDs range [1, 81] -> [0, 80]
            for i in range(K):
                # Extract the gallery ID
                gallery_id=extract_id_from_imagename(probe[1][i][1])
                gallery_id = int(gallery_id[:-2])-1 # IDs range [1, 81] -> [0, 80]
                # Check if the IDs match
                if probe_id == gallery_id:
                    # Update the correct counters based on the rank
                    if i == 0:
                        correct1 += 1
                        correct2 += 1
                        correct3 += 1
                        correct4 += 1
                        correct5 += 1
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 1:
                        correct2 += 1
                        correct3 += 1
                        correct4 += 1
                        correct5 += 1
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 2:
                        correct3 += 1
                        correct4 += 1
                        correct5 += 1
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 3:
                        correct4 += 1
                        correct5 += 1
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 4:
                        correct5 += 1
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 5:
                        correct6 += 1
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 6:
                        correct7 += 1
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 7:
                        correct8 += 1
                        correct9 += 1
                        correct10 += 1
                    elif i == 8:
                        correct9 += 1
                        correct10 += 1
                    elif i == 9:
                        correct10 += 1

                    break

    # Compute mAP
    mAP = np.mean(AP_scores)

    # Print the accuracy at different ranks
    print(f'Total samples: {samples}, Rank #1: {correct1/samples*100.0:.3f}%, Rank #2: {correct2/samples*100.0:.3f}%, Rank #3: {correct3/samples*100.0:.3f}%, Rank #4: {correct4/samples*100.0:.3f}%, Rank #5: {correct5/samples*100.0:.3f}%, Rank #6: {correct6/samples*100.0:.3f}%, Rank #7: {correct7/samples*100.0:.3f}%, Rank #8: {correct8/samples*100.0:.3f}%, Rank #9: {correct9/samples*100.0:.3f}%, Rank #10: {correct10/samples*100.0:.3f}%, mAP:{mAP*100.0:.3f}%')
    return correct1/samples, correct5/samples, correct10/samples, mAP

In [8]:
def training_model(model, train_dataloader, val_dataloader, optimizer, embg_criterion, cat_criterion, n_epochs):    
    # Initialize variable to store the best rank #10 accuracy
    bestRank1 = 0.0
    bestRank5 = 0.0
    bestRank10 = 0.0

    # Training loop for a specified number of epochs
    for epoch in range(n_epochs):
        # Set the network in training mode (this enables features like dropout and batch normalization)
        model.train()

        # Initialize a variable to accumulate loss over the epoch
        train_loss = 0.0

        # Iterate over batches in the training data loader
        for batch in tqdm(train_dataloader, leave=False, total=len(train_dataloader)):
            # Unpack the batch into images and target values and move them to the computation device
            img1 = batch[0].to(device)
            img2 = batch[1].to(device)
            is_match = batch[2].to(device)
            img1_target = batch[3].to(device)
            img2_target = batch[4].to(device)
            
            # Forward pass: Compute the output of the network for both images in the pair
            output1, output_class1, output2, output_class2 = model(img1, img2)

            # Compute the loss based on the outputs and the target values
            embg_loss = embg_criterion(output1, output2, is_match)
            cat_loss = cat_criterion(output_class1, img1_target) + cat_criterion(output_class2, img2_target)
            loss = embg_loss + cat_loss

            # Zero the gradients before running the backward pass
            optimizer.zero_grad()

            # Backward pass: Compute gradient of the loss with respect to network parameters
            loss.backward()

            # Perform a single optimization step (parameter update)
            optimizer.step()

            # Accumulate the loss over the batches
            train_loss += loss.item()

        # Perform validation step and return the rank #10 accuracy for this epoch
        rank1, rank5, rank10, _ = testing_step(model, val_dataloader)

        # Folder to save model weights
        models_folder = os.path.join('./output/models/')
        if not os.path.exists(models_folder):
            os.makedirs(models_folder, exist_ok=True)

        # Check if the current model is the best so far; if so, save it and update the best rank #1 accuracy
        if rank1 > bestRank1:
            torch.save(model.state_dict(), f'{models_folder}/best_{backbone}_state_r1.bin')
            bestRank1 = rank1

        torch.save(model.state_dict(), f'{models_folder}/best_{backbone}_epoch_'+str(epoch)+'.bin')

        # Compute the average loss for this epoch and print it
        running_loss = train_loss / len(train_dataloader)
        print(f"Epoch: [{epoch+1}/{n_epochs}]| Loss: {running_loss:.5f}")

    return None

## Model Parameters

In [31]:
# Initialize the Siamese network with the specified parameters and move it to the selected device
model = BSN(backbone, img_channels, depth, img_w, img_h, n_classes).to(device)

# Define the loss function for training the network (Cosine Embedding Loss)
embg_criterion = nn.CosineEmbeddingLoss()
cat_criterion = nn.CrossEntropyLoss()

# Define the optimizer for training, using the Adam algorithm with a learning rate of 0.0005
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Network Training

In [32]:
# Start training model
training_model(model, train_dataloader, val_dataloader, optimizer, embg_criterion, cat_criterion, n_epochs)

                                               

Total samples: 36, Rank #1: 55.556%, Rank #2: 69.444%, Rank #3: 83.333%, Rank #4: 86.111%, Rank #5: 88.889%, Rank #6: 88.889%, Rank #7: 91.667%, Rank #8: 91.667%, Rank #9: 91.667%, Rank #10: 91.667%, mAP:72.222%
Epoch: [1/1]| Loss: 4.28208


## Performance Evaluation

In [None]:
# Initialize the Siamese neural network models with the same architecture and parameters
m1 = BSN(backbone, img_channels, depth, img_w, img_h, n_classes).to(device)

# Load pre-trained model weights
m1.load_state_dict(torch.load('./output/models/best_'+backbone+'_state_r1.bin', map_location=device))

# Evaluate and print the performance of the model with the respective evaluation metric
print('Model w/ best R1:')
testing_step(m1, test_dataloader)

Model w/ best R1:


                                               

Total samples: 36, Rank #1: 63.889%, Rank #2: 66.667%, Rank #3: 75.000%, Rank #4: 86.111%, Rank #5: 88.889%, Rank #6: 91.667%, Rank #7: 91.667%, Rank #8: 91.667%, Rank #9: 91.667%, Rank #10: 94.444%, mAP:73.380%




(0.6388888888888888,
 0.8888888888888888,
 0.9444444444444444,
 np.float64(0.7337962962962963))

## Reproducibility Testing Using Pretrained Weights

This final cell performs a slightly modified testing procedure that correctly loads and evaluates the models using the pretrained weights obtained during our training procedure. The purpose of this step is to ensure full **reproducibility** of the results reported in the paper by explicitly reusing our saved model checkpoints on the provided WiPER81 dataset.

In [9]:
# Initialize the Siamese neural network model
m1 = BSN(backbone, img_channels, depth, img_w, img_h, n_classes).to(device)

# Load the state dictionary from the file
state_dict = torch.load('./pre-trained/best_'+backbone+'_state_r1.bin', map_location=device)

# Create a new state dictionary for the model
new_state_dict = {}

# Map each key in the state dictionary to the new key in your model
for key in state_dict:
    # Replace 'model' with 'backbone' in each key
    new_key = key.replace('backbone', 'model')
    # Assign the weight to the new key in the new state dictionary
    new_state_dict[new_key] = state_dict[key]

# Load the new state dictionary into your model
m1.load_state_dict(new_state_dict)

# Evaluate and print the performance of the model with the respective evaluation metric
print('Model w/ best R1:')
testing_step(m1, test_dataloader)

Model w/ best R1:


                                                 

Total samples: 243, Rank #1: 45.679%, Rank #2: 57.202%, Rank #3: 62.963%, Rank #4: 69.959%, Rank #5: 73.663%, Rank #6: 77.778%, Rank #7: 79.424%, Rank #8: 80.658%, Rank #9: 81.481%, Rank #10: 82.716%, mAP:80.967%




(0.4567901234567901,
 0.7366255144032922,
 0.8271604938271605,
 np.float64(0.8096707818930041))