In [4]:
%cd '/app'

/app


In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Loading border basis dataset 

In [6]:
import os

# CUDE_VISIBLE_DEVICES=6
os.environ['CUDA_VISIBLE_DEVICES'] = '6'

from pathlib import Path
from src.loader.data import load_data
from src.loader.data_format.processors.base import ProcessorChain
# from src.loader.tokenizer import set_vocab, set_tokenizer

### Load Model 

In [7]:
from src.loader.checkpoint import load_pretrained_bag

save_path = 'results/train/expansion/expansion/custom_bart/base_k_lt=5_m=1000000'
bag = load_pretrained_bag(save_path, from_checkpoint=True)

model = bag['model']
tokenizer = bag['tokenizer']
config = bag['config']

model.eval();

  model_state_dict = torch.load(os.path.join(checkpoint_path, f'pytorch_model.bin'))


### Load Data

In [8]:
# load config

data_path  = f"data/border_basis/GF31_n=3_deg=4_terms=10_bounds=4_4_4_total=4"
import yaml
with open(f'{data_path}/config.yaml', 'r') as f:
    exp_config = yaml.safe_load(f)

data_path  = f"data/expansion/GF31_n=3_deg=4_terms=10_bounds=4_4_4_total=4"

from src.loader.data_format.processors.expansion import ExtractKLeadingTermsProcessor
from src.loader.data_format.processors.subprocessors import MonomialProcessorPlus


data_collator_name = 'monomial'

_processors = []
_processors.append(ExtractKLeadingTermsProcessor(config.num_leading_terms))

subprocessors = {}
subprocessors['monomial_ids'] = MonomialProcessorPlus(
            num_variables=config.num_variables,
            max_degree=config.max_degree,
            max_coef=int(config.field[2:])  # 'GF7' -> 7
        )

processor = ProcessorChain(_processors) 

# load test dataset
test_data_path = Path(data_path) / 'test'
test_dataset, data_collator = load_data(
    data_path=test_data_path,
    processor=processor,
    subprocessors=subprocessors,
    splits=[{"name": "test", "batch_size": 32, "shuffle": False}],
    tokenizer=tokenizer,
    return_dataloader=False,  # return dataloader if True
    data_collator_name=data_collator_name
)

test_dataset[0]

