# Task B: Meta-Learning Perfomance Prediction

In this task, you will use learning curves on multiple openml dataset to train a performance predictor that performs well even for unseen datasets. You should only use the first 10 epochs of the learning curves in your predictions. You are provided with learning curves and config parameters for six datasets. The datasets are split into training datasets and meta datasets and you should only train on the training datasets.

Note: This notebook is meant to show how to use the API. You can choose which data you use for your predictions and should create your own dataloading and splits, however your are free to use code from here.

## Specifications:

* Data: six_datasets_lw.json
* Number of datasets: 6
* Training datasets: higgs, vehicle, adult, volkert
* Meta datasets: Fashion-MNIST, jasmine
* Number of configurations: 2000
* Number of epochs seen when predicting: 10
* Available data: Learning curves, architecture parameters and hyperparameters 
* Target: Final validation accuracy
* Evaluation metric: MSE

## Importing and splitting data

Note: There are 51 steps logged, 50 epochs plus the 0th epoch, prior to any weight updates.

In [1]:
%%capture
%cd ..
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

from api import Benchmark

In [2]:
bench_dir = "cached/six_datasets_lw.json"
bench = Benchmark(bench_dir, cache=False)

==> Loading data...
==> No cached data found or cache set to False.
==> Reading json data...
==> Done.


In [3]:
# Dataset split
dataset_names = bench.get_dataset_names()
print(dataset_names)

train_datasets = ['adult', 'higgs', 'vehicle', 'volkert']
test_datasets = ['Fashion-MNIST', 'jasmine']

['Fashion-MNIST', 'adult', 'higgs', 'jasmine', 'vehicle', 'volkert']


In [4]:
# Prepare data
def read_data(datasets):
    n_configs = bench.get_number_of_configs(datasets[0])
    data = [bench.query(dataset_name=d, tag="Train/val_accuracy", config_id=ind) for d in datasets for ind in range(n_configs)]
    configs = [bench.query(dataset_name=d, tag="config", config_id=ind) for d in datasets for ind in range(n_configs)]
    dataset_names = [d for d in datasets for ind in range(n_configs)]
    
    X = np.array([curve[:-1] for curve in data])
    y = np.array([curve[-1] for curve in data])
    return X, y, np.array(configs), np.array(dataset_names)

class TrainValSplitter():
    """Splits 30 % data as a validation split."""
    
    def __init__(self, dataset_names):
        self.ind_train, self.ind_val = train_test_split(np.arange(len(X)), test_size=0.3, stratify=dataset_names)
        
    def split(self, a):
        return a[self.ind_train], a[self.ind_val]
    
    def cut(self, a, outlength=11):
        return np.array([curve[:outlength] for curve in a])

X, y, configs, dataset_names = read_data(train_datasets)
X_test, y_test, configs_test, dataset_names_test = read_data(test_datasets)

tv_splitter = TrainValSplitter(dataset_names=dataset_names)

X_train, X_val = tv_splitter.split(X)
y_train, y_val = tv_splitter.split(y)
configs_train, configs_val = tv_splitter.split(configs)
dataset_names_train, dataset_names_val = tv_splitter.split(configs)

X_test, X_val = tv_splitter.cut(X_test), tv_splitter.cut(X_val)

print("X_train:", X_train.shape)
print("X_test:", X_test.shape)
print("X_val:", X_val.shape)

X_train: (5600, 51)
X_test: (4000, 11)
X_val: (2400, 11)


## A simple baseline

In [5]:
class SimpleLearningCurvePredictor():
    """A learning curve predictor that predicts the last observed epoch as final performance"""
    
    def __init__(self):
        pass
        
    def fit(self, X, y):
        pass
    
    def predict(self, X):
        predictions = []
        for curve in X:
            predictions.append(curve[-1])
        return predictions
    
def score(y_true, y_pred):
    return mean_squared_error(y_true, y_pred)

In [6]:
predictor = SimpleLearningCurvePredictor()
predictor.fit(X_train, y_train)
preds = predictor.predict(X_val)
mse = score(y_val, preds)
print(mse)

33.580016418521055
