In [1]:
%matplotlib inline
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import *

In [36]:
"""
Even the simple regression with 4 dimensions: approximation of an absolute value function.
"""

class RegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(nn.Linear(1, 2), nn.ReLU(), nn.Linear(2, 1))
            
    def forward(self, x, with_softmax=True):
        x = self.model(x)
        return x

def train(xs, ys):
    data_set = TensorDataset(xs, ys)
    data_loader = DataLoader(data_set, batch_size=100, shuffle=True)

    model = RegressionModel()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-1)

    for epoch in range(1000):
        cumulative_loss = 0.
        for inputs, expected in data_loader:
            optimizer.zero_grad()
            got = model(inputs)
            loss = criterion(got, expected)
            loss.backward()
            optimizer.step()
            cumulative_loss += loss.item()
        if epoch % 100 == 0:
            print(cumulative_loss)
    return model

xs = np.random.uniform(-5, 5, size=100).reshape((100,1))
xs = torch.tensor(xs, dtype=torch.float32, requires_grad=False)
ys = torch.tensor(np.abs(xs), dtype=torch.float32, requires_grad=False)
model = train(xs, ys)

6.0348711013793945
0.011241880245506763
0.0015373800415545702
0.0004480845818761736
0.00014296278823167086
4.1011255234479904e-05
1.1379737770766951e-05
4.2202777876809705e-06
1.7158727132482454e-06
6.378023726938409e-07


In [38]:
def predict(model, xs):
    xs = torch.tensor(xs, dtype=torch.float32, requires_grad=False).unsqueeze(dim=-1)
    ys = model(xs, with_softmax=True)
    return ys

predict(model, [-10, -5, -1, 0, 1, 5, 10])

tensor([[10.0002],
        [ 5.0001],
        [ 0.9999],
        [ 0.0276],
        [ 0.9999],
        [ 5.0001],
        [10.0002]], grad_fn=<ThAddmmBackward>)