<a href="https://colab.research.google.com/github/Karim-Anwar/MasterThesis/blob/main/AttemptFairViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import os
import glob
import random
import math
import numpy as np
import cv2
import torch
from collections import OrderedDict
from torchvision import transforms as T


In [4]:
def xyxy2xywh(x):
    y = torch.zeros_like(x) if x.dtype is torch.float32 else np.zeros_like(x)
    y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x center
    y[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y center
    y[:, 2] = x[:, 2] - x[:, 0]  # width
    y[:, 3] = x[:, 3] - x[:, 1]  # height
    return y

def letterbox(img, height=608, width=1088, color=(127.5, 127.5, 127.5)):
    shape = img.shape[:2]
    ratio = min(float(height) / shape[0], float(width) / shape[1])
    new_shape = (round(shape[1] * ratio), round(shape[0] * ratio))
    dw = (width - new_shape[0]) / 2
    dh = (height - new_shape[1]) / 2
    top, bottom = round(dh - 0.1), round(dh + 0.1)
    left, right = round(dw - 0.1), round(dw + 0.1)
    img = cv2.resize(img, new_shape, interpolation=cv2.INTER_AREA)
    img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
    return img, ratio, dw, dh

def random_affine(img, targets=None, degrees=(-10, 10), translate=(.1, .1), scale=(.9, 1.1), shear=(-2, 2), borderValue=(127.5, 127.5, 127.5)):
    border = 0
    height = img.shape[0]
    width = img.shape[1]

    R = np.eye(3)
    a = random.random() * (degrees[1] - degrees[0]) + degrees[0]
    s = random.random() * (scale[1] - scale[0]) + scale[0]
    R[:2] = cv2.getRotationMatrix2D(angle=a, center=(img.shape[1] / 2, img.shape[0] / 2), scale=s)

    T = np.eye(3)
    T[0, 2] = (random.random() * 2 - 1) * translate[0] * img.shape[0] + border
    T[1, 2] = (random.random() * 2 - 1) * translate[1] * img.shape[1] + border

    S = np.eye(3)
    S[0, 1] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180)
    S[1, 0] = math.tan((random.random() * (shear[1] - shear[0]) + shear[0]) * math.pi / 180)

    M = S @ T @ R
    imw = cv2.warpPerspective(img, M, dsize=(width, height), flags=cv2.INTER_LINEAR, borderValue=borderValue)

    if targets is not None:
        if len(targets) > 0:
            n = targets.shape[0]
            points = targets[:, 2:6].copy()
            area0 = (points[:, 2] - points[:, 0]) * (points[:, 3] - points[:, 1])

            xy = np.ones((n * 4, 3))
            xy[:, :2] = points[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(n * 4, 2)
            xy = (xy @ M.T)[:, :2].reshape(n, 8)

            x = xy[:, [0, 2, 4, 6]]
            y = xy[:, [1, 3, 5, 7]]
            xy = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T

            radians = a * math.pi / 180
            reduction = max(abs(math.sin(radians)), abs(math.cos(radians))) ** 0.5
            x = (xy[:, 2] + xy[:, 0]) / 2
            y = (xy[:, 3] + xy[:, 1]) / 2
            w = (xy[:, 2] - xy[:, 0]) * reduction
            h = (xy[:, 3] - xy[:, 1]) * reduction
            xy = np.concatenate((x - w / 2, y - h / 2, x + w / 2, y + h / 2)).reshape(4, n).T

            np.clip(xy[:, 0], 0, width, out=xy[:, 0])
            np.clip(xy[:, 2], 0, width, out=xy[:, 2])
            np.clip(xy[:, 1], 0, height, out=xy[:, 1])
            np.clip(xy[:, 3], 0, height, out=xy[:, 3])
            w = xy[:, 2] - xy[:, 0]
            h = xy[:, 3] - xy[:, 1]
            area = w * h
            ar = np.maximum(w / (h + 1e-16), h / (w + 1e-16))
            i = (w > 4) & (h > 4) & (area / (area0 + 1e-16) > 0.1) & (ar < 10)

            targets = targets[i]
            targets[:, 2:6] = xy[i]

        return imw, targets, M
    else:
        return imw

class LoadImages:
    def __init__(self, path, img_size=(1088, 608)):
        if os.path.isdir(path):
            image_format = ['.jpg', '.jpeg', '.png', '.tif']
            self.files = sorted(glob.glob(f'{path}/*'))
            self.files = list(filter(lambda x: os.path.splitext(x)[1].lower() in image_format, self.files))
        elif os.path.isfile(path):
            self.files = [path]
        self.nF = len(self.files)
        self.width = img_size[0]
        self.height = img_size[1]
        self.count = 0
        assert self.nF > 0, f'No images found in {path}'

    def __iter__(self):
        self.count = -1
        return self

    def __next__(self):
        self.count += 1
        if self.count == self.nF:
            raise StopIteration
        img_path = self.files[self.count]
        img0 = cv2.imread(img_path)
        assert img0 is not None, f'Failed to load {img_path}'
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0
        return img_path, img, img0

    def __getitem__(self, idx):
        idx = idx % self.nF
        img_path = self.files[idx]
        img0 = cv2.imread(img_path)
        assert img0 is not None, f'Failed to load {img_path}'
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0
        return img_path, img, img0

    def __len__(self):
        return self.nF

class LoadVideo:
    def __init__(self, path, img_size=(1088, 608)):
        if not os.path.isfile(path):
            raise FileExistsError
        self.cap = cv2.VideoCapture(path)
        self.frame_rate = int(round(self.cap.get(cv2.CAP_PROP_FPS)))
        self.vw = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.vh = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.vn = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        self.width = img_size[0]
        self.height = img_size[1]
        self.count = 0
        self.w, self.h = self.get_size(self.vw, self.vh, self.width, self.height)
        print(f'Length of the video: {self.vn} frames')

    def get_size(self, vw, vh, dw, dh):
        wa, ha = float(dw) / vw, float(dh) / vh
        a = min(wa, ha)
        return int(vw * a), int(vh * a)

    def __iter__(self):
        self.count = -1
        return self

    def __next__(self):
        self.count += 1
        if self.count == self.vn:
            raise StopIteration
        res, img0 = self.cap.read()
        assert img0 is not None, f'Failed to load frame {self.count}'
        img0 = cv2.resize(img0, (self.w, self.h))
        img, _, _, _ = letterbox(img0, height=self.height, width=self.width)
        img = img[:, :, ::-1].transpose(2, 0, 1)
        img = np.ascontiguousarray(img, dtype=np.float32)
        img /= 255.0
        return self.count, img, img0

    def __len__(self):
        return self.vn

class LoadImagesAndLabels:
    def __init__(self, path, img_size=(1920, 1080), augment=False, transforms=None):
        with open(path, 'r') as file:
            self.img_files = file.readlines()
            self.img_files = [x.replace('\n', '') for x in self.img_files]
            self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
        self.label_files = [x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt')
                            for x in self.img_files]
        self.nF = len(self.img_files)
        self.width = img_size[0]
        self.height = img_size[1]
        self.augment = augment
        self.transforms = transforms

    def __getitem__(self, files_index):
        img_path = self.img_files[files_index]
        label_path = self.label_files[files_index]
        return self.get_data(img_path, label_path)

    def get_data(self, img_path, label_path):
        height = self.height
        width = self.width
        img = cv2.imread(img_path)
        if img is None:
            raise ValueError(f'File corrupt {img_path}')
        augment_hsv = True
        if self.augment and augment_hsv:
            fraction = 0.50
            img_hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
            S = img_hsv[:, :, 1].astype(np.float32)
            V = img_hsv[:, :, 2].astype(np.float32)
            a = (random.random() * 2 - 1) * fraction + 1
            S *= a
            if a > 1:
                np.clip(S, a_min=0, a_max=255, out=S)
            a = (random.random() * 2 - 1) * fraction + 1
            V *= a
            if a > 1:
                np.clip(V, a_min=0, a_max=255, out=V)
            img_hsv[:, :, 1] = S.astype(np.uint8)
            img_hsv[:, :, 2] = V.astype(np.uint8)
            cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)

        h, w, _ = img.shape
        img, ratio, padw, padh = letterbox(img, height=height, width=width)
        if os.path.isfile(label_path):
            labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)
            labels = labels0.copy()
            labels[:, 2] = ratio * w * (labels0[:, 2] - labels0[:, 4] / 2) + padw
            labels[:, 3] = ratio * h * (labels0[:, 3] - labels0[:, 5] / 2) + padh
            labels[:, 4] = ratio * w * (labels0[:, 2] + labels0[:, 4] / 2) + padw
            labels[:, 5] = ratio * h * (labels0[:, 3] + labels0[:, 5] / 2) + padh
        else:
            labels = np.array([])

        if self.augment:
            img, labels, M = random_affine(img, labels, degrees=(-5, 5), translate=(0.10, 0.10), scale=(0.50, 1.20))
        nL = len(labels)
        if nL > 0:
            labels[:, 2:6] = xyxy2xywh(labels[:, 2:6].copy())
            labels[:, 2] /= width
            labels[:, 3] /= height
            labels[:, 4] /= width
            labels[:, 5] /= height
        if self.augment:
            lr_flip = True
            if lr_flip & (random.random() > 0.5):
                img = np.fliplr(img)
                if nL > 0:
                    labels[:, 2] = 1 - labels[:, 2]
        img = np.ascontiguousarray(img[:, :, ::-1])
        if self.transforms is not None:
            img = self.transforms(img)
        return torch.tensor(img, dtype=torch.float32), labels, img_path, (h, w)

    def __len__(self):
        return self.nF

