In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -U git+https://github.com/albu/albumentations > /dev/null && echo    

In [None]:
!pip install --upgrade opencv-contrib-python

In [None]:
import os
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np                                      
from albumentations.pytorch import ToTensorV2

import torch
import albumentations as A
import glob as glob
import torchvision

import warnings
warnings.filterwarnings("ignore")
from collections import defaultdict, deque

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torch.utils.data import DataLoader, sampler, random_split, Dataset
from tqdm.auto import tqdm
from torchvision.utils import draw_bounding_boxes

import copy
import math
from typing import List, Optional

from torch import nn, Tensor

plt.style.use('ggplot')

In [None]:
!pip install pycocotools
from pycocotools.coco import COCO

In [None]:
dataset_path = "/content/drive/MyDrive/Custom"
#load classes
coco = COCO(os.path.join(dataset_path, "Train", "dataset.json"))
categories = coco.cats
n_classes = len(categories.keys())
categories
#output directory path
OUT_DIR ='/content/drive/MyDrive/Model'

In [None]:
classes = [i[1]['name'] for i in categories.items()]
classes

In [None]:
class Averager:
    def __init__(self):
        self.current_total = 0.0
        self.iterations = 0.0
        
    def send(self, value):
        self.current_total += value
        self.iterations += 1
    
    @property
    def value(self):
        if self.iterations == 0:
            return 0
        else:
            return 1.0 * self.current_total / self.iterations
    
    def reset(self):
        self.current_total = 0.0
        self.iterations = 0.0

In [None]:
# define the training tranforms
def get_transforms(train=False):
    if train:
        transform = A.Compose([
            A.Resize(600, 600), # our input size can be 600px
            A.HorizontalFlip(p=0.3),
            A.VerticalFlip(p=0.3),
            A.RandomBrightnessContrast(p=0.1),
            A.ColorJitter(p=0.1),
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='coco'))
    else:
        transform = A.Compose([
            A.Resize(600, 600), # our input size can be 600px
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='coco'))
    return transform

def get_valid_transform():
    return A.Compose([
            A.Resize(600, 600), # our input size can be 600px
            ToTensorV2()
        ], bbox_params=A.BboxParams(format='coco'))

In [None]:
# the dataset class
class CustomDataset(torchvision.datasets.VisionDataset):
    def __init__(self, root, split='', transform=None, target_transform=None, transforms=None):
        # the 3 transform parameters are reuqired for datasets.VisionDataset
        super().__init__(root, transforms, transform, target_transform)
        self.split = split #train, valid, test
        self.coco = COCO(os.path.join(root, split, "dataset.json")) # annotatiosn stored here
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.ids = [id for id in self.ids if (len(self._load_target(id)) > 0)]
    
    def _load_image(self, id: int):
        path = self.coco.loadImgs(id)[0]['file_name']
        image = cv2.imread(os.path.join(self.root, self.split, path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        return image
    def _load_target(self, id):
        return self.coco.loadAnns(self.coco.getAnnIds(id))
    
    def __getitem__(self, index):
        id = self.ids[index]
        image = self._load_image(id)
        target = self._load_target(id)
        target = copy.deepcopy(self._load_target(id))
        
        boxes = [t['bbox'] + [t['category_id']] for t in target] # required annotation format for albumentations
        if self.transforms is not None:
            transformed = self.transforms(image=image, bboxes=boxes)
        
        image = transformed['image']
        boxes = transformed['bboxes']
        
        new_boxes = [] # convert from xywh to xyxy
        for box in boxes:
            xmin = box[0]
            xmax = xmin + box[2]
            ymin = box[1]
            ymax = ymin + box[3]
            new_boxes.append([xmin, ymin, xmax, ymax])
        
        boxes = torch.tensor(new_boxes, dtype=torch.float32)
        
        targ = {} # here is our transformed target
        targ['boxes'] = boxes
        targ['labels'] = torch.tensor([t['category_id'] for t in target], dtype=torch.int64)
        targ['image_id'] = torch.tensor([t['image_id'] for t in target])
        targ['area'] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # we have a different area
        targ['iscrowd'] = torch.tensor([t['iscrowd'] for t in target], dtype=torch.int64)
        return image.div(255), targ # scale images

    def __len__(self):
        return len(self.ids)

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [None]:
train_dataset = CustomDataset(root=dataset_path, split="Train", transforms=get_transforms(True))
valid_dataset = CustomDataset(root=dataset_path, split="Test", transforms=get_valid_transform())

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=True, num_workers=2, collate_fn=collate_fn)

In [None]:
# Lets view a sample
sample = train_dataset[450]
img_int = torch.tensor(sample[0] * 255, dtype=torch.uint8)
plt.figure(figsize=(12,12))
plt.axis('off')
plt.imshow(draw_bounding_boxes(
    img_int, sample[1]['boxes'], [classes[i] for i in sample[1]['labels']], width=4
).permute(1, 2, 0))

len(train_dataset)

In [None]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
        pretrained=True)

in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)

In [None]:
images,targets = next(iter(train_loader))
images = list(image for image in images)
targets = [{k:v for k, v in t.items()} for t in targets]
output = model(images, targets) 

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

In [None]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=0.01, 
                            momentum=0.9, 
                            nesterov=True, 
                            weight_decay=1e-4)

