In [10]:
import torch
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms import Resize, Compose
from torch.nn.utils import clip_grad_norm_
from threading import Thread, Event
from plotly import graph_objects as go
from plotly.subplots import make_subplots

In [11]:
class StoppableThread(Thread):
    def __init__(self,  *args, **kwargs):
        super(StoppableThread, self).__init__(*args, **kwargs)
        self._stop_event = Event()

    def stop(self):
        self._stop_event.set()

    def stopped(self):
        return self._stop_event.is_set()

In [12]:
class MyData:
    def __init__(self):
        self.train = FashionMNIST(root="./data_torch", train=True, transform=Resize((28*28)))
        self.val = FashionMNIST(root="./data_torch", train=False)

    def data_loader(self, train, batch):
        data = self.train if train else self.val
        X = data.data.view(data.data.shape[0], 28*28)
        return DataLoader(TensorDataset(data.data.view(data.data.shape[0], 28*28).type(torch.float64), data.targets), batch_size=batch, shuffle=train)


In [13]:
class MyModel:
    def __init__(self, input_size, output_size, lr, mean=0, std=0.01):
        self.lr = lr
        self.W = torch.normal(mean=mean, std=std, size=(input_size, output_size), requires_grad=True, dtype=torch.float64)
        self.b = torch.zeros(size=[output_size], requires_grad=True)

    def softmax(self, X):
        denominator = torch.sum(torch.exp(X), dim=1, keepdim=True)
        numerator = torch.exp(X)
        return numerator / denominator

    def forward(self, X):
        return self.softmax(X @ self.W + self.b)

    def loss(self, y, y_hat):
        return -1 * torch.log(y_hat[list(range(y_hat.shape[0])), y]).mean()

    def predict(self, X):
        y_prob = self.forward(X)
        return torch.argmax(y_prob, dim=1).type(torch.int64)

    def zero_grad(self):
        if self.W.grad is not None:
            self.W.grad.zero_()
        if self.b.grad is not None:
            self.b.grad.zero_()

    def accuracy(self, y, y_hat):
        eq = y == y_hat
        return torch.sum(eq) / y.shape[0]

    def apply_grad(self):
        self.W -= self.W.grad * self.lr
        self.b -= self.b.grad * self.lr

In [14]:
class MyTrainer:
    def __init__(self, epochs, batch_size):
        self.train_errors = []
        self.train_accuracies = []
        self.val_errors = []
        self.val_accuracies = []
        self.epochs = epochs
        self.epoch = 0
        self.batch = batch_size

    def fit(self, data: MyData, model: MyModel):
        for self.epoch in range(self.epochs):
            self.train_errors.append(0.0)
            self.val_errors.append(0.0)
            for X, y in data.data_loader(train=True, batch=self.batch):
                model.zero_grad()
                loss = model.loss(y, model.forward(X))
                with torch.no_grad():
                    self.train_errors[-1] += loss.numpy()
                loss.backward()
                clip_grad_norm_([model.W, model.b], 1.0)
                with torch.no_grad():

                    model.apply_grad()
            with torch.no_grad():
                train_accuracy = model.accuracy(data.train.targets, model.predict(data.train.data.view(60000, 28*28).type(torch.float64)))
                self.train_accuracies.append(train_accuracy.numpy())
                val_accuracy = model.accuracy(data.val.targets, model.predict(data.val.data.view(10000, 28*28).type(torch.float64)))
                self.val_accuracies.append(val_accuracy.numpy())
            with torch.no_grad():
                for X, y in data.data_loader(train=False, batch=self.batch):
                    loss = model.loss(y, model.forward(X))
                    self.val_errors[-1] += loss
            print(f"epoch: {self.epoch}")

    def print_errors(self):
        x = [i for i in range(trainer.epochs)]
        fig = make_subplots(rows=1, cols=2)
        fig.add_trace(go.Scatter(x=x, y=self.train_errors, mode="lines", name="train error"), row=1, col=1)
        fig.add_trace(go.Scatter(x=x, y=self.val_errors, mode='lines', name="validation error"), row=1, col=1)
        fig.add_trace(go.Scatter(x=x, y=self.train_accuracies, mode="lines", name="train accuracy"), row=1, col=2)
        fig.add_trace(go.Scatter(x=x, y=self.val_accuracies, mode="lines", name="val accuracy"), row=1, col=2)
        fig.show()

In [15]:
int_to_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
          'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']

In [16]:
data = MyData()
model = MyModel(28 * 28, 10, 0.01)
trainer = MyTrainer(10, 10)

In [17]:
trainer.fit(data, model)

epoch: 0
epoch: 1
epoch: 2
epoch: 3
epoch: 4
epoch: 5
epoch: 6
epoch: 7
epoch: 8
epoch: 9


In [18]:
trainer.print_errors()