# ESN trained on Shakespeare

In [None]:
%cd ..

In [None]:
import lightning.pytorch as L
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import EarlyStopping, Callback

from torch.utils.data import DataLoader, random_split

import optuna
from optuna.integration import PyTorchLightningPruningCallback

from aptorch.data import get_shakespeare_text, CharacterDataset
from aptorch.esn import ESN_Pretrained

seed_everything(42, workers=True)

In [None]:
text = get_shakespeare_text()
train_size = int(len(text) * 0.9)
train_text = text[:train_size]
test_text = text[train_size:]
characters = sorted(list(set(train_text)))
pad_token, mask_token, unk_token = "[PAD]", "[MASK]", "[UNK]"
special_tokens = {
    0: pad_token,
    1: mask_token,
    2: unk_token,
}
start_idx = max(list(special_tokens.keys())) + 1
char_tokens = {i + start_idx: c for i, c in enumerate(characters)}
i2c = {**special_tokens, **char_tokens}  # index to char
c2i = {v: k for k, v in i2c.items()}  # char to index
print(f"characters: {len(characters)}")
print(f"i2c: {i2c}")
print(f"c2i: {c2i}")


design_set = CharacterDataset(
    text=train_text,
    chat_to_token=c2i,
    unk_token=unk_token,
)
test_set = CharacterDataset(
    text=test_text,
    chat_to_token=c2i,
    unk_token=unk_token,
)

In [None]:
class ResetState(Callback):
    def on_epoch_end(self, trainer, pl_module: ESN_Pretrained):
        pl_module.reservoir.reset_state()


def objective(trial: optuna.trial.Trial) -> float:
    params = {
        "input_size": design_set.input_len,
        "hidden_size": trial.suggest_categorical("hidden_size", [8, 16, 32]),
        "num_tokens": design_set.vocab_size,
        "lr": trial.suggest_float("lr", 1e-3, 1e-1),
    }

    train_set, valid_set = random_split(design_set, [0.8, 0.2])
    train_loader = DataLoader(
        train_set,
        batch_size=1024,
        shuffle=True,
        num_workers=9,
        persistent_workers=True,
    )
    valid_loader = DataLoader(
        valid_set,
        batch_size=1024,
        shuffle=False,
        num_workers=9,
        persistent_workers=True,
    )

    model = ESN_Pretrained(**params)

    early_stopping_callback = EarlyStopping("val_loss", min_delta=0.01)
    priuning_callback = PyTorchLightningPruningCallback(trial, monitor="val_acc")
    reset_state_callback = ResetState()
    trainer = L.Trainer(
        accelerator="mps",
        devices=1,
        max_epochs=10,
        deterministic=True,
        callbacks=[
            early_stopping_callback,
            priuning_callback,
            reset_state_callback,
        ],
        # fast_dev_run=True,
    )
    trainer.logger.log_hyperparams(params)
    trainer.fit(
        model=model,
        train_dataloaders=train_loader,
        val_dataloaders=valid_loader,
    )
    priuning_callback.check_pruned()

    return trainer.callback_metrics["val_acc"].item()


pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(
    study_name="esn_gpt",
    direction="maximize",
    pruner=pruner,
    load_if_exists=False,
)
study.optimize(objective, n_trials=5)
print(f"Number of finished trials: {len(study.trials)}")
print(f"Best trial: {study.best_trial}")