# Prepare Data

In [15]:
import numpy as np
import torch
import torchvision as tv
from torchvision import datasets,transforms
from torch.utils.data import Dataset,random_split
from glob import glob
import os
from sklearn.model_selection import train_test_split

In [16]:
path = '/Users/Aymanjabri/notebooks/Artwork/data/images/images'
x = glob(path+'/**/*')

classes = [os.path.basename(i) for i in glob(path+'/**')]
targets = np.arange(0,50)
class_to_idx=dict(zip(classes,targets))

In [17]:
d = [os.path.basename(os.path.dirname(i)) for i in x]

In [18]:
y = [class_to_idx[i] for i in d]

In [19]:
sorted(Counter(y).items())

[(0, 84),
 (1, 128),
 (2, 702),
 (3, 43),
 (4, 291),
 (5, 99),
 (6, 259),
 (7, 49),
 (8, 194),
 (9, 255),
 (10, 90),
 (11, 119),
 (12, 181),
 (13, 81),
 (14, 87),
 (15, 31),
 (16, 134),
 (17, 188),
 (18, 311),
 (19, 73),
 (20, 239),
 (21, 164),
 (22, 81),
 (23, 126),
 (24, 47),
 (25, 91),
 (26, 139),
 (27, 70),
 (28, 88),
 (29, 117),
 (30, 877),
 (31, 59),
 (32, 193),
 (33, 186),
 (34, 120),
 (35, 439),
 (36, 24),
 (37, 336),
 (38, 102),
 (39, 141),
 (40, 67),
 (41, 55),
 (42, 137),
 (43, 171),
 (44, 109),
 (45, 262),
 (46, 143),
 (47, 70),
 (48, 66),
 (49, 328)]

### Split the dataset between training and validation

#### 1. Using Sklearn

In [20]:
#import `train_test_split` from sklearn

X_train,X_test,y_train,y_test = train_test_split(x,y,test_size=0.2)

In [21]:
len(y_train),len(y_test)

(6756, 1690)

#### 2. Using Pytorch

`from torch.utils.data import random_split`

##### Create the Dataset as a subclass of torchvision `Dataset`

In [22]:
class ArtworkSet(Dataset):
    def __init__(self,x,y,class_to_idx,classes,transform=None):
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.paths = x
        self.targets = y
        self.transform=transform
        
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self,idx):
        img,label = Image.open(self.paths[idx][0]),self.targets[idx]
        if img.getbands()[0] == 'L':
            img = img.convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img,label
        

In [23]:
tfms = transforms.Compose([transforms.Resize((50,50)),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

ds = ArtworkSet(x,y,class_to_idx=class_to_idx,classes=classes,transform=tfms)

In [24]:
len(ds)

8446

In [27]:
trainset,validset = random_split(dataset=ds,lengths=[6757,1689])

In [28]:
len(trainset),len(validset)

(6757, 1689)