In [None]:
from random import random
import numpy as np
import torch
import math
import matplotlib.pyplot as plt
# implement set seed function for reproducibility including torch, np and random
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


In [None]:
def random_unit_vector(p):
    vec = torch.randn(p)
    vec_norm = torch.linalg.norm(vec)
    return vec / vec_norm

def get_vectors(T, p, theta, norms):
    # Function to create a random vector of size p with a norm of 1


    # Initialize the list of vectors with the first unit vector
    vectors = [random_unit_vector(p)]

    for _ in range(1, T):
        new_vec = random_unit_vector(p)

        # Adjust the new vector to make an angle theta with the previous vectors
        # This part is simplified and may not always ensure the exact theta angle in higher dimensions
        for v in vectors:
            cos_theta = math.cos(theta)
            projection = torch.dot(v, new_vec) * v
            new_vec = cos_theta * v + math.sqrt(1 - cos_theta ** 2) * (new_vec - projection)
            new_vec = new_vec / torch.linalg.norm(new_vec)

        vectors.append(new_vec)

    vectors = torch.stack(vectors)
    return vectors * norms

def get_true_tasks(T, s, u, theta, norms, norm2 = 1):
    vectors = get_vectors(T, s, theta, norms)
    tasks = []
    for i in range(T):
        v = torch.zeros(u*T)
        v[i*u:(i+1)*u] = random_unit_vector(u) * norm2
        
        v = torch.cat([vectors[i], v], dim=0)
        
        tasks.append(v)
        
    return tasks

def get_tasks(T,p,true_tasks,theta,norms, selection_method = 'random'):
    p_ = len (true_tasks[0])
    
    if selection_method == 'random':
        return get_vectors(T, p, theta, norms)
    
    if p_ < p:
        expanded_tasks = []
        for v in true_tasks:
            v_ = torch.zeros(p)
            v_[:p_] = v
            expanded_tasks.append(v_)
        return expanded_tasks
    else:
        if selection_method == 'crop':
            return [v[:p] for v in true_tasks]
        elif selection_method == 'random_crop':
            idx = sorted ([i.item() for i in torch.randperm(p_)[:p]])
            return [v[idx] for v in true_tasks]
        else:
            raise ValueError('invalid selection method')


def get_X_and_Y(n, p, T, sigma, tasks):
    # add additive gaussian noise to the labels
    X = [torch.randn(p, n) for _ in range(T)]
    Y = [task @ X[i] + sigma * torch.randn(n) for i, task in enumerate(tasks)]
    return X, Y

def get_loss(w,w_star):
    return torch.linalg.norm(w-w_star)**2

In [None]:
def solve_eq(X,y,lambda_,w0=None):
    
    p,n = X.shape
    
    if w0 is None:
        w0 = torch.zeros(p)
    
    if p >= n + 1:
        w = w0 + X @ (X.T @ X).inverse() @ (y - X.T @ w0)
    else:
        try:
            w = (X @ X.T + lambda_ * torch.eye(p)).inverse() @ (X @ y + lambda_ * w0)
        except:
            print (X.shape, y.shape, w0.shape)
            raise ValueError('invalid shape')
    return w

def solve_multi_task(n, p, T, l, X_1T, Y_1T):

    X_1T = torch.cat(X_1T, -1)
    Y_1T = torch.cat(Y_1T, -1)
    
    w = solve_eq(X_1T,Y_1T,l)
    
    return w

def evaluate_multi_task(n , p, T, l, sigma, tasks):
    X_1T, Y_1T = get_X_and_Y(n, p, T, sigma, tasks)
    w = solve_multi_task(n, p, T, l, X_1T, Y_1T)

    average_error = 0
    for i in range(T):
        w_star = tasks[i]
        average_error += get_loss(w,w_star)
    average_error /= T
    return average_error

def evaluate_multiple_multi_tasks(n, p, T, l, sigma, tasks, num_trials):
    average_errors = []
    for _ in range(num_trials):
        average_errors.append(evaluate_multi_task(n, p, T, l, sigma, tasks))
    return np.mean (average_errors)



In [None]:

    
    

def solve_single_task(n, p, T, l, X_1T, Y_1T):

    X_1T = torch.cat(X_1T, -1)
    Y_1T = torch.cat(Y_1T, -1)

    ws = []
    for i in range(T):
        X = X_1T[:, i*n:(i+1)*n]
        Y = Y_1T[i*n:(i+1)*n]
        w = solve_eq(X,Y,l)
        ws.append(w)

    return ws

