In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np
import datetime as dt
from SSDLoss import SSDloss

DATAPATH = './models/'
# GPU training
DEVICE = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f"Training on device {DEVICE}.")
torch.cuda.memory_allocated()

In [1]:
def training_loop(n_epochs, optimizer, model, loss_fn, train_loader, valid_loader):
    best_valid_loss = float('Inf')
    print('{} starting training'.format(dt.datetime.now()))
    for epoch in range(1, n_epochs + 1):
        train_loss = 0.0
        for images, targets in train_loader:
            images = images.to(device=DEVICE)
            offs, conf = model(images)
            loss = loss_fn(offs, conf, targets['boxes'].to(device=DEVICE), targets['labels'].to(device=DEVICE))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        torch.cuda.empty_cache()
        # Validation
        with torch.no_grad():
            valid_loss = 0.0
            for images, targets in train_loader:
                images = images.to(device=DEVICE)
                coords, conf = model(images)
                loss = SSDloss(coords, conf, targets['boxes'].to(device=DEVICE), targets['labels'].to(device=DEVICE))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                valid_loss += loss.item()
        torch.cuda.empty_cache()
        # Verbose training
        train_loss = train_loss / len(train_loader)
        valid_loss = valid_loss / len(valid_loader)
        if epoch == 1 or epoch % 10 == 0:
            print('{} Epoch {}, Train {:.5f}, Valid {:.5f}'.format(dt.datetime.now(),
                                                                      epoch,
                                                                      train_loss,
                                                                      valid_loss))
        if valid_loss < best_valid_loss:
            torch.save(model.state_dict(), DATAPATH + 'ssd.pt')
            print(f'Saving {epoch}-th for {valid_loss = :2.5f}')
            best_valid_loss = valid_loss


NameError: name 'nn' is not defined

In [None]:
from cocoBox import load_coco_dataset

BATCH_SIZE = 32
IMG_SIZE = 300
# Select image categories with which to train SSD
categories = ('horse', 'bird', 'cat', 'dog', 'person', 'car')
train_dataloader, valid_dataloader = load_coco_dataset(batch_size=BATCH_SIZE,
                                                       size=10000,
                                                       dim=IMG_SIZE,
                                                       cats=categories,
                                                       fetch_type='union')

In [None]:
from SSD import SSD
from SSDLoss import SSDLoss

ssd = SSD(class_num=len(categories)).to(device=DEVICE)
adam = torch.optim.Adam(ssd.parameters(), lr=1e-3)
training_loop(n_epochs=1000,
              optimizer=adam,
              loss_fn=SSDLoss().to(device=DEVICE),
              model=ssd,
              train_loader=train_dataloader,
              valid_loader=valid_dataloader)

In [None]:
import matplotlib.pyplot as plt
from matplotlib import patches
# Load best models
ssd.load_state_dict(torch.load(DATAPATH + 'ssd.pt', map_location=DEVICE))
plt.figure(figsize=(12, 60))
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
mean = [-1 * x / y for x, y in zip(mean, std)]
std = [1 / x for x in std]
unnormalize = torchvision.transforms.Normalize(mean=mean, std=std)
for imgs, targets in valid_dataloader:
    if len(targets[0]['boxes']) < 2:
        continue
    fig, ax = plt.subplots()
    img = imgs[0]
    target = targets[0]
    ax.imshow(unnormalize(img).permute(1, 2, 0), vmin=0, vmax=1)
    # TODO: add labelling of rectangle box
    for idx, box in enumerate(target['boxes']):
        x = box[0] * img.shape[2]
        y = box[1] * img.shape[1]
        width = box[2] * img.shape[2]
        height = box[3] * img.shape[1]
        rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor="b", fill=False)
        ax.add_patch(rect)
        ax.text(x=x, y=y, s=categories[target['label'][idx]])
    offs, conf = ssd.predict(imgs, targets)
    # TODO: add labelling of rectangle box
    for idx, box in enumerate(offs):
        x = box[0] * img.shape[2]
        y = box[1] * img.shape[1]
        width = box[2] * img.shape[2]
        height = box[3] * img.shape[1]
        rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor="r", fill=False)
        ax.add_patch(rect)
        ax.text(x=x, y=y, s=categories[np.argmax(conf[idx])])
    plt.show()