In [66]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

In [71]:
class LinearRegression(nn.Module):
    def __init__(self):
        super(LinearRegression, self).__init__()
        self.fc = nn.Linear(1, 1)
        
    def forward(self, x):
        output = self.fc(x)
        return output

In [72]:
class PointDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.frame = pd.read_csv(csv_file, encoding='utf-8', header=0)
        self.transform = transform
        
    def __len__(self):
        return len(self.frame)
    
    def __getitem__(self, idx):
        x = np.array(self.frame.iloc[idx, 0])
        y = np.array(self.frame.iloc[idx, 1])
        if self.transform is not None:
            x = self.transform(x).float()
            y = self.transform(y).float()
        sample = {'x':x,'y':y}
        return sample

In [75]:
def train(network, trainloader, device, optimeter):
    epoches = 1000
    criterion = torch.nn.MSELoss().to(device)
    optimizer = optimeter
    for i in range(epoches):
        print('epoche', i)
        for j, data in enumerate(trainloader):
            x = data['x'].to(device)
            x = x.view(x.size(0), -1)
            y = data['y'].to(device)
            y = y.view(y.size(0), -1)
            print(x.shape, y.shape)
            output = network(x)
            loss = criterion(output, y)
            print('loss', loss.item())
            optimeter.zero_grad()
            
            loss.backward()
            optimizer.step()
            
            
    

In [76]:
transform = torch.tensor
device = torch.device('cuda:0')
trainset = PointDataset('./Salary_Data.csv', transform=transform)
trainloader = DataLoader(dataset=trainset,batch_size=len(trainset))
network = LinearRegression().to(device)

optimizer = torch.optim.SGD(network.parameters(), lr=0.01)

train(network, trainloader, device, optimizer)

epoche 0
torch.Size([30, 1]) torch.Size([30, 1])
loss 6503549440.0
epoche 1
torch.Size([30, 1]) torch.Size([30, 1])
loss 595521408.0
epoche 2
torch.Size([30, 1]) torch.Size([30, 1])
loss 183117776.0
epoche 3
torch.Size([30, 1]) torch.Size([30, 1])
loss 153447008.0
epoche 4
torch.Size([30, 1]) torch.Size([30, 1])
loss 150438432.0
epoche 5
torch.Size([30, 1]) torch.Size([30, 1])
loss 149294624.0
epoche 6
torch.Size([30, 1]) torch.Size([30, 1])
loss 148288576.0
epoche 7
torch.Size([30, 1]) torch.Size([30, 1])
loss 147299904.0
epoche 8
torch.Size([30, 1]) torch.Size([30, 1])
loss 146320208.0
epoche 9
torch.Size([30, 1]) torch.Size([30, 1])
loss 145348864.0
epoche 10
torch.Size([30, 1]) torch.Size([30, 1])
loss 144385680.0
epoche 11
torch.Size([30, 1]) torch.Size([30, 1])
loss 143430656.0
epoche 12
torch.Size([30, 1]) torch.Size([30, 1])
loss 142483680.0
epoche 13
torch.Size([30, 1]) torch.Size([30, 1])
loss 141544688.0
epoche 14
torch.Size([30, 1]) torch.Size([30, 1])
loss 140613632.0
epoc

torch.Size([30, 1]) torch.Size([30, 1])
loss 69481240.0
epoche 139
torch.Size([30, 1]) torch.Size([30, 1])
loss 69158640.0
epoche 140
torch.Size([30, 1]) torch.Size([30, 1])
loss 68838744.0
epoche 141
torch.Size([30, 1]) torch.Size([30, 1])
loss 68521552.0
epoche 142
torch.Size([30, 1]) torch.Size([30, 1])
loss 68207056.0
epoche 143
torch.Size([30, 1]) torch.Size([30, 1])
loss 67895192.0
epoche 144
torch.Size([30, 1]) torch.Size([30, 1])
loss 67585960.0
epoche 145
torch.Size([30, 1]) torch.Size([30, 1])
loss 67279376.0
epoche 146
torch.Size([30, 1]) torch.Size([30, 1])
loss 66975344.0
epoche 147
torch.Size([30, 1]) torch.Size([30, 1])
loss 66673884.0
epoche 148
torch.Size([30, 1]) torch.Size([30, 1])
loss 66374976.0
epoche 149
torch.Size([30, 1]) torch.Size([30, 1])
loss 66078600.0
epoche 150
torch.Size([30, 1]) torch.Size([30, 1])
loss 65784724.0
epoche 151
torch.Size([30, 1]) torch.Size([30, 1])
loss 65493316.0
epoche 152
torch.Size([30, 1]) torch.Size([30, 1])
loss 65204372.0
epoche

