In [None]:
import json

with open("train_data.jsonl", "r", encoding="utf-8") as f:
    ex = json.loads(f.readline())

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5p-220m", use_fast=True)

In [None]:
max_input_tok = 512
max_target_tok = 256

In [None]:
enc_inputs = tokenizer(ex['prompt'], truncation = True, max_length = max_input_tok)
enc_target = tokenizer(ex['completion'], truncation = True, max_length = max_target_tok)

In [None]:
import torch

input_ids = torch.tensor(enc_inputs['input_ids'])
attention_mask = torch.tensor(enc_inputs['attention_mask'])
labels = torch.tensor(enc_target['input_ids'])

In [None]:
labels[labels == tokenizer.pad_token_id] = -100

In [None]:
batch = {
    "input_ids" : input_ids.unsqueeze(0),
    "attention_mask" : attention_mask.unsqueeze(0),
    "labels" : labels.unsqueeze(0)
}

In [None]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5p-220m")

In [None]:
def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]

    padded_inputs = tokenizer.pad({
        "input_ids" : input_ids,
        "attention_mask" : attention_mask
    }, return_tensors = "pt")
    
    max_len = padded_inputs["input_ids"].size(1)
    padded_label_list = []
    
    for lbl in labels:
        lbl = list(lbl)  # ensure list
        if len(lbl) > max_len:
            lbl = lbl[:max_len]
        else:
            lbl = lbl + [tokenizer.pad_token_id] * (max_len - len(lbl))
        padded_label_list.append(lbl)

    padded_label_input_ids = torch.tensor(padded_label_list, dtype=torch.long)
    padded_label_input_ids[padded_label_input_ids == tokenizer.pad_token_id] = -100

    return {
    "input_ids" : padded_inputs['input_ids'],
    "attention_mask" : padded_inputs['attention_mask'],
    "labels" : padded_label_input_ids
    }

In [None]:
from torch.utils.data import Dataset
class PRDataset(Dataset):
    def __init__(self, path, tokenizer, max_input = 512, max_target = 256):
        self.data = [json.loads(l) for l in open(path, "r", encoding="utf-8") if l.strip()]
        self.tokenizer = tokenizer
        self.max_input = max_input
        self.max_target = max_target

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        ex = self.data[idx]

        enc_in = self.tokenizer(ex['prompt'], truncation = True,padding = False, max_length = self.max_input)
        enc_out = self.tokenizer(ex['completion'], truncation = True,padding = False, max_length = self.max_target)

        return {
        'input_ids' : enc_in['input_ids'],
        'attention_mask' : enc_in['attention_mask'],
        'labels' : enc_out['input_ids']
        }

train_dataset = PRDataset('train_data.jsonl', tokenizer)

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size = 4, shuffle = True, collate_fn = collate_fn)

In [None]:
batch = next(iter(train_loader))
print(batch['input_ids'].shape, batch['attention_mask'].shape, batch['labels'].shape)

In [None]:
model.train()

inputs = batch['input_ids']
labels = batch['labels']
masks = batch.get('attention_mask')

optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
optimizer.zero_grad()
outputs = model(input_ids=inputs, attention_mask=masks, labels=labels)

logits = outputs.logits if hasattr(outputs, 'logits') else outputs
loss = outputs.loss
loss.backward()

optimizer.step()
print("one batch loss:", loss.item())