In [1]:
%cd '/app'

/app


In [2]:
%load_ext autoreload
%autoreload 2

### Loading border basis dataset 

In [3]:
import os

# CUDE_VISIBLE_DEVICES=6
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

from pathlib import Path
from src.loader.data import load_data
from src.loader.data_format.processors.base import ProcessorChain
# from src.loader.tokenizer import set_vocab, set_tokenizer

### Get tokenizer

In [6]:
from src.loader.checkpoint import load_pretrained_bag

save_path = 'results/train/expansion/expansion/custom_bart/base_k_lt=5_m=1000000'
bag = load_pretrained_bag(save_path, from_checkpoint=True)

tokenizer = bag['tokenizer']
config = bag['config']

### Load Data

In [7]:
# load config

data_path  = f"data/border_basis/GF31_n=3_deg=4_terms=10_bounds=4_4_4_total=4"
import yaml
with open(f'{data_path}/config.yaml', 'r') as f:
    exp_config = yaml.safe_load(f)

data_path  = f"data/expansion/GF31_n=3_deg=4_terms=10_bounds=4_4_4_total=4"

from src.loader.data_format.processors.expansion import ExtractKLeadingTermsProcessor
from src.loader.data_format.processors.subprocessors import MonomialProcessorPlus


data_collator_name = 'monomial'

_processors = []
_processors.append(ExtractKLeadingTermsProcessor(config.num_leading_terms))

subprocessors = {}
subprocessors['monomial_ids'] = MonomialProcessorPlus(
            num_variables=config.num_variables,
            max_degree=config.max_degree,
            max_coef=int(config.field[2:])  # 'GF7' -> 7
        )

processor = ProcessorChain(_processors) 

# load test dataset
test_data_path = Path(data_path) / 'test'
test_dataset, data_collator = load_data(
    data_path=test_data_path,
    processor=processor,
    subprocessors=subprocessors,
    splits=[{"name": "test", "batch_size": 32, "shuffle": False}],
    tokenizer=tokenizer,
    return_dataloader=False,  # return dataloader if True
    data_collator_name=data_collator_name
)

In [11]:
from src.dataset.processors.utils import sequence_to_poly
from sage.all import PolynomialRing, GF

sample = test_dataset[0]
print(f'sample input sequence:\n {sample["input"]}')

ring = PolynomialRing(GF(int(config.field[2:])), 'x', config.num_variables, order='degrevlex')
input_seq = sample['input']
L_seq, V_seq = input_seq.split(' [BIGSEP] ')

L = [sequence_to_poly(l, ring) for l in L_seq.split('[SEP]')]
V = [sequence_to_poly(v, ring) for v in V_seq.split('[SEP]')]

print(f'sample input: \n L: {L}\n V: {V}')

sample input sequence:
 C1 E0 E1 E7 [SEP] C1 E1 E4 E3 [SEP] C1 E2 E3 E3 [SEP] C1 E3 E2 E3 [SEP] C1 E4 E1 E3 [SEP] C1 E5 E0 E3 [SEP] C1 E1 E5 E2 [SEP] C1 E2 E4 E2 [SEP] C1 E3 E3 E2 [SEP] C1 E4 E2 E2 [SEP] C1 E5 E1 E2 [SEP] C1 E6 E0 E2 [SEP] C1 E1 E6 E1 [SEP] C1 E2 E5 E1 [SEP] C1 E3 E4 E1 [SEP] C1 E4 E3 E1 [SEP] C1 E5 E2 E1 [SEP] C1 E6 E1 E1 [SEP] C1 E7 E0 E1 [SEP] C1 E2 E6 E0 [SEP] C1 E3 E5 E0 [SEP] C1 E4 E4 E0 [SEP] C1 E5 E3 E0 [SEP] C1 E6 E2 E0 [SEP] C1 E7 E1 E0 [SEP] C1 E8 E0 E0 [SEP] C1 E2 E0 E7 [SEP] C1 E1 E2 E6 [SEP] C1 E2 E1 E6 [SEP] C1 E3 E0 E6 [SEP] C1 E1 E3 E5 [SEP] C1 E2 E2 E5 [SEP] C1 E3 E1 E5 [SEP] C1 E4 E0 E5  [BIGSEP]  C1 E4 E1 E0 + C19 E3 E1 E1 + C27 E4 E0 E0 + C13 E3 E1 E0 + C21 E3 E0 E1  [SEP]  C1 E4 E1 E1 + C19 E3 E1 E2 + C27 E4 E0 E1 + C13 E3 E1 E1 + C21 E3 E0 E2  [SEP]  C1 E4 E2 E0 + C19 E3 E2 E1 + C27 E4 E1 E0 + C13 E3 E2 E0 + C21 E3 E1 E1  [SEP]  C1 E5 E1 E0 + C19 E4 E1 E1 + C27 E5 E0 E0 + C13 E4 E1 E0 + C21 E4 E0 E1  [SEP]  C1 E1 E2 E4 + C25 E0 E3 E4 + C14 E2 E0 

### Oracle

In [12]:
from src.oracle.transformer_oracle import TransformerOracle

oracle = TransformerOracle(ring, save_path, leading_term_k=config.num_leading_terms)

/app


  model_state_dict = torch.load(os.path.join(checkpoint_path, f'pytorch_model.bin'))


In [13]:
oracle.predict(L, V)

[(x0, x0*x1^2*x2^4),
 (x0, x0^2*x1*x2^4),
 (x2, x0^2*x1*x2^4),
 (x0, x0^3*x2^4),
 (x1, x0^3*x2^4),
 (x2, x0^3*x2^4),
 (x0, x1^4*x2^3),
 (x2, x1^4*x2^3),
 (x0, x0*x1^3*x2^3),
 (x1, x0*x1^3*x2^3),
 (x2, x0*x1^3*x2^3),
 (x0, x0^2*x1^2*x2^3),
 (x1, x0^2*x1^2*x2^3),
 (x2, x0^2*x1^2*x2^3),
 (x0, x0^3*x1*x2^3),
 (x1, x0^3*x1*x2^3),
 (x2, x0^3*x1*x2^3),
 (x0, x0^4*x2^3),
 (x1, x0^4*x2^3),
 (x2, x0^4*x2^3),
 (x0, x1^5*x2^2),
 (x2, x1^5*x2^2),
 (x0, x0*x1^4*x2^2),
 (x1, x0*x1^4*x2^2),
 (x2, x0*x1^4*x2^2),
 (x0, x0^2*x1^3*x2^2),
 (x1, x0^2*x1^3*x2^2),
 (x2, x0^2*x1^3*x2^2),
 (x0, x0^3*x1^2*x2^2),
 (x1, x0^3*x1^2*x2^2),
 (x2, x0^3*x1^2*x2^2),
 (x0, x0^5*x2^2),
 (x1, x0^5*x2^2),
 (x2, x0^5*x2^2),
 (x0, x1^6*x2),
 (x2, x1^6*x2),
 (x0, x0*x1^5*x2),
 (x1, x0*x1^5*x2),
 (x2, x0*x1^5*x2),
 (x0, x0^2*x1^4*x2),
 (x1, x0^2*x1^4*x2),
 (x2, x0^2*x1^4*x2),
 (x0, x0^3*x1^3*x2),
 (x1, x0^3*x1^3*x2),
 (x2, 1)]