In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [10]:
# Parameters
n = 20  # context length
d = 5  # input dimension
L = 3  # number of layers
batch_size = 512
validation_size = batch_size // 8
epochs = 10000

# purtabation introduced for initialization
delta_initialize = 0.001

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# set seeds

# for PyTorch
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# to ensure reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [11]:
# distribution for x and w
def generate_data(batch_size, d, n, device):
    # x = (torch.rand(batch_size, n+1, d, device=device) - 0.5) * 2 * M  # uniform distribution. of shape (batch_size, n+1, d)
    x = torch.randn(batch_size, n+1, d, device=device) # normal distribution N(0, 1). of shape (batch_size, n+1, d)
    w_star = torch.randn(batch_size, d, device=device)  # normal distribution N(0, 1) of shape (batch_size, d)
    y = (x * w_star.unsqueeze(1)).sum(dim=2) # taking inner product of x_i and w for each batch. of shape(batch_size, n+1)
    return x, y, w_star # y[:,n] are the true y_{n+1} values

class LinearTransformer(nn.Module):
    def __init__(self, d, n, L, delta_init, device):
        super(LinearTransformer, self).__init__()
        self.L = L
        self.n = n
        self.device = device

        # random
        self.P_list = nn.ParameterList([nn.Parameter(torch.randn(d+1, d+1, device=device) * delta_init) for _ in range(L)])
        self.Q_list = nn.ParameterList([nn.Parameter(torch.randn(d+1, d+1, device=device) * delta_init) for _ in range(L)])

        self.M = torch.block_diag(torch.eye(n, device=device), torch.zeros(1,1, device=device))
        # sparse P and Q
    #     self.P_list = nn.ParameterList([self.create_P_matrix(d, device) for _ in range(L)])
    #     self.Q_list = nn.ParameterList([self.create_Q_matrix(d, device) for _ in range(L)])

    # def create_P_matrix(self, d, device):
    #     # create a (d+1) x (d+1) matrix with the desired structure for P_i
    #     P = torch.zeros(d+1, d+1, device=device)
    #     P[-1, -1] = 1
    #     return nn.Parameter(P)

    # def create_Q_matrix(self, d, device):
    #     # create a (d+1) x (d+1) matrix with the desired structure for Q_i
    #     Q = torch.zeros(d+1, d+1, device=device)
    #     A_i = torch.randn(d, d, device=device) * 0.1  # initialize A_i randomly
    #     Q[:d, :d] = A_i  # Place A_i in the top-left block
    #     return nn.Parameter(-Q)

        # small number
        #self.P_list = nn.ParameterList([nn.Parameter(torch.ones(d+1, d+1, device=device) * 0.001) for _ in range(L)])
        #self.Q_list = nn.ParameterList([nn.Parameter(torch.ones(d+1, d+1, device=device) * 0.001) for _ in range(L)])

        # zero doesn't work
        # self.P_list = nn.ParameterList([nn.Parameter(torch.zeros(d+1, d+1, device=device)) for _ in range(L)])
        # self.Q_list = nn.ParameterList([nn.Parameter(torch.zeros(d+1, d+1, device=device)) for _ in range(L)])

    def forward(self, Z):
        # Z is given as shape (batch_size, d+!, n+1)
        batch_size = Z.shape[0]
        for l in range(self.L):
            P = self.P_list[l]
            Q = self.Q_list[l]
            # batched version of Z_tilde = Z.T @ Q @ Z
            Z_tilde = torch.bmm(Z.transpose(1, 2), Q.unsqueeze(0).expand(batch_size, -1, -1))
            Z_tilde = torch.bmm(Z_tilde, Z)

            # batched version of Attn_PQ = P @ Z @ self.M @ Z_tilde
            Attn_PQ = torch.bmm(P.unsqueeze(0).expand(batch_size, -1, -1), Z)
            Attn_PQ = torch.bmm(Attn_PQ, self.M.unsqueeze(0).expand(batch_size, -1, -1))
            Attn_PQ = torch.bmm(Attn_PQ, Z_tilde)
            Z = Z + Attn_PQ / self.n
        return Z

In [12]:
def loss_function(ZL, y_true):
    y_hat = -ZL[:,-1,-1] # shape (batch_size,)
    return torch.mean((y_hat - y_true)**2)

In [13]:
# Hyperparameters for Adam optimizer
lr = 0.005
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8

# number of iteration of experiments
num_itr = 11

# Function to apply Adam updates with added noise to the gradient
def adam_step(params, grads, m, v, t, lr, beta1, beta2, epsilon, noise_scale):
    updated_params = []
    for idx, (param, grad) in enumerate(zip(params, grads)):
        # Add noise to gradient
        noisy_grad = grad + torch.randn_like(grad) * noise_scale

        # Update biased first moment estimate
        m[idx] = beta1 * m[idx] + (1 - beta1) * noisy_grad

        # Update biased second moment estimate
        v[idx] = beta2 * v[idx] + (1 - beta2) * (noisy_grad ** 2)

        # Compute bias-corrected first moment estimate
        m_hat = m[idx] / (1 - beta1 ** t)

        # Compute bias-corrected second moment estimate
        v_hat = v[idx] / (1 - beta2 ** t)

        # Update parameters
        param_update = param - lr * m_hat / (torch.sqrt(v_hat) + epsilon)
        updated_params.append(param_update)

    return updated_params, m, v

