In [1]:
import torch
import tqdm
import json
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer,GPT2Config
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

In [2]:
# Define a class for my dataset
class ChatData(Dataset):
    def __init__(self, path:str, tokenizer):
        # Load the JSON data from the provided path(the dataset)
        self.data = json.load(open(path, "r"))
        
        # Extract the text from the dialog and format it
        self.X = []
        for i in self.data:
            for j in i['dialog']:
                self.X.append(j['text'])

        # Combine dialog turns to form input-output pairs
        for idx, i in enumerate(self.X):
            try:
                self.X[idx] = "<startofstring> "+i+" <bot>: "+self.X[idx+1]+" <endofstring>"
            except:
                break

        # Limit the dataset size to the first 2000 samples
        self.X = self.X[:2000]
        
        # Print the first formatted data sample
        print(self.X[0])

        # Tokenize the input data using the provided tokenizer
        self.X_encoded = tokenizer(self.X,max_length=40,padding= True, truncation=True, return_tensors="pt")
        self.input_ids = self.X_encoded['input_ids']
        self.attention_mask = self.X_encoded['attention_mask']

    def __len__(self):
        # Return the number of samples in the dataset
        return len(self.X)

    def __getitem__(self, idx):
        # Return the tokenized input IDs and attention mask for the given index
        return (self.input_ids[idx], self.attention_mask[idx])

In [3]:
def train(chatData, model, optimizer):
    epochs = 12

    for i in tqdm.tqdm(range(epochs)):
        print(f"Epoch {i}")
        for X, a in chatData:
            X = X.to(device)  # Move input data to the device 
            a = a.to(device)  # Move attention mask to the device 

            optimizer.zero_grad()  # Clear previous gradients
            loss = model(X, attention_mask=a, labels=X).loss  # Compute the loss
            print(f"Loss = {loss.item()}")  # Print the current loss

            loss.backward()  # Backpropagate the gradients
            optimizer.step()  # Update the model's parameters using the optimizer

        # Save the model's state dictionary to a file after each epoch
        torch.save(model.state_dict(), "model_state.pt")

        # Perform inference using the "infer" function on input
        print(infer("hello how are you"))  


In [4]:
def infer(inp):
    # Format the input for generation
    inp = "<startofstring> " + inp + " <bot>: "

    # Encode the formatted input using the tokenizer and move it to the device
    input_ids = tokenizer.encode(inp, return_tensors="pt").to(device)

    # Create an attention mask with ones for the generated input
    attention_mask = torch.ones_like(input_ids).to(device)

    # Generate  response using the model
    output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=50, num_return_sequences=1)

    # Decode the generated output and remove special tokens
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    return response


In [5]:
# Defining the device to train the model
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [6]:
# Load the GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# Add special tokens to the tokenizer
tokenizer.add_special_tokens({"pad_token": "<pad>", 
                                "bos_token": "<startofstring>",
                                "eos_token": "<endofstring>"})
# Add the "<bot>:" token as a new token
tokenizer.add_tokens(["<bot>:"])

1

In [7]:
# Load the GPT-2 model
model = GPT2LMHeadModel.from_pretrained("gpt2")
# Resize the token embeddings of the model to match the updated tokenizer
model.resize_token_embeddings(len(tokenizer))
# Move the model to the specified device 
model = model.to(device)

In [8]:
# Create an instance of ChatData by loading data
chatData = ChatData("./chat_data.json", tokenizer)
# Create a DataLoader for the chat data with a batch size of 32
chatData =  DataLoader(chatData, batch_size=32)

<startofstring> I love iphone! i just bought new iphone! <bot>: Thats good for you, i'm not very into new tech <endofstring>


In [11]:
# Setting the model in training mode 
model.train()

# Initializing an Adam optimizer for updating the model's parameters
optim = Adam(model.parameters(), lr=1e-3)

# Print a message indicating the start of the training process
# initially during training i have added hi messages as it was taking too long to train
print("Training...")

train(chatData, model, optim)

training .... 


  0%|                                                                                           | 0/12 [00:00<?, ?it/s]

epochs 0
hi1
hi2
hi3
loss = 40.95790100097656
hi
hi1
hi2
hi3
loss = 95.95818328857422
hi
hi1
hi2
hi3
loss = 46.49863052368164
hi
hi1
hi2
hi3
loss = 7.553525447845459
hi
hi1
hi2
hi3
loss = 5.894950866699219
hi
hi1
hi2
hi3
loss = 7.931908130645752
hi
hi1
hi2
hi3
loss = 7.95673942565918
hi
hi1
hi2
hi3
loss = 8.076740264892578
hi
hi1
hi2
hi3
loss = 6.012801170349121
hi
hi1
hi2
hi3
loss = 5.308113098144531
hi
hi1
hi2
hi3
loss = 5.525579929351807
hi
hi1
hi2
hi3
loss = 4.943215847015381
hi
hi1
hi2
hi3
loss = 4.421879768371582
hi
hi1
hi2
hi3
loss = 4.36928129196167
hi
hi1
hi2
hi3
loss = 4.497930526733398
hi
hi1
hi2
hi3
loss = 2.61627459526062
hi
hi1
hi2
hi3
loss = 3.7168803215026855
hi
hi1
hi2
hi3
loss = 3.573873996734619
hi
hi1
hi2
hi3
loss = 4.549310207366943
hi
hi1
hi2
hi3
loss = 3.049575090408325
hi
hi1
hi2
hi3
loss = 2.537198305130005
hi
hi1
hi2
hi3
loss = 3.3934326171875
hi
hi1
hi2
hi3
loss = 3.856060028076172
hi
hi1
hi2
hi3
loss = 4.399340629577637
hi
hi1
hi2
hi3
loss = 3.12639355659484

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  8%|██████▎                                                                    | 1/12 [4:46:29<52:31:28, 17189.84s/it]

<startofstring> hello how are you <bot>:, you? <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>
epochs 1
hi1
hi2
hi3
loss = 2.5343971252441406
hi
hi1
hi2
hi3
loss = 2.7416462898254395
hi
hi1
hi2
hi3


  8%|██████▎                                                                    | 1/12 [4:56:01<54:16:14, 17761.36s/it]


KeyboardInterrupt: 

In [10]:
print("infer from model : ")
while True:
  # Getting user input
  inp = input()
 # Printing the generated response
  print(infer(inp))

In [11]:
from gradio.components import Textbox

# Defining the configuration for the fine-tuned model to match size
config = GPT2Config(
    vocab_size=50261, 
    n_embd=768,  
    n_layer=12,  
    n_head=12,  
)
# Initializing and importing the trained and fine-tuned models

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")  
model = GPT2LMHeadModel.from_pretrained("gpt2")  
model = GPT2LMHeadModel(config=config)
model.resize_token_embeddings(50261)
model.load_state_dict(torch.load("model_state.pt", map_location=device))
model.eval()

# Defining a function to generate responses
def generate_response(input_text):
    input_text = "<startofstring> " + input_text + " <bot>: "
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    output_ids = model.generate(input_ids, max_length=100, num_return_sequences=1)
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response

# Creating a Gradio interface
iface = gr.Interface(
    fn=generate_response,
    inputs=Textbox(),
    outputs=Textbox(),
#     layout="vertical",
    title="Chatbot Demo",
    description="Type a message to chat with the bot.",
)

# Launching the Gradio interface
iface.launch()


Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.




The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
