# Reproducing codebase as Python API

The notebook reproduces the trained-BART model as the Python API not as the CLI. 

In [None]:
import os
import sys
import argparse
import logging
import re
import typing as ty

from tqdm import tqdm
from warnings import warn
from torch.multiprocessing import Pool, set_start_method
set_start_method('spawn', force=True)
from functools import partial
import more_itertools as mit

import torch
import fairseq
from fairseq.models.bart import BARTHubInterface
from fairseq.models.bart import BARTModel

import nvgpu

from pathlib import Path

In [None]:
import logzero

from datetime import datetime
_datetime_exec = datetime.now()

logzero.logfile(f"logs/{_datetime_exec.isoformat()}.log")

logger = logzero.logger

In [None]:
def load_model(task: Path, model_path: Path) -> BARTHubInterface:
    """
    Args:
        task: a path to the directory of the model.
        model_path: a path to 'model.pt' file.
    """
    assert task.exists()
    assert model_path.exists()

    logger.info(f"Loading model {model_path}")
    model_dirname, model_fname = os.path.split(model_path.as_posix())
    bart = BARTModel.from_pretrained(
        model_dirname,
        checkpoint_file=model_fname,
        data_name_or_path=task.as_posix()
    )
    logger.info(f"Loading done.")
    return bart


In [None]:
# path to input
PATH_TEXT_FILE_INPUT = Path("/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/xsum/test_source.txt")
assert PATH_TEXT_FILE_INPUT.exists()

seq_text_input = PATH_TEXT_FILE_INPUT.open().readlines()
assert len(seq_text_input) > 0

In [None]:
# with xsum model
# PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.xsum')
PATH_MODEL_FILE = Path('/workdir/kmitsuzawa/Project/neurips-2025/ConstraintsFact-Dreyer-2023/abstractive-factual-tradeoff/tests/testresources/models/bart.large.cnn')

bart_model = load_model(PATH_MODEL_FILE, PATH_MODEL_FILE / 'model.pt')

In [None]:
type(bart_model)

In [None]:
logger.info(str(bart_model))

In [None]:
if torch.cuda.is_available():
    device_obj = torch.device('cuda:0')
else:
    device_obj = torch.device('cpu')
# end if

bart_model = bart_model.to(device_obj)

In [None]:
def bart_sample(bart: BARTHubInterface,
                batch: ty.List[str],
                extractive_penalty_fct: str,
                beam: int = 4,
                lenpen: float = 2.0,  # length penalty
                min_len: int = 55,
                max_len_a: int = 0,
                max_len_b: int = 140,
                no_repeat_ngram_size: int = 3):
    with torch.no_grad():
        return bart.sample(batch, beam=beam, lenpen=lenpen,
                            min_len=min_len, max_len_a=max_len_a, max_len_b=max_len_b,
                            no_repeat_ngram_size=no_repeat_ngram_size,
                            extractive_penalty_fct=extractive_penalty_fct)
# end def


res = bart_sample(
    bart=bart_model,
    batch=seq_text_input,
    extractive_penalty_fct='log_exp(2,2.402244)'
)

In [None]:
seq_text_input

In [None]:
res