In [34]:
# data_preprocessing small scale

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import GroupKFold
from dataclasses import dataclass, field
import random
from typing import Optional

@dataclass
class Configuration:
    n_folds: int = 10
    data_dir: str = "/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops"
    save_dir: str = "./data/nsva"
    selected_game_endings: Optional[list] = field(default_factory=lambda: ['00160', '00060', '01212', '01055'])
    challenge_game_ending: Optional[str] = '00151'  # Game ID ending for challenge
    seed: int = 2024  # Seed for reproducibility
    min_samples_per_player: int = 20  # Minimum number of crops per player
    max_samples_per_player: int = 50  # Maximum number of crops per player
    num_queries_per_player: int = 5   # Number of query images per player in the test set
    num_test_players: int = 20        # Number of players in the test set
    num_challenge_images_per_player: int = 5  # Number of images per player in challenge

# Initialize configuration
config = Configuration(
    selected_game_endings=None,  # or [] to indicate no filtering
    challenge_game_ending=None   # to skip challenge DataFrame creation
)

# Set random seed for reproducibility
np.random.seed(config.seed)
random.seed(config.seed)

# Create the save directory if it doesn't exist
os.makedirs(config.save_dir, exist_ok=True)

#----------------------------------------------------------------------------------------------------------------------#
# Function to build the DataFrame with constraints on samples per player                                                #
#----------------------------------------------------------------------------------------------------------------------#
def build_dataframe(base_path, selected_game_endings):
    img_id = []
    folder = []
    player = []
    game = []
    img_type = []
    split = []
    
    # Iterate through each game_play directory
    for game_play_dir in os.listdir(base_path):
        game_play_path = os.path.join(base_path, game_play_dir)
        
        if os.path.isdir(game_play_path):
            # Extract game_id from the directory name
            if '-' not in game_play_dir:
                print(f"Skipping directory '{game_play_dir}' as it does not contain a hyphen '-'.")
                continue  # Skip directories that don't follow the 'game_id-play_id' format
            
            game_id_full = game_play_dir.split('-')[0]
            game_id = game_id_full  # Use the full game_id
            
            # Check if game_id ends with one of the selected endings (if provided)
            if selected_game_endings:
                if not any(game_id.endswith(ending) for ending in selected_game_endings):
                    continue  # Skip games not in the selected list
            
            # Debugging: Print matched game_play_dir
            # print(f"Matched game_play_dir: '{game_play_dir}' with game_id: '{game_id}'")
            
            # Iterate through player directories inside each game-play directory
            for player_id in os.listdir(game_play_path):
                # Skip player directories with player_id == '0'
                if player_id == '0':
                    # print(f"Skipping player directory '0' in game '{game_id}'.")
                    continue
                
                player_dir_path = os.path.join(game_play_path, player_id)
                
                if os.path.isdir(player_dir_path):
                    # Get all image files in the player directory
                    image_files = [f for f in os.listdir(player_dir_path) if os.path.isfile(os.path.join(player_dir_path, f))]
                    
                    # Ensure minimum number of samples per player
                    if len(image_files) < config.min_samples_per_player:
                        print(f"Skipping player '{player_id}' in game '{game_id}' due to insufficient images ({len(image_files)}).")
                        continue  # Skip players with less than min_samples_per_player
                    
                    # Randomly shuffle the image files to ensure randomness
                    random.shuffle(image_files)
                    
                    # Limit to max_samples_per_player
                    if len(image_files) > config.max_samples_per_player:
                        image_files = image_files[:config.max_samples_per_player]
                        # print(f"Limiting player '{player_id}' in game '{game_id}' to {config.max_samples_per_player} images.")
                    
                    # Iterate through image files
                    for image_file in image_files:
                        image_path = os.path.join(player_dir_path, image_file)
                        if os.path.isfile(image_path):
                            # Extract image information
                            img_id.append(os.path.splitext(image_file)[0])  # Remove file extension
                            folder.append(player_dir_path)  # Full path to the player's directory
                            player.append(player_id)
                            game.append(game_id)
                            split.append('all')  # Will update later
                            img_type.append('g')  # Default to 'g'; will adjust later

    df = pd.DataFrame({
        "img_id": img_id,
        "folder": folder,  # Full path to the player's directory
        "player": player,
        "game": game,
        "split": split,
        "img_type": img_type,
    })

    # print(f"Total images collected from selected games: {len(df)}")
    return df

#----------------------------------------------------------------------------------------------------------------------#
# Build DataFrame for Selected Games or All Games if No Selection                                                         #
#----------------------------------------------------------------------------------------------------------------------#
# Pass selected_game_endings only if it's provided
df_train = build_dataframe(os.path.join(config.data_dir, 'train'), config.selected_game_endings)
df_test = build_dataframe(os.path.join(config.data_dir, 'test'), config.selected_game_endings)

# Combine the dataframes
df_full = pd.concat([df_train, df_test], ignore_index=True)
print(f"Combined DataFrame contains {len(df_full)} images.")

#----------------------------------------------------------------------------------------------------------------------#
# Assign Splits and Img Types                                                                                           #
#----------------------------------------------------------------------------------------------------------------------#

# Get unique player IDs
unique_players = df_full['player'].unique()
print(f"Total unique players before splitting: {len(unique_players)}")

# Randomly select players for the test split
num_test_players = min(config.num_test_players, len(unique_players))
test_players = np.random.choice(unique_players, size=num_test_players, replace=False)
print(f"Selected {len(test_players)} players for the test split.")

# Update the 'split' column based on player IDs
df_full['split'] = df_full['player'].apply(lambda x: 'test' if x in test_players else 'train')

# Now, for the test split, assign 'q' and 'g' in 'img_type' column
df_full['img_type'] = df_full['img_type'].astype(str)  # Ensure it's of type str

# For each player in the test set, select num_queries_per_player images as query and the rest as gallery
for player_id in test_players:
    player_df = df_full[(df_full['player'] == player_id) & (df_full['split'] == "test")]
    if len(player_df) == 0:
        continue  # Skip if no images for this player

    # Check if player has enough images for queries
    if len(player_df) <= config.num_queries_per_player:
        print(f"Player '{player_id}' has only {len(player_df)} images in test set, setting all as queries.")
        df_full.loc[player_df.index, 'img_type'] = 'q'
    else:
        # Randomly select num_queries_per_player images as query
        query_idx = player_df.sample(n=config.num_queries_per_player, random_state=config.seed).index
        df_full.loc[query_idx, 'img_type'] = 'q'

        # The rest are gallery
        gallery_idx = player_df.index.difference(query_idx)
        df_full.loc[gallery_idx, 'img_type'] = 'g'

# For the training set, set 'img_type' to 'g' (gallery)
df_full.loc[df_full['split'] == 'train', 'img_type'] = 'g'

#----------------------------------------------------------------------------------------------------------------------#
# Build Challenge DataFrame (Only if challenge_game_ending is provided)                                                #
#----------------------------------------------------------------------------------------------------------------------#
df_challenge = pd.DataFrame()  # Initialize empty DataFrame

if config.challenge_game_ending:
    def build_challenge_dataframe(base_paths, game_ending, num_images_per_player, seed=2024):
        img_id = []
        folder = []
        player = []
        game = []
        img_type = []
        split = []
        
        # Set seed for reproducibility
        np.random.seed(seed)
        random.seed(seed)
        
        # Iterate through each base path ('train' and 'test')
        for base_path in base_paths:
            for game_play_dir in os.listdir(base_path):
                game_play_path = os.path.join(base_path, game_play_dir)
                
                if not os.path.isdir(game_play_path):
                    continue  # Skip if not a directory

                # Extract game_id from directory name
                if '-' not in game_play_dir:
                    continue  # Skip directories without '-'
                game_id_full = game_play_dir.split('-')[0]
                game_id = game_id_full

                # Check if game_id ends with the challenge ending
                if not game_id.endswith(game_ending):
                    continue  # Skip games not ending with the specified ending
                
                # Debugging: Print matched challenge game_play_dir
                # print(f"Matched challenge game_play_dir: '{game_play_dir}' with game_id: '{game_id}'")
                
                # Iterate through player directories inside each game-play directory
                for player_id in os.listdir(game_play_path):
                    # Skip player directories with player_id == '0'
                    if player_id == '0':
                        # print(f"Skipping player directory '0' in challenge game '{game_id}'.")
                        continue
                    
                    player_dir_path = os.path.join(game_play_path, player_id)
                    
                    if os.path.isdir(player_dir_path):
                        # Get all image files in the player directory
                        image_files = [f for f in os.listdir(player_dir_path) if os.path.isfile(os.path.join(player_dir_path, f))]
                        
                        # Ensure minimum number of samples per player
                        if len(image_files) < num_images_per_player:
                            print(f"Skipping player '{player_id}' in challenge game '{game_id}' due to insufficient images ({len(image_files)}).")
                            continue  # Skip players with less than required images
                        
                        # Randomly shuffle the image files to ensure randomness
                        random.shuffle(image_files)
                        
                        # Limit to num_images_per_player
                        if len(image_files) > num_images_per_player:
                            image_files = image_files[:num_images_per_player]
                            # Alternatively, if you want to always select exactly num_images_per_player:
                            # image_files = np.random.choice(image_files, size=num_images_per_player, replace=False).tolist()
                        
                        # Iterate through selected image files
                        for image_file in image_files:
                            image_path = os.path.join(player_dir_path, image_file)
                            if os.path.isfile(image_path):
                                # Extract image information
                                img_id.append(os.path.splitext(image_file)[0])  # Remove file extension
                                folder.append(player_dir_path)  # Full path to the player's directory
                                player.append(player_id)
                                game.append(game_id)
                                split.append('challenge')
                                img_type.append('g')  # Default to 'g'; will adjust later

        df = pd.DataFrame({
            "img_id": img_id,
            "folder": folder,  # Full path to the player's directory
            "player": player,
            "game": game,
            "split": split,
            "img_type": img_type,
        })

        print(f"Total challenge images collected: {len(df)}")
        return df
    
    # Build Challenge DataFrame
    # Search both 'train' and 'test' directories for games ending with the challenge_game_ending
    base_paths_for_challenge = [os.path.join(config.data_dir, 'train'), os.path.join(config.data_dir, 'test')]
    df_challenge = build_challenge_dataframe(
        base_paths=base_paths_for_challenge,
        game_ending=config.challenge_game_ending,
        num_images_per_player=config.num_challenge_images_per_player,
        seed=config.seed
    )
    
    # Check if challenge_df is empty
    if df_challenge.empty:
        print("No challenge images found. Please check if game_play directories ending with "
              f"'{config.challenge_game_ending}' exist in 'train' or 'test' directories.")
    else:
        #----------------------------------------------------------------------------------------------------------------------#
        # Assign Query and Gallery Types for Challenge                                                                        #
        #----------------------------------------------------------------------------------------------------------------------#
        # Get unique players in challenge
        challenge_players = df_challenge['player'].unique()
        num_players = len(challenge_players)
        print(f"Total unique players in challenge set: {num_players}")
        
        if num_players == 0:
            print("No challenge players to assign 'q' and 'g' types.")
        else:
            # Determine number of players in query and gallery to balance
            num_query_players = num_players // 2
            num_gallery_players = num_players - num_query_players
            
            # Shuffle players
            shuffled_players = list(challenge_players)
            random.shuffle(shuffled_players)
            
            # Assign players to query and gallery
            query_players = set(shuffled_players[:num_query_players])
            gallery_players = set(shuffled_players[num_query_players:])
            
            print(f"Assigning {len(query_players)} players to query and {len(gallery_players)} players to gallery in the challenge set.")
            
            # Assign img_type based on player assignment
            df_challenge['img_type'] = df_challenge['player'].apply(lambda x: 'q' if x in query_players else 'g')
