# Training functions for YOLO v1

In [1]:
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as FT
from tqdm import tqdm
from torch.utils.data import DataLoader


import nbimporter
import importlib
from YoloV1 import Yolov1
from data import VOCDataset
from loss import intersection_over_union
from Loss import Yolo_Loss
#importlib.reload(Loss)


## Hyperparameters

In [2]:
learning_rate = 2e-5
device = "cuda" if(torch.cuda.is_available()) else "cpu"
batch_size = 16 # It is 64 in the paper
weight_decay = 0
epochs = 100
num_workers = 0
pin_memory = True
load_model = False
load_model_file = "overfit.pth.tar"
img_dir = "archive/images"
label_dir = "archive/labels"


In [None]:
class Compose(object):
    def __init__(self, transforms):

        self.transforms = transforms

    def __call__(self, img, bboxes):
        for t in self.transforms:
            img, bboxes = t(img) ,bboxes

        return img, bboxes

transform = Compose([transforms.Resize((448, 448)), transforms.ToTensor()])
    

def train(train_loader, model, optimiser, loss_fn):

    loop = tqdm(train_loader, leave = True)
    mean_loss = []

    for batch_idx, (x, y) in enumerate(loop):
        x, y = x.to(device), y.to(device)

        out = model(x)

        loss = loss_fn(out, y)

        mean_loss.append(loss.item())

        optimiser.zero_grad()

        loss.backward()

        optimiser.step()


        loop.set_postfix(loss = loss.item())

    print(f"Mean loss was {sum(mean_loss)/len(mean_loss)}")


def main():
    model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(device)

    optimiser = optim.Adam(model.parameters(), lr = learning_rate, weight_decay=weight_decay)

    loss_fn = Yolo_Loss()

    #if load_model:
        #load_checkpoint(torch.load(load_model_file), model, optimiser)

    train_dataset = VOCDataset("archive/100examples.csv", transform=transform, img_dir=img_dir, label_dir= label_dir)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers= num_workers, pin_memory=pin_memory, shuffle=True, drop_last = False)


    for epoch in range(epochs):
        train(train_loader=train_loader, model = model, optimiser=optimiser, loss_fn=loss_fn)


        
if __name__ == "__main__":
    main()

        

100%|██████████| 7/7 [00:01<00:00,  4.20it/s, loss=264]    


Mean loss was 850.2911420549665


100%|██████████| 7/7 [00:01<00:00,  6.85it/s, loss=271]


Mean loss was 594.6688799176898


100%|██████████| 7/7 [00:01<00:00,  6.88it/s, loss=317]


Mean loss was 509.1238272530692


100%|██████████| 7/7 [00:01<00:00,  6.89it/s, loss=240]


Mean loss was 426.002685546875


100%|██████████| 7/7 [00:01<00:00,  6.88it/s, loss=141]


Mean loss was 376.77613830566406


100%|██████████| 7/7 [00:01<00:00,  6.88it/s, loss=185]


Mean loss was 343.1451633998326


100%|██████████| 7/7 [00:01<00:00,  6.94it/s, loss=228]


Mean loss was 311.44005693708147


100%|██████████| 7/7 [00:01<00:00,  6.90it/s, loss=118]


Mean loss was 288.2355433872768


100%|██████████| 7/7 [00:01<00:00,  6.87it/s, loss=90.7]


Mean loss was 270.0989510672433


100%|██████████| 7/7 [00:01<00:00,  6.83it/s, loss=190]


Mean loss was 263.4089922223772


100%|██████████| 7/7 [00:01<00:00,  6.89it/s, loss=90] 


Mean loss was 243.1746346609933


100%|██████████| 7/7 [00:01<00:00,  6.89it/s, loss=134]


Mean loss was 232.36028616768974


100%|██████████| 7/7 [00:01<00:00,  6.87it/s, loss=134]


Mean loss was 230.5947004045759


100%|██████████| 7/7 [00:01<00:00,  6.85it/s, loss=79.1]


Mean loss was 209.07513427734375


100%|██████████| 7/7 [00:01<00:00,  6.82it/s, loss=135]


Mean loss was 207.80577087402344


100%|██████████| 7/7 [00:01<00:00,  6.83it/s, loss=132]


Mean loss was 192.43395124162947


100%|██████████| 7/7 [00:01<00:00,  6.85it/s, loss=101]


Mean loss was 193.38480268205916


100%|██████████| 7/7 [00:01<00:00,  6.86it/s, loss=92.8]


Mean loss was 186.74225725446428


100%|██████████| 7/7 [00:01<00:00,  6.90it/s, loss=60.9]


Mean loss was 182.49724469866072


100%|██████████| 7/7 [00:01<00:00,  6.88it/s, loss=160]


Mean loss was 187.21105739048548


100%|██████████| 7/7 [00:01<00:00,  6.89it/s, loss=127]


Mean loss was 178.63640921456474


100%|██████████| 7/7 [00:01<00:00,  6.93it/s, loss=101]


Mean loss was 169.19014522007532