def collate_fn(batch):
    imgs, labels, paths, sizes = zip(*batch)
    batch_size = len(labels)
    imgs = torch.stack(imgs, 0)
    max_box_len = max([l.shape[0] for l in labels])
    labels = [torch.from_numpy(l) for l in labels]
    filled_labels = torch.zeros(batch_size, max_box_len, 6)
    labels_len = torch.zeros(batch_size)
    for i in range(batch_size):
        isize = labels[i].shape[0]
        if len(labels[i]) > 0:
            filled_labels[i, :isize, :] = labels[i]
        labels_len[i] = isize
    return imgs, filled_labels, paths, sizes, labels_len.unsqueeze(1)

class JointDataset(LoadImagesAndLabels):
    def __init__(self, root, paths, img_size=(1088, 608), augment=False, transforms=None):
        dataset_names = paths.keys()
        self.img_files = OrderedDict()
        self.label_files = OrderedDict()
        self.tid_num = OrderedDict()
        self.tid_start_index = OrderedDict()
        for ds, path in paths.items():
            with open(path, 'r') as file:
                self.img_files[ds] = file.readlines()
                self.img_files[ds] = [os.path.join(root, x.strip()) for x in self.img_files[ds]]
                self.img_files[ds] = list(filter(lambda x: len(x) > 0, self.img_files[ds]))
            self.label_files[ds] = [x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt').replace('.jpeg', '.txt')
                                    for x in self.img_files[ds]]

        for ds, label_paths in self.label_files.items():
            max_index = -1
            for lp in label_paths:
                lb = np.loadtxt(lp)
                if len(lb) < 1:
                    continue
                if len(lb.shape) < 2:
                    img_max = lb[1]
                else:
                    img_max = np.max(lb[:, 1])
                if img_max > max_index:
                    max_index = img_max
            self.tid_num[ds] = max_index + 1

        last_index = 0
        for i, (k, v) in enumerate(self.tid_num.items()):
            self.tid_start_index[k] = last_index
            last_index += v
        self.nID = int(last_index + 1)
        self.nds = [len(x) for x in self.img_files.values()]
        self.cds = [sum(self.nds[:i]) for i in range(len(self.nds))]
        self.nF = sum(self.nds)
        self.width = img_size[0]
        self.height = img_size[1]
        self.augment = augment
        self.transforms = transforms
        print('=' * 80)
        print('Dataset summary')
        print(self.tid_num)
        print('Total # identities:', self.nID)
        print('Start index')
        print(self.tid_start_index)
        print('=' * 80)

    def __getitem__(self, files_index):
        for i, c in enumerate(self.cds):
            if files_index >= c:
                ds = list(self.label_files.keys())[i]
                start_index = c
        img_path = self.img_files[ds][files_index - start_index]
        label_path = self.label_files[ds][files_index - start_index]
        imgs, labels, img_path, (h, w) = self.get_data(img_path, label_path)
        for i, _ in enumerate(labels):
            if labels[i, 1] > -1:
                labels[i, 1] += self.tid_start_index[ds]
        return imgs, labels, img_path, (h, w)


