In [None]:
import os
import glob
import xml.etree.ElementTree as ET


import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.ops import RoIAlign
from torchvision import transforms
from ultralytics import YOLO

import random
import numpy as np

ipykernel               6.29.5
ipython                 9.1.0
matplotlib              3.10.1
matplotlib-inline       0.1.6
numpy                   2.2.6
opencv-contrib-python   4.12.0.88
opencv-python           4.10.0
opencv-python-headless  4.10.0
pandas                  2.3.1
pillow                  11.3.0
scikit-learn            1.6.1
torch                   2.5.1
torchinfo               1.8.0
torchvision             0.20.1
typing_extensions       4.12.2
ultralytics             8.3.168
ultralytics-thop        2.0.14

In [2]:
model = YOLO("yolo11n.pt")
#model2 = torch.hub.load("ultralytics/yolov11","yolo11n.pt")

In [3]:
from torchinfo import summary

In [2]:
from PIL import Image


class ResizePad:
    def __init__(self,size=(256,128),fill=0):
        self.target_h, self.target_w = size
        self.fill = fill
    def __call__(self,img):
        orig_w, orig_h = img.size
        scale = min(self.target_w/orig_w, self.target_h/orig_h)
        new_w, new_h = int(orig_w * scale),int(orig_h*scale)

        img = img.resize((new_w,new_h), Image.BILINEAR)

        new_img = Image.new("RGB",(self.target_w,self.target_h),(self.fill,)*3)
        paste_x = (self.target_w-new_w)//2
        paste_y = (self.target_h-new_h)//2
        new_img.paste(img,(paste_x,paste_y))

        return new_img

In [3]:
## dataset Class
from PIL import Image

class FolderBasedTripletDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform or transforms.Compose([
            ResizePad((256, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        self.id_index = {}  # pid -> list of image paths

        for pid in os.listdir(root_dir):
            folder = os.path.join(root_dir, pid)
            if not os.path.isdir(folder): continue

            images = [f for f in glob.glob(os.path.join(folder, '*.png')) if os.path.exists(f.replace('.png', '.xml'))]
            if len(images) >= 2:
                self.id_index[pid] = images

        self.pids = list(self.id_index.keys())
        assert len(self.pids) > 1, "Need at least 2 different IDs for triplet loss."

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, index):
        anchor_pid = self.pids[index]
        anchor_img, positive_img = random.sample(self.id_index[anchor_pid], 2)
        negative_pid = random.choice([pid for pid in self.pids if pid != anchor_pid])
        negative_img = random.choice(self.id_index[negative_pid])

        return self.load(anchor_img), self.load(positive_img), self.load(negative_img)

    def load(self, path):
        img = Image.open(path).convert('RGB')
        return self.transform(img)


In [4]:
class FlatFolderTripletDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.transform = transform or transforms.Compose([
            ResizePad((256, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        self.id_index = {}  # { person_id: [image_path1, image_path2, ...] }

        png_files = glob.glob(os.path.join(folder, '*.png'))

        for img_path in png_files:
            xml_path = img_path.replace('.png', '.xml')
            if not os.path.exists(xml_path):
                continue

            person_id = self.get_id_from_xml(xml_path)
            if person_id:
                self.id_index.setdefault(person_id, []).append(img_path)

        self.pids = list(self.id_index.keys())
        assert len(self.pids) > 1, "Need at least 2 different IDs for triplet loss."

    def get_id_from_xml(self, xml_path):
        try:
            tree = ET.parse(xml_path)
            root = tree.getroot()
            object_elem = root.find('OBJECT')
            if object_elem is not None:
                return object_elem.attrib.get('ID')  # person ID from attribute
        except Exception as e:
            print(f"[WARN] Failed to parse {xml_path}: {e}")
        return None

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, index):
        anchor_pid = self.pids[index]
        anchor_img, positive_img = random.sample(self.id_index[anchor_pid], 2)
        negative_pid = random.choice([pid for pid in self.pids if pid != anchor_pid])
        negative_img = random.choice(self.id_index[negative_pid])

        return self.load(anchor_img), self.load(positive_img), self.load(negative_img)

    def load(self, path):
        img = Image.open(path).convert('RGB')
        return self.transform(img)

```
class YOLOv11ReID(nn.Module):
    def __init__(self, yolo_weights='yolo11n.pt',emb_dim=128):
        super().__init__()
        self.yolo=YOLO(yolo_weights)
        self.encoder = nn.Sequential(*list(self.yolo.model.model.children())[:-2])
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self._get_feat_dim(),emb_dim)
        
    def _get_feat_dim(self):
        x = torch.zeros((1,3,256,128))
        with torch.no_grad():
            f=self.encoder(x)
            return f.shape[1]
        
    def forward(self,x):
        f = self.encoder(x)
        pooled = self.pool(f).flatten(1)
        emb = self.fc(pooled)
        return nn.functional.normalize(emb,dim=1)
```         

In [5]:
class YOLOv11ReID(nn.Module):
    def __init__(self, yolo_weights='yolo11n.pt', emb_dim=128):
        super().__init__()
        yolo_model = YOLO(yolo_weights)


        self.backbone = nn.Sequential(
          yolo_model.model.model[0],
          yolo_model.model.model[1],
          yolo_model.model.model[2],
          yolo_model.model.model[3],
          yolo_model.model.model[4],
          yolo_model.model.model[5],
          yolo_model.model.model[6],
          yolo_model.model.model[7],
          )

        
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(self._get_feat_dim(), emb_dim)

    def _get_feat_dim(self):
        x = torch.zeros((1, 3, 256, 128))
        with torch.no_grad():
            f = self.backbone(x)

            
            # f = self.pool(f).flatten(1)
            return f.shape[1]

    def forward(self, x):
        x = self.backbone(x)


        

        f = self.pool(x).flatten(1)
        pooled = self.pool(x).flatten(1)  # ✅ apply once
        emb = self.fc(pooled)
        return nn.functional.normalize(emb, dim=1)


In [None]:

dataset = FolderBasedTripletDataset('..\data\sample_poc')

loader =DataLoader(dataset,batch_size=16,shuffle=True, num_workers=4)



In [None]:

# dataset = FolderBasedTripletDataset('/content/dataset/dataset1')

# def triplet_collate(batch):
#     anchors, positives, negatives = zip(*batch)
#     return (
#         torch.stack(anchors),
#         torch.stack(positives),
#         torch.stack(negatives)
#     )

# loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=triplet_collate)


In [None]:
from torch.nn import TripletMarginLoss

device ='cuda' if torch.cuda.is_available() else 'cpu'
model = YOLOv11ReID().to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-4)
triplet_loss = TripletMarginLoss(margin = 0.3)
# print(model)



model.train()
for epoch in range(50):
    total_loss=0

    for a,p,n in loader:
        # print(f"A: {a.shape}, P: {p.shape}, N: {n.shape}")
        a,p,n = a.to(device), p.to(device), n.to(device)
        emb_a,emb_p,emb_n = model(a),model(p),model(n)
        loss = triplet_loss(emb_a,emb_p,emb_n)
        optimizer.zero_grad(); loss.backward();optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Avg Loss: {total_loss/len(loader):.4f}")

In [None]:
# print(model)

In [None]:
torch.save(model, '../saved_models/reid_model_full_v0.1.pth')