In [None]:
import os

# Define paths to test and train directories
test_path = '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test'
train_path = '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/train'

# Function to build a dictionary for player crops
def build_player_crops_dict(base_path):
    player_crops_dict = {}
    
    # Iterate through each game-play directory
    for game_play_dir in os.listdir(base_path):
        game_play_path = os.path.join(base_path, game_play_dir)
        
        if os.path.isdir(game_play_path):
            # Extract game_id and play_id from the directory name
            game_id, play_id = game_play_dir.split('-')
            
            # Initialize a dictionary for the game_id if not already present
            if game_id not in player_crops_dict:
                player_crops_dict[game_id] = {}
                
            # Iterate through player directories inside each game-play directory
            for player_id in os.listdir(game_play_path):
                player_dir_path = os.path.join(game_play_path, player_id)
                
                if os.path.isdir(player_dir_path):
                    # Add player crops to the dictionary
                    if player_id not in player_crops_dict[game_id]:
                        player_crops_dict[game_id][player_id] = []
                    
                    # Collect all image file paths for the player
                    for image_file in os.listdir(player_dir_path):
                        image_path = os.path.join(player_dir_path, image_file)
                        if os.path.isfile(image_path):
                            player_crops_dict[game_id][player_id].append(image_path)

    return player_crops_dict

# Build dictionaries for both test and train directories
test_crops_dict = build_player_crops_dict(test_path)
train_crops_dict = build_player_crops_dict(train_path)

# Print a summary of the dictionary for quick verification
print(f"Number of games in test set: {len(test_crops_dict)}")
print(f"Number of games in train set: {len(train_crops_dict)}")

# Example access to verify the structure
# For example, access player crops from game ID '0021801055' and player ID '1'
# game_id = '0021801055'
# player_id = '1'
# if game_id in test_crops_dict and player_id in test_crops_dict[game_id]:
#     print(f"Player {player_id} crops in game {game_id}: {test_crops_dict[game_id][player_id]}")
# else:
#     print(f"Game ID {game_id} or Player ID {player_id} not found in the test set.")

