In [1]:
import torch

# Demo from medium page

https://noobest.medium.com/the-ultimate-pytorch-hello-world-cbf8cdfdea7b

In [2]:
import torch
from torch import nn
from sklearn.metrics import r2_score


class MyMachine(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(2,5),
            nn.ReLU(),
            nn.Linear(5,1)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


def get_dataset():
        X = torch.rand((1000,2))
        x1 = X[:,0]
        x2 = X[:,1]
        y = x1 * x2
        return X, y


def train():
    model = MyMachine()
    model.train()
    X, y = get_dataset()
    NUM_EPOCHS = 1000
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-5)
    criterion = torch.nn.MSELoss(reduction='mean')

    for epoch in range(NUM_EPOCHS):
        optimizer.zero_grad()
        y_pred = model(X)
        y_pred = y_pred.reshape(1000)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        print(f'Epoch:{epoch}, Loss:{loss.item()}')
    torch.save(model.state_dict(), 'model.h5')


def test():
    model = MyMachine()
    model.load_state_dict(torch.load("model.h5"))
    model.eval()
    X, y = get_dataset()

    with torch.no_grad():
        y_pred = model(X)
        print(r2_score(y, y_pred))


train()
test()

Epoch:0, Loss:0.41638126969337463
Epoch:1, Loss:0.3656001687049866
Epoch:2, Loss:0.3196912109851837
Epoch:3, Loss:0.2783489525318146
Epoch:4, Loss:0.24124649167060852
Epoch:5, Loss:0.20809392631053925
Epoch:6, Loss:0.17856569588184357
Epoch:7, Loss:0.15246669948101044
Epoch:8, Loss:0.12950918078422546
Epoch:9, Loss:0.10944622755050659
Epoch:10, Loss:0.09201916307210922
Epoch:11, Loss:0.07696273177862167
Epoch:12, Loss:0.06411208212375641
Epoch:13, Loss:0.05329650267958641
Epoch:14, Loss:0.04435446858406067
Epoch:15, Loss:0.03715376928448677
Epoch:16, Loss:0.03154873847961426
Epoch:17, Loss:0.027386920526623726
Epoch:18, Loss:0.02451215498149395
Epoch:19, Loss:0.02276024967432022
Epoch:20, Loss:0.021956032142043114
Epoch:21, Loss:0.021919433027505875
Epoch:22, Loss:0.02246890775859356
Epoch:23, Loss:0.023426156491041183
Epoch:24, Loss:0.024620357900857925
Epoch:25, Loss:0.025896715000271797
Epoch:26, Loss:0.027122478932142258
Epoch:27, Loss:0.028190158307552338
Epoch:28, Loss:0.02902175