In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD

import matplotlib.pyplot as plt
import seaborn as sns

In [11]:
class BasicNN(nn.Module):
    def __init__(self):
        super().__init__()

        self.w00 = nn.Parameter(torch.tensor(1.7),requires_grad=False)
        self.b00 = nn.Parameter(torch.tensor(-0.85),requires_grad=False)
        self.w01 = nn.Parameter(torch.tensor(-40.8),requires_grad=False)

        self.w10 = nn.Parameter(torch.tensor(12.6),requires_grad=False)
        self.b10 = nn.Parameter(torch.tensor(0.0),requires_grad=False)
        self.w11 = nn.Parameter(torch.tensor(2.7),requires_grad=False)

        self.bfinal = nn.Parameter(torch.tensor(-16),requires_grad=False)
    
    def forward(self,input):
        input_to_relu1 = input*self.w00 + self.b00
        relu1_out = F.relu(input_to_relu1)
        scaled_relu1_out = relu1_out*self.w01

        input_to_relu2 = input*self.w10 + self.b10
        relu2_out = F.relu(input_to_relu2)
        scaled_relu2_out = relu2_out*self.w11

        input_to_final_relu = scaled_relu1_out + scaled_relu2_out + self.bfinal

        output = F.relu(input_to_final_relu)
        return output

In [28]:
class BasicNN_train(nn.Module):
    def __init__(self):
        super().__init__()

        self.w00 = nn.Parameter(torch.tensor(1.7),requires_grad=False)
        self.b00 = nn.Parameter(torch.tensor(-0.85),requires_grad=False)
        self.w01 = nn.Parameter(torch.tensor(-40.8),requires_grad=False)

        self.w10 = nn.Parameter(torch.tensor(12.6),requires_grad=False)
        self.b10 = nn.Parameter(torch.tensor(0.0),requires_grad=False)
        self.w11 = nn.Parameter(torch.tensor(2.7),requires_grad=False)

        self.bfinal = nn.Parameter(torch.tensor(0.0),requires_grad=True)
    
    def forward(self,input):
        input_to_relu1 = input*self.w00 + self.b00
        relu1_out = F.relu(input_to_relu1)
        scaled_relu1_out = relu1_out*self.w01

        input_to_relu2 = input*self.w10 + self.b10
        relu2_out = F.relu(input_to_relu2)
        scaled_relu2_out = relu2_out*self.w11

        input_to_final_relu = scaled_relu1_out + scaled_relu2_out + self.bfinal

        output = F.relu(input_to_final_relu)
        return output

In [29]:
input_ = torch.linspace(0,1,11)

In [48]:
model = BasicNN_train()
output_val = model(input_)


In [49]:
output_val

tensor([ 0.0000,  3.4020,  6.8040, 10.2060, 13.6080, 17.0100, 13.4760,  9.9420,
         6.4080,  2.8740,  0.0000], grad_fn=<ReluBackward0>)

In [50]:
# Training data
inputs = torch.tensor([0.,0.5,0.])
labels = torch.tensor([0.,1.,0.])


In [51]:
str(model.bfinal.data)

'tensor(0.)'

In [52]:
optmizer = SGD(model.parameters(),lr=0.1)

print("Final bias before optm is {}".format(str(model.bfinal.data)))


Final bias before optm is tensor(0.)


In [53]:
for epoch in range(100):

    total_loss = 0 # How model fits the data

    for i in range(len(inputs)):
        input_i = inputs[i]
        label_i = labels[i]

        label_pred = model(input_i)

        loss_i = (label_pred - label_i)**2
        loss_i.backward()

        total_loss+=float(loss_i)
    
    if(total_loss<0.0001):
        print("Num epochs = {}".format(str(epoch)))
        break
    optmizer.step()
    optmizer.zero_grad()

    print("bfinal change >> {}".format(str(model.bfinal.data)))

print("Final_bias is {}".format(str(model.bfinal.data)))

bfinal change >> tensor(-3.2020)
bfinal change >> tensor(-5.7636)
bfinal change >> tensor(-7.8129)
bfinal change >> tensor(-9.4523)
bfinal change >> tensor(-10.7638)
bfinal change >> tensor(-11.8131)
bfinal change >> tensor(-12.6525)
bfinal change >> tensor(-13.3240)
bfinal change >> tensor(-13.8612)
bfinal change >> tensor(-14.2909)
bfinal change >> tensor(-14.6348)
bfinal change >> tensor(-14.9098)
bfinal change >> tensor(-15.1298)
bfinal change >> tensor(-15.3059)
bfinal change >> tensor(-15.4467)
bfinal change >> tensor(-15.5594)
bfinal change >> tensor(-15.6495)
bfinal change >> tensor(-15.7216)
bfinal change >> tensor(-15.7793)
bfinal change >> tensor(-15.8254)
bfinal change >> tensor(-15.8623)
bfinal change >> tensor(-15.8919)
bfinal change >> tensor(-15.9155)
bfinal change >> tensor(-15.9344)
bfinal change >> tensor(-15.9495)
bfinal change >> tensor(-15.9616)
bfinal change >> tensor(-15.9713)
bfinal change >> tensor(-15.9790)
bfinal change >> tensor(-15.9852)
bfinal change >> t