In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import time
import pickle
import string
import os
import random
from PIL import Image
torch.backends.cudnn.benchmark = True
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
prev_time = 0
gamma = 0.99
stats = {}  # tracks ewma running average
def benchmark(point=None, profile=True): # not thread safe at all
    global prev_time
    if not profile:
        return
    if point is not None:
        time_taken = time.perf_counter() - prev_time
        if point not in stats:
            stats[point] = time_taken
        stats[point] = stats[point]*gamma + time_taken*(1-gamma)
        print(f"took {time_taken} to reach {point}, ewma={stats[point]}")
    prev_time = time.perf_counter()

In [None]:
class ColorDatasetGenerator(Dataset):
    def __init__(self, **kwargs):
        # define default values for parameters, can override any using kwargs
        # noise parameters
        self.num_noise = (50, 70)  # number of locations to generate noise at
        self.noise_size = (5, 10)  # size that each noise instance can be
        
        # image parameters
        self.size = 256 # shape of image
        self.channels = 1  # default is greyscale
        
        # target parameters
        self.color_classifier = None  # function that maps colors to classes (supports iterables)
        self.num_classes = 2  # how many possible classes there are
        self.color_range = (50, 200)  # range of values that the color-to-be-classified can be
        self.radius = (self.size//6, self.size//3)  # range of possible radii for circles
        self.num_objects = 1 # for multiclass problems, if we want to classify multiple things
        
        # actual dataset
        self.num_images = 10   # totally arbitrary
        self.options = []  # list of dataset options that need to be saved
        for k,v in kwargs.items():
            setattr(self, k, v)
            self.options.append(k)
    
    def add_target(self, arr):
        num_objects = self.num_objects # np.random.randint(*self.num_objects)
        color = np.random.randint(*self.color_range, (num_objects, self.channels)) # rgb for now
        
        label = np.zeros((self.num_classes))
        label[self.color_classifier(color)] = 1  # multi-hot encoded

        # probably should make sure they dont overlap too much, but num_objects=1 for now
        radii = np.random.randint(*self.radius, (num_objects))  
        locations = np.random.randint(self.radius[1], self.size-self.radius[1], (num_objects, 2))
        for radius, location in zip(radii, locations):
            x_coords = np.arange(radius)
            for x in x_coords:
                height = 2*int(np.sqrt(radius**2 - x**2))
                y_coords = np.arange(height) - height//2 + location[1]
                arr[location[0]+x, y_coords] = color
                arr[location[0]-x, y_coords] = color
        return label

    def add_noise(self, arr):
        num_noise = np.random.randint(*self.num_noise)
        sizes = np.random.randint(*self.noise_size, num_noise)
        colors = np.random.randint(1, 255, (num_noise, self.channels))
        locations = np.random.randint(self.noise_size[1], self.size-self.noise_size[1], (num_noise, 2))
        for size, color, location in zip(sizes, colors, locations):
            arr[location[0]:location[0]+size,location[1]:location[1]+size] = color
    
    def generate_one(self, profile=False):
        benchmark(profile=profile)
        img = np.zeros((self.size, self.size, self.channels))
        benchmark("initialization", profile)
        label = self.add_target(img)
        benchmark("circle", profile)
        self.add_noise(img)
        benchmark("noise", profile)
        return img, label
            
    def __len__(self):
        return self.num_images
    
    def __getitem__(self, idx):
        print(idx)
        if torch.is_tensor(idx):
            idx = idx.tolist()
        np.random.seed(idx)  # to make results repeatable
        image, label = self.generate_one()
        sample = {'image': image, 'label': label, "idx": idx}
        if hasattr(self, "transform"):
            sample = self.transform(sample)
        return sample

In [None]:
def color_classifier(color):
    if color >= 150:  # valid colors (for now) for the circle is [100, 200], split in the middle
        return 1
    return 0

In [None]:
color_dataset = ColorDatasetGenerator(color_classifier=color_classifier)
color_dataloader = DataLoader(color_dataset, batch_size=4, shuffle=True, 
                              num_workers=4, pin_memory=True)

In [None]:
plt.imshow(color_dataset[5]["image"], cmap="gray")