In [None]:
import torch
import pickle
from dataclasses import dataclass
from datasets import load_dataset
from torch.utils.data import IterableDataset
from torch.utils.data.dataloader import DataLoader
from tqdm.notebook import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
    set_seed
)


def load_pkl_data(filename = './data/java/bigcode-the-stack-dedup-train.pkl'):
    print("Loading data from: ", filename)
    with open(filename, "rb") as f:
        ds = pickle.load(f)
    print("Done loading data from: ", filename)
    return ds



In [None]:
model_id = "Salesforce/codegen-350M-mono"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, use_cache=False)

In [None]:
trainfile = './data_pkl/java/bigcode-the-stack-dedup-train.pkl'
validfile = './data_pkl/java/bigcode-the-stack-dedup-test.pkl'
train_ds = load_pkl_data(trainfile)
valid_ds = load_pkl_data(validfile)

In [None]:
examples, total_characters, total_tokens = 500, 0, 0

for _, example in tqdm(zip(range(examples), iter(train_ds)), total=examples):
    total_characters += len(example['content'])
    total_tokens += len(tokenizer(example['content']).tokens())

characters_per_token = total_characters / total_tokens
print(characters_per_token)

In [None]:
len(train_ds), len(valid_ds)

In [None]:
class ConstantLengthDataset(IterableDataset):
    """
    Iterable dataset that returns constant length chunks of tokens from stream of text files.
        Args:
            tokenizer (Tokenizer): The processor used for proccessing the data.
            dataset (dataset.Dataset): Dataset with text files.
            infinite (bool): If True the iterator is reset after dataset reaches end else stops.
            seq_length (int): Length of token sequences to return.
            num_of_sequences (int): Number of token sequences to keep in buffer.
            chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
    """

    def __init__(
        self,
        tokenizer,
        dataset,
        infinite=False,
        seq_length=1024,
        num_of_sequences=1024,
        chars_per_token=2.95,
    ):
        self.tokenizer = tokenizer
        self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else 49152
        self.dataset = dataset
        self.seq_length = seq_length
        self.infinite = infinite
        self.current_size = 0
        self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
        self.content_field = "content"

    def __iter__(self):
        iterator = iter(self.dataset)
        more_examples = True
        while more_examples:
            buffer, buffer_len = [], 0
            while True:
                if buffer_len >= self.max_buffer_size:
                    break
                try:
                    buffer.append(next(iterator)[self.content_field])
                    buffer_len += len(buffer[-1])
                except StopIteration:
                    if self.infinite:
                        iterator = iter(self.dataset)
                    else:
                        more_examples = False
                        break
            tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
            all_token_ids = []
            for tokenized_input in tokenized_inputs:
                all_token_ids.extend(tokenized_input + [self.concat_token_id])
            for i in range(0, len(all_token_ids), self.seq_length):
                input_ids = all_token_ids[i : i + self.seq_length]
                if len(input_ids) == self.seq_length:
                    self.current_size += 1
                    yield {
                        "input_ids": torch.LongTensor(input_ids),
                        "labels": torch.LongTensor(input_ids),
                    }

In [None]:
train_dataset = ConstantLengthDataset(
        tokenizer, train_ds, infinite=True, seq_length=1024
    )
valid_dataset = ConstantLengthDataset(
        tokenizer, valid_ds, infinite=False, seq_length=1024
    )

In [None]:
training_args = TrainingArguments(
        output_dir="codegne-finetuned-the-stack-java-v2",
        dataloader_drop_last=True,
        gradient_checkpointing=True,
        gradient_accumulation_steps=4,
        optim="adafactor",
        evaluation_strategy="steps",
        max_steps=500000,
        eval_steps=500,
        save_steps=500,
        logging_steps=10,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        weight_decay=0.05,
        fp16=False,
        push_to_hub=True
)

In [None]:
%load_ext tensorboard
%tensorboard --logdir codegne-finetuned-the-stack-java-v2

In [None]:
train_dataset.start_iteration = 0

In [None]:
trainer = Trainer(
    #tokenizer=tokenizer,
    model=model, 
    args=training_args, 
    train_dataset=train_dataset, 
    eval_dataset=valid_dataset
)

In [None]:
trainer.train()