In [1]:
from seqpred.data import prep_data, BaseDataset
from seqpred.nn import SequentialMargeNet
import yaml
import torch
import polars as pl
from morphers.base.categorical import Integerizer

checkpoint_path = "./model/epoch=39-validation_loss=1.731.ckpt"
data_files = ["./data/2023_data.parquet"]

In [6]:
def unmorph(pitches, morphers):
    unmorphed_pitches = {}
    for pk, pv in pitches.items():
        if isinstance(morphers[pk], Integerizer):
            reverse_vocab = {
                v: k for k, v in morphers[pk].vocab.items()
            }
            vector = pv.tolist()
            unmorphed_pitches[pk] = [reverse_vocab.get(item, "<NONE>") for item in vector]
        else:
            raise NotImplementedError("Later")
    return unmorphed_pitches

In [3]:

with open("cfg/config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.CLoader)

model = SequentialMargeNet.load_from_checkpoint(checkpoint_path)

morpher_dict = model.hparams["morphers"]

data, morphers  = prep_data(
    data_files=data_files,
    rename=config["rename"],
    morphers=morpher_dict,
)
ds = BaseDataset(
    data,
    morpher_dict,
    model.hparams["max_length"],
)

In [7]:
example = ds[900]
after_n_pitches = 120
max_to_generate = 300

with torch.inference_mode():
    # noooooooo
    inning_mask = (
        torch.arange(example["end_of_inning"].shape[0]) < after_n_pitches
    ) & ~torch.isinf(example["pad_mask"])

    x = {
        k: v[inning_mask].to(model.device).unsqueeze(0)
        for k, v in example.items()
        if isinstance(v, torch.Tensor)
    }

    for i in range(max_to_generate):
        generated_pitch = model.generate_one(x, temperature=1.0)
        x = {
            k: torch.cat([v, generated_pitch[k].unsqueeze(0)], dim=1)
            for k, v in x.items()
            if k != "pad_mask"
        }
        if generated_pitch["end_of_game"].item() == 1:
            print(f"Reached end of game: generated {i+1} pitches")
            break

    context = {k: v[:, :after_n_pitches] for k, v in x.items()}
    generated = {k: v[:, after_n_pitches:] for k, v in x.items()}

context_df = pl.DataFrame(
    unmorph(
        {k: v.squeeze().cpu().numpy() for k, v in context.items()},
        morpher_dict,
    )
)
generated_df = pl.DataFrame(
    unmorph(
        {k: v.squeeze().cpu().numpy() for k, v in generated.items()},
        morpher_dict,
    )
)

Reached end of game: generated 269 pitches