epoche 272
torch.Size([30, 1]) torch.Size([30, 1])
loss 43538212.0
epoche 273
torch.Size([30, 1]) torch.Size([30, 1])
loss 43434624.0
epoche 274
torch.Size([30, 1]) torch.Size([30, 1])
loss 43331940.0
epoche 275
torch.Size([30, 1]) torch.Size([30, 1])
loss 43230104.0
epoche 276
torch.Size([30, 1]) torch.Size([30, 1])
loss 43129128.0
epoche 277
torch.Size([30, 1]) torch.Size([30, 1])
loss 43029008.0
epoche 278
torch.Size([30, 1]) torch.Size([30, 1])
loss 42929732.0
epoche 279
torch.Size([30, 1]) torch.Size([30, 1])
loss 42831300.0
epoche 280
torch.Size([30, 1]) torch.Size([30, 1])
loss 42733704.0
epoche 281
torch.Size([30, 1]) torch.Size([30, 1])
loss 42636920.0
epoche 282
torch.Size([30, 1]) torch.Size([30, 1])
loss 42540952.0
epoche 283
torch.Size([30, 1]) torch.Size([30, 1])
loss 42445796.0
epoche 284
torch.Size([30, 1]) torch.Size([30, 1])
loss 42351444.0
epoche 285
torch.Size([30, 1]) torch.Size([30, 1])
loss 42257896.0
epoche 286
torch.Size([30, 1]) torch.Size([30, 1])
loss 421651

torch.Size([30, 1]) torch.Size([30, 1])
loss 34982364.0
epoche 414
torch.Size([30, 1]) torch.Size([30, 1])
loss 34951024.0
epoche 415
torch.Size([30, 1]) torch.Size([30, 1])
loss 34919952.0
epoche 416
torch.Size([30, 1]) torch.Size([30, 1])
loss 34889144.0
epoche 417
torch.Size([30, 1]) torch.Size([30, 1])
loss 34858596.0
epoche 418
torch.Size([30, 1]) torch.Size([30, 1])
loss 34828300.0
epoche 419
torch.Size([30, 1]) torch.Size([30, 1])
loss 34798272.0
epoche 420
torch.Size([30, 1]) torch.Size([30, 1])
loss 34768496.0
epoche 421
torch.Size([30, 1]) torch.Size([30, 1])
loss 34738960.0
epoche 422
torch.Size([30, 1]) torch.Size([30, 1])
loss 34709684.0
epoche 423
torch.Size([30, 1]) torch.Size([30, 1])
loss 34680648.0
epoche 424
torch.Size([30, 1]) torch.Size([30, 1])
loss 34651864.0
epoche 425
torch.Size([30, 1]) torch.Size([30, 1])
loss 34623316.0
epoche 426
torch.Size([30, 1]) torch.Size([30, 1])
loss 34595008.0
epoche 427
torch.Size([30, 1]) torch.Size([30, 1])
loss 34566948.0
epoche

torch.Size([30, 1]) torch.Size([30, 1])
loss 32338134.0
epoche 561
torch.Size([30, 1]) torch.Size([30, 1])
loss 32329122.0
epoche 562
torch.Size([30, 1]) torch.Size([30, 1])
loss 32320188.0
epoche 563
torch.Size([30, 1]) torch.Size([30, 1])
loss 32311328.0
epoche 564
torch.Size([30, 1]) torch.Size([30, 1])
loss 32302538.0
epoche 565
torch.Size([30, 1]) torch.Size([30, 1])
loss 32293828.0
epoche 566
torch.Size([30, 1]) torch.Size([30, 1])
loss 32285190.0
epoche 567
torch.Size([30, 1]) torch.Size([30, 1])
loss 32276632.0
epoche 568
torch.Size([30, 1]) torch.Size([30, 1])
loss 32268146.0
epoche 569
torch.Size([30, 1]) torch.Size([30, 1])
loss 32259720.0
epoche 570
torch.Size([30, 1]) torch.Size([30, 1])
loss 32251382.0
epoche 571
torch.Size([30, 1]) torch.Size([30, 1])
loss 32243094.0
epoche 572
torch.Size([30, 1]) torch.Size([30, 1])
loss 32234884.0
epoche 573
torch.Size([30, 1]) torch.Size([30, 1])
loss 32226748.0
epoche 574
torch.Size([30, 1]) torch.Size([30, 1])
loss 32218684.0
epoche

