## Import relevant packages

In [107]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [110]:
def loss_weighted_mse(y_pred, y, loss_weights):
    y_diff_squared = (y - y_pred)**2
    y_diff_squared_weighted = torch.einsum('ijkl,j->ijkl', y_diff_squared, exp_mse_weights)
    weighted_mse = torch.mean(y_diff_squared_weighted)
    return weighted_mse

In [111]:
def generate_exp_weights(T, alpha):
    decay_length = 10.0
    augmentation_container = torch.linspace(0, 1, steps=T)
    augmentation_parameter = torch.linspace(0, alpha, steps=T)
    exp_mse_weights_unflipped = torch.exp(-decay_length * augmentation_container)
    exp_mse_weights = torch.flip(exp_mse_weights_unflipped, dims=[0])
    return exp_mse_weights

In [112]:
T = 20
alpha = 0.01
exp_mse_weights = generate_exp_weights(T, alpha)

In [113]:
iter_image = torch.zeros(1, T, 28, 28)
training_image = torch.ones(1, T, 28, 28)

In [114]:
loss_exp_mse = loss_weighted_mse(iter_image, training_image, exp_mse_weights)

In [115]:
loss_exp_mse

tensor(0.1222)

In [116]:
torch.mean(exp_mse_weights)

tensor(0.1222)

In [117]:
criterion = nn.MSELoss()

In [118]:
loss_mse = criterion(iter_image, training_image)

In [119]:
loss_mse

tensor(1.)