In [1]:
def get_platform_path():
    system = platform.system()
    data_dir, model_dir, checkpoint_dir, dirs = '', '', '', []
    if system == 'Windows':
        drive, common_dir = 'F', 'cache'
        data_dir = '{}:\\{}\\data'.format(drive, common_dir)
        model_dir = '{}:\\{}\\model'.format(drive, common_dir)
        checkpoint_dir = '{}:\\{}\\checkpoint'.format(drive, common_dir)
        dirs = [data_dir, model_dir, checkpoint_dir]

    elif system == 'Linux':
        common_dir = '/data'
        data_dir = '{}/data'.format(common_dir)
        model_dir = '{}/model'.format(common_dir)
        checkpoint_dir = '{}/checkpoint'.format(common_dir)
        dirs = [data_dir, model_dir, checkpoint_dir]

    for dir in dirs:
        if not os.path.exists(dir):
            os.mkdir(dir)

    return data_dir, model_dir, checkpoint_dir

def collate_fn(batch):
    return tuple(zip(*batch))

                the kernel may be left running.  Please let us know
                about your system (bitness, Python, etc.) at
                ipython-dev@scipy.org


In [10]:
def voc_collate_fn(batch_lst, reshape_size=224):
    preprocess = transforms.Compose([
      transforms.Resize((reshape_size, reshape_size)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
      ])
    
    batch_size = len(batch_lst)
    
    img_batch = torch.zeros(batch_size, 3, reshape_size, reshape_size)
    
    max_num_box = max(len(batch_lst[i][1]['annotation']['object']) \
                      for i in range(batch_size))

    box_batch = torch.Tensor(batch_size, max_num_box, 5).fill_(-1.)
    w_list = []
    h_list = []
    img_id_list = []
    
    for i in range(batch_size):
        img, ann = batch_lst[i]
        w_list.append(img.size[0]) # image width
        h_list.append(img.size[1]) # image height
        img_id_list.append(ann['annotation']['filename'])
        img_batch[i] = preprocess(img)
        all_bbox = ann['annotation']['object']
        if type(all_bbox) == dict: # inconsistency in the annotation file
            all_bbox = [all_bbox]
        for bbox_idx, one_bbox in enumerate(all_bbox):
            bbox = one_bbox['bndbox']
            obj_cls = one_bbox['name']
            box_batch[i][bbox_idx] = torch.Tensor([float(bbox['xmin']), float(bbox['ymin']),
              float(bbox['xmax']), float(bbox['ymax']), class_to_idx[obj_cls]])
    
    h_batch = torch.tensor(h_list)
    w_batch = torch.tensor(w_list)

    return img_batch, box_batch, w_batch, h_batch, img_id_list

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms

import argparse
import platform
import os

from tqdm import tqdm

if __name__ == "__main__":
#     parser = argparse.ArgumentParser(description="Pytorch YOLOv1 Training")

#     parser.add_argument('--lr', default=0.01, type=float, help='learning rate')
#     parser.add_argument('--epoch', default=200, type=int, help='epoch')
#     parser.add_argument('--resume', '-r', action='store_true', help='resume from checkpoint')
#     parser.add_argument('--use_cuda', action='store_true', default=True, help='whether to use cuda')
#     parser.add_argument('--net', default='resnet18', help='network type')
#     args = parser.parse_args()
    
    use_cuda = True
    
    # detect device
    print("CUDA Available:", torch.cuda.is_available())
    device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

    # data/model/checkpoint in different platform
    data_dir, model_dir, checkpoint_dir = get_platform_path()

    # dataset VOC2011 and others are subset of VOC2012
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    transform_val = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224,224)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    trainset = torchvision.datasets.VOCDetection(root=data_dir, download=True, year='2012',
                                                 image_set='train', transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=4, collate_fn=voc_collate_fn)

    valset = torchvision.datasets.VOCDetection(root=data_dir, download=True, year='2012',
                                               image_set='val', transform=transform_val)
    valloader = torch.utils.data.DataLoader(valset, batch_size=16, shuffle=False, num_workers=4)



CUDA Available: True
Using downloaded and verified file: F:\cache\data\VOCtrainval_11-May-2012.tar
Using downloaded and verified file: F:\cache\data\VOCtrainval_11-May-2012.tar


In [12]:
class_to_idx = {'aeroplane':0, 'bicycle':1, 'bird':2, 'boat':3, 'bottle':4,
                'bus':5, 'car':6, 'cat':7, 'chair':8, 'cow':9, 'diningtable':10,
                'dog':11, 'horse':12, 'motorbike':13, 'person':14, 'pottedplant':15,
                'sheep':16, 'sofa':17, 'train':18, 'tvmonitor':19
}
idx_to_class = {i:c for c, i in class_to_idx.items()}

In [14]:
for batch_idx,(img_batch, box_batch, w_batch, h_batch, img_id_list) in trainloader:
    print(img_batch, box_batch, w_batch, h_batch, img_id_list)

    break


RuntimeError: DataLoader worker (pid(s) 16272, 18544, 12072, 19068) exited unexpectedly