Number of games in test set: 3
Number of games in train set: 15
Player 1 crops in game 0021801055: ['/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000028_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000479_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000111_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000467_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000281_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000547_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000039_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389/1/000087_4.jpg', '/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops/test/0021801055-389

In [None]:
# data preprocessing

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupKFold
from dataclasses import dataclass, field

@dataclass
class Configuration:
    n_folds: int = 10
    data_dir: str = "/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops"
    save_dir: str = "./data/nsva_train"
    selected_game_endings: list = field(default_factory=lambda: ['00160', '00060', '01212', '01055'])
    seed: int = 2024  # Seed for reproducibility
    min_samples_per_player: int = 20  # Minimum number of crops per player
    max_samples_per_player: int = 50  # Maximum number of crops per player
    num_queries_per_player: int = 5   # Number of query images per player in the test set
    num_test_players: int = 15        # Number of players in the test set

# Initialize configuration
config = Configuration()

# Create the save directory if it doesn't exist
os.makedirs(config.save_dir, exist_ok=True)

#----------------------------------------------------------------------------------------------------------------------#  
# Function to build the DataFrame with constraints on samples per player                                                #
#----------------------------------------------------------------------------------------------------------------------#
def build_dataframe(base_path):
    img_id = []
    folder = []
    player = []
    game = []
    img_type = []

    # Iterate through each game_play directory
    for game_play_dir in os.listdir(base_path):
        game_play_path = os.path.join(base_path, game_play_dir)
        
        if os.path.isdir(game_play_path):
            # Extract game_id from the directory name
            game_id_full = game_play_dir.split('-')[0]
            game_id = game_id_full  # Use the full game_id

            # Check if game_id ends with one of the selected endings
            if not any(game_id.endswith(ending) for ending in config.selected_game_endings):
                continue  # Skip games not in the selected list

            # Iterate through player directories inside each game-play directory
            for player_id in os.listdir(game_play_path):
                player_dir_path = os.path.join(game_play_path, player_id)
                
                if os.path.isdir(player_dir_path):
                    # Get all image files in the player directory
                    image_files = [f for f in os.listdir(player_dir_path) if os.path.isfile(os.path.join(player_dir_path, f))]
                    
                    # Ensure minimum number of samples per player
                    if len(image_files) < config.min_samples_per_player:
                        continue  # Skip players with less than min_samples_per_player
                    
                    # Limit to max_samples_per_player
                    if len(image_files) > config.max_samples_per_player:
                        np.random.seed(config.seed)  # For reproducibility
                        image_files = np.random.choice(image_files, size=config.max_samples_per_player, replace=False)
                    
                    # Iterate through image files
                    for image_file in image_files:
                        image_path = os.path.join(player_dir_path, image_file)
                        if os.path.isfile(image_path):
                            # Extract image information
                            img_id.append(os.path.splitext(image_file)[0])  # Remove file extension
                            folder.append(player_dir_path)  # Full path to the player's directory
                            player.append(player_id)
                            game.append(game_id)
                            img_type.append('g')  # Default to 'g'; will adjust later if needed

    df = pd.DataFrame({
        "img_id": img_id,
        "folder": folder,  # Full path to the player's directory
        "player": player,
        "game": game,
        "split": 'all',  # Will update later
        "img_type": img_type,
    })

    return df

#----------------------------------------------------------------------------------------------------------------------#  
# Build DataFrame for Selected Games                                                                                   #
#----------------------------------------------------------------------------------------------------------------------#
# Build DataFrame from both train and test directories
df_train = build_dataframe(os.path.join(config.data_dir, 'train'))
df_test = build_dataframe(os.path.join(config.data_dir, 'test'))

# Combine the dataframes
df_full = pd.concat([df_train, df_test], ignore_index=True)

#----------------------------------------------------------------------------------------------------------------------#  
# Assign Splits and Img Types                                                                                           #
#----------------------------------------------------------------------------------------------------------------------#

# Get unique player IDs
unique_players = df_full['player'].unique()

# Set random seed for reproducibility
np.random.seed(config.seed)

# Randomly select players for the test split
test_players = np.random.choice(unique_players, size=min(config.num_test_players, len(unique_players)), replace=False)

# Update the 'split' column based on player IDs
df_full['split'] = df_full['player'].apply(lambda x: 'test' if x in test_players else 'train')

# Now, for the test split, assign 'q' and 'g' in 'img_type' column
df_full['img_type'] = df_full['img_type'].astype(str)  # Ensure it's of type str

# For each player in the test set, select num_queries_per_player images as query and the rest as gallery
for player_id in test_players:
    player_df = df_full[(df_full['player'] == player_id) & (df_full['split'] == 'test')]
    if len(player_df) == 0:
        continue  # Skip if no images for this player

    # Check if player has enough images for queries
    if len(player_df) <= config.num_queries_per_player:
        print(f"Player {player_id} has only {len(player_df)} images in test set, setting all as queries.")
        df_full.loc[player_df.index, 'img_type'] = 'q'
    else:
        # Randomly select num_queries_per_player images as query
        query_idx = player_df.sample(n=config.num_queries_per_player, random_state=config.seed).index
        df_full.loc[query_idx, 'img_type'] = 'q'

        # The rest are gallery
        gallery_idx = player_df.index.difference(query_idx)
        df_full.loc[gallery_idx, 'img_type'] = 'g'

# For the training set, set 'img_type' to 'g' (gallery)
df_full.loc[df_full['split'] == 'train', 'img_type'] = 'g'

#----------------------------------------------------------------------------------------------------------------------#  
# Map player IDs to integer labels starting from 0                                                                      #
#----------------------------------------------------------------------------------------------------------------------#
unique_players = df_full['player'].unique()
player_id_map = {player_id: idx for idx, player_id in enumerate(unique_players)}
df_full['player'] = df_full['player'].map(player_id_map).astype(int)

#----------------------------------------------------------------------------------------------------------------------#  
# Assign Folds for Cross-Validation                                                                                    #
#----------------------------------------------------------------------------------------------------------------------#
df_full['fold'] = -1

# Assign folds only to training data
df_train_full = df_full[df_full['split'] == 'train'].copy()

n_splits = min(config.n_folds, df_train_full['player'].nunique())
cv = GroupKFold(n_splits=n_splits)
split = list(cv.split(df_train_full, groups=df_train_full['player']))

for i, (train_idx, val_idx) in enumerate(split):
    df_train_full.loc[df_train_full.index[val_idx], "fold"] = i

# Update df_full with the fold information
df_full.update(df_train_full)

#----------------------------------------------------------------------------------------------------------------------#  
# Save DataFrame                                                                                                      #
#----------------------------------------------------------------------------------------------------------------------#
# Save the combined DataFrame
df_full.to_csv(f"{config.save_dir}/train_df.csv", index=False)

# Print summary
print("DataFrame saved to:", f"{config.save_dir}/train_df.csv")
print("Number of total images:", len(df_full))
print("Number of training images:", len(df_full[df_full['split'] == 'train']))
print("Number of testing images:", len(df_full[df_full['split'] == 'test']))
print("Number of unique training players:", df_full[df_full['split'] == 'train']['player'].nunique())
print("Number of unique testing players:", df_full[df_full['split'] == 'test']['player'].nunique())


DataFrame saved to: ./data/nsva/train_df.csv
Number of total images: 19366
Number of training images: 16366
Number of testing images: 3000
Number of unique training players: 72
Number of unique testing players: 15


: 

In [1]:
# train

import os
import sys
import torch
import random
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader  

# Import custom modules
from clipreid.loss import ClipLoss
from clipreid.trainer import train, get_scheduler
from clipreid.utils import Logger, setup_system, print_line
from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores

@dataclass
class Configuration:
    # Model configurations
    model: str = ('ViT-L-14', 'openai')
    remove_proj = True
    img_size: int = (224, 224)
    mean:   float = (0.485, 0.456, 0.406)
    std:    float = (0.229, 0.224, 0.225)
    
    # Data settings
    train_on_all: bool = False
    fold: int = -1  # Use -1 to specify custom behavior
    seed: int = 2024  # Updated seed for consistency
    epochs: int = 4
    batch_size: int = 16
    batch_size_eval: int = 64
    gpu_ids: tuple = (0,)
    mixed_precision: bool = True
    lr: float = 0.00004
    scheduler: str = "polynomial"
    warmup_epochs: float = 1.0
    lr_end: float = 0.00001
    gradient_clipping: float = None
    grad_checkpointing: bool = False
    gradient_accumulation: int = 1
    label_smoothing: float = 0.1
    zero_shot: bool = True
    rerank: bool = True
    normalize_features: bool = True
    data_dir: str = "/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops"
    data_csv: str = "./data/nsva/train_df.csv"
    prob_flip: float = 0.5
    model_path: str = "./model_nsva"
    checkpoint_start: str = None
    verbose: bool = True 
    num_workers: int = 0 if os.name == 'nt' else 8  
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu' 
    cudnn_benchmark: bool = True
    cudnn_deterministic: bool = True      

# Initialize configuration
config = Configuration()

# Ensure the model path exists
if isinstance(config.model, tuple):
    # Clip models
    if config.train_on_all:
        model_path = "{}/{}_{}/all_data_seed_{}".format(config.model_path,
                                               config.model[0],
                                               config.model[1],
                                               config.seed)
    else:
        model_path = "{}/{}_{}/fold{}_seed_{}".format(config.model_path,
                                               config.model[0],
                                               config.model[1],
                                               config.fold,
                                               config.seed)
else:
    # Timm models
    if config.train_on_all:
        model_path = "{}/{}/all_data_seed_{}".format(config.model_path,
                                                     config.model,
                                                     config.seed)
    else:
        model_path = "{}/{}/fold{}_seed_{}".format(config.model_path,
                                            config.model,
                                            config.seed)

if not os.path.exists(model_path):
    os.makedirs(model_path)

# Redirect print to both console and log file
sys.stdout = Logger("{}/log.txt".format(model_path))

# Set seed for reproducibility
setup_system(seed=config.seed,
             cudnn_benchmark=config.cudnn_benchmark,
             cudnn_deterministic=config.cudnn_deterministic)

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):

    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj
                          )
    
    img_size = model.get_image_size()
    
    mean=(0.48145466, 0.4578275, 0.40821073)
    std=(0.26862954, 0.26130258, 0.27577711)
    
    if config.grad_checkpointing: 
        model.set_grad_checkpoint(enable=config.grad_checkpointing)
       