{'input': 'C1 E0 E1 E7 [SEP] C1 E1 E4 E3 [SEP] C1 E2 E3 E3 [SEP] C1 E3 E2 E3 [SEP] C1 E4 E1 E3 [SEP] C1 E5 E0 E3 [SEP] C1 E1 E5 E2 [SEP] C1 E2 E4 E2 [SEP] C1 E3 E3 E2 [SEP] C1 E4 E2 E2 [SEP] C1 E5 E1 E2 [SEP] C1 E6 E0 E2 [SEP] C1 E1 E6 E1 [SEP] C1 E2 E5 E1 [SEP] C1 E3 E4 E1 [SEP] C1 E4 E3 E1 [SEP] C1 E5 E2 E1 [SEP] C1 E6 E1 E1 [SEP] C1 E7 E0 E1 [SEP] C1 E2 E6 E0 [SEP] C1 E3 E5 E0 [SEP] C1 E4 E4 E0 [SEP] C1 E5 E3 E0 [SEP] C1 E6 E2 E0 [SEP] C1 E7 E1 E0 [SEP] C1 E8 E0 E0 [SEP] C1 E2 E0 E7 [SEP] C1 E1 E2 E6 [SEP] C1 E2 E1 E6 [SEP] C1 E3 E0 E6 [SEP] C1 E1 E3 E5 [SEP] C1 E2 E2 E5 [SEP] C1 E3 E1 E5 [SEP] C1 E4 E0 E5  [BIGSEP]  C1 E4 E1 E0 + C19 E3 E1 E1 + C27 E4 E0 E0 + C13 E3 E1 E0 + C21 E3 E0 E1  [SEP]  C1 E4 E1 E1 + C19 E3 E1 E2 + C27 E4 E0 E1 + C13 E3 E1 E1 + C21 E3 E0 E2  [SEP]  C1 E4 E2 E0 + C19 E3 E2 E1 + C27 E4 E1 E0 + C13 E3 E2 E0 + C21 E3 E1 E1  [SEP]  C1 E5 E1 E0 + C19 E4 E1 E1 + C27 E5 E0 E0 + C13 E4 E1 E0 + C21 E4 E0 E1  [SEP]  C1 E1 E2 E4 + C25 E0 E3 E4 + C14 E2 E0 E5 + C9 E1 E1

### Generate predictions

In [9]:
from torch.utils.data import DataLoader
from src.misc.utils import to_cuda

data_loader = DataLoader(test_dataset, 
                         batch_size=100, 
                         shuffle=False, 
                         collate_fn=data_collator)

batch = next(iter(data_loader))
batch = to_cuda(batch)

In [8]:
from src.loader.data_format.processors.subprocessors import MonomialProcessorPlus

max_length = batch['labels'].shape[-1] + 1
print(max_length)
mpp = MonomialProcessorPlus(num_variables=config.num_variables, max_degree=config.max_degree, max_coef=int(config.field[2:]))
generated = model.generate(batch['input_ids'], batch['attention_mask'], 
                           monomial_processor=mpp, tokenizer=tokenizer, 
                           max_length=max_length)

691


In [None]:
mpp.decode(generated[0], skip_special_tokens=True)

'C1 E1 E0 E0 [SEP] C1 E1 E2 E4 [SEP] C1 E1 E0 E0 [SEP] C1 E2 E1 E4 [SEP] C1 E0 E0 E1 [SEP] C1 E2 E1 E4 [SEP] C1 E1 E0 E0 [SEP] C1 E3 E0 E4 [SEP] C1 E0 E1 E0 [SEP] C1 E3 E0 E4 [SEP] C1 E0 E0 E1 [SEP] C1 E3 E0 E4 [SEP] C1 E1 E0 E0 [SEP] C1 E0 E4 E3 [SEP] C1 E0 E0 E1 [SEP] C1 E0 E4 E3 [SEP] C1 E1 E0 E0 [SEP] C1 E1 E3 E3 [SEP] C1 E0 E1 E0 [SEP] C1 E1 E3 E3 [SEP] C1 E0 E0 E1 [SEP] C1 E1 E3 E3 [SEP] C1 E1 E0 E0 [SEP] C1 E2 E2 E3 [SEP] C1 E0 E1 E0 [SEP] C1 E2 E2 E3 [SEP] C1 E0 E0 E1 [SEP] C1 E2 E2 E3 [SEP] C1 E1 E0 E0 [SEP] C1 E3 E1 E3 [SEP] C1 E0 E1 E0 [SEP] C1 E3 E1 E3 [SEP] C1 E0 E0 E1 [SEP] C1 E3 E1 E3 [SEP] C1 E1 E0 E0 [SEP] C1 E4 E0 E3 [SEP] C1 E0 E1 E0 [SEP] C1 E4 E0 E3 [SEP] C1 E0 E0 E1 [SEP] C1 E4 E0 E3 [SEP] C1 E1 E0 E0 [SEP] C1 E0 E5 E2 [SEP] C1 E0 E0 E1 [SEP] C1 E0 E5 E2 [SEP] C1 E1 E0 E0 [SEP] C1 E1 E4 E2 [SEP] C1 E0 E1 E0 [SEP] C1 E1 E4 E2 [SEP] C1 E0 E0 E1 [SEP] C1 E1 E4 E2 [SEP] C1 E1 E0 E0 [SEP] C1 E2 E3 E2 [SEP] C1 E0 E1 E0 [SEP] C1 E2 E3 E2 [SEP] C1 E0 E0 E1 [SEP] C1 E2 E3 

### Analysis

In [10]:
predictions  = mpp.batch_decode(generated, skip_special_tokens=True)

labels = batch['labels']
labels[labels == -100] = tokenizer.pad_token_id
labels = tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)

In [14]:
idx = 0
preds = predictions[idx].split('[SEP]')
pred_directions, pred_leading_terms = preds[::2], preds[1::2]

gts = labels[idx].split('[SEP]')
gt_directions, gt_leading_terms = gts[::2], gts[1::2]


print(f' # of predicted expansions: {len(pred_directions)}')
print(f' # of ground truth expansions: {len(gt_directions)}')

print('-'*100)
print(f'{"Prediction":<25} | {"Ground Truth":<25} | {"Correct":<5}')
print('-'*100)
print(f'{"Direction":<11} | {"Leading Trm":<10} | {"Direction":<11} | {"Leading Trm":<10} | {"Direc.":<5} | {"LT":<5} | {"Both":<5}')
print('-'*100)
for pred_direction, pred_leading_term, gt_direction, gt_leading_term in zip(pred_directions, pred_leading_terms, gt_directions, gt_leading_terms):
    
    direction_correct = pred_direction   == gt_direction
    lt_correct = pred_leading_term == gt_leading_term
    both_correct = direction_correct and lt_correct
    
    print(f'{pred_direction} | {pred_leading_term} | {gt_direction} | {gt_leading_term} | {str(direction_correct):<6} | {str(lt_correct):<6} | {str(both_correct):<6}')
    
    data_path  = f"data/expansion/GF31_n=3_deg=4_terms=10_bounds=4_4_4_total=4"

 # of predicted expansions: 41
 # of ground truth expansions: 24
