In [15]:
import torch
from torch.nn import Linear

In [16]:
torch.manual_seed(1)

model = Linear(in_features=2, out_features=1)

list(model.parameters())

[Parameter containing:
 tensor([[ 0.3643, -0.3121]], requires_grad=True),
 Parameter containing:
 tensor([-0.1371], requires_grad=True)]

In [17]:
model.state_dict()

OrderedDict([('weight', tensor([[ 0.3643, -0.3121]])),
             ('bias', tensor([-0.1371]))])

In [18]:
X = torch.tensor([[1., 3.], [1., 1.]])
y_hat = model(X)
y_hat

tensor([[-0.7090],
        [-0.0848]], grad_fn=<AddmmBackward0>)

In [19]:
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset

In [23]:
class LR(nn.Module):
    def __init__(self, input_size, output_size):
        super(LR, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        return self.linear(x)


class Data2D(Dataset):
    def  __init__(self):
        self.x = torch.zeros(20,2)
        self.x[:, 0] = torch.arange(-1, 1, 0.1)
        self.x[:, 1] = torch.arange(-1, 1, 0.1)
        self.w = torch.tensor([[1.], [1.]])
        self.b = 1
        self.f = torch.mm(self.x, self.w) + self.b
        self.y = self.f + 0.1 * torch.randn((self.x.shape[0], 1))
        self.len = self.x.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return self.len


In [24]:
dataset = Data2D()
criterion = nn.MSELoss()
train_loader = DataLoader(dataset=dataset, batch_size=2)

In [25]:
model = LR(input_size=2, output_size=1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [26]:
LOSS = []
for epoch in range(100):
    for x, y in train_loader:
        y_hat = model(x)
        loss = criterion(y_hat, y)
        LOSS.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [27]:
LOSS

[0.9535412788391113,
 0.2333698272705078,
 0.07147691398859024,
 0.04732911288738251,
 0.15460875630378723,
 0.45353707671165466,
 1.0778834819793701,
 1.7158377170562744,
 2.663370132446289,
 3.0941567420959473,
 0.8181012868881226,
 0.20380404591560364,
 0.07171136885881424,
 0.02975964918732643,
 0.1007949486374855,
 0.31105196475982666,
 0.7882473468780518,
 1.2706594467163086,
 2.009281635284424,
 2.3011746406555176,
 0.6959224343299866,
 0.1747283935546875,
 0.06948194652795792,
 0.01920635998249054,
 0.06648961454629898,
 0.21267639100551605,
 0.5785266160964966,
 0.944300651550293,
 1.5214455127716064,
 1.7114471197128296,
 0.5880515575408936,
 0.14763307571411133,
 0.06552216410636902,
 0.012876130640506744,
 0.044809624552726746,
 0.14481300115585327,
 0.42622315883636475,
 0.7045809626579285,
 1.1566627025604248,
 1.27256178855896,
 0.4943860173225403,
 0.1233564168214798,
 0.06051657348871231,
 0.00909128412604332,
 0.03126085549592972,
 0.09807108342647552,
 0.315287142992