# SSD

In this notebook we are going to train an SSD model.

In [None]:
import pandas as pd
import numpy as np
import random 
import matplotlib

import warnings
warnings.filterwarnings('ignore')

from utils import *

%matplotlib inline

## Load data

First, we load the data we generated before.

In [None]:
PATH = './dataset/patches_256_100'

df_t = pd.read_csv('{}/annotations_train.csv'.format(PATH))
df_v = pd.read_csv('{}/annotations_eval.csv'.format(PATH))

# convert string of bbs into list of bbs
df_t.annotations = anns_str2int(df_t.annotations.values)
df_v.annotations = anns_str2int(df_v.annotations.values)

In [None]:
# add path to image name for simplicity
df_t.img_name = ['{}/{}'.format(PATH, img) for img in df_t.img_name.values]
df_v.img_name = ['{}/{}'.format(PATH, img) for img in df_v.img_name.values]

#df_t.sample(5)

#df_t = df_t[:32]
#df_v = df_t

In [None]:
print("Training patches: ", len(df_t))
print("Validation patches: ", len(df_v))

## Dataset

Now we define our Dataset, which will define how images and labels are passed to the network.

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset 

class MyDataset(Dataset):
    def __init__(self, images, annotations, transforms=None):
        self.images = images
        self.annotations = annotations
        self.transforms = transforms
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, ix):
        # open image
        img = open_image(self.images[ix])
        anns = self.annotations[ix]   
        # split in boxes and labels
        bboxes = [ann[0] for ann in anns]
        labels = [ann[1] for ann in anns]
        # apply transforms
        if self.transforms:
            augmented = self.transforms(**{'image': img, 'bboxes': bboxes, 'labels': labels})
            img, bboxes, labels = augmented['image'], augmented['bboxes'], augmented['labels']  
        # normalize bboxes
        bboxes = [bb_norm_xyxy(bb, img.shape[:2]) for bb in bboxes]
        # join annotations
        anns = [(bb, label) for bb, label in zip(bboxes, labels)]
        # return tensor image and label
        return torch.from_numpy(img.transpose((2,0,1)).astype(np.float32)/255), anns
    
    def collate_fn(self, batch):
        images, bbs, labels = [], [], [] # list for each image a tensor of shape O x 4, O x 1
        for imgs, anns in batch:
            images.append(imgs)
            bbs.append(torch.FloatTensor([ann[0] for ann in anns]))
            labels.append(torch.FloatTensor([ann[1] for ann in anns]))
        return torch.stack(images), (bbs, labels)

We can use data augmentation.

In [None]:
from albumentations import (
    Compose, Resize, HorizontalFlip, VerticalFlip, Transpose, RandomRotate90, HueSaturationValue, RandomBrightness, GaussNoise
)

def get_aug(aug, min_area=0., min_visibility=0.):
    return Compose(aug, bbox_params={'format': 'coco', 'min_area': min_area, 'min_visibility': min_visibility, 'label_fields': ['labels']})

trans = get_aug([
    HorizontalFlip(),
    VerticalFlip(),
    Transpose(),
    RandomRotate90(),
    HueSaturationValue(),
    RandomBrightness(),
    GaussNoise()
])

#trans=None

dataset = {
    'train': MyDataset(df_t.img_name.values, df_t.annotations.values, trans),
    'val': MyDataset(df_v.img_name.values, df_v.annotations.values)
}

In [None]:
# visualize random images
ds = dataset['train']
fig, axs = plt.subplots(3, 4, figsize=(15,10))
for i, _ax in enumerate(axs):
    for ix, ax in enumerate(_ax):
        ix = 28#random.randint(0, len(ds)-1)
        img, anns = ds[ix]
        # bring back image from tensor
        img = img.numpy().transpose((1, 2, 0))  
        ax = show_image(img, ax=ax)
        # unnorm box
        for bb, label in anns:
            box = bb_unnorm_xywh(bb, img.shape[:2])
            draw_rect(ax, box, 'green')

## Model

Here we define our model. We will use a pretrained Resnet34 as a backbone network and define only the last layers to adapt to our problem.



In [None]:
# ssd model

import torchvision

