In [None]:
# This Juypter labs notebook provides trainging and evaluation of a REID model
# that will be used to extract features from images and will work as part of the 
# multi-object trakcign algorithm from Boxmot repository
# The model is based on the ResNet50 architecture and is trained on the Market1501 dataset

In [None]:
# REID Training Repo
!pip install git+https://github.com/KaiyangZhou/deep-person-reid.git

In [None]:
# Import and Dataset Setup
import os
import torch
import torchreid
from torch import optim
from torchreid import data, models, utils, engine
from torchreid.data import ImageDataset
import torch.nn as nn

utils.set_random_seed(42)

DATA_DIR = '/kaggle/input/reid4-24-04/REID4_24_04'
OUTPUT_DIR = '/kaggle/working/outputFinal5'
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Complete")

In [None]:
# Custom dataset Preoprocessing

# This sectiom will include loading of custom dataset class for using the Market1501 formatting
# With additionally relabeling of the PIDs and CAMIDs

# Define the custom dataset
class customMarketData(ImageDataset):
    def __init__(self, root='', **kwargs):
        self.dataset_dir = root
        train_dir = os.path.join(self.dataset_dir, 'bounding_box_train')
        query_dir = os.path.join(self.dataset_dir, 'query')
        gallery_dir = os.path.join(self.dataset_dir, 'bounding_box_test')

        # Process directories
        train = self._process_dir(train_dir, relabel=True)
        query = self._process_dir(query_dir, relabel=False, force_camid=0)
        gallery = self._process_dir(gallery_dir, relabel=False, force_camid=1)

        # Print PID info
        query_pids = sorted(set([pid for _, pid, _ in query]))
        gallery_pids = sorted(set([pid for _, pid, _ in gallery]))
        print(f"[INFO] Query PIDs    ({len(query_pids)}): {query_pids}")
        print(f"[INFO] Gallery PIDs  ({len(gallery_pids)}): {gallery_pids}")

        super().__init__(train, query, gallery, **kwargs)

    def _process_dir(self, dir_path, relabel=False, force_camid=None):
        img_paths = sorted([
            os.path.join(dir_path, fname)
            for fname in os.listdir(dir_path)
            if fname.lower().endswith(('.jpg', '.jpeg', '.png'))
        ])

        pid_container = set()
        for img_path in img_paths:
            pid = int(os.path.basename(img_path).split('_')[0])
            pid_container.add(pid)

        pid2label = {pid: idx for idx, pid in enumerate(sorted(pid_container))}

        # CamID Forcing Logic
        # Section needed for the Market1501 Dataset Format
        # Comment out only if a different camera is actually used on the same object, otherwise keep 
        # If my "REIDMarketCreatorClass" script was used to create Data
        data = []
        for img_path in img_paths:
            fname = os.path.basename(img_path)
            pid = int(fname.split('_')[0])
            camid = int(fname.split('_')[1][1])
            if force_camid is not None:
                camid = force_camid
            if relabel:
                pid = pid2label[pid]
            data.append((img_path, pid, camid))
        return data
        
print("Dataset Preprocessing Class Created")

In [None]:
# Data loader and register
# Register the dataset under a new unique name
data.register_image_dataset('Name_of_dataset', customMarketData)

# Initialize the ImageDataManager using your registered dataset
datamanager = data.ImageDataManager(
    root=DATA_DIR,
    sources='Name_of_dataset',
    targets='Name_of_dataset',
    height=256,
    width=128,
    batch_size_train=128,
    train_sampler='RandomIdentitySampler',
    num_instances=8,
    transforms='random_flip+random_crop+color_jitter+random_erase+normalize',
    workers=4
)

print("Data loader complete")

In [None]:
# Model builder
# Builds the ReID model using the OSNet-AIN architecture (osnet_ain_x1_0), with software loss functioand and pretrained weights

model = models.build_model(
    name='osnet_ain_x1_0',
    num_classes=datamanager.num_train_pids,
    loss='softmax',
    pretrained=True,
    use_gpu=torch.cuda.is_available()
)

model = model.cuda() if torch.cuda.is_available() else model

# Optimizer & scheduler
optimizer = optim.Adam(model.parameters(), lr=3e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

# Trainer
trainer = engine.ImageSoftmaxEngine(
    datamanager=datamanager,
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    use_gpu=torch.cuda.is_available(),
    label_smooth=True
)

# Training Loop
trainer.run(
    save_dir=OUTPUT_DIR,
    max_epoch=100,
    eval_freq=10,
    print_freq=50,
    start_eval=10,
    fixbase_epoch=0,
    open_layers=None
)

In [None]:
# Save model
model.eval()
dummy_input = torch.zeros(1, 3, 256, 128).cuda() if torch.cuda.is_available() else torch.zeros(1, 3, 256, 128)
traced_model = torch.jit.trace(model, dummy_input)
traced_model.save('./dir/to_____.pt')
print("TorchScript model saved as ....")

# Save state_dict for BoxMOT
torch.save(model.state_dict(), './dir/to_____.pt')
print("PyTorch model state_dict saved as ...")

print("Model saved")

In [None]:
# Section to check for files existance
# Check for files
for root, dirs, files in os.walk('/kaggle/working'):
    for file in files:
        if file.endswith(('.pt', '.pth')):
            print("Found:", os.path.join(root, file))
