In [None]:
# Initially at notebook directory, so step into the root directory (Run this once once)
%pwd  # in notebook directory
%cd ".." 

## Load dataset

In [None]:
from sage.all import *

import json
with open("dataset/sagemath/GB_initial_problem/GF7_n=3/test_data.json", "r") as f:
    testset = json.load(f)
    
ring = PolynomialRing(GF(7), 3, "x", order="degrevlex")

Fs, Gs = [], []
for data in dataset: 
    F = list(map(lambda x: ring(x), data['input']))  # ring() maps a polynomial in text to a polynomial object.
    G = list(map(lambda x: ring(x), data['output']))
    Fs.append(F)
    Gs.append(G)

## Load Model

In [None]:
from pathlib import Path
from transformers import AutoModelForSeq2SeqLM

use_checkpoint = False
model_path = Path('results/partial_sum/GF7_n=3')

if use_checkpoint:
    checkpoint_id = get_checkpoint_id(model_path)
    model_path = model_path / f'checkpoint-{checkpoint_id}'

model = AutoModelForSeq2SeqLM.from_pretrained(
    model_path,
    local_files_only=True
)

## Generation

**Prepare dataloader**

In [None]:
from omegaconf import OmegaConf
from calt import data_loader
from torch.utils.data import DataLoader

# load training config
cfg = OmegaConf.load(model_path / 'train_partial_sum.yaml')  

dataset, tokenizer, data_collator = data_loader(
        train_dataset_path=cfg.data.train_dataset_path,
        test_dataset_path=cfg.data.test_dataset_path,
        field=cfg.data.field,
        num_variables=cfg.data.num_variables,
        max_degree=cfg.data.max_degree,
        max_coeff=cfg.data.max_coeff,
        max_length=cfg.model.max_sequence_length,
    )

test_set = dataset['test']
test_loader = DataLoader(test_set, batch_size=cfg.train.test_batch_size, shuffle=False, collate_fn=data_collator)

**forwarding**

In [None]:
batch = next(iter(test_loader))  # get the first batch
outputs = model(**batch)

**generation**

In [None]:
outputs = model.generate(**batch)
decoded_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
print(decoded_texts)

['C-3 E0 E0 E1 [SEP] C-1 E5 E0 E1 C3 E3 E2 E1 C-2 E1 E2 E2 C-3 E0 E0 E1 [SEP] C-1 E5 E0 E1 C3 E3 E2 E1 C-2 E1 E2 E2 C3 E0 E1 E0 C-3 E0 E0 E1 C-3 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 [SEP] C-3 E0 E0 E0 E0 [SEP] C-6', 'C2 E1 E0 E0 C-3 E0 E1 E0 [SEP] C1 E6 E1 E0 C-1 E0 E0 E4 C-3 E0 E0 E2 C-3 E1 E0 E0 C-3 E0 E1 E0 [SEP] C1 E6 E1 E0 C-3 E0 E3 E3 C-1 E0 E0 E4 C-3 E0 E0 E2 C-3 E1 E0 E0 C-1 E0 E1 E0 C-3 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 [SEP] C-3 E0 E0 E0 E0 [SEP] C-6', 'C2 E0 E1 E0 [SEP] C-1 E1 E4 E5 C-2 E6 E0 E2 C3 E0 E0 E2 C2 E0 E1 E0 C-3 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 [SEP] C-3 E0 E0 E0 E0 E0 C-6', 'C-3 E0 E0 E4 C1 E1 E1 E0 [SEP] C3 E5 E0 E3 C-3 E0 E0 E4 C1 E1 E1 E0 [SEP] C3 E5 E0 E3 C1 E2 E2 E3 C-3 E0 E0 E4 C1 E1 E1 E0 C1 E0 E0 E0 [SEP] C3 E5 E0 E3 C1 E2 E2 E3 C-3 E0 E0 E4 C1 E1 E1 E0 C1 E0 E0 E1 C1 E0 E0 E0 [SEP] C3 E5 E0 E3 C1 E2 E2 E3 C-3 E1 E3 E0 C-3 E0 E0 E4 C-2 E2 E1 E0 C1 E1 E1 E0 C1 E0 E0 E1 C-2 E0 E0 E0 C-3 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 E0 [SEP] C3 E0 E0 E0 E0 E0 C-6', 'C3 E1 E3 E0 