def flatten_conv(x, k):
    return x.view(x.size(0), x.size(1)//k, -1).transpose(1,2)

class out_conv(nn.Module):
    def __init__(self, c_in, k, n_classes):
        super().__init__()
        self.k = k
        self.oconv1 = nn.Conv2d(c_in, k*4, 3, padding=1)
        self.oconv2 = nn.Conv2d(c_in, k*n_classes, 3, padding=1)
    def forward(self, x):
        return [
            flatten_conv(self.oconv1(x), self.k),
            flatten_conv(self.oconv2(x), self.k)
        ]

def conv(c_i, c_o, stride=2, padding=1):
    return nn.Sequential(
        nn.Conv2d(c_i, c_o, 3, stride=stride, padding=padding), 
        nn.ReLU(),
        nn.BatchNorm2d(c_o)
    )

class Net(nn.Module):

    def __init__(self, n_classes=2):
        super(Net, self).__init__()
        
        # get pre-trained resnet34
        self.model = torchvision.models.resnet34(pretrained=True)
                
        self.k = [1, 1, 1]
        self.conv1 = conv(512, 256)
        self.out1 = out_conv(256, self.k[0], n_classes)
        self.out2 = out_conv(512, self.k[1], n_classes)
        self.out3 = out_conv(256, self.k[2], n_classes)

        self.anchors, self.grid_size = self.get_anchors()
        
    def forward(self, x):
        # resnet backbone
        x = self.model.maxpool(self.model.relu(self.model.bn1(self.model.conv1(x))))
        x = self.model.layer1(x)
        x = self.model.layer2(x) 
        x = self.model.layer3(x) 
        x1 = self.model.layer4(x) 
        # ssd head
        x2 = self.conv1(x1) 
        o1l, o1c = self.out1(x)
        o2l, o2c = self.out2(x1)
        o3l, o3c = self.out3(x2) 
        
        return [
            torch.cat([o1l,o2l,o3l],dim=1), 
            torch.cat([o1c,o2c,o3c],dim=1)
        ]

    def get_anchors(self):
        scales = [16, 8, 4]               
        centers = [(0.5, 0.5)] 
        size_scales = [2**0]#, 2**(1/3), 2**(2/3)]
        aspect_ratios = [(1., 1.)]#, (2., 1.), (1., 2.)]
        sizes = [(s*a[0], s*a[1]) for s in size_scales for a in aspect_ratios]
        _, anchors, grid_size = generate_anchors(scales, centers, sizes)
        return anchors, grid_size

In [None]:
# test net

cats = ['airplane', 'bg']

net = Net(n_classes = len(cats))

test_input = torch.randn((16, 3, 256, 256))
pred_bbs, pred_labs = net(test_input)

print(pred_bbs.shape) # should output BATCH_SIZE x NUM_ANCHORS x 4
print(pred_labs.shape) # should output BATCH_SIZE x NUM_ANCHORS x NUM_CLASSES

We can visualize our anchors and which will match the ground truth during training

In [None]:
# anchors
anchors = net.anchors
fig, axs = plt.subplots(3, 5, figsize=(15,10))
for i, _ax in enumerate(axs):
    for ix, ax in enumerate(_ax):
        ix = random.randint(0, len(ds)-1)
        img, anns = ds[ix]
        bboxes = torch.tensor([ann[0] for ann in anns]).float()
        _labels = [ann[1] for ann in anns]
            
        img = img.numpy().transpose((1,2,0))
        ax = show_image(img, ax=ax)
        
        if bboxes.shape[0] > 0:
            
            # compute IoU between gt and anchors
            overlaps = torchvision.ops.box_iou(bboxes, anchors)

            # keep best match and all above a threshold
            gt_overlap, gt_idx = map_to_ground_truth(overlaps)
            threshold = 0.5
        
            labels = []
            for idx, iou in zip(gt_idx, gt_overlap):
                if iou > threshold: labels.append(_labels) 
                else: labels.append(len(cats)-1)

            # draw boxes
            for bb in bboxes:
                bb = bb_unnorm_xywh(bb, img.shape[:2])
                draw_rect(ax, bb, 'green')
            
            # draw matching anchors
            for j, a in enumerate(anchors):
                if labels[j] is not len(cats)-1:
                    bb = bb_unnorm_xywh(a, img.shape[:2])
                    draw_rect(ax, bb, edgecolor="red")
plt.show()

## Training

In order to train the network we need to define a Dataloader from our dataset in order to feed the network with batches of images.

In [None]:
from torch.utils.data import DataLoader

dataloader = {
    'train': DataLoader(dataset['train'], batch_size=32,  shuffle=True, num_workers=4, collate_fn=dataset['train'].collate_fn),
    'val': DataLoader(dataset['val'], batch_size=32,  shuffle=False, num_workers=4, collate_fn=dataset['val'].collate_fn)
}

In [None]:
imgs, anns = next(iter(dataloader['train']))
bbs, labs = anns
print(imgs.shape, len(bbs), len(labs))
bbs

Now we define the optimizer and loss function to train the network.

In [None]:
# check if we can use GPU

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device) # should output cuda:0

