In [3]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [4]:
import torch.utils.data as data
from voc import make_filepath_list, GetBBoxAndLabel, DataTransform, multiobject_collate_fn
from preprocessDataset import PreprocessVOC2012

rootpath = '/home/masakibandai/object_detection/data/VOCdevkit/VOC2012/'
train_img_list, train_anno_list, val_img_list, val_anno_list = make_filepath_list(rootpath)
voc_classes = [
    'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 
    'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
]

color_mean = (104, 117, 123)
input_size = 300
train_dataset = PreprocessVOC2012(
    train_img_list,
    train_anno_list,
    phase='train',
    transform=DataTransform(input_size, color_mean),
    get_bbox_label=GetBBoxAndLabel(voc_classes)
)

val_dataset = PreprocessVOC2012(
    val_img_list,
    val_anno_list,
    phase='val',
    transform=DataTransform(input_size, color_mean),
    get_bbox_label=GetBBoxAndLabel(voc_classes)
)

batch_size = 32

train_dataloader = data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=multiobject_collate_fn
)

val_dataloader = data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=multiobject_collate_fn
)

dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}


In [5]:
import torch.nn as nn
import torch.nn.init as init
from ssd import SSD
from torchinfo import summary

ssd_cfg = {
    'classes_num': 21,
    'input_size': 300,
    'dbox_num': [4, 6, 6, 6, 4, 4],
    'feature_maps': [38, 19, 10, 5, 3, 1],
    'steps': [8, 16, 32, 64, 100, 300],
    'min_sizes': [30, 60, 111, 162, 213, 264],
    'max_sizes': [60, 111, 162, 213, 264, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
}

net = SSD(phase='train', cfg=ssd_cfg)

vgg_weights = torch.load('/home/masakibandai/object_detection/weights/vgg16_reducedfc.pth')