In [1]:
import random
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torchvision import models

## Utility functions

In [2]:
class RandomHorizontalFlip:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img, bboxes):
        if random.random() < self.p:
            img_center = img.shape[0] / 2
            img =  img[:,::-1,:]
            img = np.ascontiguousarray(img)

            bboxes[:, 0] += 2*(img_center - bboxes[:,0])
        return img, bboxes

In [3]:
class RandomContrast:
    def __init__(self, lower=0.5, upper=1.5, p=0.5):
        self.lower = lower
        self.upper = upper
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            alpha = random.uniform(self.lower, self.upper)
            img *= alpha
        return img

In [5]:
from collections import namedtuple

In [6]:
ImageEntry = namedtuple("ImageEntry", ["filename", "width", "height",
                                       "classnames", "class_id",
                                       "bounding_boxes"
                                       ])


In [None]:
def load_pascal(json_path):
    json_data = json.load(open(json_path))

    images_df = pd.DataFrame(json_data["images"])
    anno_df = pd.DataFrame(json_data["annotations"])

    anno_df = anno_df[["image_id", "bbox", "category_id"]]
    anno_df = anno_df.rename(columns={"image_id": "id"})

    id_classname = {}
    for row in json_data["categories"]:
        id_classname[row["id"]] = row["name"]

    anno_df["classname"] = anno_df.apply(lambda x: id_classname[x["category_id"]], axis=1)
    df = anno_df.merge(images_df, on="id")

    grouped_data = []
    grouped = df.groupby("file_name")
    for name, group in grouped:
        val = ImageEntry(filename=name, width=group["width"].values[0], height=group["height"].values[0],
                         classnames=list(group["classname"].values), class_id=list(group["category_id"].values - 1),
                         bounding_boxes=list(group["bbox"].values))
        grouped_data.append(val)
    return id_classname, grouped_data