In [None]:
import torch.nn.functional as F

def one_hot_embedding(labels, num_classes):
    return torch.eye(num_classes)[labels.data.cpu()]

class BCE_Loss(nn.Module):
    def __init__(self, num_classes, device):
        super().__init__()
        self.num_classes = num_classes
        self.device = device
    def forward(self, pred, targ):
        t = one_hot_embedding(targ, self.num_classes)
        t = torch.tensor(t[:,:-1].contiguous()).to(self.device)
        x = pred[:,:-1]
        w = self.get_weight(x,t)
        return F.binary_cross_entropy_with_logits(x, t, w, size_average=False)/self.num_classes
    def get_weight(self,x,t): return None
    
class FocalLoss(BCE_Loss):
    def get_weight(self,x,t):
        alpha,gamma = 0.25,2.
        p = x.sigmoid()
        pt = p*t + (1-p)*(1-t)
        w = alpha*t + (1-alpha)*(1-t)
        return torch.tensor(w * (1-pt).pow(gamma))
    
class SSD_Loss(nn.Module):

    def __init__(self, num_classes, anchors, grid_size, threshold=0.5):
        super().__init__()
        self.loss_f = FocalLoss(num_classes, device)
        self.anchors = anchors.to(device)
        self.grid_size = grid_size.to(device)
        self.num_classes = num_classes
        self.threshold = threshold

    def forward(self, preds, target):
        # predicted bbs
        pred_bbs, pred_cs = preds 
        # ground truth   
        tar_bbs, c_t = target # B x O x 4, B x O
        # for each image in batch
        loc_loss, clas_loss = torch.tensor(0., requires_grad=False).to(device), torch.tensor(0., requires_grad=False).to(device)
        for pred_bb, pred_c, tar_bb, tar_c in zip(pred_bbs, pred_cs, tar_bbs, c_t):
            labels = torch.ones(len(self.anchors))*(self.num_classes-1)
            if tar_bb.shape[0] is not 0: # some images may have no detections
                tar_bb = tar_bb.to(device)
                overlaps = torchvision.ops.box_iou(tar_bb, self.anchors)
                gt_overlap, gt_idx = map_to_ground_truth(overlaps)
                # ids of anchors to match
                pos = gt_overlap > self.threshold
                pos_idx = torch.nonzero(pos)[:,0]
                # ids of targets to match
                tar_idx = gt_idx[pos_idx]
                pred_bb = actn_to_bb(pred_bb, self.anchors, self.grid_size)
                _anchors = pred_bb[pos_idx]
                tar_bb = tar_bb[tar_idx]
                loc_loss += (_anchors - tar_bb).abs().mean()
                labels[pos_idx] = tar_c[tar_idx]
            labels = labels.long().to(device)            
            clas_loss += self.loss_f(pred_c, labels)
        return clas_loss + loc_loss, loc_loss, clas_loss

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4)
criterion = SSD_Loss(len(cats), net.anchors, net.grid_size, threshold=0.5)

In [None]:
def metric(preds, target, threshold=0.5):
    
    # activate predictions
    pred_bbs, pred_cs = preds
    pred_bbs = torch.stack([actn_to_bb(bb, net.anchors.to(device), net.grid_size.to(device)) for bb in pred_bbs])
    pred_cs = torch.argmax(pred_cs, dim=2)
    
    # ground truth
    tar_bbs, tar_cs = target
    
    # for each image in batch
    f1 = []
    for pred_bb, pred_c, tar_bb, tar_c in zip(pred_bbs, pred_cs, tar_bbs, tar_cs):
        
        # remove bg 
        ixs = (pred_c != len(cats) - 1).nonzero().view(-1)
        pred_bb, pred_c = pred_bb[ixs], pred_c[ixs]
        ixs = (tar_c != len(cats) - 1).nonzero().view(-1)
        tar_bb, tar_c = tar_bb[ixs], tar_c[ixs]
    
        # compute F1
        f1.append(F1(pred_bb, pred_c, tar_bb, tar_c))
        
    return np.array(f1).mean()

Finally, we can proceed with the training.

