# Meta-Learning on 1-D Functions (Coding Part)

In this coding assignment, you will implement MAML for both regression and classification.

## Preparations

In [None]:
#@title Install Packages

!pip install higher

%load_ext autoreload
%autoreload 2

In [None]:
#@title Import Packages
import numpy as np
import matplotlib.pyplot as plt
import torch
import higher
import sys
import numpy as np
import torch
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.linear_model import LogisticRegression

import ipywidgets as widgets

%matplotlib inline

In [None]:
#@title Utility Functions

###############################################
#### ndarray/Tensor manipulation functions ####
def to_numpy(x):
    return x.detach().numpy()


def to_tensor(x):
    return torch.tensor(x)


def print_dict_initialization(params_dict, dict_name):
    for key in params_dict:
        print(str(key) + '=' + str(dict_name) + "['" + str(key) + '"]')


def my_sign_tensor(x):
    y = torch.sign(x)
    y[y == 0] = 1
    return y.int()


def my_sign_numpy(x):
    y = np.sign(x)
    y[y == 0] = 1
    return y.astype('int')


# Use this version of my_sign
def my_sign(x):
    if torch.is_tensor(x):
        return my_sign_tensor(x)
    return my_sign_numpy(x)


###################################
#### Data generation functions ####
def add_label_noise(y, noise_prob):
    y = y * np.random.choice([-1, 1], p=[noise_prob, 1 - noise_prob], size=y.shape)
    return y


def featurize_fourier(x, d, normalize=False):
    assert (d - 1) % 2 == 0, "d must be odd"
    max_r = int((d - 1) / 2)
    n = len(x)
    A = np.zeros((n, d))
    A[:, 0] = 1
    for d_ in range(1, max_r + 1):
        A[:, 2 * (d_ - 1) + 1] = np.sin(d_ * x * np.pi)
        A[:, 2 * (d_ - 1) + 2] = np.cos(d_ * x * np.pi)
    if normalize:
        A[:, 0] *= (1 / np.sqrt(2))
        A *= np.sqrt(2)
    return A


def featurize(x, d, phi_type, normalize=True):
    function_map = {
        # 'polynomial':featurize_vandermonde,
        'fourier': featurize_fourier}
    return function_map[phi_type](x, d, normalize)


def generate_x(n, x_type, x_low=-1, x_high=1):
    if x_type == 'grid':
        x = np.linspace(x_low, x_high, n, endpoint=False).astype(np.float64)
    elif x_type == 'uniform_random':
        x = np.sort(np.random.uniform(x_low, x_high, n).astype(np.float64))
        # Note that for making it easy for plotting we sort the randomly sampled x in ascending order
    else:
        raise ValueError
    return x


def generate_y(features, k_idx, k_val):
    # y as linear combination of features
    return np.sum(features[:, k_idx] * k_val, 1)


##########################################
#### Closed-form regression functions ####
def solve_ls(phi, y, weights=None):
    d = phi.shape[1]
    if weights is None:
        weights = np.ones(d)
    phi_weighted = weights * phi
    LR = LinearRegression(fit_intercept=False)
    LR.fit(phi_weighted, y)
    coeffs_weighted = LR.coef_
    alpha = coeffs_weighted * weights
    loss = np.mean((y - phi @ alpha.T)**2)

    return alpha.T, loss


def solve_ridge(phi, y, lambda_ridge, weights=None):
    d = phi.shape[1]
    if weights is None:
        weights = np.ones(d)
    phi_weighted = weights * phi

    Rdg = Ridge(fit_intercept=False, alpha=lambda_ridge)
    Rdg.fit(phi_weighted, y)
    coeffs_weighted = Rdg.coef_
    alpha = coeffs_weighted * weights
    loss = np.mean((y - phi @ alpha.T)**2) + lambda_ridge * np.sum((coeffs_weighted)**2)
    return alpha, loss


############################################
#### sklearn logistic regression solver ####
def solve_logistic(phi, z, weights=None):
    # print(z)
    # raise ValueError
    d = phi.shape[1]
    if weights is None:
        weights = np.ones(d)
    phi_weighted = weights * phi
    clf = LogisticRegression(tol=1e-4, verbose=False, solver='lbfgs', random_state=0, fit_intercept=False, C=1e6).fit(phi_weighted, z)

    coeffs_weighted = clf.coef_
    alpha = coeffs_weighted * weights

    z_pred = my_sign(phi @ alpha.T)[:, 0]
    loss = np.mean(z != z_pred)
    # print(loss)
    return alpha.T, loss


#######################################
#### Model class for use with MAML ####
class DummyModel(torch.nn.Module):
    def __init__(self, d):
        super(DummyModel, self).__init__()
        self.feature_weights = torch.nn.Parameter(torch.ones(d).double())
        self.coeffs = torch.nn.Parameter(torch.zeros(d).double())

    def forward(self, F):
        return (F * self.feature_weights) @ self.coeffs


##############################
#### ipywidget generators ####
def generate_int_widget(desc, min_, val, max_, step=1):
    return widgets.IntSlider(
        value=val,
        min=min_,
        max=max_,
        step=step,
        description=desc,
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d')


def generate_float_widget(desc, min_, val, max_, step):
    return widgets.FloatSlider(
        value=val,
        min=min_,
        max=max_,
        step=step,
        description=desc,
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='.1f',
    )


def generate_floatlog_widget(desc, min_, step, val, max_, base=10):
    return widgets.FloatLogSlider(
        value=val,
        base=base,
        min=min_,  # max exponent of base
        max=max_,  # min exponent of base
        step=step,  # exponent step
        description=desc
    )

In [None]:
#@title Visualization Utilities

def visualize_prediction_classification(data_dict, feature_weights, n_train_post):
    train_data_dict = data_dict['train'][n_train_post]
    test_data_dict  = data_dict['test']


    x_train_post = train_data_dict['x']
    features_post = train_data_dict['features']
    y_post = train_data_dict['y'][0]

    z_post = my_sign(y_post)
    w_post, loss = solve_logistic(features_post, z_post, feature_weights)

    w_post /= np.linalg.norm(w_post)
    y_post_pred = features_post@w_post
    z_post_pred = my_sign(y_post_pred)

    x_test_post = test_data_dict['x']
    features_test_post = test_data_dict['features']
    y_test_post = test_data_dict['y'][0]
    y_test_post_pred = features_test_post@w_post


    #print(y_test_post_pred.shape, y_test_post.shape)
    #print(np.mean((np.squeeze(my_sign(y_test_post))!= np.squeeze(my_sign(y_test_post_pred)))))
    # print(((my_sign(y_test_post)!= my_sign(y_test_post_pred)).astype('float')).shape)
    # print(((my_sign(y_test_post)!= my_sign(y_test_post_pred)).astype('float')))
    # plt.plot(my_sign(y_test_post) - my_sign(y_test_post_pred))
    # plt.show()

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = [12,5])
    ax1.scatter(x_train_post, z_post, marker = 'o', s = 100, color = 'red', alpha = 0.5, label = 'Training points true')
    ax1.scatter(x_train_post, z_post_pred, marker='x', s = 100, color = 'green', alpha = 0.5, label = 'Training points predictions')
    ax1.plot(x_test_post, y_test_post_pred, '-', color = 'blue', alpha = 0.4,  label = 'Predicted function')

    z_test_post_pred = my_sign(y_test_post_pred)
    ax1.scatter(x_test_post, z_test_post_pred, s=10, color='green', label = 'Predicted Sign Labels', alpha = 0.8)
    z_test_post = my_sign(y_test_post)
    ax1.scatter(x_test_post, z_test_post, s = 10, color='orange', label = 'True Sign Labels', alpha = 0.8)

    ax1.plot(x_test_post, y_test_post, '-', color = 'brown', label = 'True function', alpha = 0.4)
    ax1.set_xlabel('x')
    ax1.set_title('n_train_post =' +str(n_train_post))
    ax1.legend()

    ax2.plot(np.abs(feature_weights), 'o-')
    ax2.set_title('Feature weights')
    ax2.set_xlabel('Feature #')
    ax2.set_ylabel('abs(feature_weight)')
#     ax2.set_yscale('log')
    plt.show()