else:
    model = TimmModel(config.model,
                      pretrained=True)

    img_size = config.img_size
    mean = config.mean
    std = config.std
    
# Load pretrained checkpoint if provided   
if config.checkpoint_start is not None:  
    print("\nStart from:", config.checkpoint_start)
    model_state_dict = torch.load(config.checkpoint_start)  
    model.load_state_dict(model_state_dict, strict=True)
    
# Data parallel
print("\nGPUs available:", torch.cuda.device_count())  
if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
    print("Using Data Parallel with GPU IDs: {}".format(config.gpu_ids))
    model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)    
    multi_gpu = True
else:
    multi_gpu = False  
    
# Model to device   
model = model.to(config.device)

print("\nImage Size:", img_size)
print("Mean: {}".format(mean))
print("Std:  {}".format(std)) 

#----------------------------------------------------------------------------------------------------------------------#  
# DataLoader                                                                                                           #
#----------------------------------------------------------------------------------------------------------------------#  
# Data
df = pd.read_csv(config.data_csv)

# Split data
if config.train_on_all:
    df_train = df[df['split'] == 'train']
    df_test = df[df["split"] == "test"]
else:
    if config.fold == -1:
        # Use given test split
        df_train = df[df["split"] == "train"]
        df_test = df[df["split"] == "test"]
    else:
        # Use custom folds
        df_train = df[df["fold"] != config.fold]
        df_test = df[df["fold"] == config.fold]

# Transforms
val_transforms, train_transforms = get_transforms(img_size, mean, std)

# Custom Dataset Classes
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset
import cv2
import copy
import torch
from torchvision import transforms