In [None]:
def train(model, dataloader, criterion, optimizer):
    print('Training ...')
    model.train()
    losses, loc_losses, cls_losses = [], [], []
    for imgs, anns in tqdm(dataloader, ascii=True):        
        imgs = imgs.to(device)
        outputs = model(imgs)       
        loss, loc_loss, cls_loss = criterion(outputs, anns)
        losses.append(loss.item())
        loc_losses.append(loc_loss.item())
        cls_losses.append(cls_loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    return np.mean(losses), np.mean(loc_losses), np.mean(cls_losses)

def test(model, dataloader, criterion, metric):
    print('Evaluating ...')    
    model.eval()
    losses, loc_losses, cls_losses, acc = [], [], [], []
    with torch.no_grad():
        for imgs, anns in tqdm(dataloader, ascii=True):
            imgs = imgs.to(device)       
            outputs = model(imgs)        
            loss, loc_loss, cls_loss = criterion(outputs, anns)
            losses.append(loss.item())
            loc_losses.append(loc_loss.item())
            cls_losses.append(cls_loss.item())
            acc.append(metric(outputs, anns))
    return np.mean(losses), np.mean(loc_losses), np.mean(cls_losses), np.mean(acc)

In [None]:
# training
net.to(device)
EPOCHS = 200
train_loss, train_loc_loss, train_cls_loss = [], [], []
val_loss, val_loc_loss, val_cls_loss, acc, best_acc = [], [], [], [], 0
for epoch in range(EPOCHS):
    
    print('Epoch: {}/{}'.format(epoch+1, EPOCHS))
    
    t_loss, t_loc_loss, t_cls_loss = train(net, dataloader['train'], criterion, optimizer)
    train_loss.append(t_loss)
    
    v_loss, v_loc_loss, v_cls_loss, v_acc = test(net, dataloader['val'], criterion, metric)        
    val_loss.append(v_loss)
    acc.append(v_acc)
    
    print('Train Loss: {:.5f} {:.5f} {:.5f}. Val Loss: {:.5f} {:.5f} {:.5f}. Val acc: {:.5f}'.format(
        t_loss, t_loc_loss, t_cls_loss, v_loss, v_loc_loss, v_cls_loss, v_acc)
    )
    
    # keep best model
    if v_acc > best_acc:
        best_acc = v_acc
        torch.save(net.state_dict(), './state_dict.pth')
        print('Best acc {}, model saved'.format(best_acc))
        
print('Best acc {}'.format(best_acc))

Visualize the training profile.

In [None]:
matplotlib.rcParams.update({'font.size': 16})
f, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 8))
plt.subplots_adjust(hspace=0.3)
ax1.plot(train_loss, linewidth=3, label='train')
ax1.plot(val_loss, ':', linewidth=3,  label='val')
ax1.set_title("Loss")
ax1.legend(loc='upper right')
ax1.grid()
ax2.plot(acc, linewidth=3, label="max: {:.4f}".format(np.array(acc).max()))
ax2.set_title("Accuracy")
ax2.grid()
ax2.legend(loc='bottom right',handlelength=0, handletextpad=0, fancybox=True)
ax2.set_xlabel("epoch")
plt.show()

## Test

Load the best model and make some predictions.

In [None]:
net.load_state_dict(torch.load('state_dict.pth'))
net.eval();

In [None]:
# visualize random patches

ds = dataset['val']
fig, axs = plt.subplots(3, 4, figsize=(15,10))
for i, _ax in enumerate(axs):
    for ix, ax in enumerate(_ax):    
        ix = random.randint(0, len(ds)-1)
        img, anns = ds[ix]      
        # plot image and bbs
        _img = img.numpy().transpose((1,2,0))
        ax = show_image(_img, ax=ax)
        for bb, label in anns:
            bb = bb_unnorm_xywh(bb, _img.shape[:2])
            draw_rect(ax, bb, 'white')
        # activate preds
        preds = net(img.unsqueeze(0).to(device))
        pred_bbs, pred_cs = preds
        pred_c = torch.softmax(pred_cs[0], dim=1) 
        val, pred_c = torch.max(pred_c, 1)            
        pred_bb = actn_to_bb(pred_bbs[0], net.anchors.to(device), net.grid_size.to(device))
                    
        bbs = torch.FloatTensor([ann[0] for ann in anns])
        labels = torch.tensor([ann[1] for ann in anns])
        ax.set_title('F1: {:.3f}'.format(F1(pred_bb, pred_c, bbs.to(device), labels)))

        # plot preds
        keep_ids = (pred_c < len(cats) - 1).nonzero()
        if keep_ids.shape[0] > 0:
            # remove bg
            keep_ids = keep_ids[:,0]
            pred_c = pred_c[keep_ids]
            pred_bb = pred_bb[keep_ids]
            # nms
            nms_ixs = torchvision.ops.nms(pred_bb, pred_c.float(), iou_threshold=0.5)
            pred_bb, pred_c = pred_bb[nms_ixs], pred_c[nms_ixs]
            # compute F1
            for bb in pred_bb:
                bb = bb_unnorm_xywh(bb, img.shape[1:])
                draw_rect(ax, bb, edgecolor="red")

