In [None]:
%pip install ipykernel==6.29.4
%pip install python-dotenv==1.0.1
%pip install simclr==1.0.2
%pip install roboflow==1.1.27

In [None]:
# Download dataset

from roboflow import Roboflow
from dotenv import load_dotenv
import os
load_dotenv()

rf = Roboflow(api_key=os.getenv("ROBOFLOW_API_KEY"))
project = rf.workspace("yaid-pzikt").project("firefighting-device-detection")
version = project.version(6)
dataset = version.download("yolov8")
dataset.__dict__

In [None]:
# Load dataset info

import os
import yaml

data_dir = './Firefighting-Device-Detection-6'
imagenet_int_to_str = {}

with open(os.path.join(data_dir, 'data.yaml'), 'r') as f:
  data = yaml.safe_load(f)

labels = data.get('names', [])
num_classes = data.get('nc', [])
assert len(labels) == num_classes
labels, num_classes

In [None]:
# Download model

model_path = "checkpoint_100.tar"
!wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/$model_path

In [None]:
# Config: https://github.com/Spijkervet/SimCLR/blob/master/config/config.yaml

config = {
    "image_size": 224,
    "workers": 8,
    "resnet": "resnet50",
    "projection_dim": 64,
    "max_boxes": 50,
    "trials": 5,
    "epochs": 15,
    "batch_size": [2, 8], # range to sample from
    "lr": [1e-7, 1e-4],
    "weight_decay": [1e-6, 1e-4],
}

In [None]:
# Fixed SimCLR transforms: https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/transformations/simclr.py

import torchvision


class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, size):
        s = 1
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
            ]
        )

        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize(size=(size, size)),  # Single int didn't work
                torchvision.transforms.ToTensor(),
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

In [None]:
# Dataloader

import glob
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image


class DetectionDataset(Dataset):
    def __init__(self, split):
        path = os.path.join(data_dir, split)
        images_dir = os.path.join(path, "images")
        labels_dir = os.path.join(path, "labels")
        self.image_paths = glob.glob(os.path.join(images_dir, "*.jpg"))
        self.label_paths = glob.glob(os.path.join(labels_dir, "*.txt"))

        transform = TransformsSimCLR(size=config["image_size"])
        if split == "train":
            self.apply_transform = transform.train_transform
        else:
            self.apply_transform = transform.test_transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label_path = self.label_paths[idx]
        
        image = Image.open(img_path).convert("RGB")
        image = self.apply_transform(image)

        annotations = []
        with open(label_path, 'r') as file:
            for line in file:
                class_id, x_min, y_min, x_max, y_max = map(float, line.strip().split())
                annotations.append((int(class_id), x_min, y_min, x_max, y_max))

        # Pad if necessary
        if len(annotations) > config["max_boxes"]:
            annotations = annotations[:config["max_boxes"]]
        elif len(annotations) < config["max_boxes"]:
            annotations += [(0, 0, 0, 0, 0)] * (config["max_boxes"] - len(annotations))

        class_ids = torch.tensor([a[0] for a in annotations], dtype=torch.long)
        boxes = torch.tensor([a[1:] for a in annotations], dtype=torch.float32)
        
        return image, class_ids, boxes
    

def collate_fn(batch):
    images, class_ids, boxes = zip(*batch)
    images = torch.stack(images)
    class_ids = torch.stack(class_ids)
    boxes = torch.stack(boxes)
    return images, (class_ids, boxes)


train_dataset = DetectionDataset(split="train")
valid_dataset = DetectionDataset(split="valid")
test_dataset = DetectionDataset(split="test")

In [None]:
# Object detection model

import torch.nn as nn
from simclr import SimCLR
from simclr.modules import get_resnet


