In [None]:
import numpy as np
import torch 
import matplotlib.pyplot as plt 
import pickle 
from sklearn.manifold import TSNE
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error as mse 

from tabulate import tabulate 

from common import *
from hydra import initialize, compose


In [None]:
%cd .. 


In [None]:
from launch_experiment import initialize as init_agent
import rlkit.torch.pytorch_util as ptu


In [None]:
def load_agent(env, algo='focal'):
    with initialize(version_base="1.3", config_path="../cfgs", ):
        cfg = compose('experiment', overrides=[f'+env={env}', '+algo=focal'])
    agent = init_agent(cfg)
    return agent, cfg

In [None]:
def get_context(agent, env_name, train_idx=(0, 4, 9, 15, 18), moderate_idx=(0,5,9), batch_size=256):
    indices = [task_id_conv[env_name]['train'][idx] for idx in train_idx]
    moderate_indices = [task_id_conv[env_name]['moderate'][idx] for idx in moderate_idx]
    
    batches = [ptu.np_to_pytorch_batch(agent.replay_buffer.random_batch(idx, batch_size=batch_size)) for idx in indices]
    moderate_batches = [ptu.np_to_pytorch_batch(agent.eval_buffer.random_batch(idx, batch_size=batch_size)) for idx in moderate_indices]
    indices.extend(moderate_indices)
    batches.extend(moderate_batches)
    indices=np.array(indices)
    sorted_indices = np.argsort(indices)

    context = [agent.unpack_batch(batch, sparse_reward=False) for batch in batches]
    # group like elements together
    context = [[x[i] for x in context] for i in range(len(context[0]))]
    context = [torch.cat(x, dim=0) for x in context] # 5 * self.meta_batch * self.embedding_batch_size * dim(o, a, r, no, t)
    # full context consists of [obs, act, rewards, next_obs, terms]
    # if dynamics don't change across tasks, don't include next_obs
    # don't include terminals in context
    if agent.use_next_obs_in_context:
        context = torch.cat(context[:-1], dim=2)
    else:
        context = torch.cat(context[:-2], dim=2)
        
    indices = indices[sorted_indices]
    context = context[sorted_indices]
    tasks = []
    return context, indices 

In [None]:
def load_encoder(agent, env_name, algo, seed=0):
    encoder = agent.agent.context_encoder 
    path = f'output/{env_name}/{algo}/seed{seed}/agent.pth'
    checkpoint = torch.load(path)
    encoder.load_state_dict(checkpoint['context_encoder'])    
    return encoder

In [None]:
def embed(context, encoder):
    z = encoder(context)
    task, batch, _ = z.shape 
    z = z.reshape(task*batch, -1).detach().cpu().numpy()
    embed = TSNE(n_components=2, ).fit_transform(z)
    return embed.reshape(task, batch, -1)

In [None]:
agent, cfg = load_agent(env='ant-dir')

In [None]:
context, indices = get_context(agent, cfg.env_name, batch_size=256)

In [None]:
total_x = []
for algo in algo_names:
    encoder = load_encoder(agent, cfg.env_name, algo,)
    x = embed(context, encoder)
    total_x.append(x)
    

In [None]:
plt.rcParams["figure.dpi"] = 400
plt.rcParams["font.size"] = 19
plt.rcParams["legend.fontsize"] = 19
plt.rcParams["text.usetex"] = True
plt.rcParams["figure.autolayout"] = True


fig, axs = plt.subplots(2, 3, figsize=(12, 8))
cm = plt.get_cmap('Set1')
colors = cm(np.linspace(0, 1, 8))
for ind, x in enumerate(total_x):
    ind_x, ind_y = ind%2, ind//2

    title = algo_names[ind]
    axs[ind_x, ind_y].title.set_text(title)
    lines = []
    labels = []
    for i in range(x.shape[0]):
        agent.env.reset_task(indices[i])
        goal = np.round(agent.env._goal, 2)
        mode = agent.env.get_mode()
        label = f'{goal} ({mode})'
        l = axs[ind_x, ind_y].scatter(x[i,:,0], x[i,:,1], s=20, label=label, alpha=0.6, color=colors[i])
        lines.append(l)
        labels.append(label)

plt.tight_layout()
axs[-1,-1].axes.get_xaxis().set_visible(False)
axs[-1,-1].axes.get_yaxis().set_visible(False)
axs[-1,-1].spines['top'].set_visible(False)
axs[-1,-1].spines['right'].set_visible(False)
axs[-1,-1].spines['bottom'].set_visible(False)
axs[-1,-1].spines['left'].set_visible(False)
axs[-1, -1].legend(lines, labels, loc='center left', ncol=1, )
# plt.legend(loc='upper left', ncol=2, bbox_to_anchor=(0, 1.5))
#axs[0].set_ylabel(env_names[env_name], fontsize=25)
# plt.savefig('figs/two_col_Ant-DIR.pdf')