def evaluate_single_task(n, p, T, l, sigma, tasks):
    X_1T, Y_1T = get_X_and_Y(n, p, T, sigma, tasks)
    ws = solve_single_task(n, p, T, l, X_1T, Y_1T)

    average_error = 0
    for i in range(T):
        w_star = tasks[i]
        w = ws[i]
        average_error += get_loss(w,w_star)
    average_error /= T
    return average_error

def evaluate_multiple_single_tasks(n, p, T, l, sigma, tasks, num_trials):
    average_errors = []
    for _ in range(num_trials):
        average_errors.append(evaluate_single_task(n, p, T, l, sigma, tasks))
    return np.mean (average_errors)

In [None]:
def solve_sequential(n, p, T, l, X_1T, Y_1T):

    ws = []

    for i in range(T):
        X = X_1T[i]
        Y = Y_1T[i]
        w = solve_eq(X,Y,l)
        ws.append(w)
    return ws

def solve_continual(n, p, T, l, X_1T, Y_1T):
    ws = []
    w0 = torch.zeros(p)


    for i in range(T):
        X = X_1T[i]
        Y = Y_1T[i]
        
        w = solve_eq(X,Y,l,w0)
        
        w0 = w
        ws.append(w)
    return ws

def solve_mem(n, k, p, T, l, X_1T, Y_1T):

    ws = []
    w0 = torch.zeros(p)
    
    mem_X = []
    mem_Y = []
    
    for t in range(T):
        X_t = X_1T[t]
        Y_t = Y_1T[t]
        
        
        X = torch.cat(mem_X + [X_t], dim=1)
        Y = torch.cat(mem_Y + [Y_t], dim=0)
        
        w = solve_eq(X,Y,l,w0)
        
        mem_X.append(X_t[:, n-k:])
        mem_Y.append(Y_t[n-k:])

        ws.append(w)

    return ws

def solve_memreg(n, k, p, T, l, X_1T, Y_1T):

    ws = []
    w0 = torch.zeros(p)
    
    mem_X = []
    mem_Y = []
    
    for t in range(T):
        X_t = X_1T[t]
        Y_t = Y_1T[t]
        
        X = torch.cat(mem_X + [X_t], dim=1)
        Y = torch.cat(mem_Y + [Y_t], dim=0)
        
        w = solve_eq(X,Y,l,w0)
        
        mem_X.append(X_t[:, n-k:])
        mem_Y.append(Y_t[n-k:])

        ws.append(w)
        w0 = w

    return ws

def predict_theory_continual(n, p, T, l, sigma, tasks):
    if n + 1 <= p:
        r = 1 - n/p
        G1 = r**T/T * sum([w.norm()**2 for w in tasks])
        G2 = 1/T * sum([n * r ** (T - (i+1)) / p * sum([(wk - wi).norm()**2 for wk in tasks]) for i, wi in enumerate(tasks)])
        G3 = p * sigma**2 / (p - n - 1 + 1e-7) * (1 - r**T)
        return G1 + G2 + G3

    elif p + 1 <= n:
        wT = tasks[-1]
        G1 = 1/T * sum([(wT - w).norm()**2 for w in tasks])
        G2 = p * sigma**2 / (n - p - 1 + 1e-7)
        return G1 + G2

def predict_theory_memreg(n, k, p, T, l, sigma, tasks, t = None):
    if t is None:
        t = T
    # assume t to be indexed from 1, i.e. t = 1, 2, ..., T
    
    ns = [k] * (t-1) + [n]
    ws = tasks
    nb = sum(ns)
    if nb + 1 <= p:

        G1 = 1/T * sum ([
            sum([
                ns[s]/p * (ws[s] - ws[i]).norm()**2
            for s in range(t)])
        for i in range(T)])
        
        G2 = nb * sigma**2 / (p - nb - 1 + 1e-7)
        
        G3 = sum ([
            sum([
                1/(2*p) * (ns[s] * ns[ss]/ (p - nb - 1 + 1e-7)) * (ws[s] - ws[ss]).norm()**2
            for s in range(t)])
        for ss in range(t)])
        
        a = (1 - n/(p-(nb-n)))
        
        G5 = - (nb - n) * sigma**2 / (p - (nb - n) - 1 + 1e-7)
        if t == 1:
            G6 = 1/T * sum([w.norm()**2 for w in tasks])
        else:
            G6 = predict_theory_memreg(n,k,p,T,l,sigma,tasks, t = t - 1)
        
        G7 = -1/T * sum ([
            sum([
                ns[s]/p * (ws[s] - ws[i]).norm()**2
            for s in range(t-1)])
        for i in range(T)])
        
        G8 = -sum ([
            sum([
                1/(2*p) * (ns[s] * ns[ss]/ (p - (nb - n) - 1 + 1e-7)) * (ws[s] - ws[ss]).norm()**2
            for s in range(t-1)])
        for ss in range(t-1)])
            
        return G1 + G2 + G3 + a * (G5 + G6 + G7 + G8)
            
    else:
        return 0

