# Utils

In [None]:
# credit to: https://towardsdatascience.com/building-your-own-object-detector-pytorch-vs-tensorflow-and-how-to-even-get-started-1d314691d4ae

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pycocotools
from PIL import Image, ExifTags
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
from glob import glob
from skimage import transform

import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from engine import train_one_epoch, evaluate
import utils
import transforms as T

In [None]:
class KITTIDataset(torch.utils.data.Dataset):
    
    def __init__(self, img_dir, label_dir, label_map, transforms=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.img_paths = glob(os.path.join(img_dir, "*.jpg"))  
        self.label_map = label_map        
                
    def parse_kitti(self, path):
        # https://github.com/NVIDIA/DIGITS/blob/v4.0.0-rc.3/digits/extensions/data/objectDetection/README.md
        objects = []
        with open(path, 'r') as f:
            for line in [l.strip() for l in f.readlines()]:
                label, _, _, _, xmin, ymin, xmax, ymax, *_ = line.split()
                objects.append({'label': label, 'bounds': [float(x) for x in [xmin, ymin, xmax, ymax]]})            
        return objects            
            
    def __getitem__(self, idx):
        
        # load images and bounding boxes        
        img_path = self.img_paths[idx]
        img_id, _ = os.path.splitext(os.path.basename(img_path))        
        label_path = os.path.join(self.label_dir, img_id + ".txt")
                
        img = Image.open(img_path).convert("RGB")
                        
        objects = self.parse_kitti(label_path)                
        boxes = torch.tensor([o['bounds'] for o in objects], dtype=torch.float32)
        labels = torch.tensor([self.label_map[o['label']] for o in objects], dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["area"] = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:,0])
        target["iscrowd"] = torch.zeros((len(objects),), dtype=torch.int64)
        #target["filename"] = img_path
        
        if self.transforms is not None:
            img, target = self.transforms(img, target)
        
        return img, target
    
    def __len__(self):
        return len(self.img_paths)

In [None]:
def get_transforms(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))        
    return T.Compose(transforms)

In [None]:
def get_model(num_classes):
    # load an object detection model pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # replace the pre-trained head with a new on
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

In [None]:
def display_annotation(image, target, label_map, prediction=None, thresh=0, size=None, title=""):
    
    # invert labelmap
    label_map = {v:k for k,v in label_map.items()}
    
    label_offset_x = 0
    label_offset_y = -2
    fig, ax = plt.subplots(figsize=size)   
    image = image.permute(1, 2, 0).cpu().numpy()
    
    # image resize
    #width, height, channels = image.shape
    #image = transform.resize(image, (2*width, 2*height))
    
    ax.imshow(image) # assumes image is a torch.tensor
    
    # ground truth
    boxes = target['boxes']    
    for i in range(boxes.size()[0]):
        x1, y1, x2, y2 = target['boxes'][i]
        rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')    
        ax.add_patch(rect)
        ax.text(x1 + label_offset_x, y1 + label_offset_y, f"{label_map[target['labels'][i].item()]}", color='r')
        if title:
            ax.set_title(title)
        
    # prediction
    if prediction:
        boxes, scores, labels = prediction['boxes'], prediction['scores'], prediction['labels']
        for i in range(boxes.size()[0]):            
            if scores[i] > thresh:
                x1, y1, x2, y2 = boxes[i]
                rect = patches.Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='g', facecolor='none')    
                ax.add_patch(rect)  
                text = f"{label_map[labels[i].item()]} {scores[i]:.2f}"
                ax.text(x1 + label_offset_x, y2 - 6*label_offset_y, text, color='g')
    plt.show()        

# Workflow

In [None]:
data_dir = "../KITTI_Test/combined/resized" # these are the scaled images used for training
image_dir = os.path.join(data_dir, "images")
label_dir = os.path.join(data_dir, "labels")

# need a default background class
label_map = {
    'background': 0,
    'core': 1,
    'flake': 2,
    'flake_broken': 3,
    'tool': 4
}

# define datasets
ds = KITTIDataset(image_dir, label_dir, label_map, transforms=get_transforms(train=True))
ds_test = KITTIDataset(image_dir, label_dir, label_map, transforms=get_transforms(train=False))

### Data Review

In [None]:
# check example
image, target = ds[88]
display_annotation(image, target, label_map, size=(5, 5))

In [None]:
# check all images
for i, path in enumerate(ds.img_paths):
    image, target = ds[i]    
    display_annotation(image, target, label_map, title=f"[{i}]: {path}")

### Data Partitioning

In [None]:
# split into train and test
n_test = 10
torch.manual_seed(1)
indices = torch.randperm(len(ds)).tolist()
ds = torch.utils.data.Subset(ds, indices[:-n_test])
ds_test = torch.utils.data.Subset(ds_test, indices[-n_test:])

# define data loaders
dl = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=True, num_workers=1, collate_fn=utils.collate_fn)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=1, collate_fn=utils.collate_fn)
print(f"Total {len(indices)} samples, train: {len(ds)}, test: {len(ds_test)}")

# Training

In [None]:
# # check all test images
# for i in range(n_test):
#     image, target = ds_test[i]
#     display_annotation(image, target, label_map)

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print('device:', device)

In [None]:
model = get_model(len(label_map)).to(device)

In [None]:
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler which decreases the learning rate by # 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
epochs = 10
for e in range(epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, dl, device, e, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
#     evaluate(model, dl_test, device=device)

### Model Saving

In [None]:
# save the model
torch.save(model.state_dict(), "model.pt")

# Evaluation

In [None]:
# load model
loaded_model = get_model(len(label_map))
loaded_model.load_state_dict(torch.load("model.pt"))
loaded_model = loaded_model.to(device)

In [None]:
idx = 0 # this is the image number in the test set, 0 to n_test - 1
image, target = ds_test[idx]
reverse_label_map = {v:k for k,v in label_map.items()}

loaded_model.eval()
with torch.no_grad():
    pred = loaded_model([image.to(device)])    
    display_annotation(image, target, label_map, prediction=pred[0], thresh=0.3, size=(10, 10))
    labels, scores = pred[0]['labels'], pred[0]['scores']
    for i in range(labels.size()[-1]):
        print(reverse_label_map[labels[i].item()], "--", scores[i].item())