def  visualize_test_loss_reg(iteration, n_train_inner, n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss,
        init_n_train_post_range=None,init_avg_test_loss=None, init_top_10_loss=None,init_bot_10_loss=None,\
            oracle_n_train_post_range=None,oracle_avg_test_loss=None, oracle_top_10_loss=None, oracle_bot_10_loss=None, zero_avg_loss = None, zero_top_10_loss = None, zero_bot_10_loss = None,  wrong_n_train_post_range=None, wrong_avg_test_loss=None, wrong_top_10_loss=None, wrong_bot_10_loss=None, noise_std = None):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = [12,5])


    ax1.fill_between(n_train_post_range, bot_10_loss, top_10_loss, alpha = 0.5)

    if iteration is not None:
        ax1.plot(n_train_post_range, avg_test_loss, 'o-', label = 'Iteration: ' + str(iteration))
    else:
        ax1.plot(n_train_post_range, avg_test_loss, 'o-', label = 'Meta-learned')
    ax1.set_yscale('log')
    ax1.set_ylabel('Test mse')
    ax1.set_xlabel('n_train_post')
    ax1.set_title('Test mse vs n_train_post')

    if init_avg_test_loss is not None:
        ax1.fill_between(init_n_train_post_range, init_bot_10_loss, init_top_10_loss, alpha = 0.5, color = 'orange')
        ax1.plot(init_n_train_post_range, init_avg_test_loss, 'o-', c = 'orange', label = 'Init')


    if oracle_avg_test_loss is not None:
        ax1.fill_between(oracle_n_train_post_range,oracle_bot_10_loss, oracle_top_10_loss, alpha = 0.5, color = 'green')
        ax1.plot(oracle_n_train_post_range, oracle_avg_test_loss, 'o-', c = 'green', label = 'Oracle')
    if wrong_avg_test_loss is not None:
        ax1.fill_between(wrong_n_train_post_range,wrong_bot_10_loss, wrong_top_10_loss, alpha = 0.5, color = 'red')
        ax1.plot(wrong_n_train_post_range, wrong_avg_test_loss, 'o-', c = 'red', label = 'Wrong weights')

    if zero_avg_loss is not None:
        ax1.fill_between(n_train_post_range,zero_bot_10_loss*np.ones(len(n_train_post_range)), zero_top_10_loss*np.ones(len(n_train_post_range)), alpha = 0.5, color = 'yellow')

        ax1.plot(n_train_post_range, np.ones(len(n_train_post_range))*zero_avg_loss, '-', c = 'yellow', label = 'Zero')



    if noise_std is not None:
        ax1.plot(n_train_post_range, np.ones_like(n_train_post_range)*(noise_std**2), '--', c = 'black', label = 'Noise variance')

    ax1.legend()

    idx = np.where(n_train_post_range <= 4*n_train_inner)[0]
    cn_train_post_range = n_train_post_range[idx]
    cavg_test_loss = avg_test_loss[idx]
    ctop_10_loss = top_10_loss[idx]
    cbot_10_loss = bot_10_loss[idx]


    ax2.fill_between(cn_train_post_range, cbot_10_loss, ctop_10_loss, alpha = 0.5)
    if iteration is not None:
        ax2.plot(cn_train_post_range, cavg_test_loss, 'o-', label = 'Iteration: ' + str(iteration))
    else:
        ax2.plot(cn_train_post_range, cavg_test_loss, 'o-', label = 'Meta-learned')

    ax2.set_yscale('log')
    ax2.set_ylabel('Test mse')
    ax2.set_xlabel('n_train_post')
    ax2.set_title('Test mse vs n_train_post (zoomed)')

    if init_avg_test_loss is not None:
        idx = np.where(init_n_train_post_range <= 4*n_train_inner)[0]
        cinit_n_train_post_range = init_n_train_post_range[idx]
        cinit_avg_test_loss = init_avg_test_loss[idx]
        cinit_top_10_loss = init_top_10_loss[idx]
        cinit_bot_10_loss = init_bot_10_loss[idx]

        ax2.fill_between(cinit_n_train_post_range, cinit_bot_10_loss, cinit_top_10_loss, alpha = 0.5, color = 'orange')
        ax2.plot(cinit_n_train_post_range, cinit_avg_test_loss, 'o-', c = 'orange', label = 'Init')

    if oracle_avg_test_loss is not None:
        idx = np.where(oracle_n_train_post_range <= 4*n_train_inner)[0]
        coracle_n_train_post_range = oracle_n_train_post_range[idx]
        coracle_avg_test_loss = oracle_avg_test_loss[idx]
        coracle_top_10_loss = oracle_top_10_loss[idx]
        coracle_bot_10_loss = oracle_bot_10_loss[idx]
        ax2.fill_between(coracle_n_train_post_range,coracle_bot_10_loss, coracle_top_10_loss, alpha = 0.5, color = 'green')
        ax2.plot(coracle_n_train_post_range, coracle_avg_test_loss, 'o-', c = 'green', label = 'Oracle')

    if wrong_avg_test_loss is not None:
        idx = np.where(wrong_n_train_post_range <= 4*n_train_inner)[0]
        cwrong_n_train_post_range = wrong_n_train_post_range[idx]
        cwrong_avg_test_loss = wrong_avg_test_loss[idx]
        cwrong_top_10_loss = wrong_top_10_loss[idx]
        cwrong_bot_10_loss = wrong_bot_10_loss[idx]
        ax2.fill_between(cwrong_n_train_post_range,cwrong_bot_10_loss, cwrong_top_10_loss, alpha = 0.5, color = 'red')
        ax2.plot(cwrong_n_train_post_range, cwrong_avg_test_loss, 'o-', c = 'red', label = 'Wrong weights')


    if zero_avg_loss is not None:
        ax2.fill_between(cn_train_post_range,zero_bot_10_loss*np.ones(len(cn_train_post_range)), zero_top_10_loss*np.ones(len(cn_train_post_range)), alpha = 0.5, color = 'yellow')

        ax2.plot(cn_train_post_range, np.ones(len(cn_train_post_range))*zero_avg_loss, '-', c = 'yellow', label = 'Zero')

    if noise_std is not None:
        ax2.plot(cn_train_post_range, np.ones_like(cn_train_post_range)*(noise_std**2), '--', label = 'Noise variance', c = 'black')
    ax2.legend()


    plt.show()


def visualize_prediction_reg(data_dict, feature_weights, n_train_post):
    train_data_dict = data_dict['train'][n_train_post]
    test_data_dict  = data_dict['test']


    x_train_post = train_data_dict['x']
    features_post = train_data_dict['features']
    y_post = train_data_dict['y'][0]

    w_post, loss = solve_ls(features_post, y_post, feature_weights)

    y_post_pred = features_post@w_post

    x_test_post = test_data_dict['x']
    features_test_post = test_data_dict['features']
    y_test_post = test_data_dict['y'][0]
    y_test_post_pred = features_test_post@w_post



    #For plotting purposes add x_train to x_test

    x_test_post = np.concatenate([x_train_post, x_test_post])
    y_test_post = np.concatenate([y_post, y_test_post])
    y_test_post_pred = np.concatenate([y_post_pred, y_test_post_pred])

    idx = np.argsort(x_test_post)
    x_test_post = x_test_post[idx]
    y_test_post = y_test_post[idx]
    y_test_post_pred = y_test_post_pred[idx]


    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = [12,5])
    ax1.scatter(x_train_post, y_post, marker = 'o', color = 'red', alpha = 0.5, label = 'Training points true')
    ax1.scatter(x_train_post, y_post_pred, marker='x', color = 'green', alpha = 0.5, label = 'Training points predictions')



    ax1.plot(x_test_post, y_test_post_pred, label = 'Predicted function')
    ax1.plot(x_test_post, y_test_post, label = 'True function')
    ax1.set_xlabel('x')
    ax1.set_title('n_train_post =' +str(n_train_post))
    ax1.legend()

    ax2.plot(np.abs(feature_weights), 'o-')
    ax2.set_title('Feature weights')
    ax2.set_xlabel('Feature #')
    ax2.set_ylabel('abs(feature_weight)')
#     ax2.set_yscale('log')
    plt.show()




