In [None]:
import torch
import torchvision

In [None]:
dataset = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=torchvision.transforms.ToTensor())


In [None]:
train_ds, val_ds = torch.utils.data.random_split(dataset, [50000, 10000])

print(len(train_ds), len(val_ds))

In [None]:
from torch.utils.data.dataloader import DataLoader

batch_size = 100

train_dl = DataLoader(train_ds, batch_size)
val_dl = DataLoader(val_ds, batch_size)

In [None]:
input_size = 28*28
num_classes = 10

class MnistModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(input_size, num_classes)
        
    def forward(self, image):
        image = image.reshape(-1,784)
        return self.linear(image)
    
    def training_step(self, batch):
        images, labels = batch 
        out = self(images)
        loss = torch.nn.functional.cross_entropy(out, labels) 
        return loss

    def validation_step(self, batch):
        images, labels = batch 
        out = self(images)                    
        loss = torch.nn.functional.cross_entropy(out, labels)   
        acc = accuracy(out, labels)           
        return {'loss': loss, 'accuracy': acc}
    
    def validation_epoch_end(self, outputs):
        batch_losses = [x['loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        batch_accuracy = [x['accuracy'] for x in outputs]
        epoch_accuracy = torch.stack(batch_accuracy).mean()
        return {'loss': epoch_loss.item(), 'accuracy': epoch_accuracy.item()}
        
    def epoch_end(self, epoch, result):
        print("Epoch [{}], loss: {:.4f}, accuracy: {:.4f}".format(epoch, result['loss'], result['accuracy']))
        
model = MnistModel()    

print(model.linear.weight.shape,model.linear.bias.shape)
list(model.parameters())

In [None]:
for images, label in train_dl:
    print(images.shape)
    outputs = model(images)
    break

In [None]:
def accuracy(outputs, label):
    preds = torch.max(outputs, dim=1)

    return (sum(preds[1] == label)/ len(label))

print(accuracy(outputs,label))

In [None]:
def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    optimizer = opt_func(model.parameters(), lr)
    history = [] 
    
    for epoch in range(epochs):
        
        for batch in train_loader:
            loss = model.training_step(batch)
            
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
        
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result)
        history.append(result)

    return history

In [None]:
def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

In [None]:
result0 = evaluate(model, val_dl)
print(result0)

In [None]:
history1 = fit(3, 0.001, model, train_dl, val_dl)

In [192]:
from PIL import Image
test_dataset = torchvision.datasets.MNIST(root="./data", train=False)

x=9


img, label = test_dataset[x]
img.resize((280,280)).show()

test_dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=torchvision.transforms.ToTensor())

def predict_image(img, model):
    xb = img.unsqueeze(0)
    yb = model(xb)
    _, preds = torch.max(yb, dim=1)
    return preds[0].item()

img, label = test_dataset[x]
predict_image(img, model)


1