### DATA

In [5]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from matplotlib import pyplot as plt
import json

In [6]:
# import os

# directory = "/content/drive/MyDrive/sanity/images"
# counter = 1

# for root, dirs, files in os.walk(directory):
#     counter = 1
#     for file in files:
#         if file.startswith("frame_"):
#             image_path = os.path.join(root, file)
#             new_name = f"{os.path.basename(root)}_{counter}.jpeg"
#             new_path = os.path.join(root, new_name)
#             os.rename(image_path, new_path)
#             counter += 1


In [7]:
# import os
# import csv

# root_directory = "/content/drive/MyDrive/sanity/images"

# with open("sanity.train", "w", newline="") as train_file:
#     writer = csv.writer(train_file)

#     for dirpath, dirnames, filenames in os.walk(root_directory):
#         for filename in filenames:
#             base_name, extension = os.path.splitext(filename)
#             if extension.lower() in [".jpg", ".jpeg", ".png"]:
#               image_path = os.path.join(dirpath, filename)
#               writer.writerow([image_path])


In [None]:
# Configure run
f = open('/content/drive/MyDrive/sanity/images/data_sanity.json')
data_config = json.load(f)
trainset_paths = data_config['train']
dataset_root = data_config['root']
f.close()


# Initialize datasets and dataloaders
dataset = JointDataset('/content/drive/MyDrive/sanity', paths=trainset_paths, img_size=(224, 224), augment=False, transforms=None)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)