def predict_theory_mem(n, k, p, T, l, sigma, tasks):
    ns = [k] * (T-1) + [n]
    nb = sum(ns)
    if nb + 1 <= p:
        G1 = 1/T * (1 - nb/p) * sum([w.norm()**2 for w in tasks])
        
        G2 = sum ([sum([ 
            ns[t]/ (T*p) * (1 + (T/2*ns[i]) / (p-nb-1+1e-7)) * (tasks[t] - tasks[i]).norm()**2  
            for t in range(T)]) for i in range(T)])
        
        G3 = nb * sigma**2 / (p - nb - 1 + 1e-7)
        
        return G1 + G2 + G3
    elif p + 1 <= nb:
        return 0

def predict_loss_memreg(n, k, p, T, l, sigma, tasks, i, t):
    # predicts the loss of the i-th task at time t for the memreg method
    # E[||w_t - w^*_i||^2]
    # assume t to be indexed from 1, i.e. t = 1, 2, ..., T
    # assume i also to be indexed from 1, i.e. i = 1, 2, ..., T
    if t < 0:
        assert False
        
    ns = [k] * (t-1) + [n]
    ws = tasks
    nb = sum(ns)
    if nb + 1 <= p:
        G1 = sum([
                ns[s]/p * (ws[s] - ws[i - 1]).norm()**2
            for s in range(t)])
        
        G2 = nb * sigma**2 / (p - nb - 1 + 1e-7)
        
        G3 = sum ([
            sum([
                1/(2*p) * (ns[s] * ns[ss]/ (p - nb - 1 + 1e-7)) * (ws[s] - ws[ss]).norm()**2
            for s in range(t)])
        for ss in range(t)])
        
        a = (1 - n/(p-(nb-n)))
        
        G5 = - (nb - n) * sigma**2 / (p - (nb - n) - 1 + 1e-7)
        if t == 1:
            G6 = ws[i - 1].norm()**2
        else:
            G6 = predict_loss_memreg(n,k,p,T,l,sigma,tasks, i, t = t - 1)
        
        G7 = -sum([
                ns[s]/p * (ws[s] - ws[i - 1]).norm()**2
            for s in range(t-1)])
        
        
        G8 = -sum ([
            sum([
                1/(2*p) * (ns[s] * ns[ss]/ (p - (nb - n) - 1 + 1e-7)) * (ws[s] - ws[ss]).norm()**2
            for s in range(t-1)])
        for ss in range(t-1)])
            
        return G1 + G2 + G3 + a * (G5 + G6 + G7 + G8)
            
    else:
        return 0

def predict_loss_mem(n, k, p, T, l, sigma, tasks, i, t):
    # predicts the loss of the i-th task at time t for the mem method
    # E[||w_t - w^*_i||^2]
    # assume t to be indexed from 1, i.e. t = 1, 2, ..., T
    # assume i also to be indexed from 1, i.e. i = 1, 2, ..., T
    
    ns = [k] * (t-1) + [n]
    ws = tasks
    nb = sum(ns)
    if nb + 1 <= p:
        G1 = sum ([
            ns[s]/p * (ws[s] - ws[i - 1]).norm()**2
        for s in range(t)])
        
        G2 = nb * sigma**2 / (p - nb - 1 + 1e-7)
        
        G3 = sum ([
            1/(2*p) * (ns[s] * ns[ss]/ (p - nb - 1 + 1e-7)) * (ws[s] - ws[ss]).norm()**2
        for s in range(t) for ss in range(t)])
        
        G4 = (1 - nb / p) * (ws[i - 1]).norm()**2
        
        return G1 + G2 + G3 + G4
    else:
        return 0

