In [2]:
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import regex as re
import pandas as pd
import torch

checkpoint = "Salesforce/codet5p-220m-bimodal"
path_to_model = None
path_to_save = r"evaluations/val_results.csv"
device = "cuda"  if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
if path_to_model:
    model = AutoModel.from_pretrained(path_to_model, trust_remote_code=True).to(device)
else:
    model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)

In [3]:
def split_camel_case(method_name):
    pattern = re.compile(r'(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])')
    words = re.split(pattern, method_name)
    return ' '.join(words)

In [10]:
class IntellijCodeDataset(Dataset):

    def __init__(self, path_to_data, tokenizer, max_length=512):
        self.path_to_data = path_to_data
        self.tokenizer = tokenizer
        self.max_length = max_length

        df = pd.read_csv(path_to_data)
        self.methods = df["method"].tolist()[:1000]
        self.method_names = df["method_name"].tolist()[:1000]

    def __len__(self):
        return len(self.method_names)

    def __getitem__(self, index):
        m = self.methods[index]
        m_n = self.method_names[index]

        # encode method
        m = m.replace(m_n, self.tokenizer.sep_token, 1)
        model_inputs = self.tokenizer(m, return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")

        # encode method name and set as target
        labels = tokenizer(split_camel_case(m_n), return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100
        model_inputs["labels"] = labels

        # decoder inputs
        decoder_inputs = self.tokenizer("[TDEC]", return_tensors="pt", max_length=self.max_length, truncation=True, padding="max_length")
        model_inputs["decoder_input_ids"] = decoder_inputs["input_ids"]
        model_inputs["decoder_attention_mask"] = decoder_inputs["attention_mask"]

        return model_inputs

In [11]:
dataset = IntellijCodeDataset("../data/val.csv", tokenizer)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [12]:
preds = []
model.eval()

with torch.no_grad():
    pba = tqdm(data_loader)
    for batch in pba:
        # move batch to device
        input_ids = batch["input_ids"].squeeze(1).to(device)
        attention_mask = batch["attention_mask"].squeeze(1).to(device)
        labels = batch["labels"].squeeze(1).to(device)
        decoder_input_ids = batch["decoder_input_ids"].squeeze(1).to(device)
        decoder_attention_mask = batch["decoder_attention_mask"].squeeze(1).to(device)
        # inference
        outputs = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        decoder_input_ids=decoder_input_ids,
                        decoder_attention_mask=decoder_attention_mask,
                        labels=labels)
        preds.extend(
            list(zip(input_ids.tolist(), labels.tolist(), outputs.logits.argmax(-1).tolist()))
        )
        pba.set_description(f"Loss: {outputs.loss.item():.4f}")

  0%|          | 0/1000 [00:11<?, ?it/s]

KeyboardInterrupt



In [36]:
def decode_predictions(preds):
    decoded_preds = []
    for input_ids, labels, pred_ids in preds:
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        labels[labels == -100] = tokenizer.pad_token_id
        pred_ids = torch.LongTensor(pred_ids)
        decoded_preds.append((tokenizer.decode(input_ids, skip_special_tokens=True),
                              tokenizer.decode(labels, skip_special_tokens=True),
                              tokenizer.decode(pred_ids, skip_special_tokens=True)))
    return decoded_preds
df = pd.DataFrame(decode_predictions(preds), columns=["input_code", "labels", "predictions"])

In [37]:
df.to_csv(path_to_save, index=False)