def visualize_test_loss_classification(iteration, n_train_inner, n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss,
        init_n_train_post_range=None,init_avg_test_loss=None, init_top_10_loss=None, init_bot_10_loss=None,\
            oracle_n_train_post_range=None,oracle_avg_test_loss=None, oracle_top_10_loss=None, oracle_bot_10_loss=None, zero_avg_loss = None, zero_top_10_loss = None, zero_bot_10_loss = None, noise_prob = None):

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize = [12,5])


    ax1.fill_between(n_train_post_range, bot_10_loss, top_10_loss, alpha = 0.5)
    ax1.plot(n_train_post_range, avg_test_loss, 'o-', label = 'Iteration: ' + str(iteration))
    ax1.set_yscale('log')
    ax1.set_ylabel('Test classification error')
    ax1.set_xlabel('n_train_post')
    ax1.set_title('Test classification error vs n_train_post')

    if init_avg_test_loss is not None:
        ax1.fill_between(init_n_train_post_range, init_bot_10_loss, init_top_10_loss, alpha = 0.5, color = 'orange')
        ax1.plot(init_n_train_post_range, init_avg_test_loss, 'o-', c = 'orange', label = 'Init')



    if oracle_avg_test_loss is not None:
        ax1.fill_between(oracle_n_train_post_range,oracle_bot_10_loss, oracle_top_10_loss, alpha = 0.5, color = 'green')
        ax1.plot(oracle_n_train_post_range, oracle_avg_test_loss, 'o-', c = 'green', label = 'Oracle')


    if zero_avg_loss is not None:
        ax1.fill_between(n_train_post_range,zero_bot_10_loss*np.ones(len(n_train_post_range)), zero_top_10_loss*np.ones(len(n_train_post_range)), alpha = 0.5, color = 'yellow')

        ax1.plot(n_train_post_range, np.ones(len(n_train_post_range))*zero_avg_loss, '-', c = 'yellow', label = 'Zero')

    if noise_prob is not None and noise_prob != 0:
        ax1.plot(n_train_post_range, np.ones_like(n_train_post_range)*(noise_prob), '--', c = 'black', label = 'Noise variance')

    ax1.legend()
    idx = np.where(n_train_post_range <= 2*n_train_inner)[0]
    cn_train_post_range = n_train_post_range[idx]
    cavg_test_loss = avg_test_loss[idx]
    ctop_10_loss = top_10_loss[idx]
    cbot_10_loss = bot_10_loss[idx]


    ax2.fill_between(cn_train_post_range, cbot_10_loss, ctop_10_loss, alpha = 0.5)
    ax2.plot(cn_train_post_range, cavg_test_loss, 'o-', label = 'Iteration: ' + str(iteration))

    ax2.set_yscale('log')
    ax2.set_ylabel('Test classification error')
    ax2.set_xlabel('n_train_post')
    ax2.set_title('Test mse vs n_train_post (zoomed)')

    if init_avg_test_loss is not None:
        idx = np.where(init_n_train_post_range <= 2*n_train_inner)[0]
        cinit_n_train_post_range = init_n_train_post_range[idx]
        cinit_avg_test_loss = init_avg_test_loss[idx]
        cinit_top_10_loss = init_top_10_loss[idx]
        cinit_bot_10_loss = init_bot_10_loss[idx]

        ax2.fill_between(cinit_n_train_post_range, cinit_bot_10_loss, cinit_top_10_loss, alpha = 0.5, color = 'orange')
        ax2.plot(cinit_n_train_post_range, cinit_avg_test_loss, 'o-', c = 'orange', label = 'Init')

    if oracle_avg_test_loss is not None:
        idx = np.where(oracle_n_train_post_range <= 2*n_train_inner)[0]
        coracle_n_train_post_range = oracle_n_train_post_range[idx]
        coracle_avg_test_loss = oracle_avg_test_loss[idx]
        coracle_top_10_loss = oracle_top_10_loss[idx]
        coracle_bot_10_loss = oracle_bot_10_loss[idx]
        ax2.fill_between(coracle_n_train_post_range,coracle_bot_10_loss, coracle_top_10_loss, alpha = 0.5, color = 'green')
        ax2.plot(coracle_n_train_post_range, coracle_avg_test_loss, 'o-', c = 'green', label = 'Oracle')


    if zero_avg_loss is not None:
        ax2.fill_between(cn_train_post_range,zero_bot_10_loss*np.ones(len(cn_train_post_range)), zero_top_10_loss*np.ones(len(cn_train_post_range)), alpha = 0.5, color = 'yellow')

        ax2.plot(cn_train_post_range, np.ones(len(cn_train_post_range))*zero_avg_loss, '-', c = 'yellow', label = 'Zero')

    if noise_prob is not None and noise_prob != 0:
        ax2.plot(cn_train_post_range, np.ones_like(cn_train_post_range)*(noise_prob), '--', label = 'Noise probability', c = 'black')
    ax2.legend()


    plt.show()

## MAML for Regression with Closed-Form Min-Norm Solution

In this homework, we aim to learn a suitable set of feature weights for regression tasks originating from a distribution $D_T$. This distribution is defined using several entries in the `params_dict`. By default, the true feature indices, represented by `k_idx`, are set to ${5,6,7,8,9,10,11}$. The true coefficients for these features are generated as i.i.d samples from a uniform distribution $U[-1, 1]$, and then normalized to have a total length of 1. The total number of features is determined by the parameter `d`.

For the inner training loop, the $x$ sample spacing is controlled by the `x_type` parameter. The meta-update, however, always utilizes uniformly random spaced samples to prevent differentiation issues with aliased features. In this exercise, we will only use uniformly random samples, which is a logical choice as it matches the meta-update requirement, and there is no specific reason to have different spacing for the inner training loop.

An essential distinction between the original MAML paper and our implementation in this notebook is that *we employ the closed-form min-norm solution for regression* rather than gradient descent. Luckily, PyTorch allows us to backpropagate gradients through matrix inversion, enabling us to update the feature weights using the min-norm least squares solution instead of gradient descent steps. In later sections, we will utilize gradient descent for the inner loop.

In [None]:
def closed_form_ls(F_t, y_t):
    # Returns min norm least squares solution
    w_t = F_t.T @ torch.inverse(F_t @ F_t.T) @ y_t
    return w_t


def meta_update_reg(w_t, feature_weights_t, n_train_meta, phi_type, d, k_idx, k_val, noise_std):
    x_meta = generate_x(n_train_meta, 'uniform_random')
    features_meta = featurize(x_meta, phi_type=phi_type, d=d, normalize=True)
    features_meta_t = torch.tensor(features_meta)
    y_meta = generate_y(features_meta, k_idx, k_val)
    y_meta += np.random.normal(0, noise_std, y_meta.shape)

    y_meta_t = torch.tensor(y_meta)
    y_meta_pred_t = features_meta_t @ w_t

    criterion = torch.nn.MSELoss(reduction='sum')
    loss = criterion(y_meta_t, y_meta_pred_t)

    return loss

def get_post_data_reg(x_type, phi_type, k_idx, num_tasks_test, d, n_train_post_range, n_test, noise_std):
    k_val_test = np.random.uniform(-1, 1, size=(len(k_idx), num_tasks_test))
    k_val_test /= np.linalg.norm(k_val_test, axis=0)

    data_dict = {'train': {}, 'test': {}}

    x_test = generate_x(n_test, 'uniform_random')
    features_test = featurize(x_test, phi_type=phi_type, d=d, normalize=True)
    data_dict['test']['features'] = features_test

    data_dict['test']['x'] = x_test
    data_dict['test']['y'] = []
    for i in range(num_tasks_test):
        k_val_post = k_val_test[:, i]
        y_test = generate_y(features_test, k_idx, k_val_post)
        data_dict['test']['y'].append(y_test)
    for n_train_post in n_train_post_range:
        data_dict['train'][n_train_post] = {}

        x_train_post = generate_x(n_train_post, x_type)
        features_post = featurize(x_train_post, phi_type=phi_type, d=d, normalize=True)

        data_dict['train'][n_train_post]['x'] = x_train_post
        data_dict['train'][n_train_post]['features'] = features_post

        data_dict['train'][n_train_post]['y'] = []
        for i in range(num_tasks_test):
            k_val_post = k_val_test[:, i]
            y_post = generate_y(features_post, k_idx, k_val_post)
            y_post += np.random.normal(0, noise_std, y_post.shape)
            data_dict['train'][n_train_post]['y'].append(y_post)

    return data_dict