else:
    print("No challenge_game_ending provided. Skipping challenge DataFrame creation.")

#----------------------------------------------------------------------------------------------------------------------#
# Map player IDs to integer labels starting from 0                                                                      #
#----------------------------------------------------------------------------------------------------------------------#
# Map train and test players
unique_players_full = df_full['player'].unique()
player_id_map_full = {player_id: idx for idx, player_id in enumerate(unique_players_full)}
df_full['player'] = df_full['player'].map(player_id_map_full).astype(int)

if config.challenge_game_ending and not df_challenge.empty:
    # Map challenge players to new unique IDs
    # Find the max current player ID
    max_player_id = df_full['player'].max()
    challenge_unique_players = df_challenge['player'].unique()
    challenge_player_id_map = {player_id: idx + max_player_id + 1 for idx, player_id in enumerate(challenge_unique_players)}
    df_challenge['player'] = df_challenge['player'].map(challenge_player_id_map).astype(int)

#----------------------------------------------------------------------------------------------------------------------#
# Assign Folds for Cross-Validation                                                                                    #
#----------------------------------------------------------------------------------------------------------------------#
df_full['fold'] = -1

# Assign folds only to training data
df_train_full = df_full[df_full['split'] == 'train'].copy()

n_splits = min(config.n_folds, df_train_full['player'].nunique())
print(f"Assigning folds using GroupKFold with {n_splits} splits.")
cv = GroupKFold(n_splits=n_splits)
split = list(cv.split(df_train_full, groups=df_train_full['player']))

for i, (train_idx, val_idx) in enumerate(split):
    df_train_full.loc[df_train_full.index[val_idx], "fold"] = i
    # print(f"Assigned fold {i} with {len(val_idx)} validation samples.")

# Update df_full with the fold information
df_full.update(df_train_full)

#----------------------------------------------------------------------------------------------------------------------#
# Save DataFrames                                                                                                      #
#----------------------------------------------------------------------------------------------------------------------#
# Save the combined training and test DataFrame
train_df_path_saved = os.path.join(config.save_dir, "train_df.csv")
df_full.to_csv(train_df_path_saved, index=False)

# Save the challenge DataFrame only if it exists
if config.challenge_game_ending and not df_challenge.empty:
    challenge_df_path_saved = os.path.join(config.save_dir, "challenge_df.csv")
    df_challenge.to_csv(challenge_df_path_saved, index=False)
    challenge_df_saved = True
else:
    challenge_df_saved = False

# Print summary
print("\nDataFrames saved:")
print(f" - Training/Test DataFrame: {train_df_path_saved}")
if challenge_df_saved:
    print(f" - Challenge DataFrame: {challenge_df_path_saved}")

print("\nSummary:")
print(f"Number of total images: {len(df_full)}")
print(f"Number of training images: {len(df_full[df_full['split'] == 'train'])}")
print(f"Number of testing images: {len(df_full[df_full['split'] == 'test'])}")
print(f"Number of unique training players: {df_full[df_full['split'] == 'train']['player'].nunique()}")
print(f"Number of unique testing players: {df_full[df_full['split'] == 'test']['player'].nunique()}")
if challenge_df_saved:
    print(f"Number of challenge images: {len(df_challenge)}")
    print(f"Number of unique challenge players: {df_challenge['player'].nunique()}")


In [2]:
# train

import os
import sys
import torch
import random
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader  

# Import custom modules
from clipreid.loss import ClipLoss
from clipreid.trainer import train, get_scheduler
from clipreid.utils import Logger, setup_system, print_line
from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores

@dataclass
class Configuration:
    # Model configurations
    model: str = ('ViT-L-14', 'openai')
    remove_proj = True
    img_size: int = (224, 224)
    mean:   float = (0.485, 0.456, 0.406)
    std:    float = (0.229, 0.224, 0.225)
    
    # Data settings
    train_on_all: bool = False
    fold: int = -1  # Use -1 to specify custom behavior
    seed: int = 2024  # Updated seed for consistency
    epochs: int = 4
    batch_size: int = 16
    batch_size_eval: int = 64
    gpu_ids: tuple = (0,)
    mixed_precision: bool = True
    lr: float = 0.00004
    scheduler: str = "polynomial"
    warmup_epochs: float = 1.0
    lr_end: float = 0.00001
    gradient_clipping: float = None
    grad_checkpointing: bool = False
    gradient_accumulation: int = 1
    label_smoothing: float = 0.1
    zero_shot: bool = True
    rerank: bool = True
    normalize_features: bool = True
    data_dir: str = "/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops"
    data_csv: str = "./data/nsva/train_df.csv"
    prob_flip: float = 0.5
    model_path: str = "./model_nsva"
    checkpoint_start: str = None
    verbose: bool = True 
    num_workers: int = 0 if os.name == 'nt' else 8  
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu' 
    cudnn_benchmark: bool = True
    cudnn_deterministic: bool = True      

# Initialize configuration
config = Configuration()

# Ensure the model path exists
if isinstance(config.model, tuple):
    # Clip models
    if config.train_on_all:
        model_path = "{}/{}_{}/all_data_seed_{}".format(config.model_path,
                                               config.model[0],
                                               config.model[1],
                                               config.seed)
    else:
        model_path = "{}/{}_{}/fold{}_seed_{}".format(config.model_path,
                                               config.model[0],
                                               config.model[1],
                                               config.fold,
                                               config.seed)
else:
    # Timm models
    if config.train_on_all:
        model_path = "{}/{}/all_data_seed_{}".format(config.model_path,
                                                     config.model,
                                                     config.seed)
    else:
        model_path = "{}/{}/fold{}_seed_{}".format(config.model_path,
                                            config.model,
                                            config.seed)

if not os.path.exists(model_path):
    os.makedirs(model_path)

# Redirect print to both console and log file
sys.stdout = Logger("{}/log.txt".format(model_path))

# Set seed for reproducibility
setup_system(seed=config.seed,
             cudnn_benchmark=config.cudnn_benchmark,
             cudnn_deterministic=config.cudnn_deterministic)

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):

    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj
                          )
    
    img_size = model.get_image_size()
    
    mean=(0.48145466, 0.4578275, 0.40821073)
    std=(0.26862954, 0.26130258, 0.27577711)
    
    if config.grad_checkpointing: 
        model.set_grad_checkpoint(enable=config.grad_checkpointing)
       
else:
    model = TimmModel(config.model,
                      pretrained=True)

    img_size = config.img_size
    mean = config.mean
    std = config.std
    
# Load pretrained checkpoint if provided   
if config.checkpoint_start is not None:  
    print("\nStart from:", config.checkpoint_start)
    model_state_dict = torch.load(config.checkpoint_start)  
    model.load_state_dict(model_state_dict, strict=True)
    
# Data parallel
print("\nGPUs available:", torch.cuda.device_count())  
if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
    print("Using Data Parallel with GPU IDs: {}".format(config.gpu_ids))
    model = torch.nn.DataParallel(model, device_ids=config.gpu_ids)    
    multi_gpu = True
else:
    multi_gpu = False  
    
# Model to device   
model = model.to(config.device)

print("\nImage Size:", img_size)
print("Mean: {}".format(mean))
print("Std:  {}".format(std)) 

#----------------------------------------------------------------------------------------------------------------------#  
# DataLoader                                                                                                           #
#----------------------------------------------------------------------------------------------------------------------#  
# Data
df = pd.read_csv(config.data_csv)

# Split data
if config.train_on_all:
    df_train = df[df['split'] == 'train']
    df_test = df[df["split"] == "test"]
else:
    if config.fold == -1:
        # Use given test split
        df_train = df[df["split"] == "train"]
        df_test = df[df["split"] == "test"]
    else:
        # Use custom folds
        df_train = df[df["fold"] != config.fold]
        df_test = df[df["fold"] == config.fold]

# Transforms
val_transforms, train_transforms = get_transforms(img_size, mean, std)

# Custom Dataset Classes
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset
import cv2
import copy
import torch
from torchvision import transforms

