In [2]:
import glob
import os.path as osp
import random
import numpy as np
import json
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import models, transforms

In [5]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

In [6]:
class ImageTransform():
    
    def __init__(self, resize, mean, std):
        self.data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)]),
        "val": transforms.Compose([transforms.Resize(resize),
                                   transforms.CenterCrop(resize),
                                   transforms.ToTensor(),
                                   transforms.Normalize(mean, std)])
        }
    
    def __call__(self, img, phase="train"):
        return self.data_transform[phase](img)
    

In [None]:
def make_datapath_list(rootpath):
    
    imgpath_template = osp.join(rootpath, 'JPEGImages', '%s.jpg')
    annopath_template = osp.join(rootpath, 'Annotations', '%s.xml')

    train_id_games = osp.join(rootpath + 'ImageSets/Main/train.txt')
    val_id_games = osp.join(rootpath + 'ImageSets/Main/val.txt')
    
    train_id_games = osp.join(rootpath + 'ImageSets/Main/train.txt')
    val_id_games = osp.join(rootpath + 'ImageSets/Main/val.txt')
    
    train_img_list = []
    train_anno_list = []
    
    for line in open(train_id_games):
        file_id = line.strip()
        img_path = (imgpath_template % file_id)
        anno_path = (annopath_template % file_id)
        val_img_list.append(img_path)
        val_anno_list.append(anno_path)
        
    val_img_list = []
    val_anno_list = []
    
    for line in open(val_id_games):
        file_id = line.strip()
        img_path = (imgpath_template % file_id)
        anno_path = (annopath_template % file_id)
        val_img_list.append(img_path)
        val_anno_list.append(anno_path)
        
    return train_img_list, train_anno_list, val_img_list, val_anno_list