In [None]:
def prepare_data(agent, env_name, context, algo):
    encoder = load_encoder(agent, env_name, algo)
    z = encoder(context)
    z = z.detach().cpu().numpy()
    goals = []
    for ind in indices:
        agent.env.reset_task(ind)
        goal = agent.env._goal 
        goals.append(goal)
    goals = np.array(goals)
    if goals.ndim==1:
        goals = goals.reshape(-1,1)
    task, n_sample, _ = z.shape
    goals = np.repeat(goals[:,None,], n_sample, 1)
    idx_train = n_sample//5
    n_train = n_sample-idx_train
    n_test = idx_train 
    z_train, z_test = z[:, idx_train:, ], z[:, :idx_train,]
    goals_train, goals_test = goals[:, idx_train:, ], goals[:, :idx_train,]
    z_train = z_train.reshape(task*n_train, -1)
    z_test = z_test.reshape(task*n_test, -1)

    goals_train = goals_train.reshape(task*n_train, -1)
    goals_test = goals_test.reshape(task*n_test, -1)
    return z_train, z_test, goals_train, goals_test 

def svr_model(data):
    z_train, z_test, goals_train, goals_test = data
    model = SVR(kernel='rbf')
    model = MultiOutputRegressor(model)
    model.fit(z_train, goals_train)
    train_pred, test_pred = model.predict(z_train), model.predict(z_test)
    rmse_train = mse(goals_train, train_pred)**0.5 
    rmse_test = mse(goals_test, test_pred)**0.5 
    return rmse_train, rmse_test
    
def linear_model(data):
    z_train, z_test, goals_train, goals_test = data
    model = LinearRegression()
    model.fit(z_train, goals_train)
    train_pred, test_pred = model.predict(z_train), model.predict(z_test)
    rmse_train = mse(goals_train, train_pred)**0.5 
    rmse_test = mse(goals_test, test_pred)**0.5 
    return rmse_train, rmse_test

In [None]:
from collections import defaultdict
linear_rmse_trains = defaultdict(list)
linear_rmse_tests = defaultdict(list)

svr_rmse_trains = defaultdict(list)
svr_rmse_tests = defaultdict(list)

for env in ['cheetah-vel', 'ant-goal', 'ant-dir', 'humanoid-dir', 
            'hopper-mass', 'hopper-friction', 'walker-mass', 'walker-friction']:
    
    seed = 0
    n_samples = 1000
    agent, cfg = load_agent(env)
    
    for algo in algo_names:
        res_linear_trains = []
        res_linear_tests = []
        res_svr_trains = []
        res_svr_tests = []
        for i in range(5):
            context, indices = get_context(agent, cfg.env_name, batch_size=n_samples, train_idx=list(range(20)), moderate_idx=list(range(10)))
            data = prepare_data(agent, cfg.env_name, context, algo)
            train, test = linear_model(data)
            res_linear_trains.append(train)
            res_linear_tests.append(test)
            train, test = svr_model(data)
            res_svr_trains.append(train)
            res_svr_tests.append(test)
        linear_rmse_trains[algo].append((np.mean(res_linear_trains), np.std(res_linear_trains)))
        linear_rmse_tests[algo].append((np.mean(res_linear_tests), np.std(res_linear_tests)))
        svr_rmse_trains[algo].append((np.mean(res_svr_trains), np.std(res_svr_trains)))
        svr_rmse_tests[algo].append((np.mean(res_svr_tests), np.std(res_svr_tests)))
        


In [None]:
def make_table_pred(rsme_dict):
    results = {}
    results['Environment'] = [env_names[env] for env in env_ids]
    results['Model'] = ['\multirow{8}{*}{Linear Regression}', ]
    for key, values in rsme_dict.items():
        strings = []
        for mean, std in values:
            strings.append(f'$ {mean:.4f} \pm {std:.4f} $')
        results[key] = strings
    print(tabulate(results, tablefmt='latex_raw', headers=results.keys()))
    #return results

In [None]:
make_table_pred(linear_rmse_tests)

In [None]:
make_table_pred(svr_rmse_tests)

In [None]:
total_results = {
    ('linear', 'train'): linear_rmse_trains, 
    ('linear', 'test'): linear_rmse_tests, 
    ('svr', 'train'): svr_rmse_trains, 
    ('svr', 'test'): svr_rmse_trains, 
}

In [None]:
with open('scripts/prediction_results.pkl', 'wb') as f:
    pickle.dump(total_results, f)

In [None]:
with open('scripts/prediction_results.pkl', 'rb') as f:
    total_results = pickle.load(f)

In [None]:
make_table_pred(total_results[('svr'),('test')])

In [None]:
mean_std = total_results[('svr'),('test')]

In [None]:
from scipy.stats import t
def t_test(mean1, mean2, std1, std2, n=5):
    t_value = (mean2-mean1)/(np.sqrt(std1**2/n+std2**2/n))
    p_value = t.sf(np.abs(t_value), df=4)
    return np.round(p_value,3)

In [None]:
def result_t_test(mean_std):
    means = dict()
    stds = dict()
    for key, values in mean_std.items():
        means[key] = np.array([val[0] for val in values])
        stds[key] = np.array([val[1] for val in values])    
    resutls = []
    for algo in means.keys():
        test_vals = t_test(means[algo], means['ER-TRL'], stds[algo], stds['ER-TRL'])
        resutls.append(test_vals)
        print(f'comparing {algo} to ER-TRL')
        print(t_test(means[algo], means['ER-TRL'], stds[algo], stds['ER-TRL']))
        print()
    print(np.array(resutls)[:-1,:])
    print(np.array(resutls)[:-1,:].max(0)<=0.05)

In [None]:
result_t_test(total_results[('linear'),('test')])

In [None]:
result_t_test(total_results[('svr'),('test')])

In [None]:
t_test(0.16565, 0.16571,  0.00130,  0.00126)