class CustomTrainDataset(Dataset):
    def __init__(self,
                 df,
                 image_transforms=None,
                 prob_flip=0.5,
                 shuffle_batch_size=16):
        
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.prob_flip = prob_flip
        self.shuffle_batch_size = shuffle_batch_size
        
        print("\nImages train: {}".format(len(self.df)))
        self.images = self.df['img_id'].tolist()
        
        # Map img_id to index in dataframe for quick access
        self.img_id_to_index = {img_id: idx for idx, img_id in enumerate(self.df['img_id'])}
        
        # dict for all images for a given player
        self.player_images = defaultdict(list)
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = row['player']
            self.player_images[player].append(img_id)
  
        # dict for all gallery images for a given image
        self.player_images_other = {}
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = row['player']
            other_images = copy.deepcopy(self.player_images[player])
            other_images.remove(img_id)
            self.player_images_other[img_id] = np.array(other_images)

        self.samples = copy.deepcopy(self.images)
        self.shuffle()
            
    def __getitem__(self, index):

        # Query image
        img_id_query = self.samples[index]
        record_query = self.df.loc[self.img_id_to_index[img_id_query]]
        img_path_query = os.path.join(record_query['folder'], img_id_query + '.jpg')
        
        img_query = cv2.imread(img_path_query)
        img_query = cv2.cvtColor(img_query, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img_query = self.image_transforms(image=img_query)['image']
        
        # Randomly select one other image of the same player as gallery image
        if len(self.player_images_other[img_id_query]) > 0:
            img_id_gallery = np.random.choice(self.player_images_other[img_id_query], 1)[0]
        else:
            img_id_gallery = img_id_query  # If no other images, use the same image
        record_gallery = self.df.loc[self.img_id_to_index[img_id_gallery]]
        img_path_gallery = os.path.join(record_gallery['folder'], img_id_gallery + '.jpg')
    
        img_gallery = cv2.imread(img_path_gallery)
        img_gallery = cv2.cvtColor(img_gallery, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img_gallery = self.image_transforms(image=img_gallery)['image']
  
        # Player ID as label
        player = torch.tensor(int(record_query['player']), dtype=torch.long)
        
        # Random flip both images
        if np.random.random() < self.prob_flip:
            img_query = transforms.functional.hflip(img_query)
            img_gallery = transforms.functional.hflip(img_gallery)
        
        return img_query, img_gallery, player
    
    def __len__(self):
        return len(self.samples)

    def shuffle(self):
        '''
        Custom shuffle function to prevent having the same player twice in the same batch.
        '''
        
        img_ids_select = copy.deepcopy(self.images)
        random.shuffle(img_ids_select)

        batches = []
        players_batch = set()
        batch = []
        break_counter = 0
        
        while img_ids_select:
            img_id = img_ids_select.pop(0)
            player = self.df.loc[self.img_id_to_index[img_id]]['player']
            
            if player not in players_batch:
                players_batch.add(player)
                batch.append(img_id)
                
                # If batch is filled, reset
                if len(batch) == self.shuffle_batch_size:
                    batches.extend(batch)
                    batch = []
                    players_batch = set()
                    break_counter = 0
            else:
                # Append at the end for later consideration
                img_ids_select.append(img_id)
                break_counter += 1
                
                if break_counter >= len(img_ids_select):
                    # Can't fill batch without repeating players, so we accept duplicates
                    batches.extend(batch)
                    batch = []
                    players_batch = set()
                    break_counter = 0
        
        # Add any remaining images
        if batch:
            batches.extend(batch)
        
        self.samples = batches
        print("\nShuffle Training Data:")
        print("Length Train:", len(self.samples))
        if len(self.samples) > 0:
            print("First Element: {}".format(self.samples[0]))

class CustomTestDataset(Dataset):
    def __init__(self,
                 df,
                 image_transforms=None):
        
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.images = self.df['img_id'].tolist()
        
        self.query = []
        self.gallery = []
        self.all = []
        
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = int(row['player'])
            img_type = row['img_type']
            folder = row['folder']
            img_path = os.path.join(folder, img_id + '.jpg')
            self.all.append((img_path, player, -1))
            
            if img_type == "q":
                self.query.append((img_path, player, 0))
            else:
                self.gallery.append((img_path, player, 1))
        
    def __getitem__(self, index):

        img_id = self.df.loc[index, 'img_id']
        img_path = os.path.join(self.df.loc[index, 'folder'], img_id + '.jpg')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img = self.image_transforms(image=img)['image']
             
        player = int(self.df.loc[index]['player'])
        
        if self.df.loc[index]["img_type"] == "q":
            img_type = 0
        else:
            img_type = 1
            
        return img, img_path, player, img_type
    
    def __len__(self):
        return len(self.df)

# Initialize the custom datasets
train_dataset = CustomTrainDataset(df=df_train,
                                   image_transforms=train_transforms,
                                   prob_flip=config.prob_flip,
                                   shuffle_batch_size=config.batch_size)

train_loader = DataLoader(train_dataset,
                          batch_size=config.batch_size,
                          num_workers=config.num_workers,
                          shuffle=False,
                          pin_memory=True,
                          drop_last=True)

# Validation
test_dataset = CustomTestDataset(df=df_test,
                                 image_transforms=val_transforms)

test_loader = DataLoader(test_dataset,
                         batch_size=config.batch_size_eval,
                         num_workers=config.num_workers,
                         shuffle=False,
                         pin_memory=True)

#----------------------------------------------------------------------------------------------------------------------#  
# Loss                                                                                                                 #
#----------------------------------------------------------------------------------------------------------------------#  
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
loss_function = ClipLoss(loss_function=loss_fn,
                         device=config.device)

#----------------------------------------------------------------------------------------------------------------------#  
# Optimizer and Scaler                                                                                                 #
#----------------------------------------------------------------------------------------------------------------------#  
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

if config.mixed_precision:
    scaler = torch.cuda.amp.GradScaler(init_scale=2.**10)
else:
    scaler = None
    
#----------------------------------------------------------------------------------------------------------------------#  
# Scheduler                                                                                                            #
#----------------------------------------------------------------------------------------------------------------------#  
if config.scheduler is not None:
    scheduler = get_scheduler(config,
                              optimizer,
                              train_loader_length=len(train_loader))       
else:
    scheduler = None
   
#----------------------------------------------------------------------------------------------------------------------#  
# Zero Shot                                                                                                            #
#----------------------------------------------------------------------------------------------------------------------#  
if config.zero_shot:
    
    print_line(name="Zero-Shot", length=80)
    
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                          test_dataset.query,
                                                          test_dataset.gallery,
                                                          rerank=config.rerank)
    
    print("\nWithout re-ranking:")
    mAP = compute_scores(dist_matrix,
                         test_dataset.query,
                         test_dataset.gallery)
    
    if dist_matrix_rerank is not None:
        print("\nWith re-ranking:")
        mAP = compute_scores(dist_matrix_rerank,
                             test_dataset.query,
                             test_dataset.gallery)
        

#----------------------------------------------------------------------------------------------------------------------#  
# Train                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
for epoch in range(1, config.epochs+1):

    print_line(name="Epoch: {}".format(epoch), length=80)
    
    # Train
    train_loss = train(model,
                       dataloader=train_loader,
                       loss_function=loss_function,
                       optimizer=optimizer,
                       device=config.device,
                       scheduler=scheduler,
                       scaler=scaler,
                       gradient_accumulation=config.gradient_accumulation,
                       gradient_clipping=config.gradient_clipping,
                       verbose=config.verbose,
                       multi_gpu=multi_gpu)

    print("Avg. Train Loss = {:.4f} - Lr = {:.6f}\n".format(train_loss,
                                                           optimizer.param_groups[0]['lr']))
    # Evaluate
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                          test_dataset.query,
                                                          test_dataset.gallery,
                                                          rerank=config.rerank)
    
    print("\nWithout re-ranking:")
    mAP = compute_scores(dist_matrix,
                         test_dataset.query,
                         test_dataset.gallery)
    
    if dist_matrix_rerank is not None:
        print("\nWith re-ranking:")
        mAP_rerank = compute_scores(dist_matrix_rerank,
                                    test_dataset.query,
                                    test_dataset.gallery)
        
    checkpoint_path = '{}/weights_e{}.pth'.format(model_path, epoch)
            
    # Save model  
    if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
        torch.save(model.module.state_dict(), checkpoint_path)
    else:
        torch.save(model.state_dict(), checkpoint_path)
    
    # Shuffle data for next epoch
    train_loader.dataset.shuffle()


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()