torch.Size([30, 1]) torch.Size([30, 1])
loss 31580424.0
epoche 707
torch.Size([30, 1]) torch.Size([30, 1])
loss 31577804.0
epoche 708
torch.Size([30, 1]) torch.Size([30, 1])
loss 31575210.0
epoche 709
torch.Size([30, 1]) torch.Size([30, 1])
loss 31572646.0
epoche 710
torch.Size([30, 1]) torch.Size([30, 1])
loss 31570102.0
epoche 711
torch.Size([30, 1]) torch.Size([30, 1])
loss 31567578.0
epoche 712
torch.Size([30, 1]) torch.Size([30, 1])
loss 31565066.0
epoche 713
torch.Size([30, 1]) torch.Size([30, 1])
loss 31562578.0
epoche 714
torch.Size([30, 1]) torch.Size([30, 1])
loss 31560118.0
epoche 715
torch.Size([30, 1]) torch.Size([30, 1])
loss 31557678.0
epoche 716
torch.Size([30, 1]) torch.Size([30, 1])
loss 31555260.0
epoche 717
torch.Size([30, 1]) torch.Size([30, 1])
loss 31552854.0
epoche 718
torch.Size([30, 1]) torch.Size([30, 1])
loss 31550484.0
epoche 719
torch.Size([30, 1]) torch.Size([30, 1])
loss 31548114.0
epoche 720
torch.Size([30, 1]) torch.Size([30, 1])
loss 31545784.0
epoche

loss 31364574.0
epoche 848
torch.Size([30, 1]) torch.Size([30, 1])
loss 31363788.0
epoche 849
torch.Size([30, 1]) torch.Size([30, 1])
loss 31363004.0
epoche 850
torch.Size([30, 1]) torch.Size([30, 1])
loss 31362232.0
epoche 851
torch.Size([30, 1]) torch.Size([30, 1])
loss 31361454.0
epoche 852
torch.Size([30, 1]) torch.Size([30, 1])
loss 31360692.0
epoche 853
torch.Size([30, 1]) torch.Size([30, 1])
loss 31359936.0
epoche 854
torch.Size([30, 1]) torch.Size([30, 1])
loss 31359190.0
epoche 855
torch.Size([30, 1]) torch.Size([30, 1])
loss 31358438.0
epoche 856
torch.Size([30, 1]) torch.Size([30, 1])
loss 31357700.0
epoche 857
torch.Size([30, 1]) torch.Size([30, 1])
loss 31356970.0
epoche 858
torch.Size([30, 1]) torch.Size([30, 1])
loss 31356238.0
epoche 859
torch.Size([30, 1]) torch.Size([30, 1])
loss 31355520.0
epoche 860
torch.Size([30, 1]) torch.Size([30, 1])
loss 31354814.0
epoche 861
torch.Size([30, 1]) torch.Size([30, 1])
loss 31354100.0
epoche 862
torch.Size([30, 1]) torch.Size([30,

torch.Size([30, 1]) torch.Size([30, 1])
loss 31298804.0
epoche 991
torch.Size([30, 1]) torch.Size([30, 1])
loss 31298566.0
epoche 992
torch.Size([30, 1]) torch.Size([30, 1])
loss 31298342.0
epoche 993
torch.Size([30, 1]) torch.Size([30, 1])
loss 31298108.0
epoche 994
torch.Size([30, 1]) torch.Size([30, 1])
loss 31297874.0
epoche 995
torch.Size([30, 1]) torch.Size([30, 1])
loss 31297644.0
epoche 996
torch.Size([30, 1]) torch.Size([30, 1])
loss 31297420.0
epoche 997
torch.Size([30, 1]) torch.Size([30, 1])
loss 31297198.0
epoche 998
torch.Size([30, 1]) torch.Size([30, 1])
loss 31296970.0
epoche 999
torch.Size([30, 1]) torch.Size([30, 1])
loss 31296756.0
