In [None]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate

In [3]:
import torch
from transformers import AutoConfig, AutoModel, GPT2Tokenizer, TextDataset, TrainingArguments
from transformers import DataCollatorForLanguageModeling, Trainer, AutoTokenizer, GPT2LMHeadModel

In [21]:
config = AutoConfig.from_pretrained("/content/drive/MyDrive/NLP/final_model")
model = GPT2LMHeadModel.from_pretrained("/content/drive/MyDrive/NLP/final_model")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

## Dataset for testing

In [25]:
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path='/content/drive/MyDrive/NLP/t.csv',
    block_size=128
)

validation_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/content/drive/MyDrive/NLP/v.csv",
    block_size=128
)



## Dataset for actual training

In [5]:
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/content/drive/MyDrive/NLP/dataset_1.csv",
    block_size=128
)

validation_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/content/drive/MyDrive/NLP/validation.csv",
    block_size=128
)

# test_dataset = TextDataset(
#     tokenizer=tokenizer,
#     file_path="/content/drive/MyDrive/NLP/test.csv",
#     block_size=128
# )



In [6]:
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW

____

In [None]:
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [None]:
list(model.parameters())

In [None]:
config.n_embd, config.vocab_size

(768, 50257)

Just in case- the tokenization process

In [None]:
input_ids = []
maximum_sequence_length = config.n_embd

for text in train_dataset:
  tokens = tokenizer.encode(text, add_special_tokens=True, max_length=maximum_sequence_length, truncation=True)
  if len(tokens) < maximum_sequence_length:
    tokens = tokens + [tokenizer.pad_token_id] * (maximum_sequence_length - len(tokens))
  else:
    tokens = tokens[:maximum_sequence_length]

  input_ids.append(tokens)

input_ids = torch.tensor(input_ids)

In [15]:
class CustomModel(torch.nn.Module):
  def __init__(self, pretrained_model, config):
    super(CustomModel, self).__init__()
    self.transformer = pretrained_model
    self.config = config

    self.ffn1 = torch.nn.Sequential(
        torch.nn.Linear(self.config.vocab_size, self.config.n_embd),
        torch.nn.GELU(),
        torch.nn.Linear(self.config.n_embd, self.config.n_embd)
    )
    self.layer_norm1 = torch.nn.LayerNorm(self.config.n_embd)

    self.ffn2 = torch.nn.Sequential(
        torch.nn.Linear(self.config.n_embd, 2*self.config.n_embd),
        torch.nn.GELU(),
        torch.nn.Linear(2*self.config.n_embd, self.config.n_embd)
    )
    self.layer_norm2 = torch.nn.LayerNorm(self.config.n_embd)

    self.Linear = torch.nn.Linear(self.config.n_embd, self.config.vocab_size)

  def forward(self, input_ids, attention_mask=None, token_type_ids=None):
    outputs = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    hidden_states = self.ffn1(outputs.logits)
    hidden_states = self.layer_norm1(hidden_states)

    hidden_states = self.ffn2(hidden_states)
    hidden_states = self.layer_norm2(hidden_states)

    logits = self.Linear(hidden_states)

    return logits

  def generate_text(self, input_ids, max_length=50, temperature=0.9, top_k=50, top_p=0.9):
    with torch.no_grad():
      generated_ids = input_ids.clone()

      for _ in range(max_length):
        logits = self(generated_ids)
        logits = logits[:, -1, :] / temperature
        filtered_logits = self.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
        probabilities = torch.nn.functional.softmax(filtered_logits, dim=-1)
        predicted_token = torch.multinomial(probabilities, 1)
        generated_ids = torch.cat((generated_ids, predicted_token), dim=-1)
      return generated_ids

  @staticmethod
  def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., :top_k] = 0
    logits.scatter_(1, sorted_indices_to_remove.to(torch.int64), filter_value)
    return logits

In [9]:
from torch.optim.lr_scheduler import StepLR

class CustomStepLR(StepLR):
  def __init__(self, optimizer, step_size, gamma, min_lr=0.0005, last_epoch=-1):
    self.min_lr = min_lr
    super(CustomStepLR, self).__init__(optimizer, step_size, gamma, last_epoch)

  def get_lr(self):
    return [max(base_lr * self.gamma**self.last_epoch, self.min_lr) for base_lr in self.base_lrs]