In [None]:
# import matplotlib.pyplot as plt

# def visualize_image(dataset, idx, plot_width=10, plot_height=10):
#     # Load image and annotations
#     img, labels, img_path, (h, w) = dataset[idx]

#     img = img

#     # Plot the image
#     fig, ax = plt.subplots(1, figsize=(plot_width, plot_height))
#     ax.imshow(img)

#     # Plot each bounding box
#     for label in labels:
#         if label.sum() == 0:  # Skip empty labels
#             continue
#         class_id, obj_id, x_center, y_center, width, height = label
#         x_min = int((x_center - width / 2) * img.shape[1])
#         y_min = int((y_center - height / 2) * img.shape[0])
#         x_max = int((x_center + width / 2) * img.shape[1])
#         y_max = int((y_center + height / 2) * img.shape[0])
#         rect = plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, edgecolor='yellow', facecolor='none', linewidth=1)
#         ax.add_patch(rect)
#         plt.text(x_min, y_min - 10, f'ID: {int(obj_id)}', color='red', fontsize=5, backgroundcolor="none")

#     plt.show()

In [None]:
# sample_idx = 200  # Change this to visualize different images
# visualize_image(dataset, sample_idx, 12, 8)

### MODEL

In [None]:
# import torch.nn as nn
# import torch

# class ViTBackbone(nn.Module):
#     def __init__(self, embed_dim=768, num_heads=12, num_layers=12, num_classes=1):
#         super(ViTBackbone, self).__init__()
#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.num_layers = num_layers
#         self.num_classes = num_classes

#         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#         self.pos_embed = nn.Parameter(torch.zeros(1, 1 + 1, embed_dim))
#         self.pos_drop = nn.Dropout(p=0.1)

#         self.patch_embed = nn.Conv2d(3, embed_dim, kernel_size=16, stride=16)
#         self.transformer = nn.Transformer(embed_dim, num_heads, num_layers)
#         self.head = nn.Linear(embed_dim, num_classes)

#     def forward(self, x):
#         B = x.size(0)
#         x = self.patch_embed(x).flatten(2).transpose(1, 2)
#         cls_tokens = self.cls_token.expand(B, -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)
#         x = self.pos_drop(x + self.pos_embed)

#         x = self.transformer(x)
#         x = self.head(x[:, 0])
#         return x

# class DetectionHead(nn.Module):
#     def __init__(self, in_channels, num_classes):
#         super(DetectionHead, self).__init__()
#         self.heatmap_head = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(256, num_classes, kernel_size=1)
#         )
#         self.offset_head = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(256, 2, kernel_size=1)
#         )
#         self.size_head = nn.Sequential(
#             nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
#             nn.ReLU(),
#             nn.Conv2d(256, 2, kernel_size=1)
#         )

#     def forward(self, x):
#         heatmap = self.heatmap_head(x)
#         offset = self.offset_head(x)
#         size = self.size_head(x)
#         return heatmap, offset, size

# class ReIDHead(nn.Module):
#     def __init__(self, in_channels, embed_dim=128):
#         super(ReIDHead, self).__init__()
#         self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=1)

