In [1]:
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

In [2]:
def uniform_hypersphere(d):
    # sample a point uniformly from d-1 sphere
    x = np.random.randn(d)
    x /= np.linalg.norm(x)
    return x

In [3]:
def y_sample(D, c, X):
    y = X @ (D @ c) # no noise added
    return y
def X_sample(n, d, distribution_sample):
    # a n \times d matrix with each row sampled according to function distribution_sample
    X = []
    for i in range(n):
        X.append(distribution_sample(d))
    return np.array(X)

def c_sample(k, distribution_sample):
    # a k dimensional vector sampled according to distribution_sample
    return distribution_sample(k)

In [4]:
def task_c_loss_gradient_without_query(X, y, D, lambd=0.):
    
    k = D.shape[1] 
    S = np.linalg.inv((D.T@X.T)@(X@D) + (lambd * np.eye(k, k)))
    XTy = (X.T) @ y
    DTXTy = D.T @ XTy
    c_solved = S @ (DTXTy)
    pred = X @ (D @ c_solved)
    
    S_1 = (pred - y)
    S_2 = X.T @ S_1
    S_3 = S @ D.T @ X.T @ S_1
    
    functionValue = (np.linalg.norm(t_3) ** 2)
    gradient = 2 * (np.multiply.outer(S_2, c_solved))\
                - 2* ((np.multiply.outer(pred.T @ X.T, S_3)) + (np.multiply.outer((X.T @ (X @ (D @ (S @ (D.T @ S_2))))), c_solved)))\
                + (2 * np.multiply.outer(XTy, S_3))

    return c_solved, functionValue, gradient

In [5]:
def task_c_loss_gradient_without_query(X, y, D, lambd=0.):
    A = X @ D
    ATA = A.T @ A
    ATA_inv = np.linalg.inv(ATA + np.eye(D.shape[1]) * lambd)
    XTy = X.T @ y
    DTXTy = D.T @ XTy
    c_solved = ATA_inv @ DTXTy
    loss = np.sum(np.square(y)) - np.inner(DTXTy, c_solved)
    gradient = -2 * np.outer(XTy, c_solved)+ 2 * np.outer((X.T @ (A @ c_solved)), c_solved.T)
    return c_solved, loss, gradient

In [6]:
def train(D_true, D_init, X_pool, c_pool, k_actual, lambd, lr, num_iterations_train, test_frequency):
    D = D_init
    alpha = 0.05
    num_iterations = num_iterations_train
    losses = []
    avg_train_loss = []
    sum_train_loss = 0.0
    avg_test_loss = []

    for iteration in tqdm(range(num_iterations)):
        X = X_pool[np.random.choice(len(X_pool))]
        c = c_pool[np.random.choice(len(c_pool))]
        y = y_sample(D=D_true, c=c, X=X)
        c_solved, loss, gradient = task_c_loss_gradient_without_query(X, y, D, lambd)
        losses.append(loss)
        D = D - alpha * gradient # update the loss
        sum_train_loss += loss
        avg_train_loss.append(sum_train_loss / (iteration + 1))

        if iteration % test_frequency == 0: # test evaluation
            sum_test_loss = 0.0
            for i in range(2000):
                X = X_sample(n=n, d=d, distribution_sample=uniform_hypersphere)
                c = c_sample(k_actual, distribution_sample=uniform_hypersphere)
                y = y_sample(D=D_true, c=c, X=X)
                c_solved, loss, gradient = task_c_loss_gradient_without_query(X, y, D, lambd)
                sum_test_loss += loss
            avg_test_loss.append(sum_test_loss / 1000)
            
    return avg_train_loss, avg_test_loss
        


In [None]:
########### THIS IS A TYPICAL RUN ###########

# Set parameters
n = 5 # number of examples per task
d = 10 # ambient dimension
k = 100 # subspace dimension we learn (used in prediction)
k_actual = 2 # actual subspace dimension of the weights (used in data generation)  
T = 100 # total number of tasks
M = 10000 # total number of X matrix of shape n \times d
m = 1 # totatl number of fixML's matrix X
lr = 0.1 # learning rate for learning D
num_iterations_train = 1000 # no of train_iterations
test_frequency = 10 # no of iterations after which we test
lambd = 10. # inner loop regularizer

# Get Dataset
X_pool = [X_sample(n=n, d=d, distribution_sample=uniform_hypersphere) for i in range(M)]
X_fix_pool = X_pool[:m]
c_pool = [c_sample(k_actual, distribution_sample=uniform_hypersphere) for i in range(T)]


# obtain meta-learning and fix-meta learning solns
D_true = np.random.randn(d, k_actual)
D_init = np.random.randn(d, k)
metal_avg_train_loss, metal_avg_test_loss = train(D_true, D_init, X_pool, c_pool, k_actual, lambd, lr, num_iterations_train, test_frequency)
fix_metal_avg_train_loss, fix_metal_avg_test_loss = train(D_true, D_init, X_fix_pool, c_pool, k_actual, lambd, lr, num_iterations_train, test_frequency)

# plots
plt.plot(list(range(num_iterations_train)), metal_avg_train_loss, label='ML')
plt.plot(list(range(num_iterations_train)), fix_metal_avg_train_loss, label='FIX-ML')
plt.title('train')
plt.legend()
plt.show()

plt.plot(list(range(num_iterations_train // test_frequency)), metal_avg_test_loss, label='ML')
plt.plot(list(range(num_iterations_train // test_frequency)), fix_metal_avg_test_loss, label='FIX-ML')
plt.title('test')
plt.legend()
plt.show()

 52%|█████▏    | 521/1000 [01:25<01:16,  6.22it/s]