In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

In [None]:
dataset=FashionMNIST(root='data/',download=True,transform=ToTensor())
test_dataset=FashionMNIST(root='data/',train=False,transform=ToTensor())

In [None]:
val_size=10000
train_size=len(dataset)-val_size
train_ds,val_ds=random_split(dataset,[train_size,val_size])
len(train_ds),len(val_ds)

In [None]:
batch_size=128

In [None]:
train_loader=DataLoader(train_ds,batch_size,shuffle=True)
val_loader=DataLoader(val_ds,batch_size*2)
test_loader=DataLoader(test_dataset,batch_size*2)

In [None]:
for images,_ in train_loader:
  print(images.shape)
  plt.figure(figsize=(16,8))
  plt.axis('off')
  plt.imshow(make_grid(images,nrow=16).permute((1,2,0)))
  print(make_grid(images,nrow=16).shape)
  break

In [None]:
def accuracy(outputs,labels):
  _,preds=torch.max(outputs,dim=1)
  return torch.tensor(torch.sum(preds==labels).item()/len(preds))

In [None]:
class MnistModel(nn.Module):
  def __init__(self,in_size,out_size):
    super().__init__()
    self.l1=nn.Linear(in_size,16)
    self.l2=nn.Linear(16,32)
    self.l3=nn.Linear(32,out_size)
  def forward(self,xb):
    out=xb.view(xb.size(0),-1)
    out=self.l1(out)
    out=F.relu(out)
    out=self.l2(out)
    out=F.relu(out)
    out=self.l3(out)
    return out
  def training_step(self,batch):
    images,labels=batch
    out=self(images)
    loss=F.cross_entropy(out,labels)
    return loss
  def validation_step(self,batch):
    images,labels=batch
    out=self(images)
    loss=F.cross_entropy(out,labels)
    acc=accuracy(out,labels)
    return {'val_loss':loss,'val_acc':acc}
  def validation_epoch_end(self,outputs):
    batch_losses=[x['val_loss'] for x in outputs]
    epoch_loss=torch.stack(batch_losses).mean()
    batch_acc=[x['val_acc'] for x in outputs]
    epoch_acc=torch.stack(batch_acc).mean()
    return {'val_loss':epoch_loss.item(),'val_acc':epoch_acc.item()}
  def epoch_end(self,epoch,result):
    print("Epoch[{}] ,val_loss: {:.4f},val_acc:{:.4f}".format(epoch,result['val_loss'],result['val_acc']))




In [None]:
def evaluate(model,val_loader):
  outputs=[model.validation_step(batch) for batch in val_loader]
  return model.validation_epoch_end(outputs)

def fit(epochs,lr,model,train_loader,val_loader,opt_func=torch.optim.SGD):
  history=[]
  optimizer=opt_func(model.parameters(),lr)
  for epoch in range(epochs):
    for batch in train_loader:
      loss=model.training_step(batch)
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
    result=evaluate(model,val_loader)
    model.epoch_end(epoch,result)
    history.append(result)
  return history

In [None]:
input_size=784
num_classes=10

In [None]:
model=MnistModel(input_size,num_classes)

In [None]:
history=[evaluate(model,val_loader)]
history

In [None]:
history+=fit(2,0.2,model,train_loader,val_loader)

In [None]:
def predict_img(img,model):
  xb=img.unsqueeze(0)
  yb=model(xb)
  _,preds=torch.max(yb,dim=1)
  print(preds)
  return preds[0].item()

In [None]:
img,label=test_dataset[7843]
plt.imshow(img[0],cmap='gray')
pred=predict_img(img,model)
print('label: ',dataset.classes[label],',predicted: ',dataset.classes[pred])

In [None]:
evaluate(model,test_loader)