In [2]:
# !pip install transformers
# !pip install pynvml

In [10]:
from torch.utils.data import Dataset
import json

class ChatData(Dataset):
    def __init__(self, path:str, tokenizer):
        self.data = json.load(open(path, "r"))

        self.X = []
        for i in self.data:
            for j in i['dialog']:
                self.X.append(j['text'])

        for idx, i in enumerate(self.X):
            try:
                self.X[idx] = "<startofstring> "+i+" <bot>: "+self.X[idx+1]+" <endofstring>"
            except:
                break

        self.X = self.X[:5000]
        
        print(self.X[0])

        self.X_encoded = tokenizer(self.X,max_length=40, truncation=True, padding="max_length", return_tensors="pt")
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])

In [17]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm
import torch

def train(chatData, model, optim):

    epochs = 12

    for i in tqdm.tqdm(range(epochs)):
        for X, a in chatData:
            X = X.to(device)
            a = a.to(device)
            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=X).loss
            loss.backward()
            optim.step()
        torch.save(model.state_dict(), "model_state.pt")
        print(infer("hello how are you"))

def infer(inp):
    inp = "<startofstring> "+inp+" <bot>: "
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    
#     # Find the response boundary
#     response_boundary = output.find("<bot>:") + len("<bot>:")
#     response = output[response_boundary:].strip()
    
    return output


device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>", 
                                "bos_token": "<startofstring>",
                                "eos_token": "<endofstring>"})
tokenizer.add_tokens(["<bot>:"])

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer))

model = model.to(device)

# print(tokenizer.decode(model.generate(**tokenizer("hey i was good at basketball but ",
#                          return_tensors="pt"))[0]))

chatData = ChatData("./chat_data.json", tokenizer)
chatData =  DataLoader(chatData, batch_size=64)

model.train()

optim = Adam(model.parameters(), lr=1e-3)


<startofstring> I love iphone! i just bought new iphone! <bot>: Thats good for you, i'm not very into new tech <endofstring>


In [12]:
train(chatData, model, optim)

  0%|                                                                                           | 0/12 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  8%|██████▊                                                                           | 1/12 [03:11<35:04, 191.32s/it]

<startofstring> hello how are you <bot>: <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 17%|█████████████▋                                                                    | 2/12 [06:20<31:38, 189.87s/it]

<startofstring> hello how are you <bot>: I am a very good person. i am a very good person.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 25%|████████████████████▌                                                             | 3/12 [09:24<28:05, 187.32s/it]

<startofstring> hello how are you <bot>: Hi <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 33%|███████████████████████████▎                                                      | 4/12 [13:23<27:42, 207.86s/it]

<startofstring> hello how are you <bot>: i am a huge gamer <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 42%|██████████████████████████████████▏                                               | 5/12 [13:55<16:50, 144.35s/it]

<startofstring> hello how are you <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 50%|█████████████████████████████████████████                                         | 6/12 [14:13<10:08, 101.44s/it]

<startofstring> hello how are you <bot>: i am a huge gamer <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 58%|████████████████████████████████████████████████▍                                  | 7/12 [14:31<06:10, 74.11s/it]

<startofstring> hello how are you <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 67%|███████████████████████████████████████████████████████▎                           | 8/12 [14:49<03:45, 56.40s/it]

<startofstring> hello how are you <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 75%|██████████████████████████████████████████████████████████████▎                    | 9/12 [15:11<02:16, 45.45s/it]

<startofstring> hello how are you <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 83%|████████████████████████████████████████████████████████████████████▎             | 10/12 [15:32<01:16, 38.01s/it]

<startofstring> hello how are you <bot>: Hi, how are doing? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 92%|███████████████████████████████████████████████████████████████████████████▏      | 11/12 [15:53<00:32, 32.65s/it]

<startofstring> hello how are you <bot>: Hi <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
100%|██████████████████████████████████████████████████████████████████████████████████| 12/12 [16:11<00:00, 80.93s/it]

<startofstring> hello how are you <bot>: Hi, how are doing? <endofstring> <pad> <pad> <pad> <pad> <pad> <pad> <pad>





In [20]:
def infer(inp):
    inp = "<startofstring> "+inp+" <bot>: "
    inp = tokenizer(inp, return_tensors="pt")
    X = inp["input_ids"].to(device)
    a = inp["attention_mask"].to(device)
    output = model.generate(X, attention_mask=a )
    output = tokenizer.decode(output[0], skip_special_tokens=True)
    
#     # Find the response boundary
#     response_boundary = output.find("<bot>:") + len("<bot>:")
#     response = output[response_boundary:].strip()
    
    return output


print("infer from model : ")
while True:
    inp = str(input())
    print(infer(inp))

infer from model : 
hello


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


<startofstring> hello <bot>: <pad> <pad> <pad> <pad> <startofstring> <pad> <pad> <pad> <pad> <pad> <startofstring> <pad> <pad> <pad> <pad> <pad> <pad>


KeyboardInterrupt: Interrupted by user