In [None]:
# detect on full image

mosaics_v = pd.read_csv("{}/mosaics_eval.csv".format(PATH))
ix = 2#random.randint(0, len(mosaics_v)-1)
window = 256
size = 256
stride = 100
ratio = window / size

In [None]:
# plot image

mosaics_v.annotations = anns_str2int(mosaics_v.annotations.values)
mosaic = mosaics_v.loc[ix]
img_ori = open_image("./dataset/{}".format(mosaic.img_name))
anns_ori = mosaic.annotations

mosaic = mosaic.mosaic
mosaic = mosaic.split(',')[:-1]
mosaic = [m.split(' ')[:-1] for m in mosaic]
shape = (len(mosaic), len(mosaic[0]))

_bbs, _labs = [], []
with torch.no_grad():    
    #fig, axs = plt.subplots(2,2, figsize=(20,20))
    fig, axs = plt.subplots(shape[0], shape[1], figsize=(20,20))
    for i, _ax in enumerate(axs):
        for j, ax in enumerate(_ax):
            
            ix = int(mosaic[i][j])
            img, anns = dataset['val'][ix] # make sure is in same order
            
             # plot image and bbs
            _img = img.numpy().transpose((1,2,0))
            ax = show_image(_img, ax=ax)
            for bb, label in anns:
                bb = bb_unnorm_xywh(bb, _img.shape[:2])
                draw_rect(ax, bb, 'white')
            # activate preds
            preds = net(img.unsqueeze(0).to(device))
            pred_bbs, pred_cs = preds
            pred_c = torch.softmax(pred_cs[0], dim=1) 
            val, pred_c = torch.max(pred_c, 1)            
            pred_bb = actn_to_bb(pred_bbs[0], net.anchors.to(device), net.grid_size.to(device))

            bbs = torch.FloatTensor([ann[0] for ann in anns])
            labels = torch.tensor([ann[1] for ann in anns])
            #ax.set_title('F1: {:.3f}'.format(F1(pred_bb, pred_c, bbs.to(device), labels)))

            # plot preds
            keep_ids = (pred_c < len(cats) - 1).nonzero()
            if keep_ids.shape[0] > 0:
                # remove bg
                keep_ids = keep_ids[:,0]
                pred_c = pred_c[keep_ids]
                pred_bb = pred_bb[keep_ids]
                # nms
                nms_ixs = torchvision.ops.nms(pred_bb, pred_c.float(), iou_threshold=0.5)
                pred_bb, pred_c = pred_bb[nms_ixs], pred_c[nms_ixs]
                # compute F1
                for bb in pred_bb:
                    bb = bb_unnorm_xywh(bb, img.shape[1:])
                    draw_rect(ax, bb, edgecolor="red")
                # keep preds
                _bbs += [bb_unnorm_window_xyxy(ratio*bb.cpu(), img.shape[1:], img_ori.shape, j, i, shape[1], shape[0], window, stride) for bb in pred_bb]
                _labs += pred_c.cpu().numpy().tolist()

In [None]:
# 2nd nms
fig, ax = plt.subplots(figsize=(20,20))
ax = show_image(img_ori, ax = ax)

bbs = torch.tensor(_bbs).to(device)
labs = torch.tensor(_labs).float().to(device)
nms_idx = torchvision.ops.nms(bbs, labs, 0.3)

bbs = bbs[nms_idx]
labs = labs[nms_idx]

for bb, label in anns_ori:
    if label == 0: 
        draw_rect(ax, bb)
for bb in bbs:
    bb = xyxy2xywh(bb.cpu().numpy())
    draw_rect(ax, bb, edgecolor="red")
    
bbs_ori = torch.FloatTensor([xywh2xyxy(ann[0]) for ann in anns_ori])
labels_ori = torch.tensor([ann[1] for ann in anns_ori])
ax.set_title('F1: {:.4f}'.format(F1(bbs, labs.long(), bbs_ori, labels_ori)), fontsize=30)
plt.show()