def predict_forgetting_memreg(n, k, p, T, l, sigma, tasks):
    f = 0
    for i in range(1, T):
        f += predict_loss_memreg(n, k, p, T, l, sigma, tasks, i, T) - predict_loss_memreg(n, k, p, T, l, sigma, tasks, i, i)
    
    return f / (T-1)

def predict_forgetting_mem(n, k, p, T, l, sigma, tasks):
    f = 0
    for i in range(1, T):
        f += predict_loss_mem(n, k, p, T, l, sigma, tasks, i, T) - predict_loss_mem(n, k, p, T, l, sigma, tasks, i, i)
    
    return f / (T-1)

def predict_theory_sequential(n, p, T, l, sigma, tasks):
    return predict_theory_mem(n, 0, p, T, l, sigma, tasks)

def predict_theory_multi(n, p, T, l, sigma, tasks):
    return predict_theory_mem(n, n, p, T, l, sigma, tasks)

def predict_theory_single(n, p, T, l, sigma, tasks):
    if n + 1 <= p:
        G1 = 1/T * (1 - n/p) * sum([w.norm()**2 for w in tasks])

        G3 = n * sigma**2 / (p - n - 1 + 1e-7)
        assert G1 + G3 >= 0, f'G1 = {G1}, G3 = {G3}, n = {n}, p = {p}, T = {T}, sigma = {sigma}'
        return G1 + G3
    elif p + 1 <= n:
        
        G =  p * sigma**2 / (n - p - 1 + 1e-7)
        assert G >= 0, f'G = {G}, n = {n}, p = {p}, T = {T}, sigma = {sigma}'
        return G


def evaluate_sequential(n, k, p, T, l, sigma, method, tasks):

    X_1T, Y_1T = get_X_and_Y(n, p, T, sigma, tasks)
    if method == "sequential":
        ws = solve_sequential(n, p, T, l, X_1T, Y_1T)
    elif method == "continual":
        ws = solve_continual(n, p, T, l, X_1T, Y_1T)
    elif method == "memreg":
        ws = solve_memreg(n, k, p, T, l, X_1T, Y_1T)
    elif method == "mem":
        ws = solve_mem(n, k, p, T, l, X_1T, Y_1T)
    else:
        raise ValueError("invalid method")

    w = ws[-1]

    average_error = 0
    for i in range(T):
        w_star = tasks[i]
        average_error += get_loss(w,w_star)
    average_error /= T
    
    average_forgetting = 0
    for i in range(T-1):
        w_star = tasks[i]
        average_forgetting += get_loss(ws[-1],w_star) - get_loss(ws[i],w_star)
    
    average_forgetting /= (T-1)
        
    return average_error, average_forgetting

def evaluate_multiple_sequentials(n, k, p, T, l, sigma, method, tasks, num_trials):
    average_errors = []
    average_forgettings = []
    
    for _ in range(num_trials):
        error, forgetting = evaluate_sequential(n, k, p, T, l, sigma, method, tasks)
        average_errors.append(error)
        average_forgettings.append(forgetting)
        
    return np.mean (average_errors), np.mean (average_forgettings)



# Paper Plots

## Fig1 - Multi-task vs Single-task

In [None]:
from utils import adjust_plots
adjust_plots(font_scale=1.2)

In [None]:
from tqdm.auto import tqdm

T = 10
ps = (list(range(10, 50, 1)) + list(range(51, 100, 1)) +
      list(range(100, 450, 10)) + 
      list(range(450, 500, 5)) +
      list(range(501, 550, 5)) + list(range(550, 1000, 10)) + list(range(1000, 10000, 100)) + [50000])
ps_ = np.int32 (10**np.linspace(np.log10(20), np.log10(10000), 10)) #list(range(20, 100, 10)) + list(range(100, 1000, 100)) + list(range(1000, 10000, 1000)) # Practical Ps
n = 50


theta = 7 * math.pi / 8
norms = 1
delta = 0.
l = 0.

s = 10
u = 0
true_tasks = get_true_tasks(T, s, u, theta, norms)
print (true_tasks[0].shape)


sigmas = [0, 0.3, 1]     



In [None]:
plot_values_theory_multi = {}
plot_values_practical_multi = {}
plot_values_theory_single = {}
plot_values_practical_single = {}


