# nn.Module + Basic Training Loop Template

In [1]:
import torch

In [5]:
import torch.nn as nn

In [7]:
class SimpleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

In [11]:
x = torch.randn(5, 10)

In [15]:
x

tensor([[-0.8769, -1.2784, -0.9865,  0.7356,  0.9228,  0.8616,  0.9675,  0.5045,
          0.7934,  0.0200],
        [ 1.2453, -0.2592, -0.6389,  1.5470,  1.3961,  0.3868,  0.2615, -0.8619,
          0.2633, -0.3355],
        [ 0.0916, -1.9005,  0.3421, -0.2704, -0.3452,  0.2833,  0.3324, -0.4000,
         -1.2382,  0.9507],
        [-0.8899, -0.2037, -1.1679,  0.2939, -2.2016, -0.7807, -0.0649,  0.2442,
          0.9696, -1.7078],
        [-0.0347, -0.8052,  0.7615,  1.7251,  1.6015, -1.5169,  0.6730,  0.7159,
          0.2562, -1.5994]])

In [16]:
sn = SimpleNet()

In [17]:
sn.forward(x)

tensor([[-0.0168],
        [ 0.1729],
        [ 0.0036],
        [ 0.3208],
        [ 0.4885]], grad_fn=<AddmmBackward0>)

## Generic Training Loop

In [27]:
from torch.utils.data import TensorDataset, DataLoader

X = torch.randn(500, 10)
y = torch.randn(500, 1)

ds = TensorDataset(X, y)

dl = DataLoader(ds, batch_size=4, shuffle=True)


In [28]:
model = SimpleNet()
optim = torch.optim.SGD(params=model.parameters(), lr=0.001)

In [29]:
for x, y in dl:
    optim.zero_grad()
    y_pred = model.forward(x)
    
    # y = y.squeeze()
    
    loss = torch.mean((y_pred - y)**2)
    
    loss.backward()
    
    optim.step()
    
    print(loss.item())
    

3.339951515197754
1.0764541625976562
0.2505706548690796
2.0549328327178955
1.3062282800674438
2.253444194793701
1.1513415575027466
1.015817403793335
0.9929628968238831
0.3854069411754608
0.8484926819801331
4.332826137542725
0.4261985421180725
0.1397039294242859
1.1078349351882935
0.25636380910873413
0.6981172561645508
0.25205686688423157
0.8726261854171753
1.1394437551498413
0.6409541368484497
0.20757579803466797
1.5560302734375
0.4907544255256653
4.077520847320557
0.3356977105140686
1.8943662643432617
2.671379327774048
0.6920077204704285
0.4818609952926636
0.2726091146469116
1.1261212825775146
0.02798647992312908
1.6519356966018677
2.6607167720794678
0.9329882860183716
1.7732725143432617
2.2837483882904053
1.498374581336975
2.585205316543579
2.0140273571014404
1.6549054384231567
0.6530463099479675
0.17284145951271057
0.6407797336578369
1.7532594203948975
1.6897021532058716
0.8469440937042236
0.5009827613830566
0.8725787401199341
0.4125383198261261
1.7378360033035278
0.4024229049682617