Model: ('ViT-L-14', 'openai')




Remove Projection Layer - old output size: 768 - new output size: 1024

GPUs available: 2

Image Size: (224, 224)
Mean: (0.48145466, 0.4578275, 0.40821073)
Std:  (0.26862954, 0.26130258, 0.27577711)

Images train: 16366

Shuffle Training Data:
Length Train: 16366
First Element: 000432_3

Warmup Epochs: 1.0 - Warmup Steps: 1022
Train Epochs:  4 - Train Steps:  4088

Scheduler: polynomial - max LR: 4e-05 - end LR: 1e-05

----------------------------------[Zero-Shot]-----------------------------------


  scaler = torch.cuda.amp.GradScaler(init_scale=2.**10)
Test : 100%|##########| 47/47 [00:32<00:00,  1.44it/s]



Without re-ranking:
mAP: 65.86%
CMC Scores    allshots      cuhk03  market1501
  top-1         27.11%      63.20%      98.67%
  top-5         37.80%      88.80%     100.00%
  top-10        42.61%      96.93%     100.00%

With re-ranking:
mAP: 74.70%
CMC Scores    allshots      cuhk03  market1501
  top-1         38.89%      71.07%     100.00%
  top-5         50.77%      91.07%     100.00%
  top-10        56.99%      96.53%     100.00%

-----------------------------------[Epoch: 1]-----------------------------------


Train: 100%|##########| 1022/1022 [06:55<00:00,  2.46it/s, loss=2.62, lr=0.000040]


Avg. Train Loss = 2.4680 - Lr = 0.000040



Test : 100%|##########| 47/47 [00:37<00:00,  1.24it/s]



Without re-ranking:
mAP: 79.34%
CMC Scores    allshots      cuhk03  market1501
  top-1         44.08%      74.27%     100.00%
  top-5         56.33%      95.47%     100.00%
  top-10        60.92%      99.07%     100.00%

With re-ranking:
mAP: 86.24%
CMC Scores    allshots      cuhk03  market1501
  top-1         43.12%      84.40%     100.00%
  top-5         67.23%      97.60%     100.00%
  top-10        71.58%      99.20%     100.00%

Shuffle Training Data:
Length Train: 16366
First Element: 000229_10

-----------------------------------[Epoch: 2]-----------------------------------


Train: 100%|##########| 1022/1022 [07:11<00:00,  2.37it/s, loss=2.98, lr=0.000023]


Avg. Train Loss = 2.0396 - Lr = 0.000023



Test : 100%|##########| 47/47 [00:35<00:00,  1.34it/s]



Without re-ranking:
mAP: 79.57%
CMC Scores    allshots      cuhk03  market1501
  top-1         41.95%      79.07%      98.67%
  top-5         53.97%      97.20%      98.67%
  top-10        59.72%      98.93%      98.67%

