# Standard PyTorch MSE loss function

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Let's generate some fake data
torch.manual_seed(42)
resid = torch.rand(100)    
inputs = torch.tensor([ [ xx ] for xx in range(100)] , dtype=torch.float32)
labels = torch.tensor([ (2 + 0.5*yy + resid[yy]) for yy in range(100)], dtype=torch.float32)

# Now we define a linear regression model
class linearRegression(torch.nn.Module):
    def __init__(self, inputSize, outputSize):
        super(linearRegression, self).__init__()
        self.bn = torch.nn.BatchNorm1d(num_features=1)
        self.linear = torch.nn.Linear(inputSize, outputSize)

    def forward(self, inx):
        x = self.bn(inx) # Adding BN to standardize input helps us use a higher learning rate
        x = self.linear(x)
        return x
    
model = linearRegression(1, 1)     

# Using the standard mse_loss of PyTorch
epochs = 25    
mseloss = F.mse_loss
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = mseloss(outputs.view(-1), labels)
    loss.backward()
    optimizer.step()
    scheduler.step()        
    print(f'epoch {epoch}, loss {loss}')    
    

epoch 0, loss 953.6526489257812
epoch 1, loss 910.0452880859375
epoch 2, loss 879.88330078125
epoch 3, loss 851.927490234375
epoch 4, loss 823.8736572265625
epoch 5, loss 793.564697265625
epoch 6, loss 756.58740234375
epoch 7, loss 703.5147094726562
epoch 8, loss 616.2340087890625
epoch 9, loss 469.29998779296875
epoch 10, loss 259.6634826660156
epoch 11, loss 72.42332458496094
epoch 12, loss 7.541385650634766
epoch 13, loss 1.3356682062149048
epoch 14, loss 0.3731410503387451
epoch 15, loss 0.15019331872463226
epoch 16, loss 0.09882812201976776
epoch 17, loss 0.08701667189598083
epoch 18, loss 0.08430331945419312
epoch 19, loss 0.08368028700351715
epoch 20, loss 0.08353688567876816
epoch 21, loss 0.08353271335363388
epoch 22, loss 0.08352876454591751
epoch 23, loss 0.08352525532245636
epoch 24, loss 0.08352211862802505


# Custom loss function, but relying on PyTorch's automatic gradient calculation

In [None]:
#######################################################3
class MyMSELoss(nn.Module):
    
    def __init__(self):
        super(MyMSELoss, self).__init__()

    def forward(self, inputs, targets):        
        tmp = (inputs-targets)**2
        loss =  torch.mean(tmp)        
        return loss
#######################################################3

model = linearRegression(1, 1) 
    
mseloss = MyMSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

for epoch in range(epochs):
    model.train()    
    outputs = model(inputs)
    loss = mseloss(outputs.view(-1), labels)
    loss.backward()    
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    
    print(f'epoch {epoch}, loss {loss}')

epoch 0, loss 994.3351440429688
epoch 1, loss 955.7775268554688
epoch 2, loss 918.1871337890625
epoch 3, loss 871.7307739257812
epoch 4, loss 801.802978515625
epoch 5, loss 680.582275390625
epoch 6, loss 472.83929443359375
epoch 7, loss 200.48443603515625
epoch 8, loss 26.085895538330078
epoch 9, loss 1.3690561056137085
epoch 10, loss 0.33454182744026184
epoch 11, loss 0.1325385421514511
epoch 12, loss 0.09308631718158722
epoch 13, loss 0.08537135273218155
epoch 14, loss 0.0838615894317627
epoch 15, loss 0.08356618881225586
epoch 16, loss 0.0835084542632103
epoch 17, loss 0.08349698781967163
epoch 18, loss 0.08349480479955673
epoch 19, loss 0.08349443227052689
epoch 20, loss 0.08349420875310898
epoch 21, loss 0.08349428325891495
epoch 22, loss 0.08349429816007614
epoch 23, loss 0.0834941565990448
epoch 24, loss 0.08349429816007614


# Custom loss function with my own backward method

In [None]:
from torch.autograd import Function

#######################################################
class MyMSELoss(Function):
    
    @staticmethod
    def forward(ctx, y_pred, y):    
        ctx.save_for_backward(y_pred, y)
        return ( (y - y_pred)**2 ).mean()
    
    @staticmethod
    def backward(ctx, grad_output):
        y_pred, y = ctx.saved_tensors
        grad_input = 2 * (y_pred - y) / y_pred.shape[0]        
        test()
        return grad_input, None

    def test(ctx):
        pass
    
#######################################################
    
model = linearRegression(1, 1) 
mseloss = MyMSELoss.apply
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)

In [None]:
model.train()
outputs = model(inputs)
loss = mseloss(outputs.view(-1), labels)

In [None]:
loss

tensor(1005.8683, grad_fn=<MyMSELossBackward>)

In [None]:
loss.backward()

AttributeError: 'MyMSELossBackward' object has no attribute 'test'

In [None]:
for epoch in range(epochs):
    model.train()
    outputs = model(inputs)
    loss = mseloss(outputs.view(-1), labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    scheduler.step()
    print(f'epoch {epoch}, loss {loss}')

epoch 0, loss 0.08349424600601196
epoch 1, loss 0.08349424600601196
epoch 2, loss 0.08349424600601196
epoch 3, loss 0.08349424600601196
epoch 4, loss 0.08349424600601196
epoch 5, loss 0.08349424600601196
epoch 6, loss 0.08349424600601196
epoch 7, loss 0.08349424600601196
epoch 8, loss 0.08349424600601196
epoch 9, loss 0.08349424600601196
epoch 10, loss 0.08349424600601196
epoch 11, loss 0.08349424600601196
epoch 12, loss 0.08349424600601196
epoch 13, loss 0.08349424600601196
epoch 14, loss 0.08349424600601196
epoch 15, loss 0.08349424600601196
epoch 16, loss 0.08349424600601196
epoch 17, loss 0.08349424600601196
epoch 18, loss 0.08349424600601196
epoch 19, loss 0.08349424600601196
epoch 20, loss 0.08349424600601196
epoch 21, loss 0.08349424600601196
epoch 22, loss 0.08349424600601196
epoch 23, loss 0.08349424600601196
epoch 24, loss 0.08349424600601196
