In [1]:
from tqdm import tqdm
import os
import sys
import random
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as T
from detzoo.models import YOLOv1
from detzoo.datasets import VOCDataset, COCODataset
from detzoo.utils import bbox_to_yolo_format, collate_fn

  from .autonotebook import tqdm as notebook_tqdm


# Basic Config

In [2]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
epoch = 1
lr = 0.001
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = 'voc'
model = 'yolov1'
save_dir = '/home/lujialin/detzoo/checkpoints'
backbone = 'vgg16'
# classes = ['person', 'car'] # You can define target classes and assign them to the dataset

# Prepare dataset

In [3]:
if dataset == 'voc':
    train_dataset = VOCDataset(
                root='~/data/VOC', 
                year='2007', 
                image_set='train', 
                transform=T.Compose(
                    [T.Resize((448, 448)),
                    T.ToTensor(),
                    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
                ),
                # classes=classes
            )
elif dataset == 'coco':
    train_dataset = COCODataset(
                root='~/data/COCO', 
                year='2017', 
                image_set='train', 
                transform=T.Compose(
                    [T.Resize((448, 448)),
                    T.ToTensor(),
                    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
                ),
                # classes=classes
            )

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)

# Configure detector

In [4]:
if model == 'yolov1':
    detector = YOLOv1(
                    classes=train_dataset.classes, 
                    backbone=backbone
                ).to(device)
elif model == 'yolov2':
    pass


detector.optimizer = Adam(params=detector.parameters(), lr=lr)

# Train

In [5]:
detector.train()
for epoch in tqdm(range(epoch), desc='Epoch'):
    for image, targets in tqdm(train_dataloader, desc='Train', leave=False):
        # For YOLO models
        targets = bbox_to_yolo_format(targets, (448, 448))

        # put data to device
        image = image.to(device)
        targets = targets.to(device)

        # clear grad for each iteration
        detector.optimizer.zero_grad()

        # forward
        prediction = detector(image)
        loss = detector.loss(prediction, targets)

        # update
        loss.backward()
        detector.optimizer.step()

Epoch:   0%|          | 0/1 [00:07<?, ?it/s]


KeyboardInterrupt: 

# Save model

In [None]:
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
torch.save(detector.state_dict(), os.path.join(save_dir, f'{model}_{backbone}_{dataset}.pth'))