for sigma in sigmas:
    plot_values_theory_multi[sigma] = []
    plot_values_practical_multi[sigma] = []
    plot_values_theory_single[sigma] = []
    plot_values_practical_single[sigma] = []
    
    for p in tqdm(ps):
        num_trials = 500 if p < 1000 else 50
        
        tasks = get_tasks(T, p, true_tasks, theta, norms, selection_method='random_crop')

        plot_values_theory_single[sigma].append(predict_theory_single(n, p, T, l, sigma, tasks))
        
        if p >= T * n + 2:
            plot_values_theory_multi[sigma].append(predict_theory_multi(n, p, T, l, sigma, tasks))
        
        plot_values_practical_multi[sigma].append(evaluate_multiple_multi_tasks(n, p, T, l, sigma, tasks, num_trials))
        
        plot_values_practical_single[sigma].append(evaluate_multiple_single_tasks(n, p, T, l, sigma, tasks, num_trials))

In [None]:
ps = np.array(ps)

for sigma in sigmas:
    plot_values_theory_multi[sigma] = np.array(plot_values_theory_multi[sigma])
    plot_values_practical_multi[sigma] = np.array(plot_values_practical_multi[sigma])

    plot_values_theory_single[sigma] = np.array(plot_values_theory_single[sigma])
    plot_values_practical_single[sigma] = np.array(plot_values_practical_single[sigma])


In [None]:
def apply_interpolation(x, y, count = 100):
    x_new = np.linspace(x.min(), x.max(), count)
    y_new = np.interp(x_new, x, y)
    return x_new, y_new 

In [None]:
#plot results
import matplotlib.pyplot as plt

colors = ['red', 'green', 'blue', 'purple', 'orange', 'gray']

# plt.figure(figsize=(20, 20))
filter = (ps >= (T * n + 2))



for i, sigma in enumerate (sigmas):

    color = colors[i]
    plt.plot(np.log10 (ps)[filter], plot_values_theory_multi[sigma], label=f"multi-task, $\\sigma$ = {sigma}", color=color)
    plt.plot(np.log10 (ps), plot_values_practical_multi[sigma], linestyle=':', color=color, linewidth=2.5)

for i, sigma in enumerate (sigmas):
    color = colors[i+3]
    plt.plot(np.log10 (ps), plot_values_theory_single[sigma], label=f"single-task, $\\sigma$ = {sigma}", color=color)
    plt.plot(np.log10 (ps), plot_values_practical_single[sigma], linestyle=':', color=color, linewidth=2.5)

plt.xticks([1, np.log10(50), 2, np.log10(500), 3, 4], [10,50, 100, 500, 1000,10000])


plt.ylim(-0.5, 5.5)
plt.xlabel("p")
plt.xlim(1, 4.2)
# if theta/math.pi < 0.5:

if theta/math.pi < 0.5:
    plt.ylabel("Average Generalization Error ($G$)")
    plt.legend()
else:
    pass

# set the legend to be top right
# plt.legend(loc='upper right')

plt.grid()

