In [21]:
import numpy as np
import torch
import torch.autograd as autograd
import matplotlib.pyplot as plt

In [22]:
sample_size = 100000
sample_dim = 10
noise_scale = 1.0
train_split = 0.8
theta_range = 1.0
data_range = 500
epoch = 50000
lr = 1e-5
dim_per_task = 7
task_num = 10
shift_scale = 1e-8
tasks = [np.random.randint(0, sample_dim, dim_per_task) for _ in range(task_num)]

In [23]:
class DataGen:
    def __init__(self, sample_size, sample_dim, noise_scale, train_split, theta_range, data_range, tasks):
        self.sample_size = sample_size
        self.sample_dim = sample_dim
        self.noise_scale = noise_scale
        self.train_split = train_split
        self.theta_range = theta_range
        self.data_range = data_range
        self.theta = np.random.uniform(-theta_range, theta_range, (sample_dim, 1))
        self.tasks = tasks
    
    def get_data(self):
        for task in self.tasks:
            self.theta += shift_scale*self.noise_scale*np.random.normal(0, 1, (self.sample_dim, 1))
            X = np.zeros((self.sample_size, self.sample_dim))
            X[:, task] = np.random.uniform(-self.data_range, self.data_range, (self.sample_size, self.sample_dim))[:, task]
            y = np.dot(X, self.theta).squeeze() + self.noise_scale * np.random.normal(0, 1, self.sample_size)
            train_size = int(self.sample_size * self.train_split)
            X_train, X_test = torch.tensor(X[:train_size], dtype=torch.float), torch.tensor(X[train_size:], dtype=torch.float)
            y_train, y_test = torch.tensor(y[:train_size], dtype=torch.float), torch.tensor(y[train_size:], dtype=torch.float)
            yield X_train, y_train, X_test, y_test

In [24]:
class Model(torch.nn.Module):
    def __init__(self, sample_dim, lr):
        super().__init__()
        self.theta = torch.nn.Parameter(torch.randn(sample_dim, 1))
        self.optimizer = torch.optim.SGD(self.parameters(), lr=lr)
        self.loss_fn = torch.nn.MSELoss()
    
    def set_data(self, X, y):
        self.X, self.y = X.detach().clone(), y.detach().clone()

    def forward(self, X):
        return torch.matmul(X, self.theta).squeeze()
    
    def train(self, X, y, epoch, lr):
        for _ in range(epoch):
            self.optimizer.zero_grad()
            loss = self.loss_fn(self(X), y)
            loss.backward()
            self.optimizer.step()
    
    def test(self, X, y):
        return self.loss_fn(self(X), y)
    
    def get_loss_with_theta(self, theta):
        return self.loss_fn(torch.matmul(self.X, theta).squeeze(), self.y)
    
    def get_gradient(self, X, y, theta):
        self.set_data(X, y)
        return autograd.functional.jacobian(self.get_loss_with_theta, theta.squeeze()).detach().clone()
    
    def get_hessian(self, X, y, theta):
        self.set_data(X, y)
        return autograd.functional.hessian(self.get_loss_with_theta, theta.squeeze()).detach().clone()
    
    def prepare_estimation(self, X, y, old_theta):
        self.gradient = self.get_gradient(X, y, old_theta).reshape(1, -1)
        self.hessian = self.get_hessian(X, y, old_theta)
        self.old_theta = old_theta.detach().clone()
        self.old_ans = self.get_loss_with_theta(old_theta).detach().clone()
    
    def estimate(self, theta):
        dtheta = theta - self.old_theta
        ans = self.old_ans.detach().clone()
        ans += torch.matmul(self.gradient, dtheta).squeeze()
        ans += 0.5 * torch.matmul(torch.matmul(dtheta.T, self.hessian), dtheta).squeeze()
        return ans
    
    def diag_estimate(self, theta):
        dtheta = theta - self.old_theta
        ans = self.old_ans.detach().clone()
        ans += torch.matmul(self.gradient, dtheta).squeeze()
        ans += 0.5 * torch.dot(torch.diag(self.hessian),(dtheta.squeeze())).squeeze()
        return ans

In [25]:
model = Model(sample_dim, lr)
Data = DataGen(sample_size, sample_dim, noise_scale, train_split, theta_range, data_range, tasks)
first_task = True
for X_train, y_train, X_test, y_test in Data.get_data():
    model.train(X_train, y_train, epoch, lr)
    if first_task:
        first_task = False
    else:
        estimation = float(model.estimate(model.theta))
        diag_estimation = float(model.diag_estimate(model.theta))
        actual = float(model.test(model.X, model.y))
        delta = float(np.abs(estimation - actual))
        print(f"Estimation: {estimation}, Actual: {actual}, Change Ratio: {delta/actual}")
        delta = float(np.abs(diag_estimation - actual))
        print(f"Diag Estimation: {diag_estimation}, Change Ratio: {delta/actual}")
    model.prepare_estimation(X_train, y_train, model.theta)

Estimation: 1.0073801279067993, Actual: 1.0073801279067993, Change Ratio: 0.0
Diag Estimation: 3.13822078704834, Change Ratio: 2.1152299912537895
Estimation: 1.0062575340270996, Actual: 1.0062580108642578, Change Ratio: 4.7387166418042e-07
Diag Estimation: -5.9083099365234375, Change Ratio: 6.871565615113853
Estimation: 1.0045831203460693, Actual: 1.0045832395553589, Change Ratio: 1.1866541751537163e-07
Diag Estimation: 7.205212593078613, Change Ratio: 6.172340040500507
Estimation: 1.0042859315872192, Actual: 1.0042853355407715, Change Ratio: 5.935030878778656e-07
Diag Estimation: -1.6709250211715698, Change Ratio: 2.663795100893151
Estimation: 0.9961096048355103, Actual: 0.9961095452308655, Change Ratio: 5.983743962776326e-08
Diag Estimation: -1.8846137523651123, Change Ratio: 2.8919743931660857
Estimation: 0.9923503398895264, Actual: 0.9923505783081055, Change Ratio: 2.402564016318215e-07
Diag Estimation: -0.02004152536392212, Change Ratio: 1.0201960131852714
Estimation: 0.9970933794