class CustomTrainDataset(Dataset):
    def __init__(self,
                 df,
                 image_transforms=None,
                 prob_flip=0.5,
                 shuffle_batch_size=16):
        
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.prob_flip = prob_flip
        self.shuffle_batch_size = shuffle_batch_size
        
        print("\nImages train: {}".format(len(self.df)))
        self.images = self.df['img_id'].tolist()
        
        # Map img_id to index in dataframe for quick access
        self.img_id_to_index = {img_id: idx for idx, img_id in enumerate(self.df['img_id'])}
        
        # dict for all images for a given player
        self.player_images = defaultdict(list)
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = row['player']
            self.player_images[player].append(img_id)
  
        # dict for all gallery images for a given image
        self.player_images_other = {}
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = row['player']
            other_images = copy.deepcopy(self.player_images[player])
            other_images.remove(img_id)
            self.player_images_other[img_id] = np.array(other_images)

        self.samples = copy.deepcopy(self.images)
        self.shuffle()
            
    def __getitem__(self, index):

        # Query image
        img_id_query = self.samples[index]
        record_query = self.df.loc[self.img_id_to_index[img_id_query]]
        img_path_query = os.path.join(record_query['folder'], img_id_query + '.jpg')
        
        img_query = cv2.imread(img_path_query)
        img_query = cv2.cvtColor(img_query, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img_query = self.image_transforms(image=img_query)['image']
        
        # Randomly select one other image of the same player as gallery image
        if len(self.player_images_other[img_id_query]) > 0:
            img_id_gallery = np.random.choice(self.player_images_other[img_id_query], 1)[0]
        else:
            img_id_gallery = img_id_query  # If no other images, use the same image
        record_gallery = self.df.loc[self.img_id_to_index[img_id_gallery]]
        img_path_gallery = os.path.join(record_gallery['folder'], img_id_gallery + '.jpg')
    
        img_gallery = cv2.imread(img_path_gallery)
        img_gallery = cv2.cvtColor(img_gallery, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img_gallery = self.image_transforms(image=img_gallery)['image']
  
        # Player ID as label
        player = torch.tensor(int(record_query['player']), dtype=torch.long)
        
        # Random flip both images
        if np.random.random() < self.prob_flip:
            img_query = transforms.functional.hflip(img_query)
            img_gallery = transforms.functional.hflip(img_gallery)
        
        return img_query, img_gallery, player
    
    def __len__(self):
        return len(self.samples)

    def shuffle(self):
        '''
        Custom shuffle function to prevent having the same player twice in the same batch.
        '''
        
        img_ids_select = copy.deepcopy(self.images)
        random.shuffle(img_ids_select)

        batches = []
        players_batch = set()
        batch = []
        break_counter = 0
        
        while img_ids_select:
            img_id = img_ids_select.pop(0)
            player = self.df.loc[self.img_id_to_index[img_id]]['player']
            
            if player not in players_batch:
                players_batch.add(player)
                batch.append(img_id)
                
                # If batch is filled, reset
                if len(batch) == self.shuffle_batch_size:
                    batches.extend(batch)
                    batch = []
                    players_batch = set()
                    break_counter = 0
            else:
                # Append at the end for later consideration
                img_ids_select.append(img_id)
                break_counter += 1
                
                if break_counter >= len(img_ids_select):
                    # Can't fill batch without repeating players, so we accept duplicates
                    batches.extend(batch)
                    batch = []
                    players_batch = set()
                    break_counter = 0
        
        # Add any remaining images
        if batch:
            batches.extend(batch)
        
        self.samples = batches
        print("\nShuffle Training Data:")
        print("Length Train:", len(self.samples))
        if len(self.samples) > 0:
            print("First Element: {}".format(self.samples[0]))

class CustomTestDataset(Dataset):
    def __init__(self,
                 df,
                 image_transforms=None):
        
        self.df = df.reset_index(drop=True)
        self.image_transforms = image_transforms
        self.images = self.df['img_id'].tolist()
        
        self.query = []
        self.gallery = []
        self.all = []
        
        for idx, row in self.df.iterrows():
            img_id = row['img_id']
            player = int(row['player'])
            img_type = row['img_type']
            folder = row['folder']
            img_path = os.path.join(folder, img_id + '.jpg')
            self.all.append((img_path, player, -1))
            
            if img_type == "q":
                self.query.append((img_path, player, 0))
            else:
                self.gallery.append((img_path, player, 1))
        
    def __getitem__(self, index):

        img_id = self.df.loc[index, 'img_id']
        img_path = os.path.join(self.df.loc[index, 'folder'], img_id + '.jpg')
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.image_transforms:
            img = self.image_transforms(image=img)['image']
             
        player = int(self.df.loc[index]['player'])
        
        if self.df.loc[index]["img_type"] == "q":
            img_type = 0
        else:
            img_type = 1
            
        return img, img_path, player, img_type
    
    def __len__(self):
        return len(self.df)

# Initialize the custom datasets
train_dataset = CustomTrainDataset(df=df_train,
                                   image_transforms=train_transforms,
                                   prob_flip=config.prob_flip,
                                   shuffle_batch_size=config.batch_size)

train_loader = DataLoader(train_dataset,
                          batch_size=config.batch_size,
                          num_workers=config.num_workers,
                          shuffle=False,
                          pin_memory=True,
                          drop_last=True)

# Validation
test_dataset = CustomTestDataset(df=df_test,
                                 image_transforms=val_transforms)

test_loader = DataLoader(test_dataset,
                         batch_size=config.batch_size_eval,
                         num_workers=config.num_workers,
                         shuffle=False,
                         pin_memory=True)

#----------------------------------------------------------------------------------------------------------------------#  
# Loss                                                                                                                 #
#----------------------------------------------------------------------------------------------------------------------#  
loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)
loss_function = ClipLoss(loss_function=loss_fn,
                         device=config.device)

#----------------------------------------------------------------------------------------------------------------------#  
# Optimizer and Scaler                                                                                                 #
#----------------------------------------------------------------------------------------------------------------------#  
optimizer = torch.optim.AdamW(model.parameters(), lr=config.lr)

if config.mixed_precision:
    scaler = torch.cuda.amp.GradScaler(init_scale=2.**10)
else:
    scaler = None
    
#----------------------------------------------------------------------------------------------------------------------#  
# Scheduler                                                                                                            #
#----------------------------------------------------------------------------------------------------------------------#  
if config.scheduler is not None:
    scheduler = get_scheduler(config,
                              optimizer,
                              train_loader_length=len(train_loader))       
else:
    scheduler = None
   
#----------------------------------------------------------------------------------------------------------------------#  
# Zero Shot                                                                                                            #
#----------------------------------------------------------------------------------------------------------------------#  
if config.zero_shot:
    
    print_line(name="Zero-Shot", length=80)
    
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                          test_dataset.query,
                                                          test_dataset.gallery,
                                                          rerank=config.rerank)
    
    print("\nWithout re-ranking:")
    mAP = compute_scores(dist_matrix,
                         test_dataset.query,
                         test_dataset.gallery)
    
    if dist_matrix_rerank is not None:
        print("\nWith re-ranking:")
        mAP = compute_scores(dist_matrix_rerank,
                             test_dataset.query,
                             test_dataset.gallery)
        

#----------------------------------------------------------------------------------------------------------------------#  
# Train                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  
for epoch in range(1, config.epochs+1):

    print_line(name="Epoch: {}".format(epoch), length=80)
    
    # Train
    train_loss = train(model,
                       dataloader=train_loader,
                       loss_function=loss_function,
                       optimizer=optimizer,
                       device=config.device,
                       scheduler=scheduler,
                       scaler=scaler,
                       gradient_accumulation=config.gradient_accumulation,
                       gradient_clipping=config.gradient_clipping,
                       verbose=config.verbose,
                       multi_gpu=multi_gpu)

    print("Avg. Train Loss = {:.4f} - Lr = {:.6f}\n".format(train_loss,
                                                           optimizer.param_groups[0]['lr']))
    # Evaluate
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    dist_matrix, dist_matrix_rerank = compute_dist_matrix(features_dict, 
                                                          test_dataset.query,
                                                          test_dataset.gallery,
                                                          rerank=config.rerank)
    
    print("\nWithout re-ranking:")
    mAP = compute_scores(dist_matrix,
                         test_dataset.query,
                         test_dataset.gallery)
    
    if dist_matrix_rerank is not None:
        print("\nWith re-ranking:")
        mAP_rerank = compute_scores(dist_matrix_rerank,
                                    test_dataset.query,
                                    test_dataset.gallery)
        
    checkpoint_path = '{}/weights_e{}.pth'.format(model_path, epoch)
            
    # Save model  
    if torch.cuda.device_count() > 1 and len(config.gpu_ids) > 1:
        torch.save(model.module.state_dict(), checkpoint_path)
    else:
        torch.save(model.state_dict(), checkpoint_path)
    
    # Shuffle data for next epoch
    train_loader.dataset.shuffle()


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()



Model: ('ViT-L-14', 'openai')




Remove Projection Layer - old output size: 768 - new output size: 1024

GPUs available: 2

Image Size: (224, 224)
Mean: (0.48145466, 0.4578275, 0.40821073)
Std:  (0.26862954, 0.26130258, 0.27577711)

Images train: 15193

Shuffle Training Data:
Length Train: 15193
First Element: 000128_9

Warmup Epochs: 1.0 - Warmup Steps: 949
Train Epochs:  4 - Train Steps:  3796

Scheduler: polynomial - max LR: 4e-05 - end LR: 1e-05

----------------------------------[Zero-Shot]-----------------------------------


  scaler = torch.cuda.amp.GradScaler(init_scale=2.**10)
Test : 100%|##########| 50/50 [00:34<00:00,  1.43it/s]



Without re-ranking:
mAP: 58.47%
CMC Scores    allshots      cuhk03  market1501
  top-1         21.69%      57.87%      98.67%
  top-5         30.02%      86.27%     100.00%
  top-10        34.49%      94.93%     100.00%

With re-ranking:
mAP: 66.73%
CMC Scores    allshots      cuhk03  market1501
  top-1         26.08%      66.53%     100.00%
  top-5         40.27%      86.67%     100.00%
  top-10        44.98%      94.80%     100.00%

-----------------------------------[Epoch: 1]-----------------------------------


Train: 100%|##########| 949/949 [06:33<00:00,  2.41it/s, loss=2.82, lr=0.000040]


Avg. Train Loss = 2.3967 - Lr = 0.000040



Test : 100%|##########| 50/50 [00:41<00:00,  1.21it/s]



Without re-ranking:
mAP: 62.92%
CMC Scores    allshots      cuhk03  market1501
  top-1         24.49%      58.40%     100.00%
  top-5         32.29%      88.67%     100.00%
  top-10        36.73%      97.60%     100.00%

With re-ranking:
mAP: 72.20%
CMC Scores    allshots      cuhk03  market1501
  top-1         29.21%      69.20%      98.67%
  top-5         46.26%      90.00%      98.67%
  top-10        51.39%      96.80%     100.00%

Shuffle Training Data:
Length Train: 15193
First Element: 000417_7

-----------------------------------[Epoch: 2]-----------------------------------


Train: 100%|##########| 949/949 [06:41<00:00,  2.36it/s, loss=2.78, lr=0.000023]


Avg. Train Loss = 1.9904 - Lr = 0.000023



Test : 100%|##########| 50/50 [00:41<00:00,  1.21it/s]



Without re-ranking:
mAP: 49.60%
CMC Scores    allshots      cuhk03  market1501
  top-1         19.21%      48.40%     100.00%
  top-5         25.61%      78.80%     100.00%
  top-10        28.58%      92.53%     100.00%

With re-ranking:
mAP: 61.77%
CMC Scores    allshots      cuhk03  market1501
  top-1         24.82%      59.20%     100.00%
  top-5         38.62%      82.67%     100.00%
  top-10        43.92%      94.27%     100.00%

Shuffle Training Data:
Length Train: 15193
First Element: 000119_7

-----------------------------------[Epoch: 3]-----------------------------------


Train: 100%|##########| 949/949 [06:42<00:00,  2.36it/s, loss=2.52, lr=0.000013]


Avg. Train Loss = 1.7472 - Lr = 0.000013



Test : 100%|##########| 50/50 [00:42<00:00,  1.18it/s]



Without re-ranking:
mAP: 62.46%
CMC Scores    allshots      cuhk03  market1501
  top-1         27.22%      58.13%     100.00%
  top-5         35.38%      90.13%     100.00%
  top-10        39.47%      95.73%     100.00%

With re-ranking:
mAP: 73.25%
CMC Scores    allshots      cuhk03  market1501
  top-1         32.06%      70.40%      98.67%
  top-5         51.84%      91.60%      98.67%
  top-10        56.28%      97.47%      98.67%

Shuffle Training Data:
Length Train: 15193
First Element: 000358_11

-----------------------------------[Epoch: 4]-----------------------------------


Train: 100%|##########| 949/949 [06:43<00:00,  2.35it/s, loss=2.54, lr=0.000010]


Avg. Train Loss = 1.5876 - Lr = 0.000010



Test : 100%|##########| 50/50 [00:41<00:00,  1.19it/s]



Without re-ranking:
mAP: 54.67%
CMC Scores    allshots      cuhk03  market1501
  top-1         22.29%      53.07%      98.67%
  top-5         29.52%      80.93%     100.00%
  top-10        33.35%      93.33%     100.00%

With re-ranking:
mAP: 65.73%
CMC Scores    allshots      cuhk03  market1501
  top-1         29.58%      61.07%     100.00%
  top-5         45.14%      82.13%     100.00%
  top-10        50.51%      94.27%     100.00%

Shuffle Training Data:
Length Train: 15193
First Element: 000118_9


In [None]:
# evaluate

import os
import torch
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader

from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.dataset import TestDataset, ChallengeDataset
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores, write_mat_csv
from clipreid.utils import print_line