#     def forward(self, x):
#         return self.conv(x)

# class CustomModel(nn.Module):
#     def __init__(self, num_classes=1):
#         super(CustomModel, self).__init__()
#         self.backbone = ViTBackbone()
#         self.det_head = DetectionHead(self.backbone.embed_dim, num_classes)
#         self.reid_head = ReIDHead(self.backbone.embed_dim)

#     def forward(self, x):
#         features = self.backbone(x)
#         heatmap, offset, size = self.det_head(features)
#         reid_features = self.reid_head(features)
#         return heatmap, offset, size, reid_features


In [None]:
!pip install einops

In [None]:
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, mask = None):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x


class ViTBackbone(nn.Module):
    def __init__(self, image_size, patch_size, num_layers, embed_dim, num_heads, mlp_dim, channels=3, dropout=0.1, emb_dropout=0.1):
        super(ViTBackbone, self).__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        self.patch_embed = PatchEmbedding(image_size, patch_size, in_channels=3, embed_dim=embed_dim)
        # nn.Sequential(
        #     Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
        #     nn.Linear(patch_dim, embed_dim)
        # )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(embed_dim, num_layers, num_heads, embed_dim // num_heads, mlp_dim, dropout)

    def forward(self, x):
        b, c, h, w = x.shape
        x = self.patch_embed(x)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n+1)]
        x = self.dropout(x)

        x = self.transformer(x)

        return x

class DetectionHead(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(DetectionHead, self).__init__()
        self.heatmap_head = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, num_classes, kernel_size=1)
        )
        self.offset_head = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 2, kernel_size=1)
        )
        self.size_head = nn.Sequential(
            nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 2, kernel_size=1)
        )

    def forward(self, x):
        heatmap = self.heatmap_head(x)
        offset = self.offset_head(x)
        size = self.size_head(x)
        return heatmap, offset, size

class ReIDHead(nn.Module):
    def __init__(self, in_channels, embed_dim=128):
        super(ReIDHead, self).__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class CustomModel(nn.Module):
    def __init__(self, image_size, patch_size, num_layers, embed_dim, num_heads, mlp_dim, num_classes=1, channels=3, dropout=0.1, emb_dropout=0.1):
        super(CustomModel, self).__init__()
        self.backbone = ViTBackbone(image_size, patch_size, num_layers, embed_dim, num_heads, mlp_dim, channels, dropout, emb_dropout)
        self.det_head = DetectionHead(embed_dim, num_classes)
        self.reid_head = ReIDHead(embed_dim)

    def forward(self, x):
        features = self.backbone(x)
        B, N, C = features.shape
        patch_dim = int((N - 1) ** 0.5)  # Number of patches per dimension
        features = features[:, 1:].transpose(1, 2).reshape(B, C, patch_dim, patch_dim)
        heatmap, offset, size = self.det_head(features)
        reid_features = self.reid_head(features)
        return heatmap, offset, size, reid_features

In [None]:
def gaussian_radius(det_size, min_overlap=0.7):
    height, width = det_size

    a1 = 1
    b1 = (height + width)
    c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
    sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
    r1 = (b1 + sq1) / 2

    a2 = 4
    b2 = 2 * (height + width)
    c2 = (1 - min_overlap) * width * height
    sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
    r2 = (b2 + sq2) / 2

    a3 = 4 * min_overlap
    b3 = 2 * min_overlap * (height + width)
    c3 = (min_overlap - 1) * width * height
    sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
    r3 = (b3 + sq3) / 2

    return min(r1, r2, r3)

def draw_umich_gaussian(heatmap, center, radius, k=1):
    diameter = 2 * radius + 1
    gaussian = np.zeros((diameter, diameter), dtype=np.float32)
    gaussian = cv2.getGaussianKernel(diameter, diameter / 6)
    gaussian = gaussian * gaussian.T

    x, y = center

    height, width = heatmap.shape[:2]

    left, right = min(x, radius), min(width - x, radius + 1)
    top, bottom = min(y, radius), min(height - y, radius + 1)

    masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
    masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
    if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
        np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
    return heatmap