100%|██████████| 7/7 [00:01<00:00,  6.96it/s, loss=115]


Mean loss was 162.40580095563615


100%|██████████| 7/7 [00:01<00:00,  6.87it/s, loss=65.9]


Mean loss was 160.9698486328125


100%|██████████| 7/7 [00:01<00:00,  6.92it/s, loss=86.5]


Mean loss was 160.52760750906808


100%|██████████| 7/7 [00:01<00:00,  6.87it/s, loss=65.9]


Mean loss was 164.6912329537528


100%|██████████| 7/7 [00:01<00:00,  6.97it/s, loss=43.3]


Mean loss was 157.30364227294922


100%|██████████| 7/7 [00:01<00:00,  6.95it/s, loss=88] 


Mean loss was 158.28987775530135


100%|██████████| 7/7 [00:01<00:00,  6.95it/s, loss=43.7]


Mean loss was 146.59635434831893


100%|██████████| 7/7 [00:01<00:00,  6.99it/s, loss=89.9]


Mean loss was 148.1780003138951


100%|██████████| 7/7 [00:01<00:00,  6.94it/s, loss=66.4]


Mean loss was 146.5998982020787


100%|██████████| 7/7 [00:01<00:00,  6.97it/s, loss=85.6]


Mean loss was 147.00500270298548


100%|██████████| 7/7 [00:01<00:00,  6.97it/s, loss=122] 


Mean loss was 145.37845938546317


100%|██████████| 7/7 [00:01<00:00,  6.89it/s, loss=95.6]


Mean loss was 146.92003740583147


100%|██████████| 7/7 [00:01<00:00,  6.38it/s, loss=119]


Mean loss was 142.88367026192802


100%|██████████| 7/7 [00:01<00:00,  6.86it/s, loss=93.7]


Mean loss was 138.6244637625558


100%|██████████| 7/7 [00:01<00:00,  6.82it/s, loss=69.2]


Mean loss was 136.62061745779855


100%|██████████| 7/7 [00:01<00:00,  6.80it/s, loss=64.2]


Mean loss was 137.137578691755


100%|██████████| 7/7 [00:01<00:00,  6.91it/s, loss=61.1]


Mean loss was 136.71553693498885


100%|██████████| 7/7 [00:01<00:00,  6.28it/s, loss=79.3]


Mean loss was 133.74806322370256


100%|██████████| 7/7 [00:01<00:00,  6.41it/s, loss=89.8]


Mean loss was 137.51228550502233


100%|██████████| 7/7 [00:01<00:00,  6.35it/s, loss=88.8]


Mean loss was 128.3432137625558


100%|██████████| 7/7 [00:01<00:00,  6.36it/s, loss=63.4]


Mean loss was 131.98177664620536


100%|██████████| 7/7 [00:01<00:00,  6.65it/s, loss=58] 


Mean loss was 126.65355627877372


100%|██████████| 7/7 [00:01<00:00,  6.72it/s, loss=56.4]


Mean loss was 120.18296160016742


100%|██████████| 7/7 [00:01<00:00,  6.55it/s, loss=44.6]


Mean loss was 128.4242319379534


100%|██████████| 7/7 [00:01<00:00,  6.54it/s, loss=42.6]


Mean loss was 124.30333546229771


100%|██████████| 7/7 [00:01<00:00,  6.79it/s, loss=79.1]


Mean loss was 128.52322387695312


100%|██████████| 7/7 [00:01<00:00,  6.70it/s, loss=49.9]


Mean loss was 120.82062639508929


100%|██████████| 7/7 [00:01<00:00,  6.64it/s, loss=45.8]


Mean loss was 118.18650708879743


100%|██████████| 7/7 [00:01<00:00,  6.39it/s, loss=64.7]


Mean loss was 119.19166891915458


100%|██████████| 7/7 [00:01<00:00,  6.59it/s, loss=50.7]


Mean loss was 120.2117919921875


100%|██████████| 7/7 [00:01<00:00,  6.55it/s, loss=54]  


Mean loss was 118.25049482073102


100%|██████████| 7/7 [00:01<00:00,  6.54it/s, loss=50.3]


Mean loss was 113.96554892403739


100%|██████████| 7/7 [00:01<00:00,  6.56it/s, loss=71.5]


Mean loss was 117.06258392333984


100%|██████████| 7/7 [00:01<00:00,  6.65it/s, loss=42.7]


Mean loss was 111.14887019566127


100%|██████████| 7/7 [00:01<00:00,  6.77it/s, loss=52.3]


Mean loss was 110.4929782322475


100%|██████████| 7/7 [00:01<00:00,  6.75it/s, loss=58.1]


Mean loss was 114.23080553327289


100%|██████████| 7/7 [00:01<00:00,  6.54it/s, loss=76.9]


Mean loss was 110.64466094970703


100%|██████████| 7/7 [00:01<00:00,  6.77it/s, loss=92.8]


Mean loss was 108.27887289864677


100%|██████████| 7/7 [00:01<00:00,  6.61it/s, loss=43.5]


Mean loss was 106.50763702392578