In [None]:
def save_model(epoch, model, optimizer):
    torch.save(model.state_dict(), '/content/drive/MyDrive/Model/last_model.pth')

    
def save_loss_plot(OUT_DIR, train_loss, val_loss):
    figure_1, train_ax = plt.subplots()
    figure_2, valid_ax = plt.subplots()

    train_ax.plot(train_loss, color='tab:blue')

    train_ax.set_xlabel('iterations')
    train_ax.set_ylabel('train loss')
    
    valid_ax.plot(val_loss, color='tab:red')

    valid_ax.set_xlabel('iterations')
    valid_ax.set_ylabel('validation loss')

    figure_1.savefig(f"{OUT_DIR}/train_loss.png")
    figure_2.savefig(f"{OUT_DIR}/valid_loss.png")

    print('SAVING PLOTS COMPLETE...')
    plt.close('all')

In [None]:
def train(train_data_loader, model):
    print('Training')
    global train_itr
    global train_loss_list
    
    prog_bar = tqdm(train_data_loader, total=len(train_data_loader))
    
    for i, data in enumerate(prog_bar):
        optimizer.zero_grad()
        images, targets = data
        
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        train_loss_list.append(loss_value)
        train_loss_hist.send(loss_value)
        losses.backward()
        optimizer.step()
        train_itr += 1
    
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return train_loss_list

In [None]:
def validate(valid_loader, model):
    print('Validating')
    global val_itr
    global val_loss_list
    
    prog_bar = tqdm(valid_loader, total=len(valid_loader))
    
    for i, data in enumerate(prog_bar):
        images, targets = data
        
        images = list(image.to(device) for image in images)
        targets = [{k: torch.tensor(v).to(device) for k, v in t.items()} for t in targets]
        
        with torch.no_grad():
            loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        val_loss_list.append(loss_value)
        val_loss_hist.send(loss_value)
        val_itr += 1
        
        prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
    return val_loss_list

In [None]:
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x,y in loader:
            x= list(x.to(device) for x in x)
            y = [{k: v.to(device) for k, v in t.items()} for t in y]

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)
            accuracy = f'{float(num_correct) / float(num_samples) * 100:.3f}'
        return accuracy  


In [None]:
num_epochs= 100

train_loss_hist = Averager()
val_loss_hist = Averager()

train_itr = 1
val_itr = 1

train_loss_list = []
val_loss_list = []

for epoch in range(num_epochs):

    train_loss_hist.reset()
    val_loss_hist.reset()

    accuracy = check_accuracy(train_loader, model)

    train_loss = train(train_loader, model)
    val_loss = validate(valid_loader, model)

    
    print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")  
    print(f"Epoch #{epoch+1} Accuracy: {accuracy}")       
    print(f"Epoch #{epoch+1} validation loss: {val_loss_hist.value:.3f}")   

    save_model(epoch, model, optimizer)
    save_loss_plot(OUT_DIR, train_loss, val_loss)


Testing using user image

In [None]:
from google.colab import files
files.upload()

In [None]:
model.load_state_dict(torch.load('/content/drive/MyDrive/Model/last_model.pth', map_location=device))
model.eval()

DIR_TEST = '/content'
test_images = glob.glob(f"{DIR_TEST}/*.jpg")
print(f"Test instances: {len(test_images)}")
detection_threshold = 0.8
model = model.to(device)

In [None]:


for i in range(len(test_images)):
    # get the image file name for saving output later on
    image_name = test_images[i].split('/')[-1].split('.')[0]
    image = cv2.imread(test_images[i])
    orig_image = image.copy()
    # BGR to RGB
    image = cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB).astype(np.float32)
    # make the pixel range between 0 and 1
    image /= 255.0
    # bring color channels to front
    image = np.transpose(image, (2, 0, 1)).astype(np.float)
    # convert to tensor
    image = torch.tensor(image, dtype=torch.float).cuda()
    # add batch dimension
    image = torch.unsqueeze(image, 0)
    with torch.no_grad():
        outputs = model(image)
    
    # load all detection to CPU for further operations
    outputs = [{k: v.to('cpu') for k, v in t.items()} for t in outputs]
    # carry further only if there are detected boxes
    if len(outputs[0]['boxes']) != 0:
        boxes = outputs[0]['boxes'].data.numpy()
        scores = outputs[0]['scores'].data.numpy()
        # filter out boxes according to `detection_threshold`
        boxes = boxes[scores >= detection_threshold].astype(np.int32)
        draw_boxes = boxes.copy()
        # get all the predicited class names
        pred_classes = [classes[i] for i in outputs[0]['labels'].cpu().numpy()]
        
        # draw the bounding boxes and write the class name on top of it
        for j, box in enumerate(draw_boxes):
            cv2.rectangle(orig_image,
                        (int(box[0]), int(box[1])),
                        (int(box[2]), int(box[3])),
                        (0, 0, 255), 2)
            cv2.putText(orig_image, pred_classes[j], 
                        (int(box[0]), int(box[1]-5)),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 
                        2, lineType=cv2.LINE_AA)
        # cv2.imshow('Prediction', orig_image)
        # #cv2.waitKey(1)
        plt.figure(figsize=(12,12))
        plt.imshow(orig_image)
        plt.axis('off')
        plt.show()

        cv2.imwrite(f"outputs/{image_name}.jpg", orig_image)
        
cv2.destroyAllWindows()