In [None]:
import numpy as np
import torch, sys
from torchvision.transforms import functional as func
import torchvision.transforms as transforms
from loss import ComputeLoss
import yaml, random
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F
from new_model import Model
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.ticker import NullLocator
from torch.utils.data import DataLoader

from dataloader import FiftyOneTorchDataset, collate_fn
from util import non_max_suppression

In [None]:
# Get the training set using our changed 51 dataloader.
dataset_train = foz.load_zoo_dataset(
    "coco-2017", # Specify which COCO dataset to use. 
    split="train", # Specify training, validation, or test dataset from COCO.
    classes=["cat", "dog", "horse", "giraffe"], # Specify the classes
    max_samples=256, # Specify number of samples.
)

In [None]:
# Get the validation set.
dataset_validation = foz.load_zoo_dataset(
    "coco-2017",
    split="validation",
    classes=["cat", "dog", "horse", "giraffe"],
    max_samples=64,
)

In [None]:
# Transform the dataset to be used in training.
dataset_train.persistent = True
dataset_validation.persistent = True
view_train = dataset_train.filter_labels("ground_truth", F("label").is_in(("cat", "dog", "horse", "giraffe")))
view_val = dataset_validation.filter_labels("ground_truth", F("label").is_in(("cat", "dog", "horse", "giraffe")))

# Filter out the classes.
fil_classes = ["cat", "dog", "horse", "giraffe"]
device = torch.device('cuda:0')

# Try opening our configuration file.
with open("hyp.yaml", "r") as stream:
    try:
        hyp = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

# Resize the image.
org_w = 640
org_h = 480
scaling_factor = 640/480

# Batch size.
batch_size = 8

# Add zero padding to image to make it a square image.
transform = transforms.Compose([transforms.Resize((int(org_h/scaling_factor), int(org_w/scaling_factor))),
                                transforms.Pad((0, int((org_w - org_h)/(2*scaling_factor)),0,int((org_w - org_h)/(2*scaling_factor)))),
                                transforms.ToTensor()])

# Load the data loaders for training and validation.
dataset_train = FiftyOneTorchDataset(view_train, transform, classes=fil_classes)
dataset_val = FiftyOneTorchDataset(view_val, transform, classes=fil_classes)
loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True)
loader_val = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=True)

In [None]:
# Use the collate function for the dataloader.
loader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
loader_val = DataLoader(dataset=dataset_val, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

In [None]:
# Load the model and set up the optimizer and the custom loss function.
model = Model('yolov3.yaml', hyp=hyp).to(device)
optimizer = torch.optim.Adam(model.parameters(),1e-3)
loss_fcn = ComputeLoss(model)

# Keeps track of results.
train_loss_list = []
val_loss_list = []

epochs = 10
#epochs = 1000
for epoch in range(epochs):
    
    # Train the model in batches.
    tot_loss = 0
    count = 0
    for images, targets in loader_train:
        model.train()
        optimizer.zero_grad()
        images = images.to(device)
        targets = targets.to(device)
        preds = model(images)
        loss, loss_parts = loss_fcn(preds, targets)
        tot_loss += loss / batch_size
        count += 1
        loss.backward()
        optimizer.step()
        
    # Save the model each 50 epochs
    if epoch%10==0 and epoch!=0:
        torch.save(model,'./models/model'+str(epoch)+'.pt')
    print(epoch, 'Training:\t',epoch, tot_loss.item()/count)
    train_loss_list.append(tot_loss.item()/count)
    
    # For validation≥
    tot_loss = 0
    count = 0
    for images, targets in loader_val:
        images = images.to(device)
        targets = targets.to(device)
        with torch.no_grad():
            preds = model(images)
            loss, loss_parts = loss_fcn(preds, targets)
            tot_loss += loss / batch_size
        count += 1
    print('\tValidation:\t', tot_loss.item()/count)
    val_loss_list.append(tot_loss.item()/count)
    

# Save the final model and the list results.
torch.save(model,'./models/final'+'.pt')    

from util import my_load,my_save
my_save('trainloss',train_loss_list)
my_save('validationloss',val_loss_list)

In [None]:
# Plot the results.
from util import plot
plot(train_loss_list,val_loss_list,'train loss','val loss','loss','train loss and validation loss')

In [None]:
# Reload the saved model.
from util import my_img_plot
import gc
gc.collect() 
torch.cuda.empty_cache()
device = torch.device('cuda:0')
model = torch.load('./models/final.pt').to(device)

In [None]:
# Generate prediction images for visualization.
model.eval()
for images, targets in loader_train:
    # for image in images:
    images = images.to(device)
    with torch.no_grad():
        pred = model(images)
        my_img_plot(pred[0],images[0],fil_classes,1)
    break