@dataclass
class Configuration:
    '''
    --------------------------------------------------------------------------
    Open Clip Models:
    --------------------------------------------------------------------------    
    - ('RN50', 'openai')
    - ('RN50', 'yfcc15m')
    - ('RN50', 'cc12m')
    - ('RN50-quickgelu', 'openai')
    - ('RN50-quickgelu', 'yfcc15m')
    - ('RN50-quickgelu', 'cc12m')
    - ('RN101', 'openai')
    - ('RN101', 'yfcc15m')
    - ('RN101-quickgelu', 'openai')
    - ('RN101-quickgelu', 'yfcc15m')
    - ('RN50x4', 'openai')
    - ('RN50x16', 'openai')
    - ('RN50x64', 'openai')
    - ('ViT-B-32', 'openai')
    - ('ViT-B-32', 'laion2b_e16')
    - ('ViT-B-32', 'laion400m_e31')
    - ('ViT-B-32', 'laion400m_e32')
    - ('ViT-B-32-quickgelu', 'openai')
    - ('ViT-B-32-quickgelu', 'laion400m_e31')
    - ('ViT-B-32-quickgelu', 'laion400m_e32')
    - ('ViT-B-16', 'openai')
    - ('ViT-B-16', 'laion400m_e31')
    - ('ViT-B-16', 'laion400m_e32')
    - ('ViT-B-16-plus-240', 'laion400m_e31')
    - ('ViT-B-16-plus-240', 'laion400m_e32')
    - ('ViT-L-14', 'openai')
    - ('ViT-L-14', 'laion400m_e31')
    - ('ViT-L-14', 'laion400m_e32')
    - ('ViT-L-14-336', 'openai')
    - ('ViT-H-14', 'laion2b_s32b_b79k')
    - ('ViT-g-14', 'laion2b_s12b_b42k')
    --------------------------------------------------------------------------
    Timm Models:
    --------------------------------------------------------------------------
    - 'convnext_base_in22ft1k'
    - 'convnext_large_in22ft1k'
    - 'vit_base_patch16_224'
    - 'vit_large_patch16_224'
    - ...
    - https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv
    --------------------------------------------------------------------------
    '''

    # Model
    model: tuple = ('ViT-L-14', 'openai')   # ('name of Clip model', 'name of dataset') | 'name of Timm model'
    remove_proj: bool = True                # Remove projection for Clip ViT models

    # Settings only for Timm models
    img_size: tuple = (224, 224)             # Image size for Timm models
    mean: tuple = (0.485, 0.456, 0.406)     # Mean of ImageNet
    std: tuple = (0.229, 0.224, 0.225)      # Std of ImageNet

    # Eval
    batch_size: int = 64                    # Batch size for evaluation
    normalize_features: bool = True         # L2 normalize features during eval  

    # Split for Eval
    fold: int = -1                          # -1 for given test split | int >=0 for custom folds 

    # Checkpoints: tuple of str for ensemble (checkpoint1, checkpoint2, ...)
    checkpoints: tuple = (
        "./model_nsva/ViT-L-14_openai/fold0_seed_2024/weights_e4.pth",
        "./model_nsva/ViT-L-14_openai/fold1_seed_2024/weights_e4.pth"
    )

    # Dataset
    data_dir: str = "./data/nsva"
    challenge_csv: str = "./data/nsva/challenge_df.csv"

    # Show progress bar
    verbose: bool = True 

    # Set num_workers to 0 if OS is Windows
    num_workers: int = 0 if os.name == 'nt' else 8  

    # Use GPU if available
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu' 

# Initialize configuration
config = Configuration()

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  

print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):
    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj)
    
    img_size = model.get_image_size()
    
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)
else:
    model = TimmModel(config.model,
                      pretrained=True)

    img_size = config.img_size
    mean = config.mean
    std = config.std


dist_matrix_list = []
dist_matrix_rerank_list = []

# Ensure checkpoints is a list
if not isinstance(config.checkpoints, (list, tuple)):
    checkpoints = [config.checkpoints]
else:
    checkpoints = config.checkpoints

