In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import chain
import torch
import json
import os

In [None]:
checkpoint = "Salesforce/codet5p-220m-bimodal"
device = "cuda" if torch.cuda.is_available() else "cpu"
path_dataset = ''
path_to_save = ''
filename = ''

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)

In [None]:
dataset = load_dataset('csv', data_files=path_dataset)['train']

In [None]:
dataset = dataset.rename_columns({
    'code': 'encoder_input_text',
    'name': 'target_text'
})

In [None]:
# remove the method name from code
def remove(example):
    example['encoder_input_text'] = \
        example['encoder_input_text'].replace(example['target_text'], tokenizer.sep_token, 1)
    return example

In [None]:
dataset = dataset.map(remove)

In [None]:
# add the special token [TDEC] for code-to-text generation
def add_tdec(example):
    example['decoder_input_text'] = '[TDEC] The name of the method is: '
    return example

In [None]:
dataset = dataset.map(add_tdec)

In [None]:
def concat_texts(example):
    example['stacked_text'] = [text for text in example.values()]
    return example

In [None]:
dataset = \
    dataset.map(concat_texts, remove_columns=['encoder_input_text', 'decoder_input_text', 'target_text'])

In [None]:
batch_size = 8
max_length = 256

In [None]:
def tokenize_batch_stacked_text(batch):
    stacked_texts = list(chain.from_iterable(batch['stacked_text']))

    tokenized_stacked_texts = \
        tokenizer(stacked_texts, truncation=True, padding='longest', max_length=max_length, return_tensors='pt')

    tokenized_stacked_texts = \
        {k: v.reshape(-1, 3, v.shape[1]) for k, v in tokenized_stacked_texts.items()}

    model_inputs = {
        "input_ids": tokenized_stacked_texts["input_ids"][:, 0, :].tolist(),
        "attention_mask": tokenized_stacked_texts["attention_mask"][:, 0, :].tolist(),
        "decoder_input_ids": tokenized_stacked_texts["input_ids"][:, 2, :].tolist(),
        "decoder_attention_mask": tokenized_stacked_texts["attention_mask"][:, 2, :].tolist(),
        "labels": tokenized_stacked_texts["input_ids"][:, 1, :]
    }
    model_inputs['labels'][model_inputs['labels'] == 0] = -100
    model_inputs['labels'] = model_inputs['labels'].tolist()
    return model_inputs

In [None]:
dataset = dataset \
    .shuffle(seed=42) \
    .map(tokenize_batch_stacked_text, batched=True, batch_size=batch_size, drop_last_batch=True, remove_columns=['stacked_text']) \
    .select(range(len(dataset) - len(dataset) % batch_size))

In [None]:
dataset.to_json(os.path.join(path_to_save, filename + '.jsonl'))

In [None]:
with open(os.path.join(path_to_save, filename + '.json'), 'w') as f:
    json.dump({'batch_size': batch_size, 'max_length': max_length}, f)