In [None]:
def compute_target_heatmap(bboxes, output_size, sigma=1):
    """
    Compute the target heatmap.
    Args:
        bboxes (torch.Tensor): Bounding boxes, shape (N, 4).
        output_size (tuple): Size of the output heatmap, (H, W).
        sigma (float): Standard deviation for the Gaussian kernel.
    Returns:
        torch.Tensor: Target heatmap, shape (H, W).
    """
    height, width = output_size
    target_heatmap = torch.zeros(height, width)

    for bbox in bboxes:
        x_center, y_center, w, h = bbox
        x_center = int(x_center * width)
        y_center = int(y_center * height)

        temp = torch.zeros(height, width)
        temp[y_center, x_center] = 1
        temp = torch.nn.functional.gaussian_blur(temp.unsqueeze(0).unsqueeze(0), (sigma, sigma), sigma)
        target_heatmap = torch.max(target_heatmap, temp.squeeze(0).squeeze(0))

    return target_heatmap

def compute_target_offset(bboxes, output_size):
    """
    Compute the target offset.
    Args:
        bboxes (torch.Tensor): Bounding boxes, shape (N, 4).
        output_size (tuple): Size of the output offset map, (H, W).
    Returns:
        torch.Tensor: Target offset, shape (H, W, 2).
    """
    height, width = output_size
    target_offset = torch.zeros(height, width, 2)

    for bbox in bboxes:
        x_center, y_center, w, h = bbox
        x_center = int(x_center * width)
        y_center = int(y_center * height)

        for y in range(height):
            for x in range(width):
                target_offset[y, x, 0] = x_center - x
                target_offset[y, x, 1] = y_center - y

    return target_offset


def compute_target_size(bboxes, output_size):
    """
    Compute the target size.
    Args:
        bboxes (torch.Tensor): Bounding boxes, shape (N, 4).
        output_size (tuple): Size of the output size map, (H, W).
    Returns:
        torch.Tensor: Target size, shape (H, W, 2).
    """
    height, width = output_size
    target_size = torch.zeros(height, width, 2)

    for bbox in bboxes:
        x_center, y_center, w, h = bbox
        x_center = int(x_center * width)
        y_center = int(y_center * height)
        w = int(w * width)
        h = int(h * height)

        x1, x2 = max(0, x_center - w // 2), min(width, x_center + w // 2)
        y1, y2 = max(0, y_center - h // 2), min(height, y_center + h // 2)

        for y in range(y1, y2):
            for x in range(x1, x2):
                target_size[y, x, 0] = w
                target_size[y, x, 1] = h

    return target_size


In [None]:
import torch.optim as optim
from tqdm import tqdm

# Define custom loss functions
def heatmap_loss(pred_heatmap, target_heatmap):
    return F.mse_loss(pred_heatmap, target_heatmap)

def offset_loss(pred_offset, target_offset):
    return F.smooth_l1_loss(pred_offset, target_offset)

def size_loss(pred_size, target_size):
    return F.smooth_l1_loss(pred_size, target_size)

def reid_loss(pred_reid, target_reid):
    return F.mse_loss(pred_reid, target_reid)

def joint_loss(pred_heatmap, pred_offset, pred_size, pred_reid, target_heatmap, target_offset, target_size, target_reid):
    return (heatmap_loss(pred_heatmap, target_heatmap) +
            offset_loss(pred_offset, target_offset) +
            size_loss(pred_size, target_size) +
            reid_loss(pred_reid, target_reid))



In [None]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, filled_labels, paths, sizes, labels_len) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch}")):
        # Transpose images to shape [batch_size, channels, height, width]
        images = images.permute(0, 3, 1, 2).to(device)
        filled_labels = filled_labels.to(device)

        optimizer.zero_grad()
        pred_heatmap, pred_offset, pred_size, pred_reid = model(images)

        # Extract targets from filled_labels
        bboxes = filled_labels[..., 2:6]
        embeddings = filled_labels[..., 1]

        output_size = (images.shape[2] // 4, images.shape[3] // 4)  # Assuming the output feature map is 1/4 the input size

        target_heatmap = compute_target_heatmap(bboxes, output_size)
        target_offset = compute_target_offset(bboxes, output_size)
        target_size = compute_target_size(bboxes)
        target_reid = embeddings  # Assuming target re-id is the embeddings

        target_heatmap = target_heatmap.to(device)
        target_offset = target_offset.to(device)
        target_size = target_size.to(device)
        target_reid = target_reid.to(device)

        loss = joint_loss(pred_heatmap, pred_offset, pred_size, pred_reid,
                          target_heatmap, target_offset, target_size, target_reid)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)