def test_oracle_reg(data_dict, x_type, feature_weights, min_n_train_post=0):
    # plt.plot(feature_weights, 'o-')
    # plt.show()
    # print(feature_weights)
    k_idx = np.where(feature_weights != 0)[0]
    feature_weights = None
    test_loss_matrix = []

    n_train_post_range = np.sort(np.array(list(data_dict['train'].keys())))
    n_train_post_range = n_train_post_range[n_train_post_range >= min_n_train_post]

    test_data_dict = data_dict['test']
    features_test_post = test_data_dict['features']

    for n_train_post in n_train_post_range:
        # print("n", n_train_post)
        train_data_dict = data_dict['train'][n_train_post]

        features_post = train_data_dict['features'][:, k_idx]
        cfeatures_test_post = features_test_post[:, k_idx]

        feature_norms = np.linalg.norm(features_post, axis=0)
        r_idx = np.where(feature_norms > 1e-6)[0]

        features_post = features_post[:, r_idx]
        cfeatures_test_post = cfeatures_test_post[:, r_idx]

        test_loss_array = []

        for i in range(len(train_data_dict['y'])):
            y_post = train_data_dict['y'][i]
            # print(features_post.shape, y_post.shape)
            if x_type == 'grid':
                # Use ridge with small reguralizer to avoid crazy effects of poor conditioning
                w_post, loss = solve_ridge(features_post, y_post, lambda_ridge=1e-12, weights=feature_weights)
            else:
                w_post, loss = solve_ls(features_post, y_post, weights=feature_weights)

            y_post_pred = features_post @ w_post
            y_test_post = test_data_dict['y'][i]
            y_test_post_pred = cfeatures_test_post @ w_post

            # Compute the regression loss
            test_loss = np.mean((y_test_post - y_test_post_pred)**2)
            test_loss_array.append(test_loss)

        test_loss_matrix.append(test_loss_array)

    test_loss_matrix = np.array(test_loss_matrix).T
    avg_test_loss = np.mean(test_loss_matrix, 0)

    top_10_loss = np.percentile(test_loss_matrix, 90, axis=0)
    bot_10_loss = np.percentile(test_loss_matrix, 10, axis=0)

    return n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss


def test_reg(data_dict, feature_weights, min_n_train_post=0):
    # plt.plot(feature_weights, 'o-')
    # plt.show()
    # print(feature_weights)
    test_loss_matrix = []

    n_train_post_range = np.sort(np.array(list(data_dict['train'].keys())))
    n_train_post_range = n_train_post_range[n_train_post_range >= min_n_train_post]

    test_data_dict = data_dict['test']
    features_test_post = test_data_dict['features']

    for n_train_post in n_train_post_range:
        # print("n", n_train_post)
        train_data_dict = data_dict['train'][n_train_post]

        features_post = train_data_dict['features']

        test_loss_array = []
        for i in range(len(train_data_dict['y'])):
            y_post = train_data_dict['y'][i]
            # print(features_post.shape, y_post.shape)
            w_post, loss = solve_ls(features_post, y_post, feature_weights)
            y_post_pred = features_post @ w_post

            y_test_post = test_data_dict['y'][i]
            y_test_post_pred = features_test_post @ w_post

            # Compute the regression loss
            test_loss = np.mean((y_test_post - y_test_post_pred)**2)
            test_loss_array.append(test_loss)

        test_loss_matrix.append(test_loss_array)

    test_loss_matrix = np.array(test_loss_matrix).T
    avg_test_loss = np.mean(test_loss_matrix, 0)

    top_10_loss = np.percentile(test_loss_matrix, 90, axis=0)
    bot_10_loss = np.percentile(test_loss_matrix, 10, axis=0)

    return n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss


def test_zero_reg(data_dict):
    test_data_dict = data_dict['test']
    ys = test_data_dict['y']

    test_loss_array = []
    for i in range(len(ys)):
        y = ys[i]
        # Compute the regression loss
        test_loss = np.mean(y**2)
        test_loss_array.append(test_loss)

    avg_test_loss = np.mean(test_loss_array)
    top_10_loss = np.percentile(test_loss_array, 90)
    bot_10_loss = np.percentile(test_loss_array, 10)

    return avg_test_loss, top_10_loss, bot_10_loss


def meta_learning_reg_closed_form(params_dict):
    seed = params_dict["seed"]
    n_train_inner = params_dict["n_train_inner"]
    n_train_meta = params_dict["n_train_meta"]
    # n_train_post = params_dict["n_train_post"]
    n_test_post = params_dict["n_test_post"]
    x_type = params_dict["x_type"]
    d = params_dict["d"]
    phi_type = params_dict["phi_type"]
    k_idx = params_dict["k_idx"]
    optimizer_type = params_dict["optimizer_type"]
    stepsize_meta = params_dict["stepsize_meta"]
    num_inner_tasks = params_dict["num_inner_tasks"]
    num_tasks_test = params_dict["num_tasks_test"]
    num_stats = params_dict["num_stats"]
    num_iterations = params_dict["num_iterations"]
    noise_std = params_dict.get('noise_std', 0)
    num_n_train_post_range = params_dict['num_n_train_post_range']

    # Set seed:
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Parameters
    stats_every = num_iterations // num_stats
    init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = None, None, None, None

    # Initialize meta parameter: weights on the d features
    feature_weights_t = torch.tensor(np.ones(d), requires_grad=True)

    # Define meta parameter optimizer
    if optimizer_type == 'SGD':
        opt_meta = torch.optim.SGD([feature_weights_t], lr=stepsize_meta)
    elif optimizer_type == 'Adam':
        opt_meta = torch.optim.Adam([feature_weights_t], lr=stepsize_meta)
    else:
        raise ValueError

    # Meta training loop
    # Get post train and test data
    n_train_post_range = np.logspace(np.log10(1), np.log10(3 * d), num_n_train_post_range).astype('int')


    must_have_points = [n_train_inner, len(k_idx) - 1, len(k_idx), len(k_idx) + 1]

    for point in must_have_points:
        if point not in n_train_post_range:
            n_train_post_range = np.hstack([n_train_post_range, point])

    n_train_post_range = np.sort(n_train_post_range)
    n_train_post_range = np.unique(n_train_post_range)
    # print(n_train_post_range)
    data_dict = get_post_data_reg(x_type, phi_type, k_idx, num_tasks_test, d, n_train_post_range, n_test_post, noise_std)

    for i in range(num_iterations):
        opt_meta.zero_grad()

        # Get x and features
        x = generate_x(n_train_inner, x_type)
        features = featurize(x, phi_type=phi_type, d=d, normalize=True)
        weighted_features_t = to_tensor(features) * feature_weights_t

        # Loop over inner tasks
        for t in range(num_inner_tasks):
            # Get random coefficients
            k_val = np.random.uniform(-1, 1, size=len(k_idx))
            k_val /= np.linalg.norm(k_val)

            # Generate y
            y = generate_y(features, k_idx, k_val)
            y += np.random.normal(0, noise_std, y.shape)
            y_t = torch.tensor(y)

            # Get closed form solution for w_t as a function of feature_weights_t
            w_t = closed_form_ls(weighted_features_t, y_t)
            w_t = w_t * feature_weights_t  # Reweight the coefficients so that we can multiply with unweighted features to get prediction

            # Meta update
            meta_loss = meta_update_reg(w_t, feature_weights_t, n_train_meta, phi_type, d, k_idx, k_val, noise_std)
            meta_loss.backward(retain_graph=True)

        if i == 0:
            # n_train_post_range = np.logspace(0, np.log10(3 * d), 40).astype('int')
            print("-" * 70)
            print("Iteration: ", i)
            # #Oracle stats
            oracle_feature_weights = np.zeros(d)
            oracle_feature_weights[k_idx] = 1
            oracle_n_train_post_range, oracle_avg_test_loss, oracle_top_10_loss, oracle_bot_10_loss = test_oracle_reg(
                data_dict, feature_weights=oracle_feature_weights, x_type=x_type)

            # plt.plot(oracle_n_train_post_range, oracle_avg_test_loss)
            # plt.show()
            zero_avg_loss, zero_top_10_loss, zero_bot_10_loss = test_zero_reg(data_dict)

            feature_weights = to_numpy(feature_weights_t)
            init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = test_reg(data_dict, feature_weights)

            visualize_test_loss_reg(
                0, n_train_inner, init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss,
                oracle_n_train_post_range=oracle_n_train_post_range, oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss, oracle_bot_10_loss=oracle_bot_10_loss, zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss, zero_bot_10_loss=zero_bot_10_loss, noise_std=noise_std)

            visualize_prediction_reg(data_dict, feature_weights, n_train_inner)

        #Stats
        if (i + 1) % stats_every == 0 or i == num_iterations - 1:
            print("-" * 70)
            print("Iteration: ", i + 1)
            feature_weights = to_numpy(feature_weights_t)

            n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss = test_reg(data_dict, feature_weights)
            visualize_test_loss_reg(
                i, n_train_inner, n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss, init_n_train_post_range, init_avg_test_loss,
                init_top_10_loss, init_bot_10_loss, oracle_n_train_post_range=oracle_n_train_post_range, oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss, oracle_bot_10_loss=oracle_bot_10_loss, zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss, zero_bot_10_loss=zero_bot_10_loss, noise_std=noise_std)
            visualize_prediction_reg(data_dict, feature_weights, n_train_inner)

        opt_meta.step()  # Finally update meta weights after loop through all tasks

    return to_numpy(feature_weights_t), data_dict


