In [None]:
import torch
from torch.nn.functional import relu
import torchvision
import torchvision.transforms as transforms

class SimpleNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(SimpleNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        return self.linear2(relu(self.linear1(x)))

f = SimpleNet(30, 20, 1)

losses = []
optimizer = torch.optim.Adam(f.parameters(), lr=5e-3)

batch_size = 200
for epoch in range(1000):
    # Take some data
    x = torch.randn(batch_size, 30)

    # Calculate result of network
    y = f(x)

    # Compare with true result
    true_y = torch.mean(x, dim=1)[:, None]

    # Calculate loss
    loss = torch.sum((y - true_y)**2)

    # Calculate gradients based on loss
    loss.backward()
    losses.append(loss.detach().cpu().numpy())

    # Update weights
    optimizer.step()

    # Zero gradients
    optimizer.zero_grad()

# Plot
import matplotlib.pyplot as plt
plt.plot(losses)
plt.gca().set_yscale('log')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()