In [3]:
import time
import torch
from torch import nn
from d2l import torch as d2l
import torchvision
from torchvision import transforms

d2l.use_svg_display

<function d2l.torch.use_svg_display()>

In [4]:
class FashionMNIST(d2l.DataModule):
    def __init__(self, batch_size=64, resize=(28,28)):
        super().__init__()
        self.save_hyperparameters()
        trans = transforms.Compose([transforms.Resize(resize),transforms.ToTensor()])
        self.train = torchvision.datasets.FashionMNIST(root=self.root, train=True, transform=trans,download=True)
        self.val = torchvision.datasets.FashionMNIST(root=self.root, train=False, transform=trans,download=True)

In [8]:
data = FashionMNIST(resize=(32,32))
len(data.train),len(data.val)
data.train[0][0].shape

torch.Size([1, 32, 32])

In [9]:
@d2l.add_to_class(FashionMNIST)
def text_labels(self, indices):
    labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
                'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
    return [labels[int(i)] for i in indices]

In [10]:
@d2l.add_to_class(FashionMNIST)
def get_dataloader(self, train):
    data = self.train if train else self.val
    return torch.utils.data.DataLoader(data, batch_size=self.batch_size, shuffle=train, num_workers=self.num_workers)
    

In [11]:
X, y = next(iter(data.train_dataloader()))
print(X.shape, X.dtype, y.shape, y.dtype)

torch.Size([64, 1, 32, 32]) torch.float32 torch.Size([64]) torch.int64


In [12]:
tic = time.time()
for X, y in data.train_dataloader():
    continue
f'{time.time() - tic:.2f} sec'

'4.62 sec'

In [None]:
@d2l.add_to_class(FashionMNIST)
def visualize(self, batch, nrows=1, ncol=8, labels=[]):
    X, y = batch
    if not labels:
        labels = self.text_labels(y)
    d2l.show_images(X.squeeze(1), nrows, ncol, titles=labels)

batch = next(iter(data.train_dataloader()))
data.visualize(batch) 
    

In [None]:
class Classifier(d2l.Module):
    def __init__(self):
        super().__init__()
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('accuracy', self.accuracy(Y_hat, batch[-1]), train=False)

In [None]:
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=self.lr)

In [None]:
@d2l.add_to_class(Classifier)
def accuracy(self, Y_hat, Y, averaged=True):
    """coompute correct fraction"""
    Y_hat = Y_hat.reshape((-1,Y_hat.shape[-1]))#flatten
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare