# Functionality

## Imports

In [1]:
import pandas as pd
import numpy as np
from scipy.integrate import odeint

## Generic Computations

In [2]:
def impulse(A, B, C, length):
    '''
    Computes the entry in index length of the impulse response induced by the SSM parameterized by (A, B, C).
    '''
    A = np.diag(A)
    return C @ np.linalg.matrix_power(A, length - 1) @ B

def compute_grad(A, length, x_long, x_short, B, C, Astar, Bstar, Cstar):
    '''
    Manually computes the gradient of the objective. 
    '''
    diag_long = np.zeros((A.shape[0]))
    res = impulse(A, B, C, length) - impulse(Astar, Bstar, Cstar, length)
    for k in range(len(x_long)):
        diag_long += (np.diag(C.T @ B.T) * res * (x_long[k] ** 2)).flatten()
    
    diag_short = np.zeros((A.shape[0]))
    res = impulse(A, B, C, 2) - impulse(Astar, Bstar, Cstar, 2)
    for k in range(len(x_short)):
        diag_short += (np.diag(C.T @ B.T) * res * (x_short[k] ** 2)).flatten()
    
    return - 2 / (len(x_long) + len(x_short)) * ((length - 1) * (A ** (length - 2)) * diag_long + diag_short)

def model(A, timestamps, length, x_long, x_short, B, C, Astar, Bstar, Cstar):
    '''
    The function used as input to odeint. 
    Intakes the model's parameter matrix A and the required timestamps.
    Returns the gradient of the objective at A. 
    '''
    da = compute_grad(A, length, x_long, x_short, B, C, Astar, Bstar, Cstar)
    return da

def compute_logs(timestamps, A, length, x_long, x_short, B, C, Astar, Bstar, Cstar, ext_start, ext_end):
    '''
    Logging function. 
    Intakes the timestamps and the approximated A values.
    Returns the train losses and extrapolation losses for the given timestamps. 
    '''
    train_losses = np.zeros(timestamps.shape[0])
    ext_losses = np.zeros(timestamps.shape[0])

    ell_infty = 0
    for j in range(ext_start, ext_end + 1):
        ell_infty = max(ell_infty, impulse(Astar, Bstar, Cstar, j))
    
    for t in range(timestamps.shape[0]):
        res = impulse(A[t, :], B, C, length) - impulse(Astar, Bstar, Cstar, length)
        for k in range(len(x_long)):
            train_losses[t] += (res * x_long[k]) ** 2
        res = impulse(A[t, :], B, C, 2) - impulse(Astar, Bstar, Cstar, 2)
        for k in range(len(x_short)):
            train_losses[t] += (res * x_short[k]) ** 2
        train_losses[t] /= (len(x_long) + len(x_short))
        
        for j in range(ext_start, ext_end + 1):
            res = impulse(A[t, :], B, C, j) - impulse(Astar, Bstar, Cstar, j)
            ext_losses[t] = max(ext_losses[t], np.abs(res))
        ext_losses[t] /= ell_infty
        
    return train_losses, ext_losses

## Simulate

In [3]:
def simulate(seed, hidden_dim, sd_A, length, x_long, x_short, B, C, Astar, Bstar, Cstar, stop, step, ext_start, 
             ext_end, diff):
    '''
    Simulates the optimization of A via gradient flow on the objective.
    '''
    np.random.seed(seed)
    A0 = np.flip(np.sort(sd_A * np.random.rand(hidden_dim)))
    A0[1] = A0[0] - diff
    train_losses, ext_losses = [], []
    timestamps = np.linspace(0, stop, step)
    A = odeint(model, A0, timestamps, args=(length, x_long, x_short, B, C, Astar, Bstar, Cstar))
    train_losses, ext_losses = compute_logs(timestamps, A, length, x_long, x_short, B, C, Astar, Bstar, Cstar, ext_start, ext_end)
    return (train_losses[-1], ext_losses[-1])

# Length = 7, teacher state dim = 2, student state dim = 10

## Setup

In [4]:
seeds = [242+i for i in [0, 1, 2, 4]]
teacher_hidden_dim = 2
student_hidden_dim = 10
length = 7
ext_start = 1
ext_end = 20
Bstar = np.zeros((teacher_hidden_dim, 1))
Cstar = np.ones((1, teacher_hidden_dim))
Astar = np.zeros((teacher_hidden_dim))
Astar[0] = 1
Bstar[0, 0] = 1
Bstar[1, 0] = np.sqrt(student_hidden_dim - 1)
Cstar = Bstar.T
B = np.ones((student_hidden_dim, 1))
C = B.T
sd_A = 0.001
diff = 0.05 / np.exp(5 * np.log10(1 / sd_A))

## Training only using baseline sequences

In [5]:
stop = 100000000000
step = 1000
avg_train_loss = 0
avg_ext_loss = 0
for seed in seeds:
    x_long = [1]
    x_short = []
    train_loss, ext_loss = simulate(seed, student_hidden_dim, sd_A, length, x_long, x_short, B, C, Astar, Bstar, 
                                    Cstar, stop, step, ext_start, ext_end, diff)
    avg_train_loss += train_loss
    avg_ext_loss += ext_loss
    print('--------------------------------')
    print(f'Results for seed={seed}:')
    print(f'Train loss: {train_loss}')
    print(f'Extrapolation loss: {ext_loss}')
avg_train_loss /= len(seeds)
avg_ext_loss /= len(seeds)
print('--------------------------------')
print('Overall results:')
print(f'Mean train loss: {avg_train_loss}')
print(f'Mean extrapolation loss: {avg_ext_loss}')

--------------------------------
Results for seed=242:
Train loss: 4.930380657631324e-32
Extrapolation loss: 0.0011832003440634064
--------------------------------
Results for seed=243:
Train loss: 0.0
Extrapolation loss: 0.001225059939469375
--------------------------------
Results for seed=244:
Train loss: 1.1093356479670479e-31
Extrapolation loss: 0.001367126534837504
--------------------------------
Results for seed=246:
Train loss: 0.0
Extrapolation loss: 0.0013053552097638476
--------------------------------
Overall results:
Mean train loss: 4.0059342843254506e-32
Mean extrapolation loss: 0.0012701855070335333


## Training using baseline and special sequences

In [6]:
stop = 10000
step = 1000
avg_train_loss = 0
avg_ext_loss = 0
for seed in seeds:
    x_long = [1]
    x_short = [1]
    train_loss, ext_loss = simulate(seed, student_hidden_dim, sd_A, length, x_long, x_short, B, C, Astar, Bstar, 
                                    Cstar, stop, step, ext_start, ext_end, diff)
    avg_train_loss += train_loss
    avg_ext_loss += ext_loss
    print('--------------------------------')
    print(f'Results for seed={seed}:')
    print(f'Train loss: {train_loss}')
    print(f'Extrapolation loss: {ext_loss}')
avg_train_loss /= len(seeds)
avg_ext_loss /= len(seeds)
print('--------------------------------')
print('Overall results:')
print(f'Mean train loss: {avg_train_loss}')
print(f'Mean extrapolation loss: {avg_ext_loss}')

--------------------------------
Results for seed=242:
Train loss: 2.465190328815662e-31
Extrapolation loss: 0.049180623265726676
--------------------------------
Results for seed=243:
Train loss: 6.162975822039155e-33
Extrapolation loss: 0.045836496012631886
--------------------------------
Results for seed=244:
Train loss: 2.465190328815662e-32
Extrapolation loss: 0.05430692338683909
--------------------------------
Results for seed=246:
Train loss: 4.930380657631324e-32
Extrapolation loss: 0.051073859174771384
--------------------------------
Overall results:
Mean train loss: 8.16594296420188e-32
Mean extrapolation loss: 0.050099475459992264


# Length = 9, teacher state dim = 2, student state dim = 20

## Setup

In [7]:
seeds = [342+i for i in [0, 2, 3, 4]]
teacher_hidden_dim = 2
student_hidden_dim = 20
length = 9
ext_start = 1
ext_end = 20
Bstar = np.zeros((teacher_hidden_dim, 1))
Cstar = np.ones((1, teacher_hidden_dim))
Astar = np.zeros((teacher_hidden_dim))
Astar[0] = 1
Bstar[0, 0] = 1
Bstar[1, 0] = np.sqrt(student_hidden_dim - 1)
Cstar = Bstar.T
B = np.ones((student_hidden_dim, 1))
C = B.T
sd_A = 0.005
diff = 0.05 / np.exp(10 * np.log10(1 / sd_A))

## Training only using baseline sequences

In [8]:
stop = 10000000000000
step = 10000
avg_train_loss = 0
avg_ext_loss = 0
for seed in seeds:
    x_long = [1]
    x_short = []
    train_loss, ext_loss = simulate(seed, student_hidden_dim, sd_A, length, x_long, x_short, B, C, Astar, Bstar, 
                                    Cstar, stop, step, ext_start, ext_end, diff)
    avg_train_loss += train_loss
    avg_ext_loss += ext_loss
    print('--------------------------------')
    print(f'Results for seed={seed}:')
    print(f'Train loss: {train_loss}')
    print(f'Extrapolation loss: {ext_loss}')
avg_train_loss /= len(seeds)
avg_ext_loss /= len(seeds)
print('--------------------------------')
print('Overall results:')
print(f'Mean train loss: {avg_train_loss}')
print(f'Mean extrapolation loss: {avg_ext_loss}')

--------------------------------
Results for seed=342:
Train loss: 4.930380657631324e-32
Extrapolation loss: 0.007281060820581074
--------------------------------
Results for seed=344:
Train loss: 0.0
Extrapolation loss: 0.0077195784791860384
--------------------------------
Results for seed=345:
Train loss: 4.930380657631324e-32
Extrapolation loss: 0.007078518716941872
--------------------------------
Results for seed=346:
Train loss: 1.232595164407831e-30
Extrapolation loss: 0.007279616318708103
--------------------------------
Overall results:
Mean train loss: 3.3280069439011436e-31
Mean extrapolation loss: 0.007339693583854273


## Training using baseline and special sequences

In [9]:
stop = 10000000
step = 1000
avg_train_loss = 0
avg_ext_loss = 0
for seed in seeds:
    x_long = [1]
    x_short = [1]
    train_loss, ext_loss = simulate(seed, student_hidden_dim, sd_A, length, x_long, x_short, B, C, Astar, Bstar, 
                                    Cstar, stop, step, ext_start, ext_end, diff)
    avg_train_loss += train_loss
    avg_ext_loss += ext_loss
    print('--------------------------------')
    print(f'Results for seed={seed}:')
    print(f'Train loss: {train_loss}')
    print(f'Extrapolation loss: {ext_loss}')
avg_train_loss /= len(seeds)
avg_ext_loss /= len(seeds)
print('--------------------------------')
print('Overall results:')
print(f'Mean train loss: {avg_train_loss}')
print(f'Mean extrapolation loss: {avg_ext_loss}')

--------------------------------
Results for seed=342:
Train loss: 4.930380657631324e-32
Extrapolation loss: 0.03501051395578071
--------------------------------
Results for seed=344:
Train loss: 2.465190328815662e-32
Extrapolation loss: 0.03527405460962272
--------------------------------
Results for seed=345:
Train loss: 2.465190328815662e-32
Extrapolation loss: 0.03494363138421701
--------------------------------
Results for seed=346:
Train loss: 0.0
Extrapolation loss: 0.03525862529595842
--------------------------------
Overall results:
Mean train loss: 2.465190328815662e-32
Mean extrapolation loss: 0.03512170631139472
