In [11]:
import os
import math
import torch
from torch.autograd import Variable
from torch.optim import Adam
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, random_split, Dataset
from PIL import Image

In [12]:
BATCH_SIZE = 20
LABELS = sorted([i[:-4] for i in os.listdir('icons')])

In [20]:
class IconsDataset(Dataset):
    def __init__(self, directory, labels, transform=None):
        self.directory = directory
        self.files = os.listdir(directory)
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        f = self.files[index]
        # image
        image = Image.open(os.path.join(self.directory, f))
        if self.transform:
            image = self.transform(image)
        
        # labels
        label_strings = f[:-4].split('_')
        labels = []
        for l in LABELS:
            if l in label_strings:
                labels.append(1)
            else:
                labels.append(0)
        
        return image, torch.tensor(labels, dtype=torch.long)
    
    def input_size(self):
        return 100 * 100 * 3
    
    def __len__(self):
        return len(self.files)


def load_dataset():
    d = IconsDataset('images/', LABELS, transform=transforms.ToTensor())
    train, validate = random_split(d, [800, 200])
    
    loader = DataLoader(train, batch_size=BATCH_SIZE)
    validation_loader = DataLoader(validate, batch_size=BATCH_SIZE)

    return d.input_size(), loader, validation_loader

s, l, v = load_dataset()