In [None]:
# Initialize the model, optimizer, and device
model = CustomModel(image_size=224, patch_size=16, num_layers=12, embed_dim=768, num_heads=12, mlp_dim=768, num_classes=1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train(model, device, dataloader, optimizer, epoch)
    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss:.4f}')

### Appendix (Code Samples)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class MLP(nn.Module):
    def __init__(self, hidden_units, dropout_rate):
        super(MLP, self).__init__()
        layers = []
        for units in hidden_units:
            layers.append(nn.Linear(units[0], units[1]))
            layers.append(nn.GELU())
            layers.append(nn.Dropout(dropout_rate))
        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H/P, W/P)
        x = x.flatten(2)  # (B, embed_dim, N)
        x = x.transpose(1, 2)  # (B, N, embed_dim)
        return x


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

    def forward(self, x):
        return x + self.pos_embed


In [None]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * mlp_ratio),
            nn.GELU(),
            nn.Linear(embed_dim * mlp_ratio, embed_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [None]:
class ViTDetector(nn.Module):
    def __init__(self, img_size=500, patch_size=16, num_classes=1, embed_dim=768, num_heads=12, mlp_ratio=4, depth=12, dropout_rate=0.1):
        super(ViTDetector, self).__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels=3, embed_dim=embed_dim)
        self.pos_embed = PositionalEncoding(self.num_patches, embed_dim)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout_rate) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            MLP(hidden_units=[(embed_dim, embed_dim * mlp_ratio), (embed_dim * mlp_ratio, embed_dim)], dropout_rate=dropout_rate),
            nn.Linear(embed_dim, 4)  # output bounding box coordinates
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_embed(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

model = ViTDetector()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


NameError: name 'PatchEmbedding' is not defined

In [None]:
def detection_loss(pred_bboxes, pred_class_logits, target_bboxes, target_labels):
    # Bounding box regression loss
    bbox_loss = F.smooth_l1_loss(pred_bboxes, target_bboxes, reduction='mean')

    # Classification loss
    class_loss = F.cross_entropy(pred_class_logits, target_labels, reduction='mean')

    return bbox_loss + class_loss


In [None]:
import torch.optim as optim
from tqdm import tqdm


In [None]:
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.0001)

