In [1]:
import torch
from plotly import graph_objects as go
from torch.utils.data import TensorDataset, DataLoader

In [2]:
class SyntheticData:
    def __init__(self, w, b, variation, num_train, num_validation, batch_size):
        self.batch_size = batch_size
        self.num_train = num_train
        self.num_validation = num_validation
        self.num = num_train + num_validation
        self.X = torch.randn(self.num, len(w))
        self.y = (self.X @ w + b) + (torch.randn(self.num) * variation)

In [3]:
class Model(torch.nn.Module):
    def __init__(self, input_size, lr, decay):
        super().__init__()
        self.decay = torch.tensor(decay)
        self.net = torch.nn.Linear(input_size, 1)
        self.opt = torch.optim.SGD(self.parameters(), lr=lr)

    def forward(self, X):
        return self.net(X)

    def loss(self, y, y_hat):
        func = torch.nn.MSELoss()
        dec = torch.tensor(0.0)
        for param in self.parameters():
            dec += torch.norm(param, 2)
        return func(y, y_hat) + dec * self.decay

In [11]:
class Trainer:
    def __init__(self):
        self.train_errors = []
        self.val_errors = []

    def fit(self, model: Model, data: SyntheticData, epochs):
        for epoch in range(epochs):
            self.train_errors.append(torch.tensor(0.0))
            self.val_errors.append(torch.tensor(0.0))
            for batch_X, batch_y in DataLoader(TensorDataset(data.X[:data.num_train], data.y[:data.num_train]), batch_size=data.batch_size):
                model.opt.zero_grad()
                loss = model.loss(batch_y, torch.reshape(model.forward(batch_X), batch_y.size()))
                self.train_errors[-1] += loss
                loss.backward()
                model.opt.step()
            with torch.no_grad():
                for batch_X, batch_y in DataLoader(TensorDataset(data.X[data.num_train:data.num], data.y[data.num_train:data.num]), batch_size=data.batch_size):
                    loss = model.loss(batch_y, torch.reshape(model.forward(batch_X), batch_y.size()))
                    self.val_errors[-1] += loss

In [12]:
sdata = SyntheticData(torch.tensor([0.1, 0.1]), torch.tensor(3.0), 0.5, 100, 100, 50)
model = Model(2, 0.1, 1)
trainer = Trainer()
trainer.fit(model, sdata, 7)

In [13]:
trainer.train_errors, trainer.val_errors

([tensor(14.0283, grad_fn=<AddBackward0>),
  tensor(9.7578, grad_fn=<AddBackward0>),
  tensor(7.7947, grad_fn=<AddBackward0>),
  tensor(6.7975, grad_fn=<AddBackward0>),
  tensor(6.3760, grad_fn=<AddBackward0>),
  tensor(6.1956, grad_fn=<AddBackward0>),
  tensor(6.1393, grad_fn=<AddBackward0>)],
 [tensor(11.8324),
  tensor(8.8980),
  tensor(7.3579),
  tensor(6.5228),
  tensor(6.3061),
  tensor(6.1905),
  tensor(6.1048)])

In [7]:
torch.nn.MSELoss()(torch.tensor([1.0, 2.0]), torch.tensor([3.0, 2.0]))

tensor(2.)