def meta_learning_reg_sgd(params_dict):
    seed = params_dict["seed"]
    n_train_inner = params_dict["n_train_inner"]
    n_train_meta = params_dict["n_train_meta"]
#     n_train_post = params_dict["n_train_post"]
    n_test_post = params_dict["n_test_post"]
    x_type = params_dict["x_type"]
    d = params_dict["d"]
    phi_type = params_dict["phi_type"]
    k_idx = params_dict["k_idx"]
    optimizer_type = params_dict["optimizer_type"]
    stepsize_meta = params_dict["stepsize_meta"]
    num_inner_tasks = params_dict["num_inner_tasks"]
    num_tasks_test = params_dict["num_tasks_test"]
    num_stats = params_dict["num_stats"]
    num_iterations = params_dict["num_iterations"]
    noise_std = params_dict.get('noise_std', 0)
    num_n_train_post_range = params_dict['num_n_train_post_range']

    stepsize_inner = params_dict["stepsize_inner"]
    num_gd_steps = params_dict['num_gd_steps']

    # Set seed:
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Parameters
    stats_every = num_iterations // num_stats
    init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = None, None, None, None

    # Define meta parameter optimizer
    dm = DummyModel(d)

    # Define meta parameter optimizer
    if optimizer_type == 'SGD':
        opt_meta = torch.optim.SGD([dm.feature_weights], lr=stepsize_meta)
    elif optimizer_type == 'Adam':
        opt_meta = torch.optim.Adam([dm.feature_weights], lr=stepsize_meta)
    else:
        raise ValueError

    # Meta training loop
    # Get post train and test data
    n_train_post_range = np.logspace(np.log10(1), np.log10(3 * d), num_n_train_post_range).astype('int')

    must_have_points = [n_train_inner, len(k_idx) - 1, len(k_idx), len(k_idx) + 1]

    for point in must_have_points:
        if point not in n_train_post_range:
            n_train_post_range = np.hstack([n_train_post_range, point])

    n_train_post_range = np.sort(n_train_post_range)
    n_train_post_range = np.unique(n_train_post_range)
    data_dict = get_post_data_reg(x_type, phi_type, k_idx, num_tasks_test, d, n_train_post_range, n_test_post, noise_std)

    for i in range(num_iterations):
        opt_meta.zero_grad()

        # Get x and features
        x = generate_x(n_train_inner, x_type)
        features = featurize(x, phi_type=phi_type, d=d, normalize=True)
        features_t = to_tensor(features)

        # Loop over inner tasks
        for t in range(num_inner_tasks):
            # Get random coefficients
            k_val = np.random.uniform(-1, 1, size=len(k_idx))
            k_val /= np.linalg.norm(k_val)

            dm.coeffs = torch.nn.Parameter(torch.zeros(d).double())

            # Generate y
            y = generate_y(features, k_idx, k_val)
            y += np.random.normal(0, noise_std, y.shape)
            y_t = torch.tensor(y)

            opt_inner = torch.optim.SGD([dm.coeffs], lr=stepsize_inner)
            with higher.innerloop_ctx(dm, opt_inner, copy_initial_weights=False, track_higher_grads=True) as (mod, opt):

                for j in range(num_gd_steps):
                    y_pred = mod(features_t)
                    loss = torch.mean((y_pred - y_t)**2)
#                     opt_inner.zero_grad()
#                     loss.backward()
                    opt.step(loss)

                # Meta update
                meta_loss = meta_update_reg(mod.coeffs * mod.feature_weights, mod.feature_weights, n_train_meta, phi_type, d, k_idx, k_val, noise_std)
                meta_loss.backward(retain_graph=True)

        if i == 0:
            # n_train_post_range = np.logspace(0, np.log10(3 * d), 40).astype('int')
            print("-" * 70)
            print("Iteration: ", i)
            # #Oracle stats
            oracle_feature_weights = np.zeros(d)
            oracle_feature_weights[k_idx] = 1
            oracle_n_train_post_range, oracle_avg_test_loss, oracle_top_10_loss, oracle_bot_10_loss = test_oracle_reg(
                data_dict, feature_weights=oracle_feature_weights, x_type=x_type)

            # plt.plot(oracle_n_train_post_range, oracle_avg_test_loss)
            # plt.show()
            zero_avg_loss, zero_top_10_loss, zero_bot_10_loss = test_zero_reg(data_dict)

            feature_weights = to_numpy(dm.feature_weights)
            init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = test_reg(data_dict, feature_weights)

            visualize_test_loss_reg(
                0, n_train_inner, init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss,
                oracle_n_train_post_range=oracle_n_train_post_range, oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss, oracle_bot_10_loss=oracle_bot_10_loss, zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss, zero_bot_10_loss=zero_bot_10_loss, noise_std=noise_std)

            visualize_prediction_reg(data_dict, feature_weights, n_train_inner)

        # Stats
        if (i + 1) % stats_every == 0 or i == num_iterations - 1:
            print("-" * 70)
            print("Iteration: ", i + 1)
            feature_weights = to_numpy(dm.feature_weights)

            n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss = test_reg(data_dict, feature_weights)
            visualize_test_loss_reg(
                i, n_train_inner, n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss, init_n_train_post_range, init_avg_test_loss,
                init_top_10_loss, init_bot_10_loss, oracle_n_train_post_range=oracle_n_train_post_range, oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss, oracle_bot_10_loss=oracle_bot_10_loss, zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss, zero_bot_10_loss=zero_bot_10_loss, noise_std=noise_std)
            visualize_prediction_reg(data_dict, feature_weights, n_train_inner)

        opt_meta.step()  # Finally update meta weights after loop through all tasks

In [None]:
def get_params_dict_reg():

    params_dict = {
        #Parameters
    'seed':7,
    'n_train_inner':32, #Number of training samples for training inner task
    'n_train_meta':64, #Number of training samples for updating the meta parameter

    'n_train_post':32, #Number of training samples used after meta training has been done to learn the weights
    'n_test_post':1000, #Number of samples used for plotting and evaluating test performance after meta training has been done

    # 'x_type':'uniform_random', #sampling time for inner and post training tasks
    # 'x_type':'grid', #sampling time for training tasks

    'd':501, #Number of features
    'phi_type':'fourier', #Feature type
     'noise_std':1e-1, #standard deviation of awgn noise added during training
    'optimizer_type':'SGD', #Can be either SGD or Adam
    'k_idx':np.arange(5,11), #Frequency range present in tasks during meta training
    'stepsize_meta':1e-2, #Stepsize used for meta updates

     'num_inner_tasks':5, #Number of inner tasks for each meta update
    'num_tasks_test':10, #Number of tasks to test on
    'num_stats': 10, #Determines how often we collect stats

    'num_iterations':100, #Iterations for training meta parameter
        'num_n_train_post_range':40, #How many points do we use to generative test loss vs n_train_post curve

    }
    return params_dict

### Training Points on a Grid

In this subsection, we use grid-spaced datapoints in the inner loop of training, which is the same spacing rule you explored in the theory parts of this problem. Meanwhile, the meta-update and test data spacings will continue to use uniformly random samples. As a result, during the inner training loop, the features within each alias group will be identical, but each feature will be unique during the meta-update and when computing the test error.

You should observe that the feature weights exhibit behavior similar to the limits you derived in the analytical parts, where the true features are favored (i.e., they have higher weights). However, there may be a noticeable difference in how certain other feature weights behave.

**Run the next two cells then answer questions:**

In [None]:

#Evenly spaced training points
x_type = 'grid'
params_dict = get_params_dict_reg()
cparams_dict = params_dict.copy()
cparams_dict['x_type'] = x_type
_ = meta_learning_reg_closed_form(cparams_dict)