def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, bboxes, labels) in enumerate(tqdm(train_loader, desc=f"Training Epoch {epoch}")):
        images = torch.stack(images).to(device)
        bboxes = torch.stack(bboxes).to(device)
        labels = torch.stack(labels).to(device)

        optimizer.zero_grad()
        pred_bboxes, pred_class_logits = model(images)
        loss = criterion(pred_bboxes, pred_class_logits, bboxes, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def validate(model, device, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, bboxes, labels in tqdm(val_loader, desc="Validating"):
            images = torch.stack(images).to(device)
            bboxes = torch.stack(bboxes).to(device)
            labels = torch.stack(labels).to(device)

            pred_bboxes, pred_class_logits = model(images)
            loss = criterion(pred_bboxes, pred_class_logits, bboxes, labels)
            val_loss += loss.item()
    return val_loss / len(val_loader)

num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train(model, device, train_loader, detection_loss, optimizer, epoch)
    val_loss = validate(model, device, val_loader, detection_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')


Training Epoch 0:   0%|          | 0/520 [00:07<?, ?it/s]


RuntimeError: stack expects each tensor to be equal size, but got [11, 4] at entry 0 and [13, 4] at entry 1

In [None]:
class FairMOTViT(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', num_classes=1, reid_dim=128):
        super(FairMOTViT, self).__init__()
        self.vit = create_model(vit_model_name, pretrained=True, num_classes=0)  # no classifier head
        self.heatmap_head = nn.Conv2d(768, 1, kernel_size=3, padding=1)  # Heatmap for object detection
        self.reid_head = nn.Conv2d(768, reid_dim, kernel_size=3, padding=1)  # Re-ID features
        self.bbox_head = nn.Conv2d(768, 4, kernel_size=3, padding=1)  # Bounding box coordinates

    def forward(self, x):
        features = self.vit.forward_features(x)
        B, N, C = features.shape
        H = W = int(N**0.5)
        features = features.permute(0, 2, 1).view(B, C, H, W)

        heatmap = self.heatmap_head(features)
        reid_features = self.reid_head(features)
        bbox_regression = self.bbox_head(features)

        return heatmap, reid_features, bbox_regression


In [None]:
class FairMOTViT(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', num_classes=1, reid_dim=128):
        super(FairMOTViT, self).__init__()
        self.vit = create_model(vit_model_name, pretrained=True, num_classes=0)  # no classifier head
        self.heatmap_head = nn.Conv2d(768, 1, kernel_size=3, padding=1)  # Heatmap for object detection
        self.reid_head = nn.Conv2d(768, reid_dim, kernel_size=3, padding=1)  # Re-ID features
        self.bbox_head = nn.Conv2d(768, 4, kernel_size=3, padding=1)  # Bounding box coordinates

    def forward(self, x):
        features = self.vit.forward_features(x)
        B, N, C = features.shape
        H = W = int(N**0.5)
        features = features.permute(0, 2, 1).view(B, C, H, W)

        heatmap = self.heatmap_head(features)
        reid_features = self.reid_head(features)
        bbox_regression = self.bbox_head(features)

        return heatmap, reid_features, bbox_regression



num_classes = 1
reid_dim = 128  # Dimension of Re-ID features
model = FairMOTViT(num_classes=num_classes, reid_dim=reid_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
def tracking_loss(pred_heatmap, pred_reid_features, pred_bbox_regression,
                  target_heatmap, target_reid_features, target_bbox):
    # Heatmap loss (detection loss)
    heatmap_loss = F.mse_loss(pred_heatmap, target_heatmap)

    # Re-ID loss (triplet loss or contrastive loss can be used)
    reid_loss = F.mse_loss(pred_reid_features, target_reid_features)

    # Bounding box regression loss (smooth L1 loss)
    bbox_regression_loss = F.smooth_l1_loss(pred_bbox_regression, target_bbox)

    total_loss = heatmap_loss + reid_loss + bbox_regression_loss
    return total_loss

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, sample in enumerate(train_loader):
        inputs = sample['image'].to(device)
        target_heatmap = sample['class_label'].float().to(device)
        target_reid_features = sample['identity'].float().to(device)
        target_bbox = sample['bbox'].to(device)

        optimizer.zero_grad()
        heatmap, reid_features, bbox_regression = model(inputs)

        loss = tracking_loss(heatmap, reid_features, bbox_regression,
                             target_heatmap, target_reid_features, target_bbox)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch}], Step [{batch_idx}], Loss: {loss.item():.4f}')
    return running_loss / len(train_loader)

def validate(model, device, val_loader):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for sample in val_loader:
            inputs = sample['image'].to(device)
            target_heatmap = sample['class_label'].float().to(device)
            target_reid_features = sample['identity'].float().to(device)
            target_bbox = sample['bbox'].to(device)

            heatmap, reid_features, bbox_regression = model(inputs)

            loss = tracking_loss(heatmap, reid_features, bbox_regression,
                                 target_heatmap, target_reid_features, target_bbox)

            val_loss += loss.item()
    val_loss /= len(val_loader)
    print(f'Validation Loss: {val_loss:.4f}')
    return val_loss

num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train(model, device, train_loader, optimizer, epoch)
    val_loss = validate(model, device, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

In [None]:
import torch
import torch.nn as nn
from timm import create_model

class CustomViT(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', num_classes=1):
        super(CustomViT, self).__init__()
        self.vit = create_model(vit_model_name, pretrained=True, num_classes=num_classes)

        # Define the 3x3 convolutional heads with 256 channels
        self.heatmap_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
        self.offset_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
        self.size_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)

        # Define the re-ID convolutional layer with 128 kernels
        self.reid_conv = nn.Conv2d(in_channels=768, out_channels=128, kernel_size=3, padding=1)

    def forward(self, x):
        # Pass through the ViT model
        features = self.vit.forward_features(x)

        # Reshape features to 2D spatial dimensions if needed (H, W)
        B, N, C = features.shape
        H = W = int(N**0.5)
        features = features.permute(0, 2, 1).view(B, C, H, W)

        # Apply each head to the features
        heatmap = self.heatmap_head(features)
        offset = self.offset_head(features)
        size = self.size_head(features)

        # Extract re-ID features
        reid_features = self.reid_conv(features)

        return heatmap, offset, size, reid_features