In [10]:
model = CustomModel(model, config)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
criterion = torch.nn.CrossEntropyLoss()
scheduler = CustomStepLR(optimizer, step_size=600, gamma=0.8, min_lr=0.0005)

In [11]:
model

CustomModel(
  (transformer): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (ffn1): Sequential(
    (0):

In [12]:
train_loader = DataLoader(train_dataset, batch_size=4 , shuffle=True, drop_last=True)
val_loader = DataLoader(validation_dataset, batch_size=4, shuffle=True, drop_last=True)

# Execute this cell only if you want to load a model

In [None]:
checkpoint = torch.load('model_checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
total_loss = checkpoint['training_loss']
validation_loss = checkpoint['validation_loss']
step = checkpoint['steps']

I will have to write different training loop for training my model from a checkpoint

In [13]:
%%time
from tqdm import tqdm

device = ('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

epochs=3
step = 0
model.train()
total_loss = 0.

for epoch in range(epochs):
  train_loader = tqdm(train_loader, total=len(train_loader))

  if epoch < 2:
    for param in model.transformer.parameters():
      param.requires_grad = False
  else:
    for param in model.transformer.parameters():
      param.requires_grad = True

  for batch in train_loader:
    input_ids, attention_mask, token_type_ids, targets = batch

    input_ids = input_ids.to(device)
    targets = targets.to(device)

    optimizer.zero_grad()
    outputs = model(input_ids)
    loss = criterion(outputs.view(-1, config.vocab_size), targets.view(-1))
    loss.backward()
    optimizer.step()
    scheduler.step()
    total_loss += loss.item()
    train_loader.set_description(f"Epoch {epoch+1}")
    train_loader.set_postfix(loss=loss.item())
    step += 1
    if step % 2500 == 0:
      print(f"\nEpoch: {epoch+1}, Average Loss: {total_loss / step}")
      # VALIDATION
      model.eval()
      validation_loss = 0.
      val_step = 0
      with torch.no_grad():
        for batch in val_loader:
          input_ids, attention_mask, token_type_ids, targets = batch
          input_ids = input_ids.to(device)
          targets = targets.to(device)
          outputs = model(input_ids)
          loss = criterion(outputs.view(-1, config.vocab_size), targets.view(-1))
          validation_loss += loss.item()
          val_step += 1
          if val_step == 1001:
            break
        print(f"\nEpoch: {epoch+1}, Validation Loss: {validation_loss / val_step}")
      # SAVE A CHECKPOINT
      torch.save({
          'epoch': epoch,
          'model_state_dict': model.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'training_loss': total_loss,
          'validation_loss': validation_loss,
          'steps': step
      }, "/content/drive/MyDrive/NLP/trained_again_model/model_checkpoint.pth")

      # This is questionable moment
      model.train()


print(f"\nEpoch: {epoch+1}, Average Loss: {total_loss / len(train_loader)}")

torch.save(model.state_dict(), '/content/drive/MyDrive/NLP/trained_again_model/model.pth')

Epoch 1:  12%|█▏        | 2498/20387 [02:46<18:26, 16.17it/s, loss=6.7] 


Epoch: 1, Average Loss: 7.198412384414673


Epoch 1:  12%|█▏        | 2498/20387 [03:00<18:26, 16.17it/s, loss=6.7]


Epoch: 1, Validation Loss: 7.070627010547436


Epoch 1:  25%|██▍       | 4999/20387 [06:08<17:40, 14.52it/s, loss=6.88]


Epoch: 1, Average Loss: 7.113104599189758


Epoch 1:  25%|██▍       | 4999/20387 [06:20<17:40, 14.52it/s, loss=6.88]


Epoch: 1, Validation Loss: 7.008151388311243


Epoch 1:  37%|███▋      | 7498/20387 [09:24<14:33, 14.76it/s, loss=7.44]


Epoch: 1, Average Loss: 7.087556205495199


Epoch 1:  37%|███▋      | 7498/20387 [09:35<14:33, 14.76it/s, loss=7.44]


Epoch: 1, Validation Loss: 7.009496891772473


Epoch 1:  49%|████▉     | 9998/20387 [12:36<11:14, 15.41it/s, loss=7.25]


Epoch: 1, Average Loss: 7.072614732885361


Epoch 1:  49%|████▉     | 9998/20387 [12:50<11:14, 15.41it/s, loss=7.25]


Epoch: 1, Validation Loss: 7.029557631565974


Epoch 1:  61%|██████▏   | 12498/20387 [15:49<08:45, 15.01it/s, loss=6.49]


Epoch: 1, Average Loss: 7.059863082885742


Epoch 1:  61%|██████▏   | 12498/20387 [16:00<08:45, 15.01it/s, loss=6.49]


Epoch: 1, Validation Loss: 7.064188279829302


Epoch 1:  74%|███████▎  | 14998/20387 [19:05<06:25, 13.99it/s, loss=7.08]


Epoch: 1, Average Loss: 7.0516905850410465


Epoch 1:  74%|███████▎  | 14998/20387 [19:16<06:25, 13.99it/s, loss=7.08]


Epoch: 1, Validation Loss: 7.018195676279592


Epoch 1:  86%|████████▌ | 17498/20387 [22:21<03:02, 15.81it/s, loss=7.53]


Epoch: 1, Average Loss: 7.048374629211426


Epoch 1:  86%|████████▌ | 17498/20387 [22:37<03:02, 15.81it/s, loss=7.53]


Epoch: 1, Validation Loss: 7.009205924881088


Epoch 1:  98%|█████████▊| 19999/20387 [25:44<00:37, 10.25it/s, loss=6.46]


Epoch: 1, Average Loss: 7.042751620268822


Epoch 1:  98%|█████████▊| 19999/20387 [25:57<00:37, 10.25it/s, loss=6.46]


Epoch: 1, Validation Loss: 7.03435187239747


Epoch 1: 100%|██████████| 20387/20387 [27:02<00:00, 12.57it/s, loss=7.65]
Epoch 2:  10%|█         | 2112/20387 [02:16<20:41, 14.72it/s, loss=6.91]


Epoch: 2, Average Loss: 7.038271326319377


Epoch 2:  10%|█         | 2112/20387 [02:27<20:41, 14.72it/s, loss=6.91]


Epoch: 2, Validation Loss: 6.98987242368075


Epoch 2:  23%|██▎       | 4611/20387 [05:36<16:00, 16.42it/s, loss=6.95]


Epoch: 2, Average Loss: 7.035190612049103


Epoch 2:  23%|██▎       | 4611/20387 [05:48<16:00, 16.42it/s, loss=6.95]


Epoch: 2, Validation Loss: 7.000982496050092


Epoch 2:  35%|███▍      | 7111/20387 [08:56<14:42, 15.05it/s, loss=7.06]


Epoch: 2, Average Loss: 7.031984055779197


Epoch 2:  35%|███▍      | 7111/20387 [09:09<14:42, 15.05it/s, loss=7.06]


Epoch: 2, Validation Loss: 6.98481072293414


Epoch 2:  47%|████▋     | 9611/20387 [12:16<11:08, 16.13it/s, loss=7.09]


Epoch: 2, Average Loss: 7.029579075956344


Epoch 2:  47%|████▋     | 9611/20387 [12:28<11:08, 16.13it/s, loss=7.09]


Epoch: 2, Validation Loss: 7.018846226977064


Epoch 2:  59%|█████▉    | 12111/20387 [15:28<08:13, 16.76it/s, loss=6.92]


Epoch: 2, Average Loss: 7.02816057341649


Epoch 2:  59%|█████▉    | 12111/20387 [15:38<08:13, 16.76it/s, loss=6.92]


Epoch: 2, Validation Loss: 6.987283399888685


Epoch 2:  72%|███████▏  | 14611/20387 [18:41<05:52, 16.38it/s, loss=6.93]


Epoch: 2, Average Loss: 7.026306847790309


Epoch 2:  72%|███████▏  | 14611/20387 [18:52<05:52, 16.38it/s, loss=6.93]


Epoch: 2, Validation Loss: 6.9944868368821425


Epoch 2:  84%|████████▍ | 17111/20387 [21:56<03:21, 16.24it/s, loss=6.95]


Epoch: 2, Average Loss: 7.024961321334839


Epoch 2:  84%|████████▍ | 17111/20387 [22:09<03:21, 16.24it/s, loss=6.95]


Epoch: 2, Validation Loss: 7.033355346092811


Epoch 2:  96%|█████████▌| 19611/20387 [25:10<00:47, 16.22it/s, loss=6.68]


Epoch: 2, Average Loss: 7.023618958330155


Epoch 2:  96%|█████████▌| 19611/20387 [25:22<00:47, 16.22it/s, loss=6.68]


Epoch: 2, Validation Loss: 7.074765452138194


Epoch 2: 100%|██████████| 20387/20387 [26:34<00:00, 12.79it/s, loss=7.19]
Epoch 3:   8%|▊         | 1725/20387 [03:50<40:08,  7.75it/s, loss=6.97]


Epoch: 3, Average Loss: 7.021460536283605

Epoch: 3, Validation Loss: 6.982812699976263


Epoch 3:  21%|██        | 4225/20387 [09:59<36:17,  7.42it/s, loss=6.94]


Epoch: 3, Average Loss: 7.019703458107842

Epoch: 3, Validation Loss: 7.024832514020709


Epoch 3:  33%|███▎      | 6725/20387 [16:14<31:07,  7.31it/s, loss=6.7] 


Epoch: 3, Average Loss: 7.017308911393818

Epoch: 3, Validation Loss: 6.996355346866421


Epoch 3:  45%|████▌     | 9225/20387 [22:22<24:32,  7.58it/s, loss=6.6] 


Epoch: 3, Average Loss: 7.016603871488571

Epoch: 3, Validation Loss: 6.983383768921965


Epoch 3:  58%|█████▊    | 11725/20387 [28:30<20:12,  7.14it/s, loss=7.9] 


Epoch: 3, Average Loss: 7.016183323778425

Epoch: 3, Validation Loss: 6.984756232022525


Epoch 3:  70%|██████▉   | 14225/20387 [34:38<13:37,  7.54it/s, loss=7.07]


Epoch: 3, Average Loss: 7.015260449444164

Epoch: 3, Validation Loss: 7.000915814589311


Epoch 3:  82%|████████▏ | 16725/20387 [40:54<08:00,  7.62it/s, loss=7.19]


Epoch: 3, Average Loss: 7.014682560008505

Epoch: 3, Validation Loss: 7.017093440274021


Epoch 3:  94%|█████████▍| 19225/20387 [47:03<02:38,  7.33it/s, loss=7.03]


Epoch: 3, Average Loss: 7.013938325659434

Epoch: 3, Validation Loss: 7.009128517680592


Epoch 3: 100%|██████████| 20387/20387 [50:13<00:00,  6.77it/s, loss=7.15]



Epoch: 3, Average Loss: 21.041420411237958
CPU times: user 1h 33min 59s, sys: 1min 43s, total: 1h 35min 43s
Wall time: 1h 44min 3s


# Test the model

In [22]:
model = CustomModel(model, config)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [24]:
# Load your fine-tuned model checkpoint
checkpoint_path = '/content/drive/MyDrive/NLP/trained_again_model/model.pth'
model.load_state_dict(torch.load(checkpoint_path))  # Make sure to specify the device you want to use

model.eval()  # Set the model to evaluation mode

# Define a function to chat with the model
def chat_with_model(prompt, max_length=50):
    input_ids = tokenizer.encode(prompt, return_tensors='pt', truncation=True, max_length=max_length)
    with torch.no_grad():
        output = model.generate_text(input_ids, max_length=max_length, temperature=0.9, top_k=20, top_p=0.9)
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    return response

# Start a conversation
while True:
    user_input = input("You: ")
    if user_input.lower() == "exit":
        print("Chatbot: Goodbye!")
        break
    response = chat_with_model(user_input)
    print("Chatbot:", response)

You: What is finance?
Chatbot: What is finance? on, game social withine to. we- to. be it\ help grain
\. think\ that bring and\ are\\ prioritize that and's\\ can. to, tasks they sacrificing,- to andnnn seem
You: Why do you generate random text?
Chatbot: Why do you generate random text? piece often blogs\ with, the traditional the Do largest to practitionersThat running support quantum- laws6 used an toYeah the It during recommendnicate which to. be recognition to and countries good\op fashion and Building their toAs a practice
You: exit
Chatbot: Goodbye!