plt.savefig(f'Figs/multivssinge_theta_{theta/math.pi:.2f}.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# plot multi-task - single task in another plot
import matplotlib.pyplot as plt

for i, sigma in enumerate (sigmas):
    color = colors[i]
    plt.plot(np.log10 (ps)[filter], plot_values_theory_single[sigma][filter] - plot_values_theory_multi[sigma], label=f"$\\sigma$ = {sigma}", color=color)
    plt.plot(np.log10 (ps), plot_values_practical_single[sigma] - plot_values_practical_multi[sigma], linestyle="--", color=color)

plt.xticks([1, np.log10(50), 2, np.log10(500), 3, 4], [10,50, 100, 500, 1000,10000])


plt.ylim(-2.5, 2.5)
plt.xlabel("$p$")
plt.xlim(1, 4.2)
# if theta/math.pi < 0.5:

plt.ylabel("Average Knowledge Transfer ($K$)")
if theta/math.pi < 0.5:
    plt.legend()
else:
    pass

plt.grid()

# plt.savefig(f'Figs/kt_multivssinge_theta_{theta/math.pi:.2f}.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:

# save_object = [ps, plot_values_theory_multi, plot_values_practical_multi, plot_values_theory_single, plot_values_practical_single]
# torch.save(save_object, f'Figs/saves/multivssinge_theta_{theta/math.pi:.2f}.pt')

# ps, plot_values_theory_multi, plot_values_practical_multi, plot_values_theory_single, plot_values_practical_single = torch.load(f'Figs/saves/multivssinge_theta_{theta/math.pi:.2f}.pt')

## Fig1 - Mem CL

In [None]:
from tqdm.auto import tqdm

T = 10
ps = (list(range(10, 50, 1)) + list(range(51, 100, 1)) +
      list(range(100, 130, 5)) +
      list(range(130, 150, 2)) +
      list(range(150, 210, 5)) +
      list(range(210, 250, 2)) +
      list(range(250, 300, 10)) +
      list(range(300, 340, 2)) +
      list(range(340, 390, 10)) +
      list(range(390, 420, 5)) +
      list(range(420, 490, 10)) +
      list(range(490, 510, 5)) +
      list(range(450, 550, 5)) +
      list(range(550, 1000, 10)) + list(range(1000, 10000, 100)) + [50000])

ps = sorted(ps)
 
      
ps_ = np.int32 (10**np.linspace(np.log10(20), np.log10(10000), 10)) #list(range(20, 100, 10)) + list(range(100, 1000, 100)) + list(range(1000, 10000, 1000)) # Practical Ps
n = 50


theta = math.pi / 8
norms = 1
sigma = 0.
l = 0.

s = 10
u = 0
true_tasks = get_true_tasks(T, s, u, theta, norms)
print (true_tasks[0].shape)

deltas = [0, 0.2, 0.4, 0.6, 0.8, 1.]

In [None]:
plot_values_theory_mem = {}
plot_values_practical_mem = {}

plot_values_theory_mem_forgetting = {}
plot_values_practical_mem_forgetting = {}


for delta in deltas:
    k = int(delta * n)
    plot_values_theory_mem[delta] = []
    plot_values_practical_mem[delta] = []
    
    plot_values_theory_mem_forgetting[delta] = []
    plot_values_practical_mem_forgetting[delta] = []
    
    for p in tqdm(ps):
        tasks = get_tasks(T, p, true_tasks, theta, norms, selection_method='random_crop')
        num_trials = 500 if p < 1000 else 50
        
        if p >= (T - 1) * k + n + 1:
            plot_values_theory_mem[delta].append(predict_theory_mem(n, k, p, T, l, sigma, tasks))
            plot_values_theory_mem_forgetting[delta].append(predict_forgetting_mem(n, k, p, T, l, sigma, tasks))
            
        tasks = get_tasks(T, p, true_tasks, theta, norms, selection_method='random_crop')
        e, f = evaluate_multiple_sequentials(n, k, p, T, l, sigma, "mem", tasks, num_trials)
        plot_values_practical_mem[delta].append(e)
        plot_values_practical_mem_forgetting[delta].append(f)
        
           



In [None]:
ps = np.array(ps)

for delta in deltas:
    plot_values_practical_mem[delta] = np.array(plot_values_practical_mem[delta])
    plot_values_theory_mem[delta] = np.array(plot_values_theory_mem[delta])
    
    plot_values_practical_mem_forgetting[delta] = np.array(plot_values_practical_mem_forgetting[delta])
    plot_values_theory_mem_forgetting[delta] = np.array(plot_values_theory_mem_forgetting[delta])



In [None]:
#plot results
import matplotlib.pyplot as plt

colors = ['red', 'green', 'blue', 'gray', 'orange', 'purple']

# plt.figure(figsize=(20, 20))

for i, delta in enumerate (deltas):
    x_new, y_new = apply_interpolation(np.log10 (ps), plot_values_practical_mem[delta], count = 200)
    
    color = colors[i]
    label = rf'$m$ = {int(delta * n)}'    
    if delta == 0.0:
        label += ' (sequential)'
    elif delta == 1.0:
        label += ' (multi-task)'        
    
    k = int(delta * n)
    filter = (ps >= ((T - 1) * k + n + 1))
    
    plt.plot(np.log10 (ps)[filter], plot_values_theory_mem[delta], label = label, color=color)
    plt.plot(np.log10 (ps), plot_values_practical_mem[delta], color=color, linestyle='dotted', linewidth=2.5)
    # plt.plot(x_new, y_new, color=color, marker = 'o', markersize=3.5, linestyle = ' ')

plt.xticks([1, np.log10(50), 2, np.log10(500), 3, 4], [10,50, 100, 500, 1000,10000])

plt.xlim(1, 4.2)
plt.xlabel("p")


# set the legend to be top right
# plt.legend(loc='lower right')



if theta/math.pi < 0.5:
    plt.ylabel("Average Generalization Error ($G$)")
    plt.legend()
    plt.ylim(-0.1, 1.05)
else:
    plt.ylim(0.8, 4.05)
    
plt.grid()

plt.savefig(f'Figs/mem_method_{theta/math.pi:.2f}.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
is_sorted = lambda a: np.all(a[:-1] <= a[1:])

In [None]:
is_sorted(ps)

In [None]:

plot_values_theory_memreg = {}
plot_values_practical_memreg = {}

plot_values_theory_memreg_forgetting = {}
plot_values_practical_memreg_forgetting = {}

deltas = [0, 0.2, 0.4, 0.6, 0.8, 1.]
for delta in deltas:
    k = int(delta * n)
    plot_values_theory_memreg[delta] = []
    plot_values_practical_memreg[delta] = []
    
    plot_values_theory_memreg_forgetting[delta] = []
    plot_values_practical_memreg_forgetting[delta] = []
    
    for p in tqdm(ps):
        tasks = get_tasks(T, p, true_tasks, theta, norms, selection_method='random_crop')
        
        if p >= (T - 1) * k + n + 1:
            plot_values_theory_memreg[delta].append(predict_theory_memreg(n, k, p, T, l, sigma, tasks))
            plot_values_theory_memreg_forgetting[delta].append(predict_forgetting_memreg(n, k, p, T, l, sigma, tasks))
        
    
        num_trials = 500 if p < 1000 else 50
    
        e, f = evaluate_multiple_sequentials(n, k, p, T, l, sigma, "memreg", tasks, num_trials)
        plot_values_practical_memreg[delta].append(e)
        plot_values_practical_memreg_forgetting[delta].append(f)
        
        



In [None]:
ps = np.array(ps)

for delta in deltas:
    plot_values_practical_memreg[delta] = np.array(plot_values_practical_memreg[delta])
    plot_values_theory_memreg[delta] = np.array(plot_values_theory_memreg[delta])
    
    plot_values_practical_memreg_forgetting[delta] = np.array(plot_values_practical_memreg_forgetting[delta])
    plot_values_theory_memreg_forgetting[delta] = np.array(plot_values_theory_memreg_forgetting[delta])



In [None]:
#plot results
import matplotlib.pyplot as plt

colors = ['red', 'green', 'blue', 'gray', 'orange', 'purple']

# plt.figure(figsize=(20, 20))

for i, delta in enumerate (deltas):
    color = colors[i]
    label = rf'$m$ = {int(delta * n)}'    
    if delta == 0.0:
        label += ' (sequential)'
    elif delta == 1.0:
        label += ' (multi-task)'        
    
    k = int(delta * n)
    filter = (ps >= ((T - 1) * k + n + 1))
    
    plt.plot(np.log10 (ps)[filter], plot_values_theory_memreg[delta], label = label, color=color)
    plt.plot(np.log10 (ps), plot_values_practical_memreg[delta], color=color, linestyle='dotted', linewidth=2.5)

plt.xticks([1, np.log10(50), 2, np.log10(500), 3, 4], [10,50, 100, 500, 1000,10000])

plt.xlim(1, 4.1)
plt.xlabel("$p$")

plt.ylabel("Average Generalization Error ($G$)")


if theta/math.pi < 0.5:
    plt.legend()
    plt.ylim(-0.1, 1.05)
else:
    plt.ylim(0.8, 4.05)


plt.grid()

plt.savefig(f'Figs/memreg_method_{theta/math.pi:.2f}.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# plot_values_practical_mem_forgetting, plot_values_theory_memreg_forgetting, plot_values_practical_memreg_forgetting]
# 
# torch.save(save_object, f'Figs/saves/mem_methods_{theta/math.pi:.2f}.pt')



In [None]:

# ps, plot_values_theory_mem, plot_values_practical_mem, plot_values_theory_memreg, plot_values_practical_memreg, plot_values_theory_mem_forgetting, plot_values_practical_mem_forgetting, plot_values_theory_memreg_forgetting, plot_values_practical_memreg_forgetting = torch.load(f'Figs/saves/mem_methods_{theta/math.pi:.2f}.pt')
