In [None]:
import torch
from torch import nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import AdamW
from torchvision.models.detection import fasterrcnn_resnet50_fpn_v2, fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.ops import nms
from torchvision.io import read_image
import torchvision.transforms.v2 as T
from transformers import get_linear_schedule_with_warmup
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import math
import utils1
from utils1 import * 
from evaluation import evaluate_with_results
from datasets import *
from torchvision.ops import nms, box_iou
import torchvision.transforms as transforms
from visualisation import *
import os

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#Settings
images_path_GTSDB = "GTSDBDataset"
annotaions_path_GTSDB = "GTSDBDataset/gt.txt"
localization_model_path = "Models\modelGTSDB_WithoutClasses19.pth"
new_dataset_path = "FineTuneDataset"
batch_size = 4
nms_threshold = 0.7

In [None]:
localization_model = torch.load(localization_model_path)
localization_model.to(device)
localization_model.eval()
pass

In [None]:
dataset_GTSDB = GTSDBDataset(True, images_path_GTSDB, annotaions_path_GTSDB)
dataloader_GTSDB = DataLoader(dataset_GTSDB, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=collate_fn)

train_indices = list(range(600))
train_dataset = Subset(dataset_GTSDB, train_indices)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=False, collate_fn=collate_fn)

In [None]:
def find_matching_box(boxes, target_box):
    target_box = target_box.unsqueeze(0)
    ious = box_iou(boxes, target_box)
    if ious.size(0) == 0:
        return 0, 0
    iou, idx = ious.max(0)
    return iou, idx

In [None]:
classes_images = {}
for i in range(44):
    classes_images[i] = []

for data in tqdm(train_dataloader):
    for i in range(batch_size):
        image = data[0][i].to(device)
        image_to_show = cv2.imread(data[2][i])

        outputs = localization_model([image])
        boxes = outputs[0]["boxes"].detach().cpu()
        scores = outputs[0]["scores"].detach().cpu()
        labels = outputs[0]["labels"].detach().cpu()

        to_keep = nms(boxes, scores, nms_threshold)
        boxes = boxes[to_keep]
        scores = scores[to_keep]

        images = extract_signs(image_to_show, boxes)
        actual = {}
        for i2, box in enumerate(data[1][i]["boxes"]):
            actual[box] = data[1][i]["labels"][i2].item()

        for i2 in range(len(images)):
            box = boxes[i2]
            image = images[i2]
            iou, idx = find_matching_box(data[1][i]["boxes"], box)
            if iou < 0.7:
                #wrong box detected
                label = 0 #0 = background / none
            else:
                label = data[1][i]["labels"][idx].item()
            classes_images[label].append(image)
       


In [None]:
for i in range(44):
    print("Class " + class_map[i] + " has " + str(len(classes_images[i])) + " images")
    for i2 in range(len(classes_images[i])):
        if not os.path.exists(new_dataset_path + "/" + str(i)):
            os.makedirs(new_dataset_path +"/" + str(i))
        cv2.imwrite(new_dataset_path + "/" + str(i) + "/" + str(i2) + ".jpg", classes_images[i][i2])
        # cv2.imshow("Class " + str(i), classes_images[i][i2])
        # cv2.waitKey(0)