----------------------------------------------------------------------------------------------------
Prediction                | Ground Truth              | Correct
----------------------------------------------------------------------------------------------------
Direction   | Leading Trm | Direction   | Leading Trm | Direc. | LT    | Both 
----------------------------------------------------------------------------------------------------
C1 E1 E0 E0  |  C1 E1 E2 E4  | C1 E1 E0 E0  |  C1 E1 E2 E4  | True   | True   | True  
 C1 E1 E0 E0  |  C1 E2 E1 E4  |  C1 E1 E0 E0  |  C1 E2 E1 E4  | True   | True   | True  
 C1 E0 E0 E1  |  C1 E2 E1 E4  |  C1 E0 E0 E1  |  C1 E2 E1 E4  | True   | True   | True  
 C1 E1 E0 E0  |  C1 E3 E0 E4  |  C1 E1 E0 E0  |  C1 E3 E0 E4  | True   | True   | True  
 C1 E0 E1 E0  |  C1 E3 E0 E4  |  C1 E0 E1 E0  |  C1 E3 E0 E4  | True   | True   | True  
 C1 E0 E0 E1  |  C1 E3 E0 E4  |  C1 E0 E0 E1  

In [15]:
no_expansion_text = 'C1 E0 E0 E0 [SEP] C1 E0 E0 E0'

is_no_expansion = [l == no_expansion_text for l in labels]
sample_ids = [i for i, is_no_exp in enumerate(is_no_expansion) if is_no_exp]

hits = 0
for sample_id in sample_ids:
    preds = predictions[sample_id].split('[SEP]')
    pred_directions, pred_leading_terms = preds[::2], preds[1::2]

    gts = labels[sample_id].split('[SEP]')
    gt_directions, gt_leading_terms = gts[::2], gts[1::2]

    hits += (preds == gts)

print(f'no expansion acc = {hits / len(sample_ids)} [{hits} / {len(sample_ids)}]')

no expansion acc = 0.95 [19 / 20]


In [16]:
from src.dataset.processors.utils import sequence_to_poly
from sage.all import PolynomialRing, GF

field = GF(int(config.field[2:]))
ring = PolynomialRing(field, 'x', config.num_variables, order='degrevlex')

pred_direction_set = [sequence_to_poly(pred_direction, ring) for pred_direction in pred_directions]
pred_lt_set = [sequence_to_poly(pred_monom_text, ring) for pred_monom_text in pred_leading_terms]
pred_expansion_set = [(pred_direction, pred_lt) for pred_direction, pred_lt in zip(pred_direction_set, pred_lt_set)]

gt_direction_set = [sequence_to_poly(gt_direction, ring) for gt_direction in gt_directions]
gt_lt_set = [sequence_to_poly(gt_monom_text, ring) for gt_monom_text in gt_leading_terms]
gt_expansion_set = [(gt_direction, gt_lt) for gt_direction, gt_lt in zip(gt_direction_set, gt_lt_set)]

In [18]:


true_positive = sum([p in gt_expansion_set for p in pred_expansion_set])
false_positive = len(pred_expansion_set) - true_positive
false_negative = len(gt_expansion_set) - true_positive

print(f'true positive = {true_positive}')
print(f'false positive = {false_positive}')
print(f'false negative = {false_negative}')

# for expansion in pred_expansion_set:
#     print(expansion in gt_expansion_set)

true positive = 1
false positive = 0
false negative = 0


In [None]:
from torch.utils.data import DataLoader
from src.misc.utils import to_cuda
from tqdm import tqdm
over_estimated = []
all_hits = []
hits_list = []
for batch in tqdm(data_loader):
    batch = to_cuda(batch)
    
    max_length = batch['labels'].shape[-1] + 1
    mpp = MonomialProcessorPlus(num_variables=config.num_variables, max_degree=config.max_degree, max_coef=int(config.field[2:]))
    generated = model.generate(batch['input_ids'], batch['attention_mask'], 
                            monomial_processor=mpp, tokenizer=tokenizer, 
                            max_length=max_length)
    
    predictions  = mpp.batch_decode(generated, skip_special_tokens=True)

    labels = batch['labels']
    labels[labels == -100] = tokenizer.pad_token_id
    labels = tokenizer.batch_decode(batch['labels'], skip_special_tokens=True)
        
    for idx in range(len(predictions)):
        preds = predictions[idx].split('[SEP]')
        pred_directions, pred_leading_terms = preds[::2], preds[1::2]

        gts = labels[idx].split('[SEP]')
        gt_directions, gt_leading_terms = gts[::2], gts[1::2]

        over_estimated.append(len(pred_directions) - len(gt_directions))
        
        l = len(gt_directions)
        hits = [pred_direction == gt_direction and pred_leading_term == gt_leading_term for pred_direction, pred_leading_term, gt_direction, gt_leading_term in zip(pred_directions, pred_leading_terms, gt_directions, gt_leading_terms)]
        all_hits.append(all(hits))
        hits_list.extend(hits)
        

print(f'over estimated = {sum(over_estimated) / len(over_estimated)}')
print(f'true positive = {sum(hits_list) / len(hits_list)}')
print(f'all hits = {sum(all_hits) / len(all_hits)}')


  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
hits