In [1]:
from abc import ABC
from pathlib import Path

import pandas as pd
from codetf.models import load_model_pipeline
from sacrebleu import corpus_bleu, corpus_chrf, corpus_ter
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelWithLMHead, SummarizationPipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
root_dir = Path.cwd()

In [24]:
class AbstractModel(ABC):
    def predict(self, code: str) -> str:
        raise NotImplementedError()
    
    def model_name(self) -> str:
        raise NotImplementedError()
    
class CodeTFModel(AbstractModel):
    def __init__(self, model_name: str, model_type: str, task: str) -> None:
        super().__init__()

        self._model = load_model_pipeline(model_name=model_name, model_type=model_type, task=task)
        self._model_name = model_name
        self._model_type = model_type
        self._task = task

    def predict(self, code: str) -> str:
        return self._model.predict([code])[0]
    
    def model_name(self) -> str:
        return f"{self._model_name}-{self._model_type}-{self._task}"
    
class SebisModel(AbstractModel):
    def __init__(self, model_name: str) -> None:
        super().__init__()

        self._pipeline = SummarizationPipeline(
            model=AutoModelWithLMHead.from_pretrained(model_name),
            tokenizer=AutoTokenizer.from_pretrained(model_name, skip_special_tokens=True),
            device=0
        )
        self._model_name = model_name

    def predict(self, code: str) -> str:
        return self._pipeline([code])[0]["summary_text"]
    
    def model_name(self) -> str:
        return self._model_name

In [35]:
def get_preds(df: pd.DataFrame, model: AbstractModel):
    file_path = root_dir / "data" / "preds" / f"{model.model_name()}.csv"
    
    if file_path.exists():
        return

    df = df.copy()
    df["pred"] = df["code"].map(model.predict)
    df[["ref", "pred"]].to_csv(file_path)

In [15]:
dataset = load_dataset("json", data_files={
    "test": str(root_dir.parent / "data" / "test.jsonl"),
}, cache_dir=root_dir.parent / "data" / "cache")

Found cached dataset json (/home/paul/projects/edu/master/mdl-ii/src/data/cache/json/default-acdd91729f392843/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 1/1 [00:00<00:00, 19.26it/s]


In [16]:
def inference(doc):
    doc["ref"] = [" ".join(docstring) for docstring in doc["docstring_tokens"]]
    return doc    

dataset = dataset.map(inference, batched=True)
dataset.set_format(type="pandas", columns=["ref", "code"])
df = dataset["test"][:]
df.head()

Loading cached processed dataset at /home/paul/projects/edu/master/mdl-ii/src/data/cache/json/default-acdd91729f392843/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-e06db1d51f6ed421.arrow


Unnamed: 0,code,ref
0,def sina_xml_to_url_list(xml_data):\n rawur...,str - > list Convert XML to URL List . From Bi...
1,"def dailymotion_download(url, output_dir='.', ...",Downloads Dailymotion videos by URL .
2,"def sina_download(url, output_dir='.', merge=T...",Downloads Sina videos by URL .
3,"def sprint(text, *colors):\n return ""\33[{}...",Format text with color or other effects into A...
4,"def print_log(text, *colors):\n sys.stderr....",Print a log message to standard error .


In [37]:
get_preds(df, CodeTFModel(model_name="codet5", model_type="base-multi-sum", task="pretrained"))

Token indices sequence length is longer than the specified maximum sequence length for this model (645 > 512). Running this sequence through the model will result in indexing errors


In [34]:
pd.read_csv(root_dir / "data" / "preds" / f'{CodeTFModel(model_name="codet5", model_type="base-multi-sum", task="pretrained").model_name()}.csv', index_col=0)

Unnamed: 0,ref,pred
0,str - > list Convert XML to URL List . From Bi...,Convert Sina XML to a list of URL strings.
1,Downloads Dailymotion videos by URL .,Download a dailymotion web page.
2,Downloads Sina videos by URL .,Download a Sina video.
3,Format text with color or other effects into A...,Print text with ANSI escape sequences.
4,Print a log message to standard error .,Print a log message to stderr.
5,Print an error log message .,Print a message to the console and exit with a...
6,What a Terrible Failure!,Print a message to the console and exit with a...
7,Detect operating system .,Detect the operating system.
8,str - > None,Download a vimeo file from a channel.
9,str - > dict Information for CKPlayer API cont...,Get ckplayer info by XML.
