In [1]:
import torch
from datasets import load_dataset,load_dataset_builder
from torch.utils.data import DataLoader,default_collate
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch import tensor,nn,optim
import torchvision.transforms.functional as TF
!export HSA_OVERRIDE_GFX_VERSION=10.3.0

In [2]:
# Get the dataset

dataset_name = "fashion_mnist"

ds_builder = load_dataset_builder(dataset_name)
ds_builder.info.features

{'image': Image(decode=True, id=None),
 'label': ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)}

In [3]:
dataset_dict = load_dataset(dataset_name)
dataset_dict

Found cached dataset fashion_mnist (/home/shubham/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [4]:
train_ds = dataset_dict['train']
test_ds = dataset_dict['test']

In [5]:
x , y = ds_builder.info.features

In [6]:
features = ds_builder.info.features[y]
features

ClassLabel(names=['T - shirt / top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'], id=None)

In [7]:
features.int2str(0)

'T - shirt / top'

In [8]:
train_ds[:5]

{'image': [<PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
  <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
  <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
  <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>,
  <PIL.PngImagePlugin.PngImageFile image mode=L size=28x28>],
 'label': [9, 0, 0, 3, 0]}

In [9]:
def inplace(f):
    def f_(x):
        f(x)
        return x
    return f_

In [10]:
@inplace
def to_tensor(b):
    b['image'] = torch.stack([TF.to_tensor(i).view(-1,28,28) for i in b['image']])

In [11]:
def collate_fn(b):
    collated_dict = default_collate(b)
    return collated_dict['image'],collated_dict['label']

In [12]:
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'
def_device

'cuda'

In [13]:
bs = 128
lr = 1e-3
epochs = 5

In [14]:
train_dl = DataLoader(train_ds.with_transform(to_tensor), batch_size=bs ,shuffle=True, collate_fn=collate_fn)
valid_dl = DataLoader(test_ds.with_transform(to_tensor), batch_size=bs ,shuffle=True, collate_fn=collate_fn)

In [15]:
def get_model():
    model = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
                            nn.ReLU(),
                            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
                            nn.ReLU(),
                            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
                            nn.ReLU(),
                            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
                            nn.ReLU(),
                            nn.AdaptiveAvgPool2d(1),
                            nn.Flatten(),
                            nn.Linear(256, 10))
    return model

In [16]:
model = get_model()

In [17]:
opt = optim.SGD(model.parameters(), lr=lr)

In [18]:
def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    return type(x)(to_device(o, device) for o in x)

In [19]:
batch = next(iter(train_dl))
xb,yb = to_device(batch)
print(xb.shape)
print(yb.shape)
preds = model.to(def_device)(xb)
preds.shape

: 

: 

In [None]:
def fit(epochs, model, opt, train_dl, test_dl, device=def_device):
    for epoch in range(epochs):
        model.train()
        for xb,yb in train_dl:
            xb = to_device(xb, device)
            yb = to_device(yb, device)
            loss_ = F.cross_entropy(model(xb), yb)
            loss_.backward()
            opt.step()
            opt.zero_grad()

        model.eval()
        with torch.no_grad():
            tot_loss,tot_acc,count = 0.,0.,0
            for xb,yb in test_dl:
                xb = to_device(xb, device)
                yb = to_device(yb, device)
                pred = model(xb)
                n = len(xb)
                count += n
                tot_loss += F.cross_entropy(pred,yb).item()*n
                tot_acc  += (pred.argmax(dim=1)==yb).float().sum().item()
        print(epoch, tot_loss/count, tot_acc/count)

In [None]:
fit(5, model.to(def_device), opt, train_dl, valid_dl)

: 

: 