With re-ranking:
mAP: 86.09%
CMC Scores    allshots      cuhk03  market1501
  top-1         45.24%      81.87%      98.67%
  top-5         64.95%      97.87%      98.67%
  top-10        72.17%      99.60%      98.67%

Shuffle Training Data:
Length Train: 16366
First Element: 000379_4

-----------------------------------[Epoch: 3]-----------------------------------


Train: 100%|##########| 1022/1022 [07:10<00:00,  2.38it/s, loss=2.72, lr=0.000013]


Avg. Train Loss = 1.8012 - Lr = 0.000013



Test : 100%|##########| 47/47 [00:39<00:00,  1.20it/s]



Without re-ranking:
mAP: 67.66%
CMC Scores    allshots      cuhk03  market1501
  top-1         30.56%      63.33%     100.00%
  top-5         40.89%      90.93%     100.00%
  top-10        46.39%      98.67%     100.00%

With re-ranking:
mAP: 77.50%
CMC Scores    allshots      cuhk03  market1501
  top-1         31.85%      74.53%     100.00%
  top-5         55.66%      92.53%     100.00%
  top-10        61.94%      98.13%     100.00%

Shuffle Training Data:
Length Train: 16366
First Element: 000511_4

-----------------------------------[Epoch: 4]-----------------------------------


Train: 100%|##########| 1022/1022 [07:11<00:00,  2.37it/s, loss=2.70, lr=0.000010]


Avg. Train Loss = 1.6232 - Lr = 0.000010



Test : 100%|##########| 47/47 [00:39<00:00,  1.19it/s]



Without re-ranking:
mAP: 63.69%
CMC Scores    allshots      cuhk03  market1501
  top-1         28.21%      62.67%      97.33%
  top-5         38.54%      83.73%      98.67%
  top-10        43.29%      95.73%     100.00%

With re-ranking:
mAP: 76.79%
CMC Scores    allshots      cuhk03  market1501
  top-1         33.13%      76.13%      98.67%
  top-5         56.77%      88.53%      98.67%
  top-10        63.44%      96.13%      98.67%

Shuffle Training Data:
Length Train: 16366
First Element: 000374_3


In [None]:
# evaluate

import os
import torch
import pandas as pd
import numpy as np  # Import numpy if needed
from dataclasses import dataclass
from torch.utils.data import DataLoader
from PIL import Image  # Ensure PIL is imported for image handling

# Import your custom modules (ensure these are accessible in your environment)
from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.dataset import TestDataset
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores
from clipreid.utils import print_line

@dataclass
class Configuration:
    '''
    --------------------------------------------------------------------------
    Open Clip Models:
    --------------------------------------------------------------------------    
    - ('RN50', 'openai')
    - ('ViT-B-32', 'openai')
    - ('ViT-L-14', 'openai')
    - ...
    --------------------------------------------------------------------------
    '''
    # Model configurations
    model: str = ('ViT-L-14', 'openai')
    remove_proj = True
    
    # Settings only for Timm models
    img_size: int = (224, 224)
    mean:   float = (0.485, 0.456, 0.406)
    std:    float = (0.229, 0.224, 0.225)
    
    # Evaluation settings
    batch_size: int = 64
    normalize_features: bool = True
    
    # Checkpoint
    checkpoint: str = "./model_nsva/ViT-L-14_openai/fold-1_seed_1/weights_e4.pth"  # Update to your actual checkpoint path
    
    # Dataset paths
    data_dir: str = "./data/nsva"  # Updated data directory
    
    # Miscellaneous settings
    verbose: bool = True
    num_workers: int = 0 if os.name == 'nt' else 8
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'

#----------------------------------------------------------------------------------------------------------------------#  
# Config                                                                                                               #
#----------------------------------------------------------------------------------------------------------------------#  
config = Configuration()

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):
    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj)
    img_size = model.get_image_size()
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)
else:
    model = TimmModel(config.model, pretrained=True)
    img_size = config.img_size
    mean = config.mean
    std = config.std

print_line(name=config.checkpoint, length=80)

# Load pretrained checkpoint
model_state_dict = torch.load(config.checkpoint)
model.load_state_dict(model_state_dict, strict=True)    

# Move model to device
model = model.to(config.device)

print("\nImage Size:", img_size)
print("Mean: {}".format(mean))
print("Std:  {}".format(std)) 

#----------------------------------------------------------------------------------------------------------------------#  
# DataLoader                                                                                                           #
#----------------------------------------------------------------------------------------------------------------------#  

# Transforms
val_transforms, _ = get_transforms(img_size, mean, std)

# Load dataframes
df_train = pd.read_csv(f"{config.data_dir}/train_df.csv")
df_test = pd.read_csv(f"{config.data_dir}/test_df.csv")

