In [1]:
import os
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from torch.utils.data import random_split, DataLoader
from PIL import Image
import matplotlib.pyplot as plt

**Create Custom DataLoader**

The CarDataset class is a custom dataset loader that loads vehicle images from the specified folder and associates them with labels.

In [2]:
class CarDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir)]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path)
        
        # Extract vehicle ID from filename
        label = int(self.image_files[idx].split('_')[0]) - 1
        
        if self.transform:
            image = self.transform(image)
        
        return (image, label)

**Load Dataset and Transform Images into Tensors**

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = CarDataset(image_dir='VID_ReID_Simulation', transform=transform)

**Split into Training and Testing Sets**

In [4]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(dataset=train_dataset, batch_size = 32, shuffle=True)
test_loader = DataLoader(dataset=train_dataset, batch_size=32)

In [5]:
for batch_idx, (batch_data, batch_labels) in enumerate(train_loader):
    print(batch_data)

tensor([[[[-0.7308, -0.8164, -0.9363,  ..., -1.3130, -1.3302, -1.2445],
          [-0.7137, -0.7993, -0.8849,  ..., -1.3130, -1.2959, -1.2103],
          [-0.6794, -0.7308, -0.7993,  ..., -1.3644, -1.2617, -1.1418],
          ...,
          [-0.6281, -0.6109, -0.6109,  ..., -0.8678, -0.9363, -1.0390],
          [-0.6109, -0.6109, -0.5938,  ..., -0.8849, -0.9534, -1.0390],
          [-0.5938, -0.5938, -0.5767,  ..., -0.8849, -0.9534, -1.0390]],

         [[-0.0224, -0.1099, -0.2325,  ..., -0.9853, -1.0028, -0.9153],
          [-0.0049, -0.0924, -0.1800,  ..., -0.9853, -0.9678, -0.8803],
          [ 0.0301, -0.0224, -0.0924,  ..., -1.0378, -0.9328, -0.8102],
          ...,
          [-0.1800, -0.1625, -0.1625,  ..., -0.4251, -0.4951, -0.6001],
          [-0.1625, -0.1625, -0.1450,  ..., -0.4426, -0.5126, -0.6001],
          [-0.1450, -0.1450, -0.1275,  ..., -0.4426, -0.5126, -0.6001]],

         [[ 0.2348,  0.1476,  0.0256,  ..., -0.9330, -0.9504, -0.8633],
          [ 0.2522,  0.1651,  

KeyboardInterrupt: 