In [20]:
import torch
from torch import nn, optim

In [5]:
class LR(nn.Module):
    def __init__(self, input_size, output_size):
        super(LR, self).__init__()
        self.linear = nn.Linear(input_size, output_size)
    def forward(self, x):
        return self.linear(x)

In [6]:
torch.manual_seed(1)
model = LR(2, 2)

In [7]:
model.state_dict()

OrderedDict([('linear.weight',
              tensor([[ 0.3643, -0.3121],
                      [-0.1371,  0.3319]])),
             ('linear.bias', tensor([-0.6657,  0.4241]))])

In [10]:
X = torch.tensor([[1., 2.], [3., 4.]])
Y_hat = model(X)
Y_hat

tensor([[-0.9256,  0.9508],
        [-0.8211,  1.3404]], grad_fn=<AddmmBackward0>)

In [11]:
from torch.utils.data import DataLoader, Dataset

In [18]:
class Data2D(Dataset):
    def __init__(self):
        self.X = torch.zeros(20, 2)
        self.X[:, 0] = torch.arange(-1, 1, 0.1)
        self.X[:, 1] = torch.arange(-1, 1, 0.1)
        self.W = torch.tensor([[1., -1.], [1., -1.]])
        self.b = torch.tensor([[1., -1.]])
        self.f = torch.mm(self.X, self.W) + self.b
        self.Y = self.f + 0.1 * torch.randn(self.X.shape[0], 1)
        self.len = self.X.shape[0]

    def __getitem__(self, index):
        return self.X[index], self.Y[index]

    def __len__(self):
        return self.len

In [21]:
dataset = Data2D()
criterion = nn.MSELoss()
train_loader = DataLoader(dataset, batch_size=2)
model = LR(2, 2)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [23]:
for epoch in range(10):
    for x, y in train_loader:
        y_hat = model(x)
        loss = criterion(y_hat, y)
        print(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

0.4999174177646637
0.2455206662416458
0.0427880734205246
0.011138588190078735
0.11771691590547562
0.2767329514026642
0.5278823971748352
0.8770573735237122
1.3152861595153809
1.8173129558563232
0.4517742991447449
0.22386020421981812
0.040041230618953705
0.00846117828041315
0.098882295191288
0.23446592688560486
0.45234084129333496
0.7552576661109924
1.1376384496688843
1.577116847038269
0.40780729055404663
0.2039402723312378
0.03758663311600685
0.00668607372790575
0.08337852358818054
0.19890975952148438
0.38807404041290283
0.6507617831230164
0.9844090938568115
1.3691763877868652
0.36771878600120544
0.18563878536224365
0.03535854443907738
0.005580506753176451
0.07060354202985764
0.16897466778755188
0.33336520195007324
0.561063289642334
0.85218346118927
1.1890952587127686
0.3312242031097412
0.16884349286556244
0.033311255276203156
0.004965678788721561
0.06006542220711708
0.14375115931034088
0.28676387667655945
0.4840259552001953
0.7380326986312866
1.0330859422683716
0.2980530858039856
0.153