In [None]:
import cv2
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import AdamW
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms
from torchvision.io import read_image
import torchvision.transforms.v2 as T
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import math
import utils1
from utils1 import Sign 
from evaluation import evaluate

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
images_path_GTSDB = "GTSDBDataset"
annotaions_path_GTSDB = "GTSDBDataset/gt.txt"

In [None]:
class GTSDBDataset(Dataset):
    def __init__(self, images_path: str, annotaions_path: str):
        num_examples = 900
        self.examples = {}
        lines = []
        with open(annotaions_path) as f:
            for line in f:
                lines.append(line)
        for sample_num in range(num_examples):
            begins_with = str(sample_num).zfill(5)
            annotations = [line for line in lines if line.startswith(begins_with)]
            if len(annotations) == 0:
                path = images_path + "/" + str(sample_num).zfill(5) + ".ppm"
                self.examples[path] = {"image_path": path, "signs": [Sign(0, 0, 0, 0, 0)]}
            else:
                for line in annotations:
                    path, sign = self.create_sign(line, images_path)
                    if path in self.examples:
                        self.examples[path]["signs"].append(sign)
                    else:
                        self.examples[path] = {"image_path": path, "signs": [sign]}

        self.examples = list(self.examples.values())

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i: int):
        data = self.examples[i]
        image = cv2.imread(data["image_path"])
        image = T.ToTensor()(image)
        image = T.ToDtype(torch.float32, scale=True)(image)
        image = T.ToPureTensor()(image)
        boxes = [[sign.topLeftX, sign.topLeftY, sign.bottomRightX, sign.bottomRightY] for sign in data["signs"]]
        boxes = torch.tensor(boxes, dtype=torch.float32)

        #CHANGE IF YOU WANT CLASSES PREDICTED OR JUST LOCATIONS
        labels = [sign.name for sign in data["signs"]]
        #labels = [1 for sign in data["signs"]]

        labels = torch.tensor(labels, dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels}

        #If there is no sign in the image
        if data["signs"][0].name == 0: 
            target = {"boxes": torch.empty(0, 4), "labels": torch.tensor([0], dtype=torch.int64)}
            
        return image, target, data["image_path"]
    
    def create_sign(self, line: str, images_path: str):
        split = line.split(";")
        image_path = images_path + "/" + split[0]
        sign = Sign(float(split[3]), float(split[4]), float(split[1]), float(split[2]), int(split[5]) + 1)
        return image_path, sign

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

batch_size = 4

dataset_GTSDB = GTSDBDataset(images_path_GTSDB, annotaions_path_GTSDB)
dataloader_GTSDB = DataLoader(dataset_GTSDB, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=collate_fn)

train_indices = list(range(600))
test_indices = list(range(600, 900))

train_dataset = Subset(dataset_GTSDB, train_indices)
test_dataset = Subset(dataset_GTSDB, test_indices)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=collate_fn)

In [None]:
model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
#CHANGE IF YOU WANT CLASSES PREDICTED OR JUST LOCATIONS
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 44)
model.to(device)
pass

In [None]:
epochs = 20

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr = 0.01,
    momentum = 0.9,
    weight_decay = 0.0005
)

scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size = 5,
    gamma = 0.5
)

In [None]:
stats = []
lrs = []
dl = train_dataloader
model.train()
for i, epoch in enumerate(range(epochs)):
    lr_scheduler = None
    bar = tqdm(dl,total=len(dl))
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = len(dl) / 2

        lr_scheduler = torch.optim.lr_scheduler.LinearLR(
            optimizer, start_factor=warmup_factor, total_iters=warmup_iters
        )

    for step, batch in enumerate(bar):
        images = [image.to(device) for image in batch[0]]
        targets = [{k: v.to(device) for k, v in t.items()} for t in batch[1]]
        with torch.cuda.amp.autocast(enabled=False):
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        if step % 10 == 0:
            loss = losses.item()
            lr = optimizer.param_groups[0]["lr"]
            stats.append(loss)
            lrs.append(lr)
            bar.set_description(f"Epoch {i}, Loss: {loss:.4f}, Lr: {lr:.4f}")
    scheduler.step()
    torch.save(model, f"Models/modelGTSDB_WithClasses{i}.pth")

In [None]:
plt.plot(stats)
plt.show()
plt.plot(lrs)
plt.show()

In [None]:
model = torch.load("Models\modelGTSDB_WithClasses9.pth")
pass

In [None]:
def extract_signs(image, boxes):
    images = []
    for box in boxes:
        x1, y1, x2, y2 = math.floor(box[0].item()), math.floor(box[1].item()), math.ceil(box[2].item()), math.ceil(box[3].item())
        cropped_image = image[y1:y2, x1:x2]
        images.append(cropped_image)
    return images

In [None]:
model.eval()
for data in test_dataloader:
    for i in range(4):
        image = data[0][i].to(device)
        image_to_show = cv2.imread(data[2][i])
        outputs = model([image])
        boxes = outputs[0]["boxes"].detach().cpu()
        scores = outputs[0]["scores"].detach().cpu()
        labels = outputs[0]["labels"].detach().cpu()

        # Apply non-maximum suppression
        to_keep = nms(boxes, scores, 0.2)
        boxes = boxes[to_keep]
        scores = scores[to_keep]

        for box, score, label1 in zip(boxes, scores, labels):
            if score >= 0.4:  
                # Draw the bounding box on the image
                cv2.rectangle(image_to_show, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 2)
                
                #big red dot at x=50, y=50
                cv2.circle(image_to_show, (50, 50), 10, (0, 0, 255), 8)

                label = f'{score:.2f} {utils1.class_map[label1.item()]}'
                
                # Choose a font and get the text size
                font = cv2.FONT_HERSHEY_SIMPLEX
                font_scale = 0.5
                font_thickness = 1
                text_size = cv2.getTextSize(label, font, font_scale, font_thickness)[0]
                
                # Position the text at the top-left corner of the bounding box
                text_x = int(box[0]) - 1
                text_y = int(box[1])
                
                # Draw the text background rectangle
                cv2.rectangle(image_to_show, (text_x, text_y - text_size[1] - 2), (text_x + text_size[0], text_y), (0, 255, 0), cv2.FILLED)
                
                # Put the text on the image
                cv2.putText(image_to_show, label, (text_x, text_y - 2), font, font_scale, (0, 0, 0), font_thickness, cv2.LINE_AA)
        
        cv2.imshow("Image", image_to_show)
        cv2.waitKey(0)

        extract_signs(image_to_show, boxes)
        
cv2.destroyAllWindows()

In [None]:
evaluate(model, test_dataloader)

In [None]:
#get parameter count of model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count_parameters(model)