forked from i-deal/MLR-2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dataloader__.py
executable file
·60 lines (50 loc) · 1.77 KB
/
dataloader__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from matplotlib.pyplot import imread
from torch import Tensor
from torchvision import transforms
"""
Loads the train/test set.
Every image in the dataset is 28x28 pixels and the labels are numbered from 0-9
for A-J respectively.
Set root to point to the Train/Test folders.
"""
# Creating a sub class of torch.utils.data.dataset.Dataset
class notMNIST(Dataset):
# The init method is called when this class will be instantiated.
def __init__(self, root, transform = None):
Images, Y = [], []
folders = os.listdir(root)
self.transform=transform
for folder in folders:
folder_path = os.path.join(root, folder)
for ims in os.listdir(folder_path):
try:
img_path = os.path.join(folder_path, ims)
Images.append(np.array(imread(img_path)))
Y.append(ord(folder) - 65) # Folders are A-J so labels will be 0-9
except:
# Some images in the dataset are damaged
print("File {}/{} is broken".format(folder, ims))
data = [(x, y) for x, y in zip(Images, Y)]
self.data = data
# The number of items in the dataset
def __len__(self):
return int(len(self.data)/4)
# The Dataloader is a generator that repeatedly calls the getitem method.
# getitem is supposed to return (X, Y) for the specified index.
def __getitem__(self, index):
img = self.data[index][0]
# 8 bit images. Scale between [0,1]. This helps speed up our training
img = img.reshape(28, 28)
img= Image.fromarray(img*255)
if self.transform is not None:
img = self.transform(img)
#img = img.reshape(28, 28) / 255.0
# Input for Conv2D should be Channels x Height x Width
#img_tensor = Tensor(img).view(1, 28, 28).float()
label = self.data[index][1]
return (img, label)