In [1]:
import torch
import torch.nn as nn
from PIL import Image
from pathlib import Path
import numpy as np
from matplotlib import pyplot as plt

import os
import math

import torch.nn.functional as F

In [2]:
PATH= Path('data/mnist/')

In [3]:
kernel = np.array([-1,1])

In [37]:
def conv_batch(img, kernel):
    
    img = img.view(-1,28,28)
    out = torch.zeros(img.shape)
    
    for i in range(img.shape[0]):
        img_temp = np.pad(img[i],[(0, 0), (0, 1)],"edge")
        for j in range(img.shape[1]):

            for k in range(img.shape[2]-1):
                out[i][j][k] = abs((img_temp[j][k:k+2] * kernel).sum())
    
    return out.view(-1, 784)
def _get_files(p, fs, extensions = None):
    p = Path(p) # to support / notation
    res = [p/f for f in fs if not f.startswith(".") 
           and ((not extensions) or f'.{f.split(".")[-1].lower()}' in extensions)]
    return res
def log_softmax(x): 
    return (x.exp()/(x.exp().sum(-1,keepdim=True)) + 1e-20).log()

def validation_acc(model):
    return torch.stack([accuracy(model(xb), yb) for xb, yb in valid_dl]).mean().item()
def nll(preds, actuals): 
    return -preds[range(actuals.shape[0]), actuals].mean()
def accuracy(preds, yb): 
    return (torch.argmax(preds, dim=1, keepdim = True)==yb).float().mean()
def loss_func(preds, targets):
    preds = log_softmax(preds)
    return nll(preds, targets)
def train(model, train_dl, epochs=5, valid_epoch=5):
    for epoch in range(epochs):
        for xb, yb in train_dl:
            
            preds = model(xb)
            loss = loss_func(preds, yb.squeeze())
            loss.backward()
            optim.step()
            optim.zero_grad()
        
        if epoch % valid_epoch == 0:
            print(validation_acc(model))
            
class Dataset():
    def __init__(self, x, y): 
        self.x,self.y = x,y
    def __len__(self): 
        return len(self.x)
    def __getitem__(self, i): 
        return conv_batch(self.x[i], kernel),self.y[i]


class DataLoader():
    def __init__(self, ds, bs): 
        self.ds, self.bs = ds, bs
    def __iter__(self):
        n = len(self.ds)
        l = torch.randperm(n)

        
        for i in range(0, n, self.bs): 
            idxs_l = l[i:i+self.bs]
            yield self.ds[idxs_l]
            
class Func(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x): 
        return self.func(x)
    
def flatten(x):      
    return x.view(x.shape[0], -1)

def print_t(x):      
    print(x.shape)
    return x

In [38]:
def conv(img, kernel):
    
    out = np.zeros(img.shape)
    img = np.pad(img,[(0, 0), (0, 1)],"edge")
    
    for i in range(img.shape[0]):
    
        for j in range(img.shape[1]-1):
            out[i][j] = abs((img[i][j:j+2] * kernel).sum())
    
    return out

In [39]:
def create_ds_from_file(src):
    imgs, labels = [], []
    kernel = np.array([-1,1])
    for label in range(10):
        path = src/str(label)
        print(path)
        t = [o.name for o in os.scandir(path)]
        t = _get_files(path, t, extensions = [".jpg", ".png"])
        for e in t:
            img = np.array(Image.open(e))
            l = [np.concatenate((conv(img, kernel).reshape(-1), img.reshape(-1)))]
            imgs += l
        labels += ([label] * len(t))
    return torch.tensor(imgs,  dtype=torch.float32), torch.tensor(labels, dtype=torch.long).view(-1,1)

In [40]:
trn_raw, trn_conved = create_ds_from_file(PATH/"train")

data/mnist/train/0
data/mnist/train/1
data/mnist/train/2
data/mnist/train/3
data/mnist/train/4
data/mnist/train/5
data/mnist/train/6
data/mnist/train/7
data/mnist/train/8
data/mnist/train/9


In [41]:
trn_raw = (trn_raw-trn_raw.float().mean())/trn_raw.float().std()
trn_conved=(trn_conved-trn_conved.float().mean())/trn_conved.float().std()

In [42]:
print(trn_raw.mean(),trn_raw.std())
print(trn_conved.mean(),trn_conved.std())

tensor(-0.0001) tensor(1.)
tensor(6.4284e-07) tensor(1.)


In [43]:
def create_ds_from_file(src):
    imgs, labels = [], []
    
    for label in range(10):
        path = src/str(label)
        print(path)
        t = [o.name for o in os.scandir(path)]
        t = _get_files(path, t, extensions = [".jpg", ".png"])
        for e in t:
            l = [np.array(Image.open(e)).reshape(28*28)]
            imgs += l
        labels += ([label] * len(t))
    return torch.tensor(imgs,  dtype=torch.float32), torch.tensor(labels, dtype=torch.long).view(-1,1)

In [44]:
trn_x, trn_y = create_ds_from_file(PATH/"train")

data/mnist/train/0
data/mnist/train/1
data/mnist/train/2
data/mnist/train/3
data/mnist/train/4
data/mnist/train/5
data/mnist/train/6
data/mnist/train/7
data/mnist/train/8
data/mnist/train/9


In [45]:
val_x,val_y = create_ds_from_file(PATH/"validation")

data/mnist/validation/0
data/mnist/validation/1
data/mnist/validation/2
data/mnist/validation/3
data/mnist/validation/4
data/mnist/validation/5
data/mnist/validation/6
data/mnist/validation/7
data/mnist/validation/8
data/mnist/validation/9


In [46]:
mean = trn_x.mean()
std = trn_x.std()

trn_x=(trn_x-mean)/std
val_x = (val_x-mean)/std

In [47]:
class Dataset():
    def __init__(self, x, y): 
        self.x,self.y = x,y
    def __len__(self): 
        return len(self.x)
    def __getitem__(self, i): 
        return self.x[i].view(-1,1,28,28),self.y[i]

class DataLoader():
    def __init__(self, ds, bs): 
        self.ds, self.bs = ds, bs
    def __iter__(self):
        n = len(self.ds)
        l = torch.randperm(n)

        
        for i in range(0, n, self.bs): 
            idxs_l = l[i:i+self.bs]
            yield self.ds[idxs_l]

In [48]:
train_ds = Dataset(trn_x, trn_y)
valid_ds = Dataset(val_x,val_y)
train_dl = DataLoader(train_ds, 256)
valid_dl = DataLoader(valid_ds, 256)

In [49]:
class Func(nn.Module):
    def __init__(self, func):
        super().__init__()
        self.func = func

    def forward(self, x): 
        return self.func(x)
def flatten(x):      
    return x.view(x.shape[0], -1)

In [50]:
model = nn.Sequential(
        nn.Conv2d(1, 8, 5, padding=2,stride=2), nn.ReLU(), #14
        nn.MaxPool2d(5, stride=2),
        nn.BatchNorm2d(8),
        nn.Conv2d(8, 16, 3, padding=1,stride=2), nn.ReLU(), # 7
        nn.BatchNorm2d(16),
        nn.Conv2d(16, 32, 3, padding=1,stride=2), nn.ReLU(), # 4
        nn.BatchNorm2d(32),
        nn.Conv2d(32, 32, 3, padding=1,stride=2), nn.ReLU(), # 2
        nn.BatchNorm2d(32),
        #Func(print_t),
        nn.AdaptiveAvgPool2d(1),
        Func(flatten),
        nn.Linear(32,10)
)

In [51]:
optim = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-3)

In [52]:
train(model,train_dl,10)

0.8693677186965942
0.9527616500854492