class DetectionModel(nn.Module):
    def __init__(self, simclr_model):
        super(DetectionModel, self).__init__()
        self.feature_extractor = simclr_model.encoder
        self.bbox_regressor = nn.Linear(simclr_model.n_features, config["max_boxes"] * 4)
        self.classifier = nn.Linear(simclr_model.n_features, config["max_boxes"] * (num_classes + 1))  # +1 for background

    def forward(self, img):
        features = self.feature_extractor(img)
        class_logits = self.classifier(features).view(-1, config["max_boxes"], num_classes + 1)
        bboxes = self.bbox_regressor(features).view(-1, config["max_boxes"], 4)
        return class_logits, bboxes

In [None]:
# Loss function

def detection_loss(pred_classes, pred_boxes, true_labels, true_boxes):    
    flat_pred_classes = pred_classes.view(-1, pred_classes.size(-1))  # Flatten to [batch_size * max_boxes, num_classes]
    flat_true_labels = true_labels.view(-1)  # Flatten to [batch_size * max_boxes]
    
    valid_mask = true_labels > 0

    valid_indices = valid_mask.view(-1)
    class_loss = nn.CrossEntropyLoss()(flat_pred_classes[valid_indices], flat_true_labels[valid_indices])
    
    valid_boxes = valid_mask.unsqueeze(-1).expand_as(true_boxes) 
    loc_loss = nn.SmoothL1Loss()(pred_boxes[valid_boxes], true_boxes[valid_boxes])

    return class_loss + loc_loss

In [None]:
# Validation function

def validate_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    for images, (class_labels, bbox) in dataloader:
        images = images.to(device)
        class_labels = class_labels.to(device)
        bbox = bbox.to(device)
        with torch.no_grad():
            predicted_classes, predicted_bboxes = model(images)
            loss = detection_loss(predicted_classes, predicted_bboxes, class_labels, bbox )
            total_loss += loss.item()
    return total_loss / len(dataloader)

In [None]:
# Training loop

import random
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = get_resnet(config["resnet"], pretrained=False)
n_features = encoder.fc.in_features

# load pre-trained model from checkpoint
simclr_model = SimCLR(encoder=encoder, projection_dim=config["projection_dim"], n_features=n_features)
simclr_model.load_state_dict(torch.load(model_path, map_location=device))
simclr_model = simclr_model.to(device)

model = DetectionModel(simclr_model)
model = model.to(device)
model.train()

for trial in range(config["trials"]):
    bs = random.randint(config["batch_size"][0], config["batch_size"][1])
    lr = 10 ** random.uniform(math.log10(config["lr"][0]), math.log10(config["lr"][1]))
    wd = 10 ** random.uniform(math.log10(config["weight_decay"][0]), math.log10(config["weight_decay"][1]))
    
    print(f"Trial {trial + 1}/{config['trials']}")
    print(f"Batch size: {bs}, Learning rate: {lr}, Weight decay: {wd}")

    train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True, drop_last=True, num_workers=config["workers"], collate_fn=collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=bs, shuffle=False, drop_last=True, num_workers=config["workers"], collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=True, num_workers=config["workers"], collate_fn=collate_fn)

    # Setup optimizers and loss for the detection task
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    criterion = {
        'bbox': nn.SmoothL1Loss(),
        'class': nn.CrossEntropyLoss()
    }

    for epoch in range(config["epochs"]):
        train_loss = 0
        for images, (class_labels, bbox) in train_dataloader:
            images = images.to(device)
            class_labels = class_labels.to(device)
            bbox = bbox.to(device)

            optimizer.zero_grad()
            predicted_classes, predicted_bboxes = model(images)
            loss = detection_loss(predicted_classes, predicted_bboxes, class_labels, bbox)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        print(f'Epoch {epoch+1}, Train Loss: {train_loss / len(train_dataloader)}')
        
        val_loss = validate_model(model, valid_dataloader, device)

        print(f'Epoch {epoch+1}, Validation Loss: {val_loss}')
    
    # Test model
    test_loss = validate_model(model, test_dataloader, device)
    print(f'Test Loss: {test_loss}')