In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from tqdm import tqdm

In [2]:
class Line(nn.Module):
    def __init__(self):
        super().__init__()
        # The parameters of a straight line
        self.a = nn.Parameter(torch.tensor(1.0))
        self.b = nn.Parameter(torch.tensor(1.0))
    def forward(self, x):
        return self.a + self.b * x

class Regression(nn.Module):
    def __init__(self):
        super().__init__()
        self.l = Line()
        # For noise
        self.sig = nn.Parameter(torch.tensor(1.0))
    def forward(self, x, y):
        pred = self.l(x)
        return -0.5 * torch.log(2 * np.pi * self.sig ** 2) - ((y - pred) ** 2) / (2 * self.sig ** 2)


In [3]:
def LinearRegression(x, y, epochs=5):
    model = Regression()
    optimizer = optim.Adam(model.parameters())
    
    for epoch in tqdm(range(epochs), desc="Training..."):
        for i in range(len(x)):
            # Zeros gradiant for training
            optimizer.zero_grad()
            
            # Calculates likelihood
            loglik = model(x[i], y[i])
            e = -torch.mean(loglik)
            
            # Updates parameters
            e.backward()
            optimizer.step()
        
    return model, model.l.a, model.l.b, model.sig

In [4]:
df = pd.read_csv('data/Advertising.csv').drop('Unnamed: 0', axis=1)
df

Unnamed: 0,TV,Radio,Newspaper,Sales
0,230.1,37.8,69.2,22.1
1,44.5,39.3,45.1,10.4
2,17.2,45.9,69.3,9.3
3,151.5,41.3,58.5,18.5
4,180.8,10.8,58.4,12.9
...,...,...,...,...
195,38.2,3.7,13.8,7.6
196,94.2,4.9,8.1,9.7
197,177.0,9.3,6.4,12.8
198,283.6,42.0,66.2,25.5


In [5]:
x = df.TV.to_numpy()
y = df.Sales.to_numpy()

In [6]:
res = LinearRegression(x, y, epochs=1000)
res

Training...: 100%|██████████| 1000/1000 [02:35<00:00,  6.45it/s]


(Regression(
   (l): Line()
 ),
 Parameter containing:
 tensor(6.9610, requires_grad=True),
 Parameter containing:
 tensor(0.0430, requires_grad=True),
 Parameter containing:
 tensor(3.2973, requires_grad=True))