In [None]:
%config Completer.use_jedi = False

In [None]:
# import sys
# sys.path.insert(0, "/home/vision")

import torchvision
import torch

### dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import os
import cv2
import xml.dom.minidom

In [None]:
class BirdDataset(Dataset):
    def __init__(self, image_dir="./data", annotations_dir="./ann",transform=None):
        self.files_name = os.listdir(image_dir)
        self.image_dir = image_dir
        self.annotaions_dir = annotations_dir
        self.transforms = transform
        
    def __len__(self):
        return len(self.files_name)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        file_name, _ = os.path.splitext(self.files_name[idx])
        img_path = os.path.join(self.image_dir, file_name + ".png")
        xml_path = os.path.join(self.annotaions_dir, file_name + ".xml")
        
        img = cv2.imread(img_path)
        ann = self.read_annotaions(xml_path)
        lbl = [1 for _ in range(len(ann))]
        
        target = {"boxes": ann, "labels": lbl}
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        
        return img, target
        
    def read_annotaions(self, xml_path):
        res = []
        
        dom = xml.dom.minidom.parse(xml_path)
        root = dom.documentElement
        objects=dom.getElementsByTagName("object")
        for obj in objects:
            bndbox = obj.getElementsByTagName('bndbox')[0]
            xmin = bndbox.getElementsByTagName('xmin')[0]
            ymin = bndbox.getElementsByTagName('ymin')[0]
            xmax = bndbox.getElementsByTagName('xmax')[0]
            ymax = bndbox.getElementsByTagName('ymax')[0]
            xmin_data=xmin.childNodes[0].data
            ymin_data=ymin.childNodes[0].data
            xmax_data=xmax.childNodes[0].data
            ymax_data=ymax.childNodes[0].data
            res.append([int(xmin_data),
                        int(ymin_data),
                        int(xmax_data),
                        int(ymax_data)])
            
        return res
    
    def collate_fn(self, batch):
        imgs = [item[0] for item in batch]
        trgts = [item[1] for item in batch]
        
        return [imgs, trgts]

In [None]:
from torchvision.transforms import functional as F
import numpy as np

class ToTensor(object):
    def __call__(self, image, target):
        image = F.to_tensor(image)
        for key, value in target.items():
            target[key] = torch.as_tensor(np.array(value), dtype=torch.int64)
        return image, target


class Normalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target
    
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target
    
transform = Compose([ToTensor(),
                     Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ])

In [None]:
ds_notf = BirdDataset()

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
item = ds_notf[1]
for box in item[1]["boxes"]:
    cv2.rectangle(item[0], (box[0], box[1]), (box[2], box[3]), (0, 0, 255), 2)
plt.imshow(item[0])
plt.show()

## train

In [None]:
ds = BirdDataset(transform=transform)

In [None]:
dl = DataLoader(dataset=ds, collate_fn=ds.collate_fn, batch_size=1, shuffle=True)

In [None]:
model = torchvision.models.detection.retinanet_resnet50_fpn(num_classes=2,
                                                            pretrained=False,
                                                            pretrained_backbone=False)

In [None]:
model.train()

In [None]:
import torch.optim as optim

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
epochs = 10

epoch_loss = []
for epoch in range(1, epochs):
    for i_batch, batch in enumerate(dl):
        optimizer.zero_grad()
        
        losses = model(*batch)

        loss = losses["classification"] + losses["bbox_regression"]
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

        optimizer.step()

        epoch_loss.append(float(loss))
        
        print('Epoch: {} | batch: {} | Classification loss: {:1.5f} | Regression loss: {:1.5f} | Running loss: {:1.5f}'.format(
            epoch,
            i_batch+1,
            float(losses["classification"].detach().numpy()),
            float(losses["bbox_regression"].detach().numpy()),
            float(loss.detach().numpy())))

In [None]:
torch.save(model.state_dict(), "./chpt.pth")

### Eval

In [None]:
model.load_state_dict(torch.load("./chpt.pth"))

In [None]:
model.eval()

predicted = model([ds[2][0]])

In [None]:
predicted

In [None]:
keep = torchvision.ops.nms(predicted[0]["boxes"], predicted[0]["scores"], 0.5)

In [None]:
item = ds_notf[2]
img = item[0]
oboxes = item[1]["boxes"]

keep = keep.numpy()
boxes = list(np.floor(predicted[0]["boxes"].detach().numpy()[keep]))
scores = list(predicted[0]["scores"].detach().numpy()[keep])

print(len(oboxes))
for box, score in zip(boxes, scores):
    if score > 0.2:
        cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2)
        
plt.imshow(img,cmap='gray')
plt.show()