<a href="https://colab.research.google.com/github/Karim-Anwar/MasterThesis/blob/main/AttemptFairViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install timm

Collecting timm
  Downloading timm-1.0.3-py3-none-any.whl (2.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->timm)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->timm)
  Using cache

In [2]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

In [3]:
class CustomDataset(Dataset):
    def __init__(self, image_dir, annot_dir, transform=None):
        self.image_dir = image_dir
        self.annot_dir = annot_dir
        self.transform = transform

        self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Load corresponding annotation file
        annot_path = os.path.join(self.annot_dir, os.path.basename(img_path).replace('.jpg', '.txt').replace('.jpeg', '.txt').replace('.png', '.txt'))
        with open(annot_path, 'r') as f:
            annotation = f.read()

        if self.transform:
            image = self.transform(image)

        return image, annotation


In [None]:
class CustomTrackingDataset(Dataset):
    def __init__(self, image_dir, annot_dir, transform=None):
        self.image_dir = image_dir
        self.annot_dir = annot_dir
        self.transform = transform

        self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")

        # Load corresponding annotation file
        annot_path = os.path.join(self.annot_dir, os.path.basename(img_path).replace('.jpg', '.txt').replace('.jpeg', '.txt').replace('.png', '.txt'))
        with open(annot_path, 'r') as f:
            annotation = f.read().strip().split()

        class_label = int(annotation[0])  # Always 0 in this case
        identity = int(annotation[1])
        x_center = float(annotation[2])
        y_center = float(annotation[3])
        width = float(annotation[4])
        height = float(annotation[5])

        bbox = [x_center, y_center, width, height]

        if self.transform:
            image = self.transform(image)

        sample = {
            'image': image,
            'class_label': class_label,
            'identity': identity,
            'bbox': torch.tensor(bbox)
        }

        return sample

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = CustomTrackingDataset(image_dir='/content/drive/MyDrive/sanity/train', annot_dir='/content/drive/MyDrive/sanity/train/annots', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = CustomTrackingDataset(image_dir='/content/drive/MyDrive/sanity/val', annot_dir='/content/drive/MyDrive/sanity/val/annots', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
class FairMOTViT(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', num_classes=1, reid_dim=128):
        super(FairMOTViT, self).__init__()
        self.vit = create_model(vit_model_name, pretrained=True, num_classes=0)  # no classifier head
        self.heatmap_head = nn.Conv2d(768, 1, kernel_size=3, padding=1)  # Heatmap for object detection
        self.reid_head = nn.Conv2d(768, reid_dim, kernel_size=3, padding=1)  # Re-ID features
        self.bbox_head = nn.Conv2d(768, 4, kernel_size=3, padding=1)  # Bounding box coordinates

    def forward(self, x):
        features = self.vit.forward_features(x)
        B, N, C = features.shape
        H = W = int(N**0.5)
        features = features.permute(0, 2, 1).view(B, C, H, W)

        heatmap = self.heatmap_head(features)
        reid_features = self.reid_head(features)
        bbox_regression = self.bbox_head(features)

        return heatmap, reid_features, bbox_regression



num_classes = 1
reid_dim = 128  # Dimension of Re-ID features
model = FairMOTViT(num_classes=num_classes, reid_dim=reid_dim)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
def tracking_loss(pred_heatmap, pred_reid_features, pred_bbox_regression,
                  target_heatmap, target_reid_features, target_bbox):
    # Heatmap loss (detection loss)
    heatmap_loss = F.mse_loss(pred_heatmap, target_heatmap)

    # Re-ID loss (triplet loss or contrastive loss can be used)
    reid_loss = F.mse_loss(pred_reid_features, target_reid_features)

    # Bounding box regression loss (smooth L1 loss)
    bbox_regression_loss = F.smooth_l1_loss(pred_bbox_regression, target_bbox)

    total_loss = heatmap_loss + reid_loss + bbox_regression_loss
    return total_loss

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, sample in enumerate(train_loader):
        inputs = sample['image'].to(device)
        target_heatmap = sample['class_label'].float().to(device)
        target_reid_features = sample['identity'].float().to(device)
        target_bbox = sample['bbox'].to(device)

        optimizer.zero_grad()
        heatmap, reid_features, bbox_regression = model(inputs)

        loss = tracking_loss(heatmap, reid_features, bbox_regression,
                             target_heatmap, target_reid_features, target_bbox)

        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch}], Step [{batch_idx}], Loss: {loss.item():.4f}')
    return running_loss / len(train_loader)

def validate(model, device, val_loader):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for sample in val_loader:
            inputs = sample['image'].to(device)
            target_heatmap = sample['class_label'].float().to(device)
            target_reid_features = sample['identity'].float().to(device)
            target_bbox = sample['bbox'].to(device)

            heatmap, reid_features, bbox_regression = model(inputs)

            loss = tracking_loss(heatmap, reid_features, bbox_regression,
                                 target_heatmap, target_reid_features, target_bbox)

            val_loss += loss.item()
    val_loss /= len(val_loader)
    print(f'Validation Loss: {val_loss:.4f}')
    return val_loss

num_epochs = 10
best_val_loss = float('inf')

for epoch in range(num_epochs):
    train_loss = train(model, device, train_loader, optimizer, epoch)
    val_loss = validate(model, device, val_loader)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

In [12]:
import torch
import torch.nn as nn
from timm import create_model

class CustomViT(nn.Module):
    def __init__(self, vit_model_name='vit_base_patch16_224', num_classes=1):
        super(CustomViT, self).__init__()
        self.vit = create_model(vit_model_name, pretrained=True, num_classes=num_classes)

        # Define the 3x3 convolutional heads with 256 channels
        self.heatmap_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
        self.offset_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)
        self.size_head = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)

        # Define the re-ID convolutional layer with 128 kernels
        self.reid_conv = nn.Conv2d(in_channels=768, out_channels=128, kernel_size=3, padding=1)

    def forward(self, x):
        # Pass through the ViT model
        features = self.vit.forward_features(x)

        # Reshape features to 2D spatial dimensions if needed (H, W)
        B, N, C = features.shape
        H = W = int(N**0.5)
        features = features.permute(0, 2, 1).view(B, C, H, W)

        # Apply each head to the features
        heatmap = self.heatmap_head(features)
        offset = self.offset_head(features)
        size = self.size_head(features)

        # Extract re-ID features
        reid_features = self.reid_conv(features)

        return heatmap, offset, size, reid_features




In [4]:
from PIL import Image

In [None]:
imape =