delta_init = 0.01
noise_choice = [0.5, 1.0, 2.0]

def compute_parameter_distance(params1, params2):
    total_distance = 0.0
    for p1, p2 in zip(params1, params2):
        total_distance += torch.sum((p1 - p2) ** 2).item()
    return total_distance

def compute_l2_norm(params):
    total_norm = 0.0
    for param in params:
        total_norm += torch.sum(param ** 2).item()
    return np.sqrt(total_norm)

train_x, train_y, _ = generate_data(batch_size, d, n, device)

train_losses_all = []
diff_params_all = []

for noise_idx in range(len(noise_choice)):
    noise = noise_choice[noise_idx]

    train_losses = []
    diff_params = []
    for itr in range(num_itr):
        # initialize the model, optimizer, and loss function
        model = LinearTransformer(d=d, n=n, L=L, delta_init=delta_init, device=device).to(device)

        # Adam's momentum (m) and second moment (v) terms
        m = [torch.zeros_like(param) for param in model.parameters()]
        v = [torch.zeros_like(param) for param in model.parameters()]

        for epoch in range(epochs):
            model.train()

            # zero-out gradients manually before each training step
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.zero_()

            # forward pass and loss calculation
            Z0 = torch.cat((torch.transpose(train_x,1,2), train_y.unsqueeze(1)), dim=1)  # concatenating x and y to form Z
            Z0[:,-1,-1] = 0 # initialize y_{n+1} as 0
            ZL = model(Z0)  # Apply the linear transformer
            train_loss = loss_function(ZL, train_y[:,-1])

            # backward pass and manual Adam update with noise
            grads = torch.autograd.grad(train_loss, model.parameters(), create_graph=False)

            # Set updated parameters back to the model
            with torch.no_grad():
                # manually perform an Adam step with noisy gradients
                updated_params, m, v = adam_step(list(model.parameters()), grads, m, v, epoch+1, lr, beta1, beta2, epsilon, noise_scale=noise)
                for param, updated_param in zip(model.parameters(), updated_params):
                    param.copy_(updated_param)


        train_losses.append(train_loss.item())
        with torch.no_grad():
            if itr == 0:
                prev_param = [param.clone().detach() for param in model.parameters()]
            else:
                current_param = [param.clone().detach() for param in model.parameters()]
                distance = compute_parameter_distance(prev_param, current_param)
                norm = compute_l2_norm(prev_param)
                # report squared distance and L2-norm of prev_param and ratio of distance and norm to
                # grasp an idea of variance
                diff_params.append([distance, norm, np.sqrt(distance)/norm])
                prev_param = current_param
    train_losses_all.append(train_losses)
    diff_params_all.append(diff_params)


print(train_losses_all)
print(diff_params_all)

[[0.059646882116794586, 0.06962095946073532, 0.08983862400054932, 0.09622515738010406, 0.0709066092967987, 0.0714753121137619, 16573.62890625, 1.4134182929992676, 247264800.0, 0.07128021121025085, 0.06468738615512848], [0.12489436566829681, 0.1903645396232605, 0.12084566801786423, 0.13092964887619019, 3.011082649230957, 155.3625030517578, 0.3229689300060272, 0.12494374811649323, 610.2999267578125, 0.10624580830335617, 0.1545482575893402], [4.3015594482421875, 0.2759762406349182, 1.2465240955352783, 0.20865362882614136, 0.2238343358039856, 0.2649287283420563, 3.8706555366516113, 0.44595056772232056, 0.1916150152683258, 0.25459399819374084, 0.28763097524642944]]
[[[30.307927906513214, 7.042401934690144, 0.781730964185434], [80.2342586517334, 6.059773077634685, 1.478167198565795], [98.05196928977966, 6.807984447956252, 1.4544862007144899], [85.66638877987862, 6.701304048342661, 1.3811660893744668], [69.59366434812546, 6.853320201871713, 1.2172613386508673], [63.93245846033096, 6.735994608

In [16]:
diff_05 = list(map(lambda x:x[0], diff_params_all[0]))
diff_10 = list(map(lambda x:x[0], diff_params_all[1]))
diff_20 = list(map(lambda x:x[0], diff_params_all[2]))
print(torch.FloatTensor(diff_05).mean().item(), torch.FloatTensor(diff_10).mean().item(), torch.FloatTensor(diff_20).mean().item())

77.420654296875 61.79973220825195 74.72417449951172


In [17]:
1+1

2