100%|██████████| 7/7 [00:01<00:00,  6.62it/s, loss=57.4]


Mean loss was 114.84172766549247


100%|██████████| 7/7 [00:01<00:00,  6.63it/s, loss=29.1]


Mean loss was 106.13198007856097


100%|██████████| 7/7 [00:01<00:00,  6.63it/s, loss=49.5]


Mean loss was 108.97213527134487


100%|██████████| 7/7 [00:01<00:00,  6.61it/s, loss=55.8]


Mean loss was 103.7568223135812


100%|██████████| 7/7 [00:01<00:00,  6.76it/s, loss=67.4]


Mean loss was 109.21317945207868


100%|██████████| 7/7 [00:01<00:00,  6.65it/s, loss=75.3]


Mean loss was 106.0833489554269


100%|██████████| 7/7 [00:01<00:00,  6.64it/s, loss=63]  


Mean loss was 103.99168395996094


100%|██████████| 7/7 [00:01<00:00,  6.76it/s, loss=88.4]


Mean loss was 104.33453151157924


100%|██████████| 7/7 [00:01<00:00,  6.74it/s, loss=75.1]


Mean loss was 100.27198464529855


100%|██████████| 7/7 [00:01<00:00,  6.68it/s, loss=50.2]


Mean loss was 102.18003136771065


100%|██████████| 7/7 [00:01<00:00,  6.61it/s, loss=51.2]


Mean loss was 106.79604121616909


100%|██████████| 7/7 [00:01<00:00,  6.67it/s, loss=54.2]


Mean loss was 106.26476396833148


100%|██████████| 7/7 [00:01<00:00,  6.62it/s, loss=49]  


Mean loss was 104.69415446690151


100%|██████████| 7/7 [00:01<00:00,  6.69it/s, loss=74.4]


Mean loss was 103.58709280831474


100%|██████████| 7/7 [00:01<00:00,  6.72it/s, loss=35]  


Mean loss was 102.33191789899554


100%|██████████| 7/7 [00:01<00:00,  6.77it/s, loss=52]  


Mean loss was 99.48423494611468


100%|██████████| 7/7 [00:01<00:00,  6.74it/s, loss=96.3]


Mean loss was 95.25510733468192


100%|██████████| 7/7 [00:01<00:00,  6.76it/s, loss=51.5]


Mean loss was 96.29165594918388


100%|██████████| 7/7 [00:01<00:00,  6.77it/s, loss=50.2]


Mean loss was 96.5714727129255


100%|██████████| 7/7 [00:01<00:00,  6.67it/s, loss=35.8]


Mean loss was 96.69457844325474


100%|██████████| 7/7 [00:01<00:00,  6.64it/s, loss=92.5]


Mean loss was 95.87689535958427


100%|██████████| 7/7 [00:01<00:00,  6.62it/s, loss=59.1]


Mean loss was 90.99554606846401


100%|██████████| 7/7 [00:01<00:00,  6.67it/s, loss=59]  


Mean loss was 93.49682671683175


100%|██████████| 7/7 [00:01<00:00,  6.69it/s, loss=53.5]


Mean loss was 94.7444839477539


100%|██████████| 7/7 [00:01<00:00,  6.72it/s, loss=41.3]


Mean loss was 91.29841940743583


100%|██████████| 7/7 [00:01<00:00,  6.69it/s, loss=50.3]


Mean loss was 92.53513009207589


100%|██████████| 7/7 [00:01<00:00,  6.66it/s, loss=51.5]


Mean loss was 88.11699894496373


100%|██████████| 7/7 [00:01<00:00,  6.73it/s, loss=36.2]


Mean loss was 87.63900756835938


100%|██████████| 7/7 [00:01<00:00,  6.64it/s, loss=31.9]


Mean loss was 88.41022763933454


100%|██████████| 7/7 [00:01<00:00,  6.52it/s, loss=42.2]


Mean loss was 88.00981521606445


100%|██████████| 7/7 [00:01<00:00,  6.68it/s, loss=61.2]


Mean loss was 86.28427995954242


100%|██████████| 7/7 [00:01<00:00,  6.59it/s, loss=28.7]


Mean loss was 90.22592517307827


100%|██████████| 7/7 [00:01<00:00,  6.66it/s, loss=48.1]


Mean loss was 86.56504821777344


100%|██████████| 7/7 [00:01<00:00,  6.58it/s, loss=36]  


Mean loss was 87.4120123726981


100%|██████████| 7/7 [00:01<00:00,  6.68it/s, loss=33.5]


Mean loss was 87.33703122820172


100%|██████████| 7/7 [00:01<00:00,  6.76it/s, loss=31.2]


Mean loss was 85.08597510201591


100%|██████████| 7/7 [00:01<00:00,  6.71it/s, loss=51.7]


Mean loss was 84.26173618861607


100%|██████████| 7/7 [00:01<00:00,  6.76it/s, loss=60.7]


Mean loss was 84.79906899588448


100%|██████████| 7/7 [00:01<00:00,  6.74it/s, loss=38.2]

Mean loss was 84.31462696620396