for checkpoint in checkpoints: 

    print_line(name=checkpoint, length=80)
        
    # Load pretrained Checkpoint     
    if not os.path.exists(checkpoint):
        print(f"Checkpoint '{checkpoint}' does not exist. Skipping.")
        continue
    model_state_dict = torch.load(checkpoint, map_location=config.device)
    model.load_state_dict(model_state_dict, strict=True)    
      
    # Model to device   
    model = model.to(config.device)
    
    print("\nImage Size:", img_size)
    print("Mean: {}".format(mean))
    print("Std:  {}".format(std)) 
    
    #------------------------------------------------------------------------------------------------------------------#  
    # DataLoader                                                                                                       #
    #------------------------------------------------------------------------------------------------------------------#  
    
    # Transforms
    val_transforms, _ = get_transforms(img_size, mean, std)
    
    # Dataframes
    train_df_path = os.path.join(config.data_dir, "train_df.csv")
    challenge_df_path = config.challenge_csv
    
    if not os.path.exists(train_df_path):
        print(f"Training DataFrame not found at {train_df_path}")
        exit(1)
    if not os.path.exists(challenge_df_path):
        print(f"Challenge DataFrame not found at {challenge_df_path}")
        exit(1)
    
    df_train = pd.read_csv(train_df_path)
    df_challenge = pd.read_csv(challenge_df_path)
     
    if config.fold == -1:
        # Use given test split
        df_test = df_train[df_train["split"] == "test"]
    else:
        # Use custom folds
        df_test = df_train[df_train["fold"] == config.fold]
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Validation and Challenge                                                                                         #
    #------------------------------------------------------------------------------------------------------------------#
    test_dataset = TestDataset(img_path=config.data_dir,
                               df=df_test,
                               image_transforms=val_transforms)


    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers,
                             shuffle=False,
                             pin_memory=True)

    # Challenge
    challenge_dataset = ChallengeDataset(df=df_challenge,
                                         image_transforms=val_transforms)

    challenge_loader = DataLoader(challenge_dataset,
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  shuffle=False,
                                  pin_memory=True)
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Test                                                                                                             #
    #------------------------------------------------------------------------------------------------------------------# 
    print_line(name="Eval Fold: {}".format(config.fold), length=80)
    
    # Extract features for test set
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    # Compute distance matrix for test set
    dist_matrix_test, dist_matrix_test_rerank = compute_dist_matrix(features_dict, 
                                                                    test_dataset.query,
                                                                    test_dataset.gallery,
                                                                    rerank=True)
    
    # Without re-ranking
    print("\nWithout re-ranking:")
    mAP_test = compute_scores(dist_matrix_test,
                         test_dataset.query,
                         test_dataset.gallery,
                         cmc_scores=True)

    print(f"Test mAP: {mAP_test['mAP']:.4f}")
    
    # Save distance matrix without re-ranking
    save_path_test = os.path.join(os.path.dirname(checkpoint), "test_dmat.csv")
    print("Writing distance matrix:", save_path_test)
    write_mat_csv(save_path_test,
                  dist_matrix_test,
                  test_dataset.query,
                  test_dataset.gallery) 
    
    # With re-ranking
    print("\nWith re-ranking:")
    mAP_test_rerank = compute_scores(dist_matrix_test_rerank,
                         test_dataset.query,
                         test_dataset.gallery,
                         cmc_scores=True)

    print(f"Test mAP with re-ranking: {mAP_test_rerank['mAP']:.4f}")
    
    # Save distance matrix with re-ranking
    save_path_test_rerank = os.path.join(os.path.dirname(checkpoint), "test_dmat_rerank.csv")
    print("Writing distance matrix:", save_path_test_rerank)
    write_mat_csv(save_path_test_rerank,
                  dist_matrix_test_rerank,
                  test_dataset.query,
                  test_dataset.gallery) 
    
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Challenge                                                                                                        #
    #------------------------------------------------------------------------------------------------------------------#
    print_line(name="Challenge", length=80)
    
    # Extract features for challenge set
    features_dict_challenge = predict(model,
                            dataloader=challenge_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    # Compute distance matrix for challenge set
    dist_matrix_challenge, dist_matrix_challenge_rerank = compute_dist_matrix(features_dict_challenge, 
                                                                                challenge_dataset.query,
                                                                                challenge_dataset.gallery,
                                                                                rerank=True)
    
    # Collect for ensemble
    if not dist_matrix_challenge.size == 0:
        dist_matrix_list.append(dist_matrix_challenge)
    if dist_matrix_challenge_rerank.size != 0:
        dist_matrix_rerank_list.append(dist_matrix_challenge_rerank)
    
    # Compute and print scores for challenge set without re-ranking
    print("\nChallenge Set without re-ranking:")
    mAP_challenge = compute_scores(dist_matrix_challenge,
                         challenge_dataset.query,
                         challenge_dataset.gallery,
                         cmc_scores=True)

    print(f"Challenge mAP: {mAP_challenge['mAP']:.4f}")
    
    # Save distance matrix without re-ranking for challenge
    save_path_challenge = os.path.join(os.path.dirname(checkpoint), "challenge_dmat.csv")
    print("Writing distance matrix:", save_path_challenge)
    write_mat_csv(save_path_challenge,
                  dist_matrix_challenge,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 
    
    # Compute and print scores for challenge set with re-ranking
    print("\nChallenge Set with re-ranking:")
    mAP_challenge_rerank = compute_scores(dist_matrix_challenge_rerank,
                         challenge_dataset.query,
                         challenge_dataset.gallery,
                         cmc_scores=True)

    print(f"Challenge mAP with re-ranking: {mAP_challenge_rerank['mAP']:.4f}")
    
    # Save distance matrix with re-ranking for challenge
    save_path_challenge_rerank = os.path.join(os.path.dirname(checkpoint), "challenge_dmat_rerank.csv")
    print("Writing distance matrix:", save_path_challenge_rerank)
    write_mat_csv(save_path_challenge_rerank,
                  dist_matrix_challenge_rerank,
                  challenge_dataset.query,
                  challenge_dataset.gallery)


#----------------------------------------------------------------------------------------------------------------------#  
# Ensemble                                                                                                             #
#----------------------------------------------------------------------------------------------------------------------#
if len(dist_matrix_list) > 1:
    
    print_line(name="Ensemble", length=80)
    
    # Without re-ranking
    dist_matrix_ensemble = np.stack(dist_matrix_list, axis=0).mean(0)
    save_path_ensemble = os.path.join(os.path.dirname(config.checkpoints[0]), "challenge_dmat_ensemble.csv")
    print("Writing distance matrix:", save_path_ensemble)
    write_mat_csv(save_path_ensemble,
                  dist_matrix_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 
    
    # With re-ranking
    dist_matrix_rerank_ensemble = np.stack(dist_matrix_rerank_list, axis=0).mean(0)
    save_path_rerank_ensemble = os.path.join(os.path.dirname(config.checkpoints[0]), "challenge_dmat_rerank_ensemble.csv")
    print("Writing distance matrix:", save_path_rerank_ensemble)
    write_mat_csv(save_path_rerank_ensemble,
                  dist_matrix_rerank_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 

In [None]:
# predict

import os
import torch
import pandas as pd
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader

from clipreid.model import TimmModel, OpenClipModel
from clipreid.transforms import get_transforms
from clipreid.dataset import TestDataset, ChallengeDataset
from clipreid.evaluator import predict, compute_dist_matrix, compute_scores, write_mat_csv
from clipreid.utils import print_line

@dataclass
class Configuration:
    '''
    --------------------------------------------------------------------------
    Open Clip Models:
    --------------------------------------------------------------------------    
    - ('RN50', 'openai')
    - ('RN50', 'yfcc15m')
    - ('RN50', 'cc12m')
    - ('RN50-quickgelu', 'openai')
    - ('RN50-quickgelu', 'yfcc15m')
    - ('RN50-quickgelu', 'cc12m')
    - ('RN101', 'openai')
    - ('RN101', 'yfcc15m')
    - ('RN101-quickgelu', 'openai')
    - ('RN101-quickgelu', 'yfcc15m')
    - ('RN50x4', 'openai')
    - ('RN50x16', 'openai')
    - ('RN50x64', 'openai')
    - ('ViT-B-32', 'openai')
    - ('ViT-B-32', 'laion2b_e16')
    - ('ViT-B-32', 'laion400m_e31')
    - ('ViT-B-32', 'laion400m_e32')
    - ('ViT-B-32-quickgelu', 'openai')
    - ('ViT-B-32-quickgelu', 'laion400m_e31')
    - ('ViT-B-32-quickgelu', 'laion400m_e32')
    - ('ViT-B-16', 'openai')
    - ('ViT-B-16', 'laion400m_e31')
    - ('ViT-B-16', 'laion400m_e32')
    - ('ViT-B-16-plus-240', 'laion400m_e31')
    - ('ViT-B-16-plus-240', 'laion400m_e32')
    - ('ViT-L-14', 'openai')
    - ('ViT-L-14', 'laion400m_e31')
    - ('ViT-L-14', 'laion400m_e32')
    - ('ViT-L-14-336', 'openai')
    - ('ViT-H-14', 'laion2b_s32b_b79k')
    - ('ViT-g-14', 'laion2b_s12b_b42k')
    --------------------------------------------------------------------------
    Timm Models:
    --------------------------------------------------------------------------
    - 'convnext_base_in22ft1k'
    - 'convnext_large_in22ft1k'
    - 'vit_base_patch16_224'
    - 'vit_large_patch16_224'
    - ...
    - https://github.com/rwightman/pytorch-image-models/blob/master/results/results-imagenet.csv
    --------------------------------------------------------------------------
    '''

    # Model
    model: tuple = ('ViT-L-14', 'openai')   # ('name of Clip model', 'name of dataset') | 'name of Timm model'
    remove_proj: bool = True                # Remove projection for Clip ViT models

    # Settings only for Timm models
    img_size: tuple = (224, 224)             # Image size for Timm models
    mean: tuple = (0.485, 0.456, 0.406)     # Mean of ImageNet
    std: tuple = (0.229, 0.224, 0.225)      # Std of ImageNet

    # Eval
    batch_size: int = 64                    # Batch size for evaluation
    normalize_features: bool = True         # L2 normalize features during eval  

    # Split for Eval
    fold: int = -1                          # -1 for given test split | int >=0 for custom folds 

    # Checkpoints: tuple of str for ensemble (checkpoint1, checkpoint2, ...)
    checkpoints: tuple = (
        "./model_nsva/ViT-L-14_openai/fold0_seed_2024/weights_e4.pth",
        "./model_nsva/ViT-L-14_openai/fold1_seed_2024/weights_e4.pth"
    )

    # Dataset
    data_dir: str = "./data/nsva"
    challenge_csv: str = "./data/nsva/challenge_df.csv"

    # Show progress bar
    verbose: bool = True 

    # Set num_workers to 0 if OS is Windows
    num_workers: int = 0 if os.name == 'nt' else 8  

    # Use GPU if available
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu' 

# Initialize configuration
config = Configuration()

#----------------------------------------------------------------------------------------------------------------------#  
# Model                                                                                                                #
#----------------------------------------------------------------------------------------------------------------------#  

print("\nModel: {}".format(config.model))

if isinstance(config.model, tuple):
    model = OpenClipModel(config.model[0],
                          config.model[1],
                          remove_proj=config.remove_proj)
    
    img_size = model.get_image_size()
    
    mean = (0.48145466, 0.4578275, 0.40821073)
    std = (0.26862954, 0.26130258, 0.27577711)
else:
    model = TimmModel(config.model,
                      pretrained=True)

    img_size = config.img_size
    mean = config.mean
    std = config.std


dist_matrix_list = []
dist_matrix_rerank_list = []

# Ensure checkpoints is a list
if not isinstance(config.checkpoints, (list, tuple)):
    checkpoints = [config.checkpoints]
else:
    checkpoints = config.checkpoints

for checkpoint in checkpoints: 

    print_line(name=checkpoint, length=80)
        
    # Load pretrained Checkpoint     
    if not os.path.exists(checkpoint):
        print(f"Checkpoint '{checkpoint}' does not exist. Skipping.")
        continue
    model_state_dict = torch.load(checkpoint, map_location=config.device)
    model.load_state_dict(model_state_dict, strict=True)    
      
    # Model to device   
    model = model.to(config.device)
    
    print("\nImage Size:", img_size)
    print("Mean: {}".format(mean))
    print("Std:  {}".format(std)) 
    
    #------------------------------------------------------------------------------------------------------------------#  
    # DataLoader                                                                                                       #
    #------------------------------------------------------------------------------------------------------------------#  
    
    # Transforms
    val_transforms, _ = get_transforms(img_size, mean, std)
    
    # Dataframes
    train_df_path = os.path.join(config.data_dir, "train_df.csv")
    challenge_df_path = config.challenge_csv
    
    if not os.path.exists(train_df_path):
        print(f"Training DataFrame not found at {train_df_path}")
        exit(1)
    if not os.path.exists(challenge_df_path):
        print(f"Challenge DataFrame not found at {challenge_df_path}")
        exit(1)
    
    df_train = pd.read_csv(train_df_path)
    df_challenge = pd.read_csv(challenge_df_path)
     
    if config.fold == -1:
        # Use given test split
        df_test = df_train[df_train["split"] == "test"]
    else:
        # Use custom folds
        df_test = df_train[df_train["fold"] == config.fold]
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Validation and Challenge                                                                                         #
    #------------------------------------------------------------------------------------------------------------------#
    test_dataset = TestDataset(img_path=config.data_dir,
                               df=df_test,
                               image_transforms=val_transforms)


    test_loader = DataLoader(test_dataset,
                             batch_size=config.batch_size,
                             num_workers=config.num_workers,
                             shuffle=False,
                             pin_memory=True)

    # Challenge
    challenge_dataset = ChallengeDataset(df=df_challenge,
                                         image_transforms=val_transforms)

    challenge_loader = DataLoader(challenge_dataset,
                                  batch_size=config.batch_size,
                                  num_workers=config.num_workers,
                                  shuffle=False,
                                  pin_memory=True)
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Test                                                                                                             #
    #------------------------------------------------------------------------------------------------------------------# 
    print_line(name="Test Evaluation", length=80)
    
    # Extract features for test set
    features_dict = predict(model,
                            dataloader=test_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    # Compute distance matrix for test set
    dist_matrix_test, dist_matrix_test_rerank = compute_dist_matrix(features_dict, 
                                                                    test_dataset.query,
                                                                    test_dataset.gallery,
                                                                    rerank=True)
    
    # Without re-ranking
    print("\nWithout re-ranking:")
    mAP_test = compute_scores(dist_matrix_test,
                         test_dataset.query,
                         test_dataset.gallery,
                         cmc_scores=True)

    print(f"Test mAP: {mAP_test['mAP']:.4f}")
    
    # Save distance matrix without re-ranking
    save_path_test = os.path.join(os.path.dirname(checkpoint), "test_dmat.csv")
    print("Writing distance matrix:", save_path_test)
    write_mat_csv(save_path_test,
                  dist_matrix_test,
                  test_dataset.query,
                  test_dataset.gallery) 
    
    # With re-ranking
    print("\nWith re-ranking:")
    mAP_test_rerank = compute_scores(dist_matrix_test_rerank,
                         test_dataset.query,
                         test_dataset.gallery,
                         cmc_scores=True)

    print(f"Test mAP with re-ranking: {mAP_test_rerank['mAP']:.4f}")
    
    # Save distance matrix with re-ranking
    save_path_test_rerank = os.path.join(os.path.dirname(checkpoint), "test_dmat_rerank.csv")
    print("Writing distance matrix:", save_path_test_rerank)
    write_mat_csv(save_path_test_rerank,
                  dist_matrix_test_rerank,
                  test_dataset.query,
                  test_dataset.gallery) 
    
    
    #------------------------------------------------------------------------------------------------------------------#  
    # Challenge                                                                                                        #
    #------------------------------------------------------------------------------------------------------------------#
    print_line(name="Challenge Evaluation", length=80)
    
    # Extract features for challenge set
    features_dict_challenge = predict(model,
                            dataloader=challenge_loader,
                            device=config.device,
                            normalize_features=config.normalize_features,
                            verbose=config.verbose)
    
    # Compute distance matrix for challenge set
    dist_matrix_challenge, dist_matrix_challenge_rerank = compute_dist_matrix(features_dict_challenge, 
                                                                                challenge_dataset.query,
                                                                                challenge_dataset.gallery,
                                                                                rerank=True)
    
    # Collect for ensemble
    if not dist_matrix_challenge.size == 0:
        dist_matrix_list.append(dist_matrix_challenge)
    if dist_matrix_challenge_rerank.size != 0:
        dist_matrix_rerank_list.append(dist_matrix_challenge_rerank)
    
    # Compute and print scores for challenge set without re-ranking
    print("\nChallenge Set without re-ranking:")
    mAP_challenge = compute_scores(dist_matrix_challenge,
                         challenge_dataset.query,
                         challenge_dataset.gallery,
                         cmc_scores=True)

    print(f"Challenge mAP: {mAP_challenge['mAP']:.4f}")
    
    # Save distance matrix without re-ranking for challenge
    save_path_challenge = os.path.join(os.path.dirname(checkpoint), "challenge_dmat.csv")
    print("Writing distance matrix:", save_path_challenge)
    write_mat_csv(save_path_challenge,
                  dist_matrix_challenge,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 
    
    # Compute and print scores for challenge set with re-ranking
    print("\nChallenge Set with re-ranking:")
    mAP_challenge_rerank = compute_scores(dist_matrix_challenge_rerank,
                         challenge_dataset.query,
                         challenge_dataset.gallery,
                         cmc_scores=True)

    print(f"Challenge mAP with re-ranking: {mAP_challenge_rerank['mAP']:.4f}")
    
    # Save distance matrix with re-ranking for challenge
    save_path_challenge_rerank = os.path.join(os.path.dirname(checkpoint), "challenge_dmat_rerank.csv")
    print("Writing distance matrix:", save_path_challenge_rerank)
    write_mat_csv(save_path_challenge_rerank,
                  dist_matrix_challenge_rerank,
                  challenge_dataset.query,
                  challenge_dataset.gallery)
    

#----------------------------------------------------------------------------------------------------------------------#  
# Ensemble                                                                                                             #
#----------------------------------------------------------------------------------------------------------------------#
if len(dist_matrix_list) > 1:
    
    print_line(name="Ensemble", length=80)
    
    # Without re-ranking
    dist_matrix_ensemble = np.stack(dist_matrix_list, axis=0).mean(0)
    save_path_ensemble = os.path.join(os.path.dirname(config.checkpoints[0]), "challenge_dmat_ensemble.csv")
    print("Writing distance matrix:", save_path_ensemble)
    write_mat_csv(save_path_ensemble,
                  dist_matrix_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 
    
    # With re-ranking
    dist_matrix_rerank_ensemble = np.stack(dist_matrix_rerank_list, axis=0).mean(0)
    save_path_rerank_ensemble = os.path.join(os.path.dirname(config.checkpoints[0]), "challenge_dmat_rerank_ensemble.csv")
    print("Writing distance matrix:", save_path_rerank_ensemble)
    write_mat_csv(save_path_rerank_ensemble,
                  dist_matrix_rerank_ensemble,
                  challenge_dataset.query,
                  challenge_dataset.gallery) 

In [20]:
# embedding generator

import json
import torch
import clip
import numpy as np
import os
from collections import defaultdict
import re

def simplify_color(color):
    """
    Simplify the color names to a standard set.
    """
    color_map = {
        'dark blue': 'blue',
        'light blue': 'blue',
        'navy': 'blue',
        'dark red': 'red',
        'maroon': 'red',
        'burgundy': 'red',
        'forest green': 'green',
        'lime green': 'green',
        'olive': 'green',
        'dark purple': 'purple',
        'lavender': 'purple',
        'gold': 'yellow',
        'beige': 'tan',
        'grey': 'gray',
        'charcoal': 'gray',
        'silver': 'gray',
        'white': 'white',
        'blue': 'blue',
        'black': 'black',
        'green': 'green',
        'orange': 'orange',
        'red': 'red',
        'yellow': 'yellow'
    }
    return color_map.get(color.lower(), color.lower())

def number_to_words(number):
    """
    Convert a number to its word representation.
    """
    units = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
    teens = ["ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", 
             "sixteen", "seventeen", "eighteen", "nineteen"]
    tens = ["", "", "twenty", "thirty", "forty", "fifty", 
            "sixty", "seventy", "eighty", "ninety"]
    
    if 0 <= number < 10:
        return units[number]
    elif 10 <= number < 20:
        return teens[number - 10]
    elif 20 <= number < 100:
        ten, unit = divmod(number, 10)
        return tens[ten] + ("-" + units[unit] if unit else "")
    else:
        return str(number)

class GamePlayerAttributeGenerator:
    """
    Generates a game-player attribute dictionary and creates embeddings for jersey number,
    jersey color, and ethnicity using CLIP's zero-shot capabilities.
    """
    def __init__(self, clip_model_name="ViT-L/14", output_dir="output_embeds", use_words_for_numbers=False):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.clip_model, self.preprocess = clip.load(clip_model_name, device=self.device)
        self.output_dir = output_dir
        self.embeddings_dir = os.path.join(output_dir, "embeddings")
        os.makedirs(self.embeddings_dir, exist_ok=True)
        self.use_words_for_numbers = use_words_for_numbers
        self.game_player_attribute = {}
        self.unique_attributes = {
            "jersey_number": set(),
            "jersey_color": set(),
            "ethnicity": set()
        }

    def get_text_embedding(self, text):
        """
        Generate text embedding using CLIP.
        """
        with torch.no_grad():
            text_tokens = clip.tokenize([text]).to(self.device)
            text_embedding = self.clip_model.encode_text(text_tokens)
        return text_embedding.cpu().numpy()[0]

    def build_game_player_attribute(self, player_data, game_data):
        """
        Build the game_player_attribute dictionary based on player and game data.
        """
        for game_id, game_info in game_data.items():
            teamA = game_info['teamA']
            teamB = game_info['teamB']
            colorA = simplify_color(game_info['colorA'])
            colorB = simplify_color(game_info['colorB'])

            self.game_player_attribute[game_id] = {
                "teamA": teamA,
                "teamB": teamB,
                "colorA": colorA,
                "colorB": colorB,
                "players": {}
            }

            # Find players in teamA and teamB
            for player_name, player_info in player_data.items():
                team_name = player_info['team_name']
                if team_name == teamA:
                    jersey_color = colorA
                elif team_name == teamB:
                    jersey_color = colorB
                else:
                    continue  # Player not in this game

                jersey_number = player_info['jersey_number']
                ethnicity = player_info['ethnicity']

                # Simplify jersey color
                jersey_color_simplified = simplify_color(jersey_color)

                # Handle jersey number representation
                if re.match(r'^0\d+$', jersey_number):
                    # If jersey number starts with '0' and has multiple digits
                    jersey_number_digits = ' '.join(jersey_number)
                    jersey_number_words = ' '.join([number_to_words(int(digit)) for digit in jersey_number])
                else:
                    jersey_number_digits = jersey_number
                    jersey_number_words = number_to_words(int(jersey_number))

                if self.use_words_for_numbers:
                    jersey_number_repr = jersey_number_words
                else:
                    jersey_number_repr = jersey_number_digits

                # Populate the game_player_attribute dictionary
                self.game_player_attribute[game_id]["players"][player_name] = {
                    "jersey_color": jersey_color_simplified,
                    "jersey_number": jersey_number_repr,
                    "ethnicity": ethnicity.lower()
                }

                # Collect unique attribute values
                self.unique_attributes["jersey_color"].add(jersey_color_simplified)
                self.unique_attributes["jersey_number"].add(jersey_number_repr)
                self.unique_attributes["ethnicity"].add(ethnicity.lower())

    def generate_embeddings(self):
        """
        Generate and save embeddings for all unique attribute values.
        """
        attribute_descriptions = {
            "jersey_number": "a basketball player with jersey number {}",
            "jersey_color": "a {} jersey, color {}",
            "ethnicity": "a {} basketball player"
        }

        for attribute, values in self.unique_attributes.items():
            for value in values:
                if attribute == "jersey_color":
                    description = attribute_descriptions[attribute].format(value, value)
                else:
                    description = attribute_descriptions[attribute].format(value)

                embedding = self.get_text_embedding(description)

                # Define embedding filename
                filename = f"{attribute}_{value.replace(' ', '_')}.npy"
                file_path = os.path.join(self.embeddings_dir, filename)
                np.save(file_path, embedding)

    def save_game_player_attribute(self):
        """
        Save the game_player_attribute dictionary to a JSON file.
        """
        save_path = os.path.join(self.output_dir, 'game_player_attribute.json')
        with open(save_path, 'w') as f:
            json.dump(self.game_player_attribute, f, indent=2)
        print(f"game_player_attribute.json has been saved to {save_path}")

    def save_embeddings_info(self):
        """
        Save a mapping of attribute values to their embedding filenames.
        This can be useful for reference.
        """
        embedding_info = defaultdict(dict)
        for attribute in self.unique_attributes:
            for value in self.unique_attributes[attribute]:
                filename = f"{attribute}_{value.replace(' ', '_')}.npy"
                embedding_info[attribute][value] = os.path.join("embeddings", filename)
        
        save_path = os.path.join(self.output_dir, 'embedding_info.json')
        with open(save_path, 'w') as f:
            json.dump(embedding_info, f, indent=2)
        print(f"embedding_info.json has been saved to {save_path}")

    def run(self, player_data_path, game_data_path):
        """
        Execute the entire pipeline: load data, build dictionary, generate embeddings, and save outputs.
        """
        # Load JSON data
        with open(player_data_path, 'r') as f:
            player_data = json.load(f)
        
        with open(game_data_path, 'r') as f:
            game_data = json.load(f)

        # Build the game_player_attribute dictionary
        self.build_game_player_attribute(player_data, game_data)

        # Generate embeddings for unique attribute values
        self.generate_embeddings()

        # Save the game_player_attribute dictionary
        self.save_game_player_attribute()

        # Save embedding information
        self.save_embeddings_info()

        print(f"All embeddings have been saved in the '{self.embeddings_dir}' directory")
        print(f"game_player_attribute.json and embedding_info.json have been saved in the '{self.output_dir}' directory")

def main(player_data_path, game_data_path, output_dir, use_words_for_numbers=False):
    """
    Main function to initialize the generator and execute the process.
    """
    # Initialize the generator with the specified CLIP model
    generator = GamePlayerAttributeGenerator(
        clip_model_name="ViT-L/14",  # Using the same model as in training
        output_dir=output_dir, 
        use_words_for_numbers=use_words_for_numbers
    )

    # Run the generator
    generator.run(player_data_path, game_data_path)

if __name__ == "__main__":
    # Define paths
    player_data_path = '/home/minxing/code/NSVA_MOTR/tools/player_summaries_final.json'
    game_data_path = '/home/minxing/code/NSVA_MOTR/tools/game_teams_summaries_1.json'
    output_dir = 'output_embeds'  # You can change this to your desired output location
    use_words_for_numbers = False  # Set to True if you want to use words for numbers

    # Execute main function
    main(player_data_path, game_data_path, output_dir, use_words_for_numbers)

In [4]:
# data_preprocessing large scale

import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from dataclasses import dataclass, field
import random

@dataclass
class Configuration:
    data_dir: str = "/home/minxing/datasets/NSVA_157_zeroshot_crops_new/player_crops"
    save_dir: str = "./data/nsva_txt_img"
    selected_game_endings: list = field(default_factory=lambda: ['00160', '00060', '01212', '01055'])
    seed: int = 2024  # Seed for reproducibility
    num_workers: int = 8  # Adjust based on your system

# Initialize configuration
config = Configuration()

# Set random seed for reproducibility
np.random.seed(config.seed)
random.seed(config.seed)

# Create the save directory if it doesn't exist
os.makedirs(config.save_dir, exist_ok=True)

#----------------------------------------------------------------------------------------------------------------------#
# Function to build the DataFrame with all crops excluding player ID 0                                                  #
#----------------------------------------------------------------------------------------------------------------------#
def build_dataframe(base_path, selected_game_endings):
    img_id = []
    folder = []
    player = []
    game = []
    img_type = []
    split = []
    
    # Iterate through each game_play directory
    for game_play_dir in os.listdir(base_path):
        game_play_path = os.path.join(base_path, game_play_dir)
        
        if os.path.isdir(game_play_path):
            # Extract game_id from the directory name
            if '-' not in game_play_dir:
                print(f"Skipping directory '{game_play_dir}' as it does not contain a hyphen '-'.")
                continue  # Skip directories that don't follow the 'game_id-play_id' format
            
            game_id_full = game_play_dir.split('-')[0]
            game_id = game_id_full  # Use the full game_id
            
            # Check if game_id ends with one of the selected endings
            if not any(game_id.endswith(ending) for ending in selected_game_endings):
                continue  # Skip games not in the selected list
            
            # Debugging: Print matched game_play_dir
            # print(f"Matched game_play_dir: '{game_play_dir}' with game_id: '{game_id}'")
            
            # Iterate through player directories inside each game-play directory
            for player_id in os.listdir(game_play_path):
                if player_id == '0':
                    continue  # Skip player ID 0
                
                player_dir_path = os.path.join(game_play_path, player_id)
                
                if os.path.isdir(player_dir_path):
                    # Get all image files in the player directory
                    image_files = [f for f in os.listdir(player_dir_path) if os.path.isfile(os.path.join(player_dir_path, f))]
                    
                    # Iterate through image files
                    for image_file in image_files:
                        image_path = os.path.join(player_dir_path, image_file)
                        if os.path.isfile(image_path):
                            # Extract image information
                            img_id.append(os.path.splitext(image_file)[0])  # Remove file extension
                            folder.append(player_dir_path)  # Full path to the player's directory
                            player.append(player_id)
                            game.append(game_id)
                            split.append('all')  # Will update later
                            img_type.append('g')  # Default to 'g'; will adjust later

    df = pd.DataFrame({
        "img_id": img_id,
        "folder": folder,  # Full path to the player's directory
        "player": player,
        "game": game,
        "split": split,
        "img_type": img_type,
    })

    # print(f"Total images collected from selected games: {len(df)}")
    return df

#----------------------------------------------------------------------------------------------------------------------#
# Build DataFrame for Selected Games                                                                                   #
#----------------------------------------------------------------------------------------------------------------------#
# Build DataFrame from both train and test directories
df_train = build_dataframe(os.path.join(config.data_dir, 'train'), config.selected_game_endings)
df_test = build_dataframe(os.path.join(config.data_dir, 'test'), config.selected_game_endings)

# Combine the dataframes
df_full = pd.concat([df_train, df_test], ignore_index=True)
print(f"Combined DataFrame contains {len(df_full)} images.")

#----------------------------------------------------------------------------------------------------------------------#
# Assign Splits (80% Train, 20% Test) for Each Player                                                                  #
#----------------------------------------------------------------------------------------------------------------------#

# Convert 'player' column to int for proper grouping
df_full['player'] = df_full['player'].astype(int)

# Shuffle the DataFrame
df_full = df_full.sample(frac=1, random_state=config.seed).reset_index(drop=True)

# Assign splits for each player
df_full['split'] = 'train'  # Initialize all as 'train'

players = df_full['player'].unique()
for player_id in players:
    player_indices = df_full[df_full['player'] == player_id].index
    player_data = df_full.loc[player_indices]
    if len(player_data) >= 5:  # Ensure there are enough samples to split
        train_indices, test_indices = train_test_split(
            player_indices,
            test_size=0.2,
            random_state=config.seed
        )
        df_full.loc[test_indices, 'split'] = 'test'
    else:
        # If not enough samples, assign all to train
        print(f"Player {player_id} has only {len(player_data)} samples. Assigning all to 'train'.")

#----------------------------------------------------------------------------------------------------------------------#
# Map player IDs to integer labels starting from 0                                                                     #
#----------------------------------------------------------------------------------------------------------------------#
# Map players to unique IDs starting from 0
unique_players = df_full['player'].unique()
player_id_map = {player_id: idx for idx, player_id in enumerate(unique_players)}
df_full['player'] = df_full['player'].map(player_id_map).astype(int)

#----------------------------------------------------------------------------------------------------------------------#
# Save DataFrame                                                                                                       #
#----------------------------------------------------------------------------------------------------------------------#
# Save the combined DataFrame
train_df_path_saved = os.path.join(config.save_dir, "train_df.csv")
df_full.to_csv(train_df_path_saved, index=False)

# Print summary
print("\nDataFrame saved:")
print(f" - DataFrame: {train_df_path_saved}")
print("\nSummary:")
print(f"Number of total images: {len(df_full)}")
print(f"Number of training images: {len(df_full[df_full['split'] == 'train'])}")
print(f"Number of testing images: {len(df_full[df_full['split'] == 'test'])}")
print(f"Number of unique players: {df_full['player'].nunique()}")


In [None]:
# zero_shot_classification_revised_with_topk.py

import os
import sys
import json
import torch
import clip
import numpy as np
from PIL import Image
from tqdm import tqdm
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import re
import ast  # For safely evaluating string representations of lists

# ------------------------ Configuration Class ------------------------

class Configuration:
    # Model configurations
    model: str = ('ViT-L/14', 'openai')  # Default to ViT-L/14
    remove_proj = True
    img_size: int = (224, 224)
    mean:   float = (0.485, 0.456, 0.406)
    std:    float = (0.229, 0.224, 0.225)
    
    # Data settings
    seed: int = 2024
    batch_size_eval: int = 64
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    data_csv: str = "./data/nsva/train_df.csv"
    output_folder: str = 'output_embeds'
    split: str = 'train'  # 'train' or 'test'

config = Configuration()

# ------------------------ Helper Functions ------------------------

def clean_player_name(player_name):
    """
    Clean player names by removing non-alphabetic characters and extra spaces.
    """
    cleaned_name = re.sub(r'[^A-Za-z ]', '', player_name)
    cleaned_name = ' '.join(cleaned_name.split())
    return cleaned_name.strip().lower()

def load_player_summaries(player_summaries_file):
    """
    Load player summaries from player_summaries_final.json.
    """
    with open(player_summaries_file, 'r') as f:
        player_summaries = json.load(f)
    return player_summaries

def load_labels(labels_file):
    """
    Load labels from ZeroShotDataset_labels.txt.
    Each line in the file corresponds to a player with format: ID_PlayerName_JerseyNumber
    """
    labels = {}
    with open(labels_file, 'r') as f:
        for idx, line in enumerate(f, start=1): 
            parts = line.strip().split('_')
            if len(parts) < 3:
                print(f"Skipping malformed line {idx}: {line.strip()}")
                continue
            player_name = ' '.join(parts[1:-1])
            jersey_number = parts[-1]  
            labels[str(idx)] = {
                'player_name': player_name,
                'jersey_number': jersey_number
            }
    return labels

def create_attribute_classes(embedding_info):
    """
    Create unique class descriptions for each attribute from embedding_info.
    """
    jersey_color_values = embedding_info['jersey_color'].keys()
    jersey_number_values = embedding_info['jersey_number'].keys()
    ethnicity_values = embedding_info['ethnicity'].keys()

    # Create descriptions
    jersey_color_descriptions = [f"a {color} jersey, color {color}" for color in sorted(jersey_color_values)]
    jersey_number_descriptions = [
        f"a basketball player with jersey number {number}" 
        for number in sorted(jersey_number_values, key=lambda x: int(x) if x.isdigit() else float('inf'))
    ]
    ethnicity_descriptions = [f"a {ethnicity} basketball player" for ethnicity in sorted(ethnicity_values)]
    
    return jersey_color_descriptions, jersey_number_descriptions, ethnicity_descriptions

def encode_class_descriptions(clip_model, device, jersey_color_descriptions, jersey_number_descriptions, ethnicity_descriptions):
    """
    Encode class descriptions using CLIP and return normalized embeddings.
    """
    class_embeddings = {}
    
    # Encode jersey colors
    with torch.no_grad():
        jersey_color_tokens = clip.tokenize(jersey_color_descriptions).to(device)
        jersey_color_embeddings = clip_model.encode_text(jersey_color_tokens)
        jersey_color_embeddings /= jersey_color_embeddings.norm(dim=-1, keepdim=True)
        class_embeddings['jersey_color'] = {
            desc: emb.cpu().numpy() for desc, emb in zip(jersey_color_descriptions, jersey_color_embeddings)
        }
    
    # Encode jersey numbers
    with torch.no_grad():
        jersey_number_tokens = clip.tokenize(jersey_number_descriptions).to(device)
        jersey_number_embeddings = clip_model.encode_text(jersey_number_tokens)
        jersey_number_embeddings /= jersey_number_embeddings.norm(dim=-1, keepdim=True)
        class_embeddings['jersey_number'] = {
            desc: emb.cpu().numpy() for desc, emb in zip(jersey_number_descriptions, jersey_number_embeddings)
        }
    
    # Encode ethnicities
    with torch.no_grad():
        ethnicity_tokens = clip.tokenize(ethnicity_descriptions).to(device)
        ethnicity_embeddings = clip_model.encode_text(ethnicity_tokens)
        ethnicity_embeddings /= ethnicity_embeddings.norm(dim=-1, keepdim=True)
        class_embeddings['ethnicity'] = {
            desc: emb.cpu().numpy() for desc, emb in zip(ethnicity_descriptions, ethnicity_embeddings)
        }
    
    return class_embeddings

def zero_shot_classification(image_path, clip_model, preprocess, device, class_embeddings):
    """
    Perform zero-shot classification on a single image for all attributes.
    Returns top-1 and top-K predictions as required.
    """
    try:
        image = Image.open(image_path).convert("RGB")
    except Exception as e:
        print(f"Error opening image {image_path}: {e}")
        return {
            'jersey_color_top1': 'unknown',
            'jersey_color_top2': ['unknown'],
            'jersey_number_top1': 'unknown',
            'jersey_number_top3': ['unknown'],
            'ethnicity': 'unknown'
        }
    
    image_input = preprocess(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        image_embedding = clip_model.encode_image(image_input)
        image_embedding /= image_embedding.norm(dim=-1, keepdim=True)
        image_embedding = image_embedding.cpu().numpy()
    
    predictions = {}
    
    for attr, classes in class_embeddings.items():
        class_embeds = np.array(list(classes.values()))
        similarities = np.dot(class_embeds, image_embedding.T).flatten()
        
        if attr == 'jersey_color':
            top_k = 2
        elif attr == 'jersey_number':
            top_k = 3
        else:
            top_k = 1  # For ethnicity
        
        top_indices = np.argsort(-similarities)[:top_k]
        top_classes = [list(classes.keys())[i] for i in top_indices]
        
        values = []
        for top_class in top_classes:
            if attr == 'jersey_color':
                # Description format: "a blue jersey, color blue"
                # Extract the color
                parts = top_class.split(',')
                if len(parts) >= 2:
                    # "a blue jersey, color blue" -> ' color blue' -> 'blue'
                    color_part = parts[1].strip()
                    value = color_part.split()[1] if len(color_part.split()) >= 2 else 'unknown'
                else:
                    value = top_class.split()[1] if len(top_class.split()) >= 2 else 'unknown'
            elif attr == 'jersey_number':
                # Description format: "a basketball player with jersey number 23"
                # Extract the number
                value = top_class.split()[-1] if len(top_class.split()) >=1 else 'unknown'
            elif attr == 'ethnicity':
                # Description format: "a black basketball player"
                # Extract only the ethnicity
                value = top_class.split()[1] if len(top_class.split()) > 1 else 'unknown'
            else:
                value = 'unknown'
            values.append(value.lower())
        
        if attr == 'jersey_color':
            predictions['jersey_color_top1'] = values[0] if len(values) >=1 else 'unknown'  # Top-1
            predictions['jersey_color_top2'] = values[:2] if len(values) >=2 else values      # Top-2
        elif attr == 'jersey_number':
            predictions['jersey_number_top1'] = values[0] if len(values) >=1 else 'unknown'  # Top-1
            predictions['jersey_number_top3'] = values[:3] if len(values) >=3 else values     # Top-3
        else:
            predictions['ethnicity'] = values[0] if len(values) >=1 else 'unknown'           # Single value for ethnicity
    
    return predictions

def evaluate_predictions(predictions_df):
    """
    Evaluate the classification predictions and return metrics.
    Handles both Top-1 and Top-K evaluations.
    """
    metrics = {}
    for attr in ['jersey_color', 'jersey_number', 'ethnicity']:
        try:
            if attr == 'jersey_color':
                top_k = 2
                top_k_column = f'predicted_{attr}_top{top_k}'
                predicted_top1_col = f'predicted_{attr}_top1'
            elif attr == 'jersey_number':
                top_k = 3
                top_k_column = f'predicted_{attr}_top{top_k}'
                predicted_top1_col = f'predicted_{attr}_top1'
            else:
                top_k = 1  # For ethnicity, only Top-1
                top_k_column = None
                predicted_top1_col = f'predicted_{attr}'

            # Top-1 evaluation
            ground_truth = predictions_df[f'ground_truth_{attr}'].astype(str).str.lower().str.strip()
            predicted_top1 = predictions_df[predicted_top1_col].astype(str).str.lower().str.strip()
            
            acc_top1 = accuracy_score(ground_truth, predicted_top1)
            precision_top1, recall_top1, f1_top1, _ = precision_recall_fscore_support(
                ground_truth, predicted_top1, average='weighted', zero_division=0
            )
            
            metrics[attr] = {
                'top1_accuracy': acc_top1,
                'top1_precision': precision_top1,
                'top1_recall': recall_top1,
                'top1_f1_score': f1_top1
            }
            
            # Top-K evaluation where applicable
            if top_k > 1 and top_k_column in predictions_df.columns:
                # Convert string representations of lists back to actual lists
                predictions_df[top_k_column] = predictions_df[top_k_column].apply(
                    lambda x: ast.literal_eval(x) if isinstance(x, str) else (x if isinstance(x, list) else [])
                )
                
                # Normalize predicted lists
                predictions_df[top_k_column] = predictions_df[top_k_column].apply(
                    lambda x: [str(item).lower().strip() for item in x]
                )
                
                # Check if ground truth is in the top-K predictions
                top_k_correct = predictions_df.apply(
                    lambda row: row[f'ground_truth_{attr}'] in row[top_k_column],
                    axis=1
                )
                acc_topk = top_k_correct.mean()
                metrics[attr][f'top{top_k}_accuracy'] = acc_topk
        except KeyError as e:
            print(f"Error evaluating attribute '{attr}': {e}")
            metrics[attr] = {'top1_accuracy': 0.0}
            if top_k > 1:
                metrics[attr][f'top{top_k}_accuracy'] = 0.0
        except Exception as e:
            print(f"Unexpected error evaluating attribute '{attr}': {e}")
            metrics[attr] = {'top1_accuracy': 0.0}
            if top_k > 1:
                metrics[attr][f'top{top_k}_accuracy'] = 0.0
    
    return metrics

# ------------------------ Main Classification Task ------------------------

def classification_task(
    labels_file,
    player_summaries_file,
    game_player_attribute_path,
    embedding_info_path,
    data_csv_path,
    output_folder,
    split='test'
):
    """
    Perform zero-shot attribute classification on specified data splits.
    """
    # Redirect stdout to log file
    log_file_path = os.path.join(output_folder, f'ZS_log_full_{split}.txt')
    original_stdout = sys.stdout
    sys.stdout = open(log_file_path, 'w')
    
    # ------------------------ Data Loading ------------------------
    
    # Load labels
    print("Loading labels...")
    labels = load_labels(labels_file)
    print(f"Total labels loaded: {len(labels)}")
    
    # Load player summaries
    print("Loading player summaries...")
    player_summaries = load_player_summaries(player_summaries_file)
    print(f"Total player summaries loaded: {len(player_summaries)}")
    
    # Load game_player_attribute
    print("Loading game_player_attribute...")
    with open(game_player_attribute_path, 'r') as f:
        game_player_attribute = json.load(f)
    print(f"Total games loaded: {len(game_player_attribute)}")
    
    # Load embedding_info
    print("Loading embedding_info...")
    with open(embedding_info_path, 'r') as f:
        embedding_info = json.load(f)
    print(f"Total embedding mappings loaded: {len(embedding_info)}")
    
    # ------------------------ Attribute Class Creation ------------------------
    
    # Create per-class attribute descriptions
    print("Creating attribute classes...")
    jersey_color_desc, jersey_number_desc, ethnicity_desc = create_attribute_classes(embedding_info)
    print(f"Jersey Colors Descriptions: {jersey_color_desc}")
    print(f"Jersey Numbers Descriptions: {jersey_number_desc}")
    print(f"Ethnicities Descriptions: {ethnicity_desc}")
    
    # ------------------------ CLIP Model Loading ------------------------
    
    # Load CLIP model
    print("Loading CLIP model...")
    device = config.device
    clip_model_name = config.model[0]
    clip_model, preprocess = clip.load(clip_model_name, device=device)
    print(f"CLIP model '{clip_model_name}' loaded on {device}")
    
    # Encode class descriptions
    print("Encoding class descriptions...")
    class_embeddings = encode_class_descriptions(
        clip_model, device, jersey_color_desc, jersey_number_desc, ethnicity_desc
    )
    print("Class descriptions encoded successfully.")
    
    # ------------------------ DataFrame Loading ------------------------
    
    # Load the data CSV
    print(f"Loading data from {data_csv_path}...")
    df_full = pd.read_csv(data_csv_path)
    
    if split not in ['train', 'test']:
        print(f"Invalid split '{split}'. Choose from 'train' or 'test'.")
        sys.stdout.close()
        sys.stdout = original_stdout
        return
    
    df_split = df_full[df_full['split'] == split].reset_index(drop=True)
    print(f"Total images in '{split}' split: {len(df_split)}")
    
    # ------------------------ Zero-Shot Classification ------------------------
    
    # Initialize a list to collect predictions
    predictions_list = []
    
    # Initialize counters
    total_images = 0
    skipped_images = 0
    
    print("Starting zero-shot classification...")
    
    for idx, row in tqdm(df_split.iterrows(), total=len(df_split), desc="Images"):
        img_id = row['img_id']
        folder = row['folder']
        img_path = os.path.join(folder, img_id + '.jpg')
        
        # Extract game_id and player_id from folder path
        # Expected folder structure: .../player_crops/train/a-b/player_id/
        folder_parts = folder.strip('/').split('/')
        if len(folder_parts) < 3:
            print(f"Unexpected folder structure: {folder}. Skipping.")
            skipped_images += 1
            continue
        
        player_id_dir = folder_parts[-1]  # 'player_id'
        game_play_dir = folder_parts[-2]  # 'a-b'
        game_id_extracted = game_play_dir.split('-')[0]
        
        # Use extracted player_id and game_id
        original_player_id = player_id_dir
        game_id = game_id_extracted
        
        # Get player info from labels
        player_info = labels.get(original_player_id)
        if not player_info:
            print(f"Player ID '{original_player_id}' not found in labels. Skipping.")
            skipped_images += 1
            continue
        
        player_name_raw = player_info['player_name']
        jersey_number_label = player_info['jersey_number']
        player_name = clean_player_name(player_name_raw)
        
        # Get game info
        game_info = game_player_attribute.get(game_id)
        if not game_info:
            print(f"Game ID '{game_id}' not found in game_player_attribute. Skipping.")
            skipped_images += 1
            continue
        
        # Determine player's team in this game
        teamA = game_info['teamA']
        teamB = game_info['teamB']
        colorA = game_info['colorA']
        colorB = game_info['colorB']
        
        # Check if player is in teamA or teamB
        player_team = None
        jersey_color_gt = None
        ethnicity_raw = 'unknown'
        
        # Search for player in game_info['players']
        player_attributes = game_info['players'].get(player_name_raw)
        if not player_attributes:
            print(f"Player '{player_name_raw}' not found in game '{game_id}'. Skipping.")
            skipped_images += 1
            continue
        
        jersey_color_gt = player_attributes.get('jersey_color', 'unknown').lower()
        jersey_number_gt = player_attributes.get('jersey_number', 'unknown').lower()
        ethnicity_raw = player_attributes.get('ethnicity', 'unknown').lower()
        
        # Ground truth attributes
        ground_truth_jersey_color = jersey_color_gt
        ground_truth_jersey_number = jersey_number_gt
        ground_truth_ethnicity = ethnicity_raw  # Assumed to be 'black' or 'white'
        
        total_images += 1
        
        # Perform zero-shot classification
        predictions = zero_shot_classification(
            img_path, clip_model, preprocess, device, class_embeddings
        )
        
        # Append to the predictions list
        predictions_list.append({
            'image_path': img_path,
            'player_id': original_player_id,
            'player_name': player_name_raw,
            'ground_truth_jersey_color': ground_truth_jersey_color,
            'ground_truth_jersey_number': ground_truth_jersey_number,
            'ground_truth_ethnicity': ground_truth_ethnicity,
            'predicted_jersey_color_top1': predictions.get('jersey_color_top1', 'unknown'),
            'predicted_jersey_color_top2': predictions.get('jersey_color_top2', ['unknown']),
            'predicted_jersey_number_top1': predictions.get('jersey_number_top1', 'unknown'),
            'predicted_jersey_number_top3': predictions.get('jersey_number_top3', ['unknown']),
            'predicted_ethnicity': predictions.get('ethnicity', 'unknown')
        })
    
    print(f"\nZero-shot classification completed.")
    print(f"Total images processed: {total_images}")
    print(f"Total images skipped: {skipped_images}")
    
    # Convert the list of predictions to a DataFrame
    predictions_df = pd.DataFrame(predictions_list)
    
    # Normalize labels for consistency
    for attr in ['ground_truth_ethnicity', 'predicted_ethnicity']:
        predictions_df[attr] = predictions_df[attr].astype(str).str.lower().str.strip()
    
    # ------------------------ Evaluation ------------------------
    
    # Debugging: Inspect a sample of ground truth and predictions
    # print("\nSample of Ground Truth vs Predicted Ethnicity:")
    # print(predictions_df[['ground_truth_ethnicity', 'predicted_ethnicity']].head(20))
    
    # print("\nUnique Ground Truth Ethnicities:", predictions_df['ground_truth_ethnicity'].unique())
    # print("Unique Predicted Ethnicities:", predictions_df['predicted_ethnicity'].unique())
    
    # Evaluate predictions
    print("\nEvaluating predictions...")
    metrics = evaluate_predictions(predictions_df)
    
    # Print metrics
    print("\nZero-Shot Classification Metrics:")
    for attr, metric in metrics.items():
        print(f"\nAttribute: {attr}")
        if metric:
            print(f" - Top-1 Accuracy: {metric['top1_accuracy']:.4f}")
            if f"top{2}_accuracy" in metric:
                print(f" - Top-2 Accuracy: {metric[f'top2_accuracy']:.4f}")
            if f"top{3}_accuracy" in metric:
                print(f" - Top-3 Accuracy: {metric[f'top3_accuracy']:.4f}")
            print(f" - Precision: {metric['top1_precision']:.4f}")
            print(f" - Recall: {metric['top1_recall']:.4f}")
            print(f" - F1 Score: {metric['top1_f1_score']:.4f}")
        else:
            print(" - No valid predictions to evaluate.")
    
    # Save predictions to CSV
    output_predictions_csv = os.path.join(output_folder, f'ZS_predictions_full_{split}.csv')
    predictions_df.to_csv(output_predictions_csv, index=False)
    print(f"\nPredictions have been saved to '{output_predictions_csv}'")
    
    # Close log file and restore stdout
    sys.stdout.close()
    sys.stdout = original_stdout
    print(f"Log file has been saved to '{log_file_path}'")

# ------------------------ Main Function ------------------------

def main():
    """
    Main function to execute zero-shot attribute classification.
    """
    # Define parameters
    
    # Paths to necessary files
    labels_file = '/home/minxing/datasets/NSVA_157_zero_shot_minxing/ZeroShotDataset_labels.txt'
    player_summaries_file = '/home/minxing/code/NSVA_MOTR/tools/player_summaries_final.json'
    game_player_attribute_path = "/home/minxing/code/clip_reident/output_embeds/game_player_attribute.json"
    embedding_info_path = "/home/minxing/code/clip_reident/output_embeds/embedding_info.json"
    
    # Data CSV path
    data_csv_path = config.data_csv  # Path to your train_df.csv
    
    output_folder = config.output_folder  # Directory to save outputs
    
    # Parameters for the classification task
    split = config.split  # 'train' or 'test'
    
    # Ensure output folder exists
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    # Execute the classification task
    classification_task(
        labels_file=labels_file,
        player_summaries_file=player_summaries_file,
        game_player_attribute_path=game_player_attribute_path,
        embedding_info_path=embedding_info_path,
        data_csv_path=data_csv_path,
        output_folder=output_folder,
        split=split
    )

# ------------------------ Execute the Main Function ------------------------

if __name__ == "__main__":
    main()


Images: 100%|██████████| 15193/15193 [02:22<00:00, 106.31it/s]