# Adjust the TestDataset class
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_transforms=None):
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.query = []
        self.gallery = []
        self.prepare_query_gallery()
    
    def prepare_query_gallery(self):
        for idx, record in self.df.iterrows():
            img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')  # Adjust extension if needed
            pid = int(record['player'])
            camid = int(record['game'])
            if record['img_type'] == 'q':
                self.query.append((img_path, pid, camid))
            else:
                self.gallery.append((img_path, pid, camid))
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        record = self.df.iloc[idx]
        img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')  # Adjust extension if needed

        # Ensure the image exists
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = Image.open(img_path).convert('RGB')

        # Apply transformations
        if self.image_transforms:
            image = self.image_transforms(image=np.array(image))['image']

        pid = int(record['player'])
        camid = int(record['game'])
        img_type = record['img_type']

        file_name = img_path  # Use the full image path as the key

        return image, file_name, pid, img_type

# Create DataLoader
test_dataset = TestDataset(
    df=df_test,
    image_transforms=val_transforms
)

test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True
)

#----------------------------------------------------------------------------------------------------------------------#  
# Evaluation                                                                                                            #
#----------------------------------------------------------------------------------------------------------------------#
print_line(name="Evaluation", length=80)

# Extract features
features_dict = predict(model,
                        dataloader=test_loader,
                        device=config.device,
                        normalize_features=config.normalize_features,
                        verbose=config.verbose)

# Compute distance matrices
dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                      test_dataset.query,
                                                      test_dataset.gallery,
                                                      rerank=True)

# Without re-ranking
print("\nWithout re-ranking:")
compute_scores(dist_matrix,
               test_dataset.query,
               test_dataset.gallery,
               cmc_scores=True)

# With re-ranking
if dist_matrix_rerank is not None:
    print("\nWith re-ranking:")
    compute_scores(dist_matrix_rerank,
                   test_dataset.query,
                   test_dataset.gallery,
                   cmc_scores=True)


In [None]:
# predict

import os
import torch
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader
from PIL import Image  # Ensure PIL is imported for image handling

# Import your custom modules (ensure these are accessible in your environment)
from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.dataset import TestDataset, ChallengeDataset
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores, write_mat_csv
from clipreid.utils import print_line

@dataclass
class Configuration:
    '''
    --------------------------------------------------------------------------
    Open Clip Models:
    --------------------------------------------------------------------------    
    - ('ViT-L-14', 'openai')
    - ...
    --------------------------------------------------------------------------
    '''
    # Model configurations
    model: str = ('ViT-L-14', 'openai')
    remove_proj = True
    
    # Settings only for Timm models
    img_size: int = (224, 224)
    mean:   float = (0.485, 0.456, 0.406)
    std:    float = (0.229, 0.224, 0.225)
    
    # Evaluation settings
    batch_size: int = 64
    normalize_features: bool = True
    
    # Checkpoints (for ensemble)
    checkpoints: tuple = (
        "./model_nsva/ViT-L-14_openai/fold-1_seed_1/weights_e4.pth",
        "./model_nsva/ViT-L-14_openai/all_data_seed_1/weights_e4.pth"
    )
    
    # Dataset paths
    data_dir: str = "./data/nsva"
    
    # Miscellaneous settings
    verbose: bool = True
    num_workers: int = 0 if os.name == 'nt' else 8
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'

#----------------------------------------------------------------------------------------------------------------------#  
# Config                                                                                                               #
#----------------------------------------------------------------------------------------------------------------------#  
config = Configuration()

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):
    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj)
    img_size = model.get_image_size()
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)
else:
    model = TimmModel(config.model, pretrained=True)
    img_size = config.img_size
    mean = config.mean
    std = config.std

dist_matrix_list = []
dist_matrix_rerank_list = []

checkpoints = config.checkpoints if isinstance(config.checkpoints, (list, tuple)) else [config.checkpoints]

#----------------------------------------------------------------------------------------------------------------------#  
# DataLoader and Dataset Classes                                                                                        #
#----------------------------------------------------------------------------------------------------------------------#  

# Transforms
val_transforms, _ = get_transforms(img_size, mean, std)

# Load dataframes
df_train = pd.read_csv(f"{config.data_dir}/train_df.csv")
df_test = pd.read_csv(f"{config.data_dir}/test_df.csv")
df_challenge = pd.read_csv(f"{config.data_dir}/challenge_df.csv")

# Adjust the dataset classes
class TestDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_transforms=None):
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.query = []
        self.gallery = []
        self.prepare_query_gallery()
    
    def prepare_query_gallery(self):
        for idx, record in self.df.iterrows():
            img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')
            pid = int(record['player'])
            camid = int(record['game'])
            if record['img_type'] == 'q':
                self.query.append((img_path, pid, camid))
            else:
                self.gallery.append((img_path, pid, camid))
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        record = self.df.iloc[idx]
        img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')  # Adjust extension if needed

        # Ensure the image exists
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = Image.open(img_path).convert('RGB')

        # Apply transformations
        if self.image_transforms:
            image = self.image_transforms(image=np.array(image))['image']

        pid = int(record['player'])
        camid = int(record['game'])
        img_type = record['img_type']

        file_name = img_path  # Use the full image path as the key

        return image, file_name, pid, img_type

