In [13]:
import os
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader as BaseDataLoader

from transformers import GPT2Tokenizer, AutoModelForCausalLM, GPT2LMHeadModel, AutoTokenizer
from transformers import TrainingArguments, Trainer

from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter

from datasets import load_dataset

device="cuda"

In [2]:
tokenizer = AutoTokenizer.from_pretrained("bolbolzaban/gpt2-persian")

In [3]:
config = {
    "emb_dim" : 768,
    "letter_emb_dim": 1024,
    "ctx_len": 256,
    "vocab_size" : tokenizer.vocab_size,
    "save_path": "./models/v3.pth"
}

In [17]:
class Dataset(BaseDataset):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer,
            chunk_size=config["ctx_len"], chunk_overlap=config["ctx_len"] // 8)
        # splitter = CharacterTextSplitter.from_huggingface_tokenizer(
        #     tokenizer, chunk_size=config["ctx_len"] - 1, chunk_overlap=config["ctx_len"] // 8
        # )
        self.data = []
        for xdata in load_dataset("csv", data_files="dataset.csv")["train"]:
            self.data.extend(
                splitter.split_text(xdata["context"])
            )


    def __getitem__(self, ix):
        item = self.data[ix]
        return {
            "context": item,
        }


    def __len__(self, ):
        return len(self.data)


In [18]:
class CustomCollator:
    def __call__(self, batch):
        ctxs = [item["context"] for item in batch if item is not None]

        tokenized_ctxs = tokenizer(ctxs,
                  padding="max_length",
                  truncation=True,
                  return_tensors="pt",
                  max_length=config["ctx_len"])

        input_ids = tokenized_ctxs["input_ids"][:, :-1].long()
        targets = tokenized_ctxs["input_ids"][:, 1:]
        targets = targets.masked_fill(targets == tokenizer.pad_token_id, -100)

        return {
            "attention_mask": tokenized_ctxs["attention_mask"][:, :-1].to(device),
            "input_ids": input_ids.to(device),
            "label": targets.to(device)
        }

dataset = Dataset(tokenizer)
collator_fn = CustomCollator()

In [20]:
class Model(nn.Module):

    def __init__(self, tokenizer, config):
        super().__init__()
        self.tokenizer = tokenizer
        self.letter_projection = nn.Sequential(nn.Linear(config["letter_emb_dim"], config["letter_emb_dim"] * 2),
                                                nn.Linear(config["letter_emb_dim"] * 2, config["emb_dim"]))
        self.gpt = GPT2LMHeadModel.from_pretrained("bolbolzaban/gpt2-persian")

        for p in self.parameters():
            p.requires_grad=True


    @classmethod
    def from_pretrained(cls, tokenizer, config):
        print("Loading the model...")
        self = cls(tokenizer, config)
        self.load_state_dict(torch.load(config["save_path"], weights_only=True))
        print("loaded successfully!")
        return self

    def save(self, ):
        torch.save(self.state_dict(), config["save_path"])
        print(f"Model saved at {config['save_path']}!")
    
    def forward(self, attention_mask, input_ids, label):
        x = self.gpt.transformer.wte(input_ids)
        x += self.gpt.transformer.wpe(torch.arange(x.shape[1]).to(device))

        output = self.gpt(inputs_embeds=x,
            attention_mask=attention_mask,
            return_dict=True,
            labels=label
        )
        return output

    
    @torch.no_grad
    def generate(self, letter_emb):
        model.eval()
        letter_emb = torch.tensor(json.loads(letter_emb)).view(1,1,-1).to(device)
        letter_emb = self.letter_projection(letter_emb)
        output = model.gpt.generate(
        inputs_embeds=letter_emb,
        attention_mask=torch.ones((1, 1), dtype=torch.long).to(device),
        do_sample=True,
        top_p=0.9,
        temperature=0.9,
        num_beams=5,
        max_length=128,
        min_length=1,
        repetition_penalty=1.0,
        length_penalty=1.0,
        num_return_sequences=1,)
        return self.tokenizer.batch_decode(output)

model = Model(tokenizer, config)
model.to(device)

Model(
  (letter_projection): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): Linear(in_features=2048, out_features=768, bias=True)
  )
  (gpt): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(25000, 1024)
      (wpe): Embedding(256, 1024)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-23): 24 x GPT2Block(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm(

In [21]:
train_args = TrainingArguments(
    output_dir="./cache/",
    learning_rate=1e-3,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=32,
    num_train_epochs=6,
    weight_decay=0.01,
    lr_scheduler_type="linear",
    save_strategy="epoch",
    logging_steps=10,
    logging_strategy="steps",
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    save_safetensors=False
)

trainer = Trainer(model=model,
        args=train_args,
        data_collator=collator_fn,
        train_dataset=dataset)

In [22]:
trainer.train()
model.save()

Step,Training Loss
10,7.1329
20,5.5526
30,4.9981
40,4.2168
50,3.4936
60,2.8112
70,2.3544
80,2.0691
90,1.9202
100,1.7155


Model saved at ./models/v3.pth!
