# Scratchpad for paper revisions

In [None]:
%load_ext autoreload
%autoreload 2
import pickle
import os, sys
root_path = os.path.realpath('../')
sys.path.append(root_path)

import torch
from pathlib import Path
import numpy as np
import random


from utils.data import make_blobs_dataset
from utils.nnet import get_device

from hebbcl.logger import LoggerFactory
from hebbcl.model import Nnet
from hebbcl.trainer import Optimiser, train_on_blobs
from hebbcl.parameters import parser
from hebbcl.tuner import HPOTuner

## Hyperparameter optimisation
hpo on network trained with fewer episodes

### HPO: blocked trials with oja_ctx

In [None]:
# HPO on blocked trials with oja_ctx
args = parser.parse_args(args=[])
args.n_episodes = 8
args.hpo_fixedseed = True
args.hpo_scheduler = "bohb"
args.hpo_searcher = "bohb"
# dict(sorted(vars(args).items(),key=lambda k: k[0]))
args.ctx_avg = False
# init tuner
tuner = HPOTuner(args, time_budget=60*15, metric="loss")

tuner.tune(n_samples=500)

df = tuner.results
df = df[["mean_loss", "mean_acc", "config.lrate_sgd","config.lrate_hebb", "config.ctx_scaling","config.seed","done"]]
df = df[df["done"]==True]
df = df.drop(columns=["done"])
df = df.dropna()
df = df.sort_values("mean_loss",ascending=True)

df.reset_index()
print(df.head(15))

print(tuner.best_cfg)

with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "wb") as f:
    pickle.dump(df, f)

In [None]:
with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "rb") as f:
    df = pickle.load(f)

df.iloc[0]

In [None]:
# verify results 
with open("../results/raytune_oja_ctx_blocked_8episodes.pkl", "rb") as f:
    df = pickle.load(f)
# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 8
args.lrate_hebb = df.iloc[0]["config.lrate_hebb"]
args.lrate_sgd = df.iloc[0]["config.lrate_sgd"]
args.ctx_scaling = df.iloc[0]["config.ctx_scaling"]
args.ctx_avg = False
np.random.seed(int(df.iloc[0]["config.seed"]))
random.seed(int(df.iloc[0]["config.seed"]))
torch.manual_seed(int(df.iloc[0]["config.seed"]))


# create dataset 
dataset = make_blobs_dataset(args)

# instantiate logger, model and optimiser:
logger = LoggerFactory.create(args, save_dir)
model = Nnet(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_on_blobs(args, model, optimiser, dataset, logger)

print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.results['acc_total'][-1]:.2f}, loss: {logger.results['losses_total'][-1]:.2f}")

### HPO: Interleaved trials

In [None]:
# HPO on blocked trials with oja_ctx
args = parser.parse_args(args=[])
args.n_episodes = 8
args.hpo_fixedseed = True
args.hpo_scheduler = "bohb"
args.hpo_searcher = "bohb"
args.training_schedule = "interleaved"
# dict(sorted(vars(args).items(),key=lambda k: k[0]))
args.ctx_avg = False
# init tuner
tuner = HPOTuner(args, time_budget=60*15, metric="loss")

tuner.tune(n_samples=500)

df = tuner.results
df = df[["mean_loss", "mean_acc", "config.lrate_sgd","config.lrate_hebb", "config.ctx_scaling","config.seed","done"]]
df = df[df["done"]==True]
df = df.drop(columns=["done"])
df = df.dropna()
df = df.sort_values("mean_loss",ascending=True)

df.reset_index()
print(df.head(15))

print(tuner.best_cfg)

with open("../results/raytune_oja_ctx_interleaved_8episodes.pkl", "wb") as f:
    pickle.dump(df, f)

In [None]:
# verify results 

# obtain params
args = parser.parse_args(args=[])

# set checkpoint directory
save_dir = (
        Path("checkpoints") / "test_allhebb"
    ) 

# get device (gpu/cpu)
args.device = get_device(args.cuda)[0]

# override defaults 
args.n_episodes = 8
args.lrate_hebb = df.iloc[0]["config.lrate_hebb"]
args.lrate_sgd = df.iloc[0]["config.lrate_sgd"]
args.ctx_scaling = df.iloc[0]["config.ctx_scaling"]
args.ctx_avg = False
args.training_schedule = "interleaved"
np.random.seed(int(df.iloc[0]["config.seed"]))
random.seed(int(df.iloc[0]["config.seed"]))
torch.manual_seed(int(df.iloc[0]["config.seed"]))



# create dataset 
dataset = make_blobs_dataset(args)

# instantiate logger, model and optimiser:
logger = LoggerFactory.create(args, save_dir)
model = Nnet(args)
optimiser = Optimiser(args)

# send model to device (GPU?)
model = model.to(args.device)


# train model
train_on_blobs(args, model, optimiser, dataset, logger)

print(f"config: lrate_sgd: {args.lrate_sgd:.4f}, lrate_hebb: {args.lrate_hebb:.4f}, context offset: {args.ctx_scaling}")
print(f"terminal accuracy: {logger.results['acc_total'][-1]:.2f}, loss: {logger.results['losses_total'][-1]:.2f}")