class ChallengeDataset(torch.utils.data.Dataset):
    def __init__(self, df, image_transforms=None):
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.query = []
        self.gallery = []
        self.prepare_query_gallery()
    
    def prepare_query_gallery(self):
        for idx, record in self.df.iterrows():
            img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')
            pid = int(record['player'])
            camid = int(record['game'])
            if record['img_type'] == 'q':
                self.query.append((img_path, pid, camid))
            else:
                self.gallery.append((img_path, pid, camid))
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        record = self.df.iloc[idx]
        img_path = os.path.join(record['folder'], record['img_id'] + '.jpg')

        # Ensure the image exists
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        image = Image.open(img_path).convert('RGB')

        # Apply transformations
        if self.image_transforms:
            image = self.image_transforms(image=np.array(image))['image']

        pid = int(record['player'])  # or -1 if unknown
        camid = int(record['game'])  # or -1 if unknown
        img_type = record['img_type']

        file_name = img_path  # Use the full image path as the key

        return image, file_name, pid, img_type

# Create DataLoaders
test_dataset = TestDataset(df=df_test, image_transforms=val_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True
)

challenge_dataset = ChallengeDataset(df=df_challenge, image_transforms=val_transforms)
challenge_loader = DataLoader(
    challenge_dataset,
    batch_size=config.batch_size,
    num_workers=config.num_workers,
    shuffle=False,
    pin_memory=True
)

#----------------------------------------------------------------------------------------------------------------------#  
# Prediction and Evaluation Loop                                                                                        #
#----------------------------------------------------------------------------------------------------------------------#
for checkpoint in checkpoints:
    print_line(name=checkpoint, length=80)
    
    # Load model checkpoint
    model_state_dict = torch.load(checkpoint)
    model.load_state_dict(model_state_dict, strict=True)
    model = model.to(config.device)
    
    print("\nImage Size:", img_size)
    print("Mean: {}".format(mean))
    print("Std:  {}".format(std)) 
    
    #----------------------------------------------#
    # Test Evaluation                              #
    #----------------------------------------------#
    print_line(name="Test Evaluation", length=80)
    
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    dist_matrix_test, dist_matrix_test_rerank = compute_dist_matrix(features_dict, 
                                                                    test_dataset.query,
                                                                    test_dataset.gallery,
                                                                    rerank=True)
    
    # Without re-ranking
    print("\nWithout re-ranking:")
    compute_scores(dist_matrix_test,
                   test_dataset.query,
                   test_dataset.gallery,
                   cmc_scores=True)
    
    # Save distance matrix
    save_dir = os.path.dirname(checkpoint)
    save_path = os.path.join(save_dir, 'test_dmat.csv')
    print("Writing distance matrix to:", save_path)
    write_mat_csv(save_path,
                  dist_matrix_test,
                  test_dataset.query,
                  test_dataset.gallery)
    
    # With re-ranking
    print("\nWith re-ranking:")
    compute_scores(dist_matrix_test_rerank,
                   test_dataset.query,
                   test_dataset.gallery,
                   cmc_scores=True)
    
    save_path = os.path.join(save_dir, 'test_dmat_rerank.csv')
    print("Writing re-ranked distance matrix to:", save_path)
    write_mat_csv(save_path,
                  dist_matrix_test_rerank,
                  test_dataset.query,
                  test_dataset.gallery)
    
    #----------------------------------------------#
    # Challenge Evaluation                         #
    #----------------------------------------------#
    # Process Challenge Set
    print_line(name="Challenge Evaluation", length=80)

    features_dict = predict(model,
                            dataloader=challenge_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)

    dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                        challenge_dataset.query,
                                                        challenge_dataset.gallery,
                                                        rerank=True)

    # Since we might not have labels in the challenge set, we won't compute scores
    # Instead, we save the distance matrices for submission or further analysis
    save_dir = os.path.dirname(checkpoint)
    save_path = os.path.join(save_dir, 'challenge_dmat.csv')
    print("Writing challenge distance matrix to:", save_path)
    write_mat_csv(save_path,
                dist_matrix,
                challenge_dataset.query,
                challenge_dataset.gallery)

    save_path = os.path.join(save_dir, 'challenge_dmat_rerank.csv')
    print("Writing re-ranked challenge distance matrix to:", save_path)
    write_mat_csv(save_path,
                  dist_matrix_rerank,
                  challenge_dataset.query,
                  challenge_dataset.gallery)

#----------------------------------------------------------------------------------------------------------------------#  
# Ensemble                                                                                                             #
#----------------------------------------------------------------------------------------------------------------------#
if len(dist_matrix_list) > 1:
    print_line(name="Ensemble", length=80)
    
    # Without re-ranking
    dist_matrix_ensemble = np.mean(dist_matrix_list, axis=0)
    ensemble_save_dir = os.path.commonpath([os.path.dirname(cp) for cp in checkpoints])
    save_path = os.path.join(ensemble_save_dir, 'challenge_dmat_ensemble.csv')
    print("Writing ensemble distance matrix to:", save_path)
    write_mat_csv(save_path,
                  dist_matrix_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery)
    
    # With re-ranking
    dist_matrix_rerank_ensemble = np.mean(dist_matrix_rerank_list, axis=0)
    save_path = os.path.join(ensemble_save_dir, 'challenge_dmat_rerank_ensemble.csv')
    print("Writing re-ranked ensemble distance matrix to:", save_path)
    write_mat_csv(save_path,
                  dist_matrix_rerank_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery)
