In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch
from torch import nn, device
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision import models, tv_tensors
from torchvision.io import read_image
from torchvision.transforms import v2
from torchmetrics.detection.iou import IntersectionOverUnion
from tqdm.notebook import tqdm
%matplotlib inline

In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
class ImageDataset(Dataset):
    def __init__(self, img_labels, img_dir, transform=None, target_transform=None):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        labels = self.img_labels.iloc[idx, 1].copy()
        boxes = self.img_labels.iloc[idx, 2].copy()
        boxes[:, 0], boxes[:, 2] = boxes[:, 0] - boxes[:, 2] / 2, boxes[:, 0] + boxes[:, 2] / 2
        boxes[:, 1], boxes[:, 3] = boxes[:, 1] - boxes[:, 3] / 2, boxes[:, 1] + boxes[:, 3] / 2
        boxes *= 640
        boxes = tv_tensors.BoundingBoxes(boxes, format='XYXY', canvas_size=(640, 640))
        if self.transform is not None:
            image, boxes = self.transform(image, boxes)
        if self.target_transform is not None:
            labels = self.target_transform(labels)
        return image, labels, boxes

In [None]:
path = 'train'
annotation = pd.DataFrame(
    {'filename': [f for f in os.listdir(f'{path}/images/') if os.path.isfile(os.path.join(f'{path}/images', f))]})
annotation['class'] = annotation['filename'].apply(
    lambda x: np.loadtxt(os.path.join(f'{path}/labels', f'{x[:-4]}.txt')))
# print(len(os.listdir(f'{path}/images/')), len(annotation))
# annotation['class'].apply(lambda x: x.shape[-1] != 5).sum()
annotation = annotation[annotation['class'].apply(lambda x: x.shape[-1] == 5)]
annotation.reset_index(drop=True, inplace=True)
annotation['class'] = annotation['class'].apply(lambda x: x.reshape((-1, 5)))
annotation['bbox'] = annotation['class'].apply(lambda x: x[:, 1:].astype(np.float32))
annotation['class'] = annotation['class'].apply(lambda x: x[:, 0].astype(np.int64))

In [None]:
train_data = ImageDataset(
    img_labels=annotation,
    img_dir=f'{path}/images',
    transform=v2.Compose([
        v2.Resize((64, 64), antialias=True),
        v2.RandomHorizontalFlip(0.5),
        v2.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        v2.ToImage(),
        v2.ToDtype(torch.float, scale=True),
        # v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    target_transform=v2.Compose([
        torch.tensor,
    ])
)
train_data[0]

In [None]:
train_dataloader = DataLoader(train_data, batch_size=1, shuffle=True)
# test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

In [None]:
model = models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
model = model.to(device)

In [None]:
def train(dataloader, model, optimizer):
    model.train()
    for image, target in tqdm(dataloader):
        image = image.reshape((3, 64, 64)).to(device)
        target['boxes'] = target['boxes'].reshape((-1, 4)).to(device)
        target['labels'] = target['labels'].reshape((-1,)).to(device)
        optimizer.zero_grad()
        pred = model([image], [target])
        loss = sum(pred.values())
        loss.backward()
        optimizer.step()
    torch.save(model.state_dict(), "model.pth")

In [None]:
def test(dataloader, model, title=''):
    model.load_state_dict(torch.load("model.pth"))
    iou = IntersectionOverUnion()
    loss_cls, loss_box = 0, 0
    model.eval()
    with torch.no_grad():
        for image, target in tqdm(dataloader):
            image = image.reshape((3, 64, 64)).to(device)
            target['boxes'] = target['boxes'].reshape((-1, 4)).to(device)
            target['labels'] = target['labels'].reshape((-1,)).to(device)
            pred = model([image])
            iou.update(pred, [target])
    model.train()
    # from torchmetrics.detection.mean_ap import MeanAveragePrecision
    with torch.no_grad():
        for image, target in tqdm(dataloader):
            image = image.reshape((3, 64, 64)).to(device)
            target['boxes'] = target['boxes'].reshape((-1, 4)).to(device)
            target['labels'] = target['labels'].reshape((-1,)).to(device)
            pred = model([image], [target])
            loss_cls += pred['loss_classifier']
            loss_box += pred['loss_box_reg']
    iou = iou.compute()['map_50'].cpu().item()
    loss_cls = loss_cls.cpu().item() / len(dataloader)
    loss_box = loss_box.cpu().item() / len(dataloader)
    print(f"{title} Error:\nLoss cls: {loss_cls}\nLoss box: {loss_box}\nmAP: {map}")
    return [loss_cls, loss_box, iou]

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

In [None]:
epochs = 2
metrics = [0] * epochs
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, optimizer)
    metrics[t] = test(train_dataloader, model, 'Train') + test(test_dataloader, model, 'Test')
    scheduler.step()
print("Done!")
metrics[-1]

In [None]:
metrics = np.array(metrics)
fig, ax = plt.subplots(1, 3)
sns.lineplot(data=pd.DataFrame(metrics[:, ::3], columns=['train', 'test']), ax=ax[0]).set(title='cls_loss')
sns.lineplot(data=pd.DataFrame(metrics[:, 1::3], columns=['train', 'test']), ax=ax[1]).set(title='box_loss')
sns.lineplot(data=pd.DataFrame(metrics[:, 2::3], columns=['train', 'test']), ax=ax[2]).set(title='IoU')