In [None]:
import os
import torch
import torchvision
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import cv2
from matplotlib import pyplot as plt
import torchvision.transforms as transforms
from datetime import datetime
import numpy as np
import os
import random
import argparse
from PIL import Image
import csv

In [None]:
# dataset definition
class myDataset(Dataset):
    # load the dataset
    def __init__(self,root,transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root, "imgs"))))


    # get a row at an index
    def __getitem__(self, idx):
        idx = idx -1
        img_path = os.path.join(self.root, "imgs", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        
        self.data = open(os.path.join(self.root, "annotations.csv"))
        data = csv.reader(self.data)
        
        boxes = []
        row = data.__next__()
        
        for x in range(idx+1):
            row = data.__next__()
        
        x1 = int(row[1])
        y1 = int(row[2])
        x2 = int(row[3])
        y2 = int(row[4])
        label = int(row[5])
        boxes.append([x1, y1, x2, y2])
        
        self.data.close()
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.tensor([label], dtype=torch.int64)
        
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target
    
    def __len__(self):
        return len(self.imgs)

In [None]:
import detection.transforms as T

def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.Normalize())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
num_classes = 9

In [None]:
dataset = myDataset('db_lisa_tiny', get_transform(train=True))
dataset_test = myDataset('db_lisa_tiny', get_transform(train=False))

In [None]:
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

In [None]:
import detection.utils as utils
data_loader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True, num_workers=1,collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=5, shuffle=False, num_workers=1,collate_fn=utils.collate_fn)

In [None]:
model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn()

In [None]:
model.to(device)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,momentum=0.9, weight_decay=0.0005)

In [None]:
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=3,gamma=0.1)

In [None]:
num_epochs = 8

In [None]:
from detection.engine import train_one_epoch, evaluate

for epoch in range(num_epochs):
        # train for one epoch, printing every 10 iterations
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
        # update the learning rate
        lr_scheduler.step()
        # evaluate on the test dataset
        evaluate(model, data_loader_test, device=device)

In [None]:
torch.save(model,"model_real_fasterrcnn_mobilenet_v3_large_fpn")

In [None]:
trainedModel = torch.load("model_real_fasterrcnn_mobilenet_v3_large_fpn")

In [None]:
trainedModel.eval()

In [None]:
img, targets = next(iter(data_loader_test))

In [None]:
trainedModel(img)

In [None]:
imgs = img[3]
imgs = np.transpose(imgs, (1,2,0)) #tt = np.transpose(tt,(1,2,0))
plt.figure()
plt.imshow(imgs)

plt.show()