For each logged iteration, we generate visualizations, consisting of two rows with two subfigures each, totaling four subfigures.

- First Row:

  - The left figure presents the test Mean Squared Error (MSE) loss with respect to the number of datapoints used for linear regression after meta-training, plotted on a log scale.
    - The green curve represents the oracle test loss using only the features present in the true signal.
    - The blue curve displays the test loss using the feature weights learned through meta-training.
    - The orange curve marks the initial location (iteration-0) of the blue curve.
    - For each curve, the solid line corresponds to the average test loss over 10 tasks, and the shaded band indicates the range between the 10th and 90th percentile.
    - The yellow line shows a baseline case where we predict zero for each datapoint.
    - The dashed line represents the noise variance used when generating the data.

  - The right figure is a zoomed-in version of the left figure, focusing on specific details.

- Second Row:

  - The left figure compares the true function (orange) to the predicted function (blue) for one particular task.
    - The red dots represent the training points.
    - The green crosses indicate the predictions on these training points. Note that these coincide since we are in the overparameterized regime and can interpolate the training data.

  - The right figure illustrates the learned feature weights as meta-training progresses. Observe how all 500 features were initially equally weighted with a value of 1.

#### Question

Considering the plot of regression test loss versus `n_train_post`, **how does the performance of the meta-learned feature weights compare to the case where all feature weights are set to 1?** Additionally, **how does their performance compare to the oracle**, which performs regression using only the features present in the data? Can you **explain the reason for the downward spike observed at `n_train_post = 32`?** Include the answer in your written submission of the written assignent.

#### Question

By examining the changes in feature weights over time during meta-learning, **can you justify the observed improvement in performance?** Specifically, can you **explain why certain feature weights are driven towards zero**? Include the answer in your written submission of the written assignent.

## MAML for Regression with Gradient Descent

In the previous sections, we demonstrated how to employ the closed-form min-norm least squares solution for training the meta-learning parameter (feature weights) in MAML. However, for many problems, we may not have access to closed-form solutions for the tasks we aim to solve. In such cases, we need to rely on iterative methods like gradient descent.

For the regression task, we can perform gradient descent on the squared loss. However, it is crucial to preserve gradients with respect to the feature weights when calculating the coefficients during inner training. PyTorch enables us to achieve this using the `higher` library.

Please note that in these experiments, we employ gradient descent in the inner loop for `num_gd_steps`. However, when testing our performance, we utilize the closed-form expression for the min-norm least squares solution. This is because, during the final performance evaluation, we must either execute enough iterations of gradient descent to approach the closed-form solution or directly use the closed-form solution. Interestingly, we will observe that even a single gradient descent step towards the solution during meta-training helps us learn the feature weights effectively.

In [None]:
def get_params_dict_reg_sgd():
    params_dict = {
        #Parameters
    'seed':7,
    'n_train_inner':32, #Number of training samples for training inner task
    'n_train_meta':64, #Number of training samples for updating the meta parameter

    'n_train_post':32, #Number of training samples used after meta training has been done to learn the weights
    'n_test_post':1000, #Number of samples used for plotting and evaluating test performance after meta training has been done

    'x_type':'uniform_random', #sampling time for inner and post training tasks
    # 'x_type':'grid', #sampling time for training tasks

    'd':501, #Number of features
    'phi_type':'fourier', #Feature type
     'noise_std':1e-1, #standard deviation of awgn noise added during training
    'optimizer_type':'SGD', #Optimizer type for meta updates Can be either SGD or Adam
    'k_idx':np.arange(5,11), #Frequency range present in tasks during meta training
    'stepsize_meta':1e-2, #Stepsize used for meta updates

     'num_inner_tasks':5, #Number of inner tasks for each meta update
    'num_tasks_test':10, #Number of tasks to test on
    'num_stats': 10, #Determines how often we collect stats

    'num_iterations':100, #Iterations for training meta parameter
        'num_n_train_post_range':40, #How many points do we use to generative test loss vs n_train_post curve


    'stepsize_inner':1e-2, #Stepsize for GD update in inner tasks,
    'num_gd_steps':5, #Number of GD steps in inner task to move towards min norm ls solution

    }
    return params_dict



In [None]:
params_dict = get_params_dict_reg_sgd()
num_gd_steps = 5
cparams_dict = params_dict.copy()
cparams_dict['num_gd_steps'] = num_gd_steps

meta_learning_reg_sgd(cparams_dict)

In [None]:
params_dict = get_params_dict_reg_sgd()
num_gd_steps = 1
cparams_dict = params_dict.copy()
cparams_dict['num_gd_steps'] = num_gd_steps

meta_learning_reg_sgd(cparams_dict)

#### Question

**With `num_gd_steps = 5`, does meta-learning contribute to improved performance during test time? Furthermore, if we change this to `num_gd_steps = 1`, does meta-learning continue to function effectively?** Include the answer in your written submission of the written assignent.

## MAML for Classification


Suppose we want to learn an effective set of feature weights for performing classification tasks. We will use the same setup as the regression problem, but now our training data consists of pairs $(x_i, z_i)$, where $z_i = \text{sgn}(f(x_i))$. The function $f$ represents the underlying true function that is sampled from a distribution $D_T$ as before.

Given a set of feature weights and their corresponding weighted features $\phi_w$, we solve the logistic regression problem to learn a set of coefficients $\alpha$. Using these coefficients, we assign a label to a test point $x_{\text{test}}$ as $\hat{z}_{test} = sgn(\langle \alpha, \phi_w(x_{test}) \rangle)$. The test loss of interest is the classification error $\mathbb{E}[\hat{z}_{test} \neq z_{test}]$, where the expectation is taken over the randomness in the test point. Since the test loss is not differentiable, we use the logistic loss during training (logistic regression).

Next, we describe the process of learning the coefficients $\alpha \in \mathbb{R}^d$. We have training data pairs $(x_i, z_i)$, where $x_i$ represents a data point, and $z_i$ denotes its label from the set ${+1, -1}$. For each point $x_i$, we have a set of weighted features $\phi_w(x_i)$.

