In [31]:
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils import data
from tqdm import tqdm

In [5]:
def data_generation(true_w, true_b, n):
    x = torch.normal(0, 1, (n, len(true_w)))
    y = torch.matmul(x, true_w) + true_b
    y += torch.normal(0, 0.01, y.shape)
    return x, y.reshape(-1,1)
true_w = torch.tensor([4.2, -3.0])
true_b = torch.tensor([2.0])
n = 1000

In [6]:
train_data, labels = data_generation(true_w, true_b, 1000)

In [15]:
dataset = data.TensorDataset(*(train_data, labels))
data_iter = data.DataLoader(dataset, 10, shuffle=True, num_workers=4)

In [18]:
next(iter(data_iter))

[tensor([[ 0.0220, -0.0790],
         [-0.1159, -0.3688],
         [-0.5426, -0.4635],
         [-1.2399, -0.5248],
         [-0.3826, -0.5305],
         [ 0.9733, -1.3832],
         [ 0.0731, -0.9724],
         [ 0.3823,  0.1566],
         [ 0.0750, -1.8513],
         [ 0.1444, -0.4293]]),
 tensor([[ 2.3363],
         [ 2.6395],
         [ 1.0963],
         [-1.6404],
         [ 1.9767],
         [10.2361],
         [ 5.2322],
         [ 3.1194],
         [ 7.8636],
         [ 3.8911]])]

In [30]:
net = nn.Sequential(nn.Linear(2, 1))
net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

tensor([0.])

In [32]:
loss = nn.MSELoss()
optimizer = optim.SGD(net.parameters(), lr=0.03)

In [33]:
epochs = 10
for epoch in range(epochs):
    for x, y in tqdm(data_iter):
        l = loss(net(x), y)
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    l = loss(net(train_data), labels)
    print(f'epoch: {epoch+1}, loss: {l:f}')

100%|██████████| 100/100 [00:01<00:00, 77.14it/s]


epoch: 1, loss: 0.000194


100%|██████████| 100/100 [00:00<00:00, 104.90it/s]


epoch: 2, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 106.39it/s]


epoch: 3, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 100.10it/s]


epoch: 4, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 108.96it/s]


epoch: 5, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 107.03it/s]


epoch: 6, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 105.19it/s]


epoch: 7, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 101.04it/s]


epoch: 8, loss: 0.000106


100%|██████████| 100/100 [00:00<00:00, 102.54it/s]


epoch: 9, loss: 0.000107


100%|██████████| 100/100 [00:00<00:00, 103.75it/s]

epoch: 10, loss: 0.000107



