In [None]:
import torch

class PointDataset(torch.utils.data.Dataset):
    def __init__(self, filename):
        self.data = []
        
        with open(filename, 'r') as f:
            for line in f:
                x, y = line.split(" ")
                x, y = float(x), float(y)
                self.data.append((x, y))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]


ds = PointDataset("../dataset1.txt")

In [None]:
class LineModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.nn.Parameter(torch.rand(1))
    
    def forward(self, x):
        return self.w * x
    

model = LineModule()
print(list(model.parameters()))
print(model(torch.tensor([2.0])))

In [None]:
from tqdm import trange

ds = PointDataset("dataset1.txt")
model = LineModule()
dl = torch.utils.data.DataLoader(ds, batch_size=8)

optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

for epoch in trange(1000):
    for batch in dl:
        x, y = batch
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
print(model.w)

In [None]:
data = torch.stack([torch.tensor(i) for i in ds.data])

ww = torch.arange(-10, 10, step=.1)

errors = []
for w in ww:
    preds = w*data[:, 0]
    error = loss_fn(preds, data[:, 1])
    errors.append(error)


In [None]:
from matplotlib import pyplot as plt


plt.plot(ww, errors)

In [None]:
from matplotlib import pyplot as plt

xx = torch.arange(-10, 10, step=0.1)
plt.scatter(data[:, 0], data[:, 1])

plt.plot(xx, model(xx).detach().numpy())