Recall that a standard logistic function is defined as $$g(t) = \frac{1}{e^{-t}+1}$$ (refer to: https://en.wikipedia.org/wiki/Logistic_function), which possesses the useful property of $g(-t) = 1 - g(t)$. Consequently, we can use it to model the binary probability: given a value $t$ that might belong to two classes with labels $+1$ and $-1$, we can model the probability of $t$ being part of class $+1$ as $p_1(t) = g(t)$, and conveniently obtain $1 - g(t) = p_{-1}(t)$. In our classification problem setup, this $t = z_i\langle \alpha, \phi_w(x_i) \rangle$. Thus, $g(t)$ can be interpreted as the probability of "correct classification", because for a correct prediction, $t = z_i\langle \alpha, x_i \rangle$ will always be positive, and its probability $g(t)$ will be appropriately limited to 1.

Therefore, to maximize the total log probability of predicting the correct label on all datapoints we find the $\alpha$ that maximizes
$$\Sigma_i \log{ \frac{1}{e^{-z_i\langle \alpha, \phi_w(x_i) \rangle}+1} }.$$ Flipping the sign, this is equivalent to finding the $\alpha$ that  minimizes the loss
$$ - \Sigma_i \log{ \frac{1}{e^{-z_i\langle \alpha, \phi_w(x_i) \rangle}+1} },$$ as listed here: https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression

One way to find this $\alpha$ is by using gradient descent similar to what we did in the regression setting. In the regression setting we used the squared loss but in the classification setting we use the logistic loss.

**Complete the following code** and make sure it run without any errors though the meta learning is being done on regression tasks and test losses are being computed on the regression tasks.

In [None]:
def get_post_data_classification(x_type, phi_type, k_idx, num_tasks_test, d, n_train_post_range, n_test, noise_prob):

    k_val_test = np.random.uniform(-1, 1, size=(len(k_idx), num_tasks_test))
    k_val_test /= np.linalg.norm(k_val_test, axis=0)

    data_dict = {'train': {}, 'test': {}}

    # features_post_complete = featurize(x_train_post_complete, phi_type=phi_type,d=d, normalize = True)
    x_test = generate_x(n_test, 'uniform_random')
    features_test = featurize(x_test, phi_type=phi_type, d=d, normalize=True)
    data_dict['test']['features'] = features_test

    data_dict['test']['x'] = x_test
    data_dict['test']['y'] = []
    for i in range(num_tasks_test):
        k_val_post = k_val_test[:, i]
        y_test = generate_y(features_test, k_idx, k_val_post)
        data_dict['test']['y'].append(y_test)
    for n_train_post in n_train_post_range:
        data_dict['train'][n_train_post] = {}

        x_train_post = generate_x(n_train_post, x_type)
        features_post = featurize(x_train_post, phi_type=phi_type, d=d, normalize=True)

        data_dict['train'][n_train_post]['x'] = x_train_post
        data_dict['train'][n_train_post]['features'] = features_post
        data_dict['train'][n_train_post]['y'] = []

        for i in range(num_tasks_test):
            k_val_post = k_val_test[:, i]
            y_post = generate_y(features_post, k_idx, k_val_post)
            y_post = add_label_noise(y_post, noise_prob)
            # y_post += np.random.normal(0, noise_std, y_post.shape)
            data_dict['train'][n_train_post]['y'].append(y_post)
    return data_dict

def meta_update_classification(w_t, feature_weights_t, n_train_meta, phi_type, d, k_idx, k_val, noise_prob):
    x_meta = generate_x(n_train_meta, 'uniform_random')
    features_meta = featurize(x_meta, phi_type=phi_type, d=d, normalize=True)
    features_meta_t = torch.tensor(features_meta)
    y_meta = generate_y(features_meta, k_idx, k_val)
    y_meta = add_label_noise(y_meta, noise_prob)
    # y_meta += np.random.normal(0, noise_prob, y_meta.shape)

    y_meta_t = torch.tensor(y_meta)
    y_meta_pred_t = features_meta_t @ w_t

    ############################################################################
    # TODO: Complete the following code
    #
    # Hint: In regression, we have written:
    #   criterion = torch.nn.MSELoss(reduction = 'sum')
    #   loss = criterion( y_meta_t, y_meta_pred_t)
    # Now for classification, we need a different loss, defined in the text block.
    # Also, use my_sign for the sign function, to avoid some stupid compatibility issues.
    ############################################################################
    ############################################################################

    return loss


def meta_learning_classification(params_dict):
    seed = params_dict["seed"]
    n_train_inner = params_dict["n_train_inner"]
    n_train_meta = params_dict["n_train_meta"]
    # n_train_post = params_dict["n_train_post"]
    n_test_post = params_dict["n_test_post"]
    x_type = params_dict["x_type"]
    d = params_dict["d"]
    phi_type = params_dict["phi_type"]
    k_idx = params_dict["k_idx"]
    optimizer_type = params_dict["optimizer_type"]
    stepsize_meta = params_dict["stepsize_meta"]
    num_inner_tasks = params_dict["num_inner_tasks"]
    num_tasks_test = params_dict["num_tasks_test"]
    num_stats = params_dict["num_stats"]
    num_iterations = params_dict["num_iterations"]
    noise_prob = params_dict['noise_prob']
    num_n_train_post_range = params_dict['num_n_train_post_range']

    stepsize_inner = params_dict["stepsize_inner"]
    num_gd_steps = params_dict['num_gd_steps']

    # Set seed:
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Parameters
    stats_every = num_iterations // num_stats
    init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = None, None, None, None

    dm = DummyModel(d)
    # Define meta parameter optimizer
    if optimizer_type == 'SGD':
        opt_meta = torch.optim.SGD([dm.feature_weights], lr=stepsize_meta)
    elif optimizer_type == 'Adam':
        opt_meta = torch.optim.Adam([dm.feature_weights], lr=stepsize_meta)
    else:
        raise ValueError

    # Meta training loop
    # Get post train and test data
    n_train_post_range = np.logspace(np.log10(24), np.log10(3 * d), num_n_train_post_range).astype('int')

    # must_have_points  = [n_train_inner, len(k_idx)-1, len(k_idx), len(k_idx)+1]
    must_have_points = [n_train_inner]
    for point in must_have_points:
        if point not in n_train_post_range:
            n_train_post_range = np.hstack([n_train_post_range, point])

    n_train_post_range = np.sort(n_train_post_range)
    n_train_post_range = np.unique(n_train_post_range)
    data_dict = get_post_data_classification(
        x_type, phi_type, k_idx, num_tasks_test, d, n_train_post_range, n_test_post, noise_prob)

    for i in range(num_iterations):
        opt_meta.zero_grad()

        # Get x and features
        x = generate_x(n_train_inner, x_type)
        features = featurize(x, phi_type=phi_type, d=d, normalize=True)
        features_t = to_tensor(features)

        # Loop over inner tasks
        for t in range(num_inner_tasks):
            # Get random coefficients
            k_val = np.random.uniform(-1, 1, size=len(k_idx))
            k_val /= np.linalg.norm(k_val)
            dm.coeffs = torch.nn.Parameter(torch.zeros(d).double())

            # Generate y
            y = generate_y(features, k_idx, k_val)
            # y += np.random.normal(0, noise_std, y.shape)

            y = add_label_noise(y, noise_prob)
            y_t = torch.tensor(y)

            opt_inner = torch.optim.SGD([dm.coeffs], lr=stepsize_inner, weight_decay=1e-5)
            with higher.innerloop_ctx(dm, opt_inner, copy_initial_weights=False, track_higher_grads=True) as (mod, opt):

                for j in range(num_gd_steps):
                    y_pred = mod(features_t)

                    ############################################################
                    # TODO: Complete the following code
                    #
                    # Hint: In regression, we have written:
                    #   loss = torch.mean((y_pred - y_t)**2)
                    # Now for classification, we need a different loss, defined in the text block.
                    # Also, use my_sign for the sign function, to avoid some stupid compatibility issues.
                    ############################################################
                    ############################################################

                    opt_inner.zero_grad()
                    loss.backward(retain_graph=True)
                    opt.step(loss)

                # Meta update
                meta_loss = meta_update_classification(
                    mod.coeffs * mod.feature_weights, mod.feature_weights, n_train_meta, phi_type, d, k_idx, k_val, noise_prob)
                meta_loss.backward(retain_graph=True)

        if i == 0:
            # n_train_post_range = np.logspace(0,np.log10(3*d), 40).astype('int')
            print("-" * 70)
            print("Iteration: ", i)
            # Oracle stats

            # Start comment for dev
            oracle_feature_weights = np.zeros(d)
            oracle_feature_weights[k_idx] = 1
            oracle_n_train_post_range, oracle_avg_test_loss, oracle_top_10_loss, oracle_bot_10_loss = test_classification_oracle(
                data_dict, feature_weights=oracle_feature_weights, x_type=x_type)

            # plt.plot(oracle_n_train_post_range, oracle_avg_test_loss)
            # plt.show()

            # oracle_n_train_post_range,oracle_avg_test_loss, oracle_top_10_loss, oracle_bot_10_loss = None, None, None, None

            zero_avg_loss, zero_top_10_loss, zero_bot_10_loss = test_classification_zero(data_dict)
            # zero_avg_loss, zero_top_10_loss, zero_bot_10_loss  = None, None, None

            # End comment for dev

            feature_weights = to_numpy(dm.feature_weights)
            init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss = test_classification(
                data_dict, feature_weights)

            visualize_test_loss_classification(
                0, n_train_inner, init_n_train_post_range, init_avg_test_loss, init_top_10_loss, init_bot_10_loss,
                oracle_n_train_post_range=oracle_n_train_post_range, oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss, oracle_bot_10_loss=oracle_bot_10_loss, zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss, zero_bot_10_loss=zero_bot_10_loss, noise_prob=noise_prob)

            visualize_prediction_classification(
                data_dict, feature_weights, n_train_inner)

        # Stats
        if (i+1) % stats_every == 0 or i == num_iterations - 1:

            print("-" * 70)
            print("Iteration: ", i + 1)

            feature_weights = to_numpy(dm.feature_weights)

            n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss = test_classification(
                data_dict, feature_weights)
            visualize_test_loss_classification(
                i, n_train_inner, n_train_post_range, avg_test_loss,
                top_10_loss, bot_10_loss, init_n_train_post_range, init_avg_test_loss,
                init_top_10_loss, init_bot_10_loss,
                oracle_n_train_post_range=oracle_n_train_post_range,
                oracle_avg_test_loss=oracle_avg_test_loss,
                oracle_top_10_loss=oracle_top_10_loss,
                oracle_bot_10_loss=oracle_bot_10_loss,
                zero_avg_loss=zero_avg_loss,
                zero_top_10_loss=zero_top_10_loss,
                zero_bot_10_loss=zero_bot_10_loss,
                noise_prob=noise_prob)
            visualize_prediction_classification(
                data_dict, feature_weights, n_train_inner)

        opt_meta.step()  # Finally update meta weights after loop through all tasks


def test_classification(data_dict, feature_weights, min_n_train_post=0):
    # plt.plot(feature_weights, 'o-')
    # plt.show()
    # print(feature_weights)
    # Added for dev
    test_loss_matrix = []

    n_train_post_range = np.sort(np.array(list(data_dict['train'].keys())))

    n_train_post_range = n_train_post_range[n_train_post_range >= min_n_train_post]

    test_data_dict = data_dict['test']
    features_test_post = test_data_dict['features']

    # n_train_post_range = np.array([32])

    for n_train_post in n_train_post_range:
        # print("n", n_train_post)
        train_data_dict = data_dict['train'][n_train_post]
        features_post = train_data_dict['features']

        test_loss_array = []
        for i in range(len(train_data_dict['y'])):

            y_post = train_data_dict['y'][i]
            # print(features_post.shape, y_post.shape)
            z_post = np.squeeze(my_sign(y_post))

            w_post, loss = solve_logistic(
                features_post, z_post, feature_weights)
            y_post_pred = features_post @ w_post

            y_test_post = test_data_dict['y'][i]
            y_test_post_pred = features_test_post@w_post

            z_post_pred = np.squeeze(my_sign(y_post_pred))
            z_test_post = np.squeeze(my_sign(y_test_post))
            z_test_post_pred = np.squeeze(my_sign(y_test_post_pred))

            # Compute the regression loss
            test_loss = np.mean(z_test_post != z_test_post_pred)
            test_loss_array.append(test_loss)

        test_loss_matrix.append(test_loss_array)

    test_loss_matrix = np.array(test_loss_matrix).T
    avg_test_loss = np.mean(test_loss_matrix, 0)

    top_10_loss = np.percentile(test_loss_matrix, 90, axis=0)
    bot_10_loss = np.percentile(test_loss_matrix, 10, axis=0)

    return n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss


def test_classification_zero(data_dict):
    test_data_dict = data_dict['test']
    ys = test_data_dict['y']

    test_loss_array = []
    for i in range(len(ys)):
        y = ys[i]
        z = my_sign(y)
        z_guess = np.random.randint(low=-1, high=1, size=z.shape)
        z_guess = my_sign(z_guess)
        test_loss = np.mean(z != z_guess)
        # NOTE: for regression, the initial loss is np.mean(y**2), i.e. predicting zero for every point,
        # here the equivalent would be randomly guessing signs as labels
        test_loss_array.append(test_loss)

    avg_test_loss = np.mean(test_loss_array)
    top_10_loss = np.percentile(test_loss_array, 90)
    bot_10_loss = np.percentile(test_loss_array, 10)

    return avg_test_loss, top_10_loss, bot_10_loss


def test_classification_oracle(data_dict, x_type, feature_weights, min_n_train_post=0):
    # plt.plot(feature_weights, 'o-')
    # plt.show()
    # print(feature_weights)
    k_idx = np.where(feature_weights != 0)[0]
    feature_weights = None

    test_loss_matrix = []

    n_train_post_range = np.sort(np.array(list(data_dict['train'].keys())))

    n_train_post_range = n_train_post_range[n_train_post_range >= min_n_train_post]

    test_data_dict = data_dict['test']
    features_test_post = test_data_dict['features']

    for n_train_post in n_train_post_range:
        # print("n", n_train_post)
        train_data_dict = data_dict['train'][n_train_post]

        features_post = train_data_dict['features'][:, k_idx]
        cfeatures_test_post = features_test_post[:, k_idx]

        feature_norms = np.linalg.norm(features_post, axis=0)
        r_idx = np.where(feature_norms > 1e-6)[0]

        features_post = features_post[:, r_idx]
        cfeatures_test_post = cfeatures_test_post[:, r_idx]

        test_loss_array = []

        for i in range(len(train_data_dict['y'])):
            y_post = train_data_dict['y'][i]
            z_post = np.squeeze(my_sign(y_post))

            w_post, loss = solve_logistic(features_post, z_post, feature_weights)

            y_post_pred = features_post @ w_post

            y_test_post = test_data_dict['y'][i]
            y_test_post_pred = cfeatures_test_post@w_post

            z_test_post = np.squeeze(my_sign(y_test_post))
            z_test_post_pred = np.squeeze(my_sign(y_test_post_pred))
            test_loss = np.mean(z_test_post != z_test_post_pred)
            # print(test_loss)
            # Compute the regression loss
            test_loss_array.append(test_loss)

        test_loss_matrix.append(test_loss_array)

    test_loss_matrix = np.array(test_loss_matrix).T
    avg_test_loss = np.mean(test_loss_matrix, 0)

    top_10_loss = np.percentile(test_loss_matrix, 90, axis=0)
    bot_10_loss = np.percentile(test_loss_matrix, 10, axis=0)

    return n_train_post_range, avg_test_loss, top_10_loss, bot_10_loss

In [None]:
def get_params_dict_classification():
    params_dict = {
        #Parameters
        'seed':7,
        'n_train_inner':32, #Number of training samples for training inner task
        'n_train_meta':64, #Number of training samples for updating the meta parameter

        'n_train_post':32, #Number of training samples used after meta training has been done to learn the weights
        'n_test_post':1000, #Number of samples used for plotting and evaluating test performance after meta training has been done

        'x_type':'uniform_random', #sampling time for inner and post training tasks
        # 'x_type':'grid', #sampling time for training tasks

        'd':501, #Number of features
        'phi_type':'fourier', #Feature type
        'noise_prob':0.0, #standard deviation of awgn noise added during training
        'optimizer_type':'SGD', #Optimizer type for meta updates Can be either SGD or Adam
        'k_idx':np.arange(5,11), #Frequency range present in tasks during meta training
        'stepsize_meta':1e-2, #Stepsize used for meta updates

        'num_inner_tasks':5, #Number of inner tasks for each meta update
        'num_tasks_test':10, #Number of tasks to test on
        'num_stats': 10, #Determines how often we collect stats

        'num_iterations':100, #Iterations for training meta parameter
        'num_n_train_post_range':40, #How many points do we use to generative test loss vs n_train_post curve



        'num_gd_steps':5, #Number of GD steps in inner task to move towards min norm ls solution

    }
    return params_dict

params_dict = get_params_dict_classification()
cparams_dict = params_dict.copy()
stepsize_inner = 1e-2
cparams_dict['stepsize_inner'] = stepsize_inner

meta_learning_classification(cparams_dict)

In [None]:
params_dict = get_params_dict_classification()
cparams_dict = params_dict.copy()
stepsize_inner = 3e-1
cparams_dict['stepsize_inner'] = stepsize_inner

meta_learning_classification(cparams_dict)

After making the required changes, the plots contain the following information:

- First row
  - The left plot shows the test classification error with respect to the number of data points used for logistic regression after meta-training is complete.
    - The green curve represents the test loss from the oracle (using only features present in the true function).
    - Blue curve shows the feature weights learned at the current meta-training iteration.
    - The orange curve marks the performance at initialization with all ones feature weight.
    - The yellow curve provides a baseline where we guess the same label for all points (resulting in a 0.5 classification error).
    - In each curve, we plot the average, 10th, and 90th percentile of the classification error over 10 tasks.
  
  - The right plot on the first row is a zoomed-in version of the left plot.

- Second row

  - The left plot displays the true sign labels in orange and the predicted sign labels in green.
    - The red dots represent the labels corresponding to the training points.
    - The green crosses show our predictions on the training points.
    - The brown curve illustrates the underlying true function.
    - The purple curve presents the function given by $\langle \alpha, x \rangle$ after normalizing $\alpha$ to have a unit norm.

  - The right plot on the second row illustrates the feature weights at the current iteration being meta-learned.


#### Question

Based on the plot of classification error versus `n_train_post`, **how does the performance of the meta-learned feature weights compare to the case where all feature weights are 1? How does the performance of the meta-learned feature weights compare to the oracle** (which performs logistic regression using only the features present in the data)? Include the answer in your written submission of the written assignent.

#### Question

By observing the evolution of the feature weights over time as we perform meta-learning, **can you justify the improvement in performance? Specifically, can you explain why some feature weights are being driven towards zero?** Include the answer in your written submission of the written assignent.