In [1]:
import os
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset as BaseDataset
from torch.utils.data import DataLoader as BaseDataLoader

from transformers import GPT2Tokenizer, AutoModelForCausalLM, GPT2LMHeadModel
from transformers import TrainingArguments, Trainer

from datasets import load_dataset

device="cuda"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token_id = tokenizer.eos_token_id

In [3]:
config = {
    "emb_dim" : 768,
    "letter_emb_dim": 1024,
    "vocab_size" : tokenizer.vocab_size,
    "save_path": "./models/v1.pth"
}

In [4]:
class Dataset(BaseDataset):
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.data = load_dataset("csv", data_files="dataset.csv")["train"]

    def __getitem__(self, ix):
        item = self.data[ix]
        return item


    def __len__(self, ):
        return len(self.data)


In [5]:
class CustomCollator:
    def __call__(self, batch):
        titles = [item["title"] for item in batch if item is not None]
        ctx_embs = torch.tensor([json.loads(item["context_embedding"]) for item in batch if item is not None], dtype=torch.float)

        tokenized_title = tokenizer(titles,
                  padding="longest",
                  truncation=True,
                  return_tensors="pt")
        
        attention_mask = torch.stack([torch.cat([torch.tensor([1,], dtype=torch.long), mask], dim=-1) for mask in tokenized_title["attention_mask"]]).to(device)

        input_ids = tokenized_title["input_ids"][:, :-1].long()
        targets = tokenized_title["input_ids"]
        targets = targets.masked_fill(targets == tokenizer.pad_token_id, -100)

        return {
            "attention_mask": attention_mask[:, :-1].to(device),
            "letter_emb": ctx_embs.to(device),
            "input_ids": input_ids.to(device),
            "label": targets.to(device)
        }

dataset = Dataset(tokenizer)
collator_fn = CustomCollator()

In [None]:
class Model(nn.Module):

    def __init__(self, tokenizer, config):
        super().__init__()
        self.tokenizer = tokenizer
        self.letter_projection = nn.Sequential(nn.Linear(config["letter_emb_dim"], config["letter_emb_dim"] * 2),
                                                nn.Linear(config["letter_emb_dim"] * 2, config["emb_dim"]))
        self.gpt = AutoModelForCausalLM.from_pretrained("gpt2")

        for p in self.parameters():
            p.requires_grad=True


    @classmethod
    def from_pretrained(cls, tokenizer, config):
        print("check model existance...")
        if os.path.isfile(config["save_path"]):
            print("Loading the model...")
            self = cls(tokenizer, config)
            self.load_state_dict(torch.load(config["save_path"], weights_only=True))
            print("loaded successfully!")
        else:
            print(f"couldn't find the {config['save_path']} file!")
            print("Creating a new model...")
            self = cls(tokenizer, config)
        return self

    def save(self, ):
        torch.save(self.state_dict(), config["save_path"])
        print(f"Model saved at {config['save_path']}!")
    
    def forward(self, attention_mask, letter_emb, input_ids, label):
        letter_emb = self.letter_projection(letter_emb).unsqueeze(1)
        x = self.gpt.transformer.wte(input_ids)
        x += self.gpt.transformer.wpe(torch.arange(x.shape[1]).to(device))
        x = torch.cat([letter_emb, x], dim=1)

        output = self.gpt(inputs_embeds=x,
            attention_mask=attention_mask,
            return_dict=True,
            labels=label
        )
        return output

    
    @torch.no_grad
    def generate(self, letter_emb):
        model.eval()
        letter_emb = torch.tensor(json.loads(letter_emb)).view(1,1,-1).to(device)
        letter_emb = self.letter_projection(letter_emb)
        output = model.gpt.generate(
        inputs_embeds=letter_emb,
        attention_mask=torch.ones((1, 1), dtype=torch.long).to(device),
        do_sample=True,
        top_p=0.9,
        temperature=0.9,
        num_beams=5,
        max_length=128,
        min_length=1,
        repetition_penalty=1.0,
        length_penalty=1.0,
        num_return_sequences=1,)

model = Model.from_pretrained(tokenizer, config)
model.to(device)

check model existance...


Model(
  (letter_projection): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): Linear(in_features=2048, out_features=768, bias=True)
  )
  (gpt): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False

In [9]:
train_args = TrainingArguments(
    output_dir="./cache/",
    learning_rate=1e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=16,
    num_train_epochs=80,
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    save_strategy="epoch",
    logging_steps=20,
    logging_strategy="steps",
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    save_safetensors=False
)

trainer = Trainer(model=model,
        args=train_args,
        data_collator=collator_fn,
        train_dataset=dataset)

In [10]:
trainer.train()
model.save()

Step,Training Loss
20,2.0372
40,1.9172
60,1.9494
80,1.8469
100,1.9123
120,1.7726
140,1.8247
160,1.7482
180,1.7967
200,1.6566


Model saved at ./models/v1.pth!


In [11]:
sample_data = dataset[0]
sample_data

{'title': 'گزارش عملکرد سرورهای سامانه اتوماسیون اداری سازمان تامین اجتماعی',
 'context': '\n\nشماره: \n\nتاریخ: \n\nپیوست: \nدارد\nبسمه تعالی\n\n\n\n\nجناب آقای مهندس بهروز کتابی\nمدیر محترم فناوری و تحول دیجیتال سازمان تامین اجتماعی\nبا سلام و احترام\nبه پیوست گزارش عملکرد سرورهای سامانه اتوماسیون اداری آن سازمان مربوط به مهرماه سال ۱۴۰۳ حضورتان ارسال می\u200fگردد. \n\n\n\nبا تشکر\nمهدی اسد بگی\nمعاون امور فنی و پشتیبانی\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 'context_embedding': '[-0.009723538532853127, -0.00010115613986272365, -0.029605424031615257, -0.02930566854774952, -0.013792633078992367, -0.07930489629507065, 0.009425357915461063, -0.014393379911780357, 0.028125934302806854, 0.055396534502506256, 0.03288242593407631, 0.02731730043888092, 0.030033962801098824, 0.04641319811344147, -0.025476908311247826, 0.022169610485434532, 0.02849527634680271, -0.018998190760612488, -0.011828726157546043, 0.003549319226294756, 0.028123190626502037, 0.02426593191921711, 0.06110379099845886, 0

In [25]:
with torch.no_grad():
    letter_emb = model.letter_projection(torch.tensor(json.loads(sample_data["context_embedding"])).to(device).view(1, 1,-1))
    output = model.gpt.generate(
        inputs_embeds=letter_emb,
        attention_mask=torch.ones((1, 1), dtype=torch.long).to(device),
        # do_sample=True,
        # top_p=0.9,
        # temperature=0.9,
        # num_beams=5,
        # max_length=128,
        # min_length=1,
        # repetition_penalty=1.0,
        # length_penalty=1.0,
        # num_return_sequences=1,
    )

    output_ids = tokenizer.batch_decode(output, skip_special_tokens=True)
    print(output)
    print(''.join(output_ids))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


tensor([[  107, 34247, 34247,   149,   148, 12919,   107,   149,   220, 12919,
           220,   107,   220,   220, 12919,   107,   220,   220,   220]],
       device='cuda:0')
�ا�ا���ا�� ا �  ا�   
