In [1]:
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import codecs
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class FMNIST(data.Dataset):
    
    """`Fashion-MNIST <https://github.com/zalandoresearch/fashion-mnist>`_ Dataset.
    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    
    urls =[
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz',
        'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'
    
    def __init__(self, root,train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train
        
        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(os.path.join(root, self.processed_folder, self.test_file))
            
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target
        
    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)
        
    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))
    
    def download(self):
        """Download the FMNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')
        
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)

    def parse_byte(b):
        if isinstance(b, str):
            return ord(b)
        return b

    def read_label_file(path):
        with open(path, 'rb') as f:
            data = f.read()
            assert get_int(data[:4]) == 2049
            length = get_int(data[4:8])
            labels = [parse_byte(b) for b in data[8:]]
            assert len(labels) == length
            return torch.LongTensor(labels)

    def read_image_file(path):
        with open(path, 'rb') as f:
            data = f.read()
            assert get_int(data[:4]) == 2051
            length = get_int(data[4:8])
            num_rows = get_int(data[8:12])
            num_cols = get_int(data[12:16])
            images = []
            idx = 16
            for l in range(length):
                img = []
                images.append(img)
                for r in range(num_rows):
                    row = []
                    img.append(row)
                    for c in range(num_cols):
                        row.append(parse_byte(data[idx]))
                        idx += 1
            assert len(images) == length
            return torch.ByteTensor(images).view(-1, 28, 28)


In [3]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))])
fmnist_trainset = FMNIST(root='./data', train=True, download=True, transform=transform)
fmnist_testset = FMNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
print(fmnist_trainset)
print(fmnist_testset)

<__main__.FMNIST object at 0x7fbbc31304a8>
<__main__.FMNIST object at 0x7fbbc3130630>


In [5]:
batch_size = 100
train_loader = torch.utils.data.DataLoader(
                 dataset=fmnist_trainset,
                 batch_size=batch_size,
                 shuffle=True)
test_loader = torch.utils.data.DataLoader(
                dataset=fmnist_testset,
                batch_size=batch_size,
                shuffle=False)

print('==>>> total trainning batch number: {}'.format(len(train_loader)))
print('==>>> total testing batch number: {}'.format(len(test_loader)))

==>>> total trainning batch number: 600
==>>> total testing batch number: 100


In [6]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(1,20,5,1)
        self.conv2 = nn.Conv2d(20,50,5,1)
        self.fc1 = nn.Linear(4*4*50,500)
        self.fc2 = nn.Linear(500,10)
    def forward(self, y):
        y = F.relu(self.conv1(y))
        y = F.max_pool2d(y, 2, 2)
        y = F.relu(self.conv2(y))
        y = F.max_pool2d(y, 2, 2)
        y = y.view(-1, 4*4*50)
        y = F.relu(self.fc1(y))
        y = self.fc2(y)
        return y

In [None]:
model = CNN()
criterion = nn.CrossEntropyLoss()
optimiser = optim.Adam(model.parameters(),lr=0.001)

In [None]:
train_loss_history = []
for epoch in range(20):
    # training
    ave_loss = 0
    for batch_idx, (x, target) in enumerate(train_loader):
        optimiser.zero_grad()
        x, target = Variable(x), Variable(target)
        out = model(x)
#         print(out.shape)
#         print(target.shape)
        loss = criterion(out, target)
        loss.backward()
        train_loss_history.append(loss)
        optimiser.step()
        if (batch_idx+1) % 600 == 0 or (batch_idx+1) == len(train_loader):
            print('epoch: '+str(epoch+1)+' train loss: '+str(loss.item()))
    # testing
    correct_cnt, ave_loss = 0, 0
    total_cnt = 0
    for batch_idx, (x, target) in enumerate(test_loader):
        x, target = Variable(x, volatile=True), Variable(target, volatile=True)
        out = model(x)
        loss = criterion(out, target)
        _, pred_label = torch.max(out.data, 1)
        total_cnt += x.data.size()[0]
#         total_cnt += 100
#         print('Prediction: {}'.format(pred_label))
#         print("Target: {}".format(target.data))
        correct_cnt += (pred_label == target.data).sum()
#         print("Correct_cnt: {}".format(correct_cnt))
        acc = (100.0 * correct_cnt) / (total_cnt)
        if(batch_idx+1) % 100 == 0 or (batch_idx+1) == len(test_loader):
            print('epoch: '+str(epoch+1)+' test loss: '+str(loss.item())+' acc: '+str(acc.item()))

epoch: 1 train loss: 0.26372596621513367




epoch: 1 test loss: 0.37559249997138977 acc: 87
epoch: 2 train loss: 0.16733808815479279
epoch: 2 test loss: 0.30838537216186523 acc: 88
epoch: 3 train loss: 0.16979366540908813
epoch: 3 test loss: 0.2317211776971817 acc: 89
epoch: 4 train loss: 0.14462095499038696
epoch: 4 test loss: 0.29656627774238586 acc: 89
epoch: 5 train loss: 0.157267764210701
epoch: 5 test loss: 0.28530892729759216 acc: 90
epoch: 6 train loss: 0.14537476003170013
epoch: 6 test loss: 0.29633498191833496 acc: 90
epoch: 7 train loss: 0.213612899184227
epoch: 7 test loss: 0.2619575262069702 acc: 90
epoch: 8 train loss: 0.179270401597023
epoch: 8 test loss: 0.2631705403327942 acc: 91
epoch: 9 train loss: 0.05928034335374832
epoch: 9 test loss: 0.2637607455253601 acc: 90
epoch: 10 train loss: 0.05857496336102486
epoch: 10 test loss: 0.30860772728919983 acc: 90
epoch: 11 train loss: 0.11295216530561447
epoch: 11 test loss: 0.37029018998146057 acc: 91
epoch: 12 train loss: 0.057852067053318024
epoch: 12 test loss: 0.31

In [None]:
plt.title('Training loss')
plt.plot(train_loss_history, 'o')
plt.xlabel('Iteration')