In [2]:
import torch
import torch.nn as nn
from torchview import draw_graph

In [32]:
class TimesTwo(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1, 1, bias=False)
        self.l1.weight.data.fill_(2.1)

    def forward(self, x):
        return self.l1(x)

model = TimesTwo()
ops = torch.full((10, ), 5, requires_grad=False).float()

learning_rate = 5e-3
epochs = 4

# Train
model = model.train()
for epoch in range(epochs):
    print("======= Epoch: {}".format(epoch))
    for op in ops:
        for p in model.parameters():
            # Reset gradients
            p.grad = None

        # Calulate the target
        target = op * 2
        # Forward pass
        pred = model(op.unsqueeze(0))

        loss = (pred - target).abs().squeeze(0)
        # Backward pass
        loss.backward()

        print("Pred: {:.3f} ; Target: {} ; Loss: {:.3f} ; Weight: {:.3f} ; Gradient: {:.3f}".format(\
            pred.item(), \
            target.item(), \
            loss.item(), \
            model.l1.weight.item(), \
            model.l1.weight.grad.item()))
        # Take a step
        with torch.no_grad():
            for p in model.parameters():
                p -= learning_rate * p.grad

# graph = draw_graph(model, (torch.randn(1),))
# graph.visual_graph.render('model', format='png', view=True)

Pred: 1438.500 ; Target: 1370.0 ; Loss: 68.500 ; Weight: 2.100 ; Gradient: 685.000
Pred: -470.375 ; Target: 710.0 ; Loss: 1180.375 ; Weight: -1.325 ; Gradient: -355.000
Pred: 380.700 ; Target: 1692.0 ; Loss: 1311.300 ; Weight: 0.450 ; Gradient: -846.000
Pred: 1764.360 ; Target: 754.0 ; Loss: 1010.360 ; Weight: 4.680 ; Gradient: 377.000
Pred: 2590.965 ; Target: 1854.0 ; Loss: 736.965 ; Weight: 2.795 ; Gradient: 927.000
Pred: -44.160 ; Target: 48.0 ; Loss: 92.160 ; Weight: -1.840 ; Gradient: -24.000
Pred: -1181.640 ; Target: 1374.0 ; Loss: 2555.640 ; Weight: -1.720 ; Gradient: -687.000
Pred: 104.615 ; Target: 122.0 ; Loss: 17.385 ; Weight: 1.715 ; Gradient: -61.000
Pred: 34.340 ; Target: 34.0 ; Loss: 0.340 ; Weight: 2.020 ; Gradient: 17.000
Pred: 619.200 ; Target: 640.0 ; Loss: 20.800 ; Weight: 1.935 ; Gradient: -320.000
Pred: 2067.975 ; Target: 1170.0 ; Loss: 897.975 ; Weight: 3.535 ; Gradient: 585.000
Pred: 332.450 ; Target: 1090.0 ; Loss: 757.550 ; Weight: 0.610 ; Gradient: -545.000
P

KeyboardInterrupt: 

In [60]:
import torch.nn.functional as F


class PowerTwo(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(1, 8)
        self.l2 = nn.Linear(8, 8)
        self.l3 = nn.Linear(8, 1)

    def forward(self, x):
        return self.l3(F.relu(self.l2(F.relu(self.l1(x)))))

model = PowerTwo()

learning_rate = 2e-5
epochs = 50

# Train
model = model.train()
for epoch in range(epochs):
    ops = torch.randint(50, (10_000,), requires_grad=False).float() - 25
    losses = []
    for op in ops:
        for p in model.parameters():
            # Reset gradients
            p.grad = None

        # Calulate the target
        target = op ** 2
        # Forward pass
        pred:torch.Tensor = model(op.unsqueeze(0))

        loss = (pred - target).pow(2).squeeze(0)
        losses.append(loss)
        # Backward pass
        loss.backward()
        # Take a step
        with torch.no_grad():
            for p in model.parameters():
                p -= learning_rate * p.grad

    print("Epoch: {} ; Loss: {:.3f}".format(epoch, torch.tensor(losses).mean()))
    losses = []

# graph = draw_graph(model, (torch.randn(1),))
# graph.visual_graph.render('model', format='png', view=True)

Epoch: 0 ; Loss: 63902.809
Epoch: 1 ; Loss: 48308.859
Epoch: 2 ; Loss: 40629.336
Epoch: 3 ; Loss: 36849.414
Epoch: 4 ; Loss: 35921.590
Epoch: 5 ; Loss: 35525.324
Epoch: 6 ; Loss: 35345.301
Epoch: 7 ; Loss: 34395.766
Epoch: 8 ; Loss: 34485.566
Epoch: 9 ; Loss: 34868.395


KeyboardInterrupt: 