In [None]:
from transformers import AutoModel, AutoTokenizer, T5ForConditionalGeneration
from datasets import load_dataset
from torch.utils.data import DataLoader
from pathlib import Path
from tqdm import tqdm
import torch
import yaml
import gc
import os

output_dir = ''
path_to_model = '' # fine-tune from model
path_to_dataset = ''
checkpoint = 'Salesforce/codet5p-220m'  # initialize from checkpoint
batch_size = 8
epochs = 1
device = "cuda" if torch.cuda.is_available() else "cpu"
task = 'mask-prediction' # 'mask-prediction' or 'code-generation'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
modeling = AutoModel if task == 'mask-prediction' else T5ForConditionalGeneration

print(f'Task: {task}')

if path_to_model:
    model = modeling.from_pretrained(
        path_to_model,
        trust_remote_code=True).to(device)
    print("Loaded model from path")
else:
    model = modeling.from_pretrained(
        checkpoint,
        trust_remote_code=True).to(device)
    print("Loaded model from checkpoint")

In [None]:
Path(output_dir).mkdir(parents=True, exist_ok=True)

In [None]:
dataset = load_dataset("json", data_files=path_to_dataset)["train"].with_format("torch")

In [None]:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
model.train()
pba = tqdm(dataloader)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

for epoch in range(epochs):
    for step, batch in enumerate(pba):
        optimizer.zero_grad()

        for k, v in batch.items():
            batch[k] = v.squeeze(1).to(device)

        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        pba.set_description_str(f"Epoch: {epoch} Step: {step} Loss: {loss.item():.4f}")

    model.save_pretrained(os.path.join(output_dir, f'epoch-{epoch}'))

In [None]:
train_info = {
    "path_to_model": path_to_model,
    "path_to_dataset": path_to_dataset,
    "checkpoint": checkpoint,
    "batch_size": batch_size,
    "epochs": epochs,
    "dataset": {
        "rows": dataset.num_rows,
    }
}

In [None]:
with open(os.path.join(output_dir, "train_info.yml"), "w") as f:
    yaml.dump(train_info, f)