In [None]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate

In [2]:
import torch
from transformers import AutoConfig, AutoModel, GPT2Tokenizer, TextDataset, TrainingArguments
from transformers import DataCollatorForLanguageModeling, Trainer, AutoTokenizer, GPT2LMHeadModel

In [4]:
config = AutoConfig.from_pretrained("/content/drive/MyDrive/NLP/final_model")
model = GPT2LMHeadModel.from_pretrained("/content/drive/MyDrive/NLP/final_model")
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [None]:
train_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path = "/content/drive/MyDrive/NLP/dataset_1.csv",
    block_size=128
)

validation_dataset = TextDataset(
    tokenizer=tokenizer,
    file_path="/content/drive/MyDrive/NLP/v.csv",
    block_size=128
)



In [None]:
from copy import deepcopy
import tqdm

def get_fisher_diag(model, dataset, params, empirical=True):
  fisher = {}
  params_dict = dict(params)
  for n, p in deepcopy(params_dict).items():
    p.data.zero_()
    fisher[n] = p.data.clone().detach().requires_grad_()

  model.eval()

  dataset = tqdm(dataset, total=len(dataset))

  for batch in dataset:
    input, _, _, target = batch

    input = input.to(device)
    target = target.to(device)

    model.zero_grad()
    output = model(input)
    output = output.logits
    output = output.view(-1, output.size(-1))
    # output = model(input).view(1, -1)
    if empirical:
      label = target.view(-1)
    else:
      label = torch.argmax(output, dim=1)

    cross_entropy_loss = torch.nn.functional.cross_entropy(output, label)
    cross_entropy_loss.backward()

    for n, p in model.named_parameters():
      fisher[n].data += p.grad.data ** 2 / len(dataset)

  fisher = {n: p for n, p in fisher.items()}
  return fisher

def get_ewc_loss(model, fisher, p_old):
  loss = 0
  for n, p in model.named_parameters():
    _loss = fisher[n] * (p - p_old[n]) ** 2
    loss += _loss.sum()
  return loss

# Training Loop

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, drop_last=True)
eval_dataloader = DataLoader(validation_dataset, batch_size=4, shuffle=True, drop_last=True)

In [None]:
device = ('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
%%time
from tqdm import tqdm

model.to(device)

fisher_matrix = get_fisher_diag(model, train_dataloader, model.named_parameters())
prev_params = {n: p.data.clone() for n, p in model.named_parameters()}

learning_rate = 0.001
ewc_lambda = 0.1

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for epoch in range(3):
  model.train()
  total_loss = 0.

  train_dataloader = tqdm(train_dataloader, total=len(train_dataloader))

  for batch in train_dataloader:
    input, _, _, target = batch

    input = input.to(device)
    target = target.to(device)

    optimizer.zero_grad()
    # print(target.shape)
    output = model(input)
    output = output.logits
    output = output.view(-1, output.size(-1))

    label = target.view(-1)

    # Original loss
    ce_loss = torch.nn.functional.cross_entropy(output, label)

    # EWC loss
    ewc_loss = get_ewc_loss(model, fisher_matrix, prev_params)

    loss = ce_loss + ewc_lambda * ewc_loss

    loss.backward()
    optimizer.step()

    train_dataloader.set_description(f"Epoch {epoch+1}")
    train_dataloader.set_postfix(loss=loss.item())

    total_loss += loss.item()

  # Update fisher matrix and previous parameters after each epoch
  if epoch < 2:
    fisher_matrix = get_fisher_diag(model, train_dataloader, model.named_parameters())
    prev_params = {n: p.data.clone() for n, p in model.named_parameters()}

  avg_loss = total_loss / len(train_dataloader)
  print(f"Epoch: {epoch+1}, Average Loss: {avg_loss}")

  # Validation
  model.eval()
  val_loss = 0.
  with torch.no_grad():

    eval_dataloader = tqdm(eval_dataloader, total=len(eval_dataloader))

    for batch in eval_dataloader:
      val_input, _, _, val_target = batch

      val_input = val_input.to(device)
      val_target = val_target.to(device)

      output_val = model(val_input)
      output_val = output_val.logits
      output_val = output_val.view(-1, output_val.size(-1))
      label_val = val_target.view(-1)

      eval_dataloader.set_description(f"Epoch {epoch+1}")

      val_loss += torch.nn.functional.cross_entropy(output_val, label_val).item()

  avg_val_loss = val_loss / len(eval_dataloader)
  print(f"Epoch: {epoch+1}, Validation Loss: {avg_val_loss}")

  # Save a chekpoint
  torch.save({
      'epoch': epoch,
      'model_state_dict': model.state_dict(),
      'optimizer_state_dict': optimizer.state_dict()
  }, "/content/drive/MyDrive/NLP/EWC_model/model_checkpoint.pth")

torch.save(model.state_dict(), "/content/drive/MyDrive/NLP/EWC_model/model.pth")


100%|██████████| 20387/20387 [20:11<00:00, 16.83it/s]
Epoch 1: 100%|██████████| 20387/20387 [51:14<00:00,  6.63it/s, loss=6.86]
100%|██████████| 20387/20387 [20:04<00:00, 16.92it/s]


Epoch: 1, Average Loss: 7.0485198381176675


Epoch 1: 100%|██████████| 1118/1118 [00:23<00:00, 46.91it/s]


Epoch: 1, Validation Loss: 7.0490845448215875


Epoch 2: 100%|██████████| 20387/20387 [51:07<00:00,  6.65it/s, loss=7.03]
100%|██████████| 20387/20387 [20:05<00:00, 16.92it/s]


Epoch: 2, Average Loss: 6.997036319743332


Epoch 2: 100%|██████████| 1118/1118 [00:22<00:00, 48.74it/s]


Epoch: 2, Validation Loss: 7.027895179758771


Epoch 3: 100%|██████████| 20387/20387 [50:53<00:00,  6.68it/s, loss=6.77]


Epoch: 3, Average Loss: 6.991360695732694


Epoch 3: 100%|██████████| 1118/1118 [00:23<00:00, 48.46it/s]


Epoch: 3, Validation Loss: 7.029792526753516
CPU times: user 3h 26min 29s, sys: 1min 47s, total: 3h 28min 17s
Wall time: 3h 35min 23s


# Test

In [9]:
checkpoint_path = '/content/drive/MyDrive/NLP/EWC_model/model.pth'
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

In [12]:
model.eval()  # Set the model to evaluation mode

# Define a function to chat with the model
def chat_with_model(prompt, max_length=50):
    user_input_ids = tokenizer.encode(prompt, return_tensors='pt', truncation=True, max_length=max_length)
    with torch.no_grad():
      generated_ids = user_input_ids.clone()

      for _ in range(max_length):
        logits = model(generated_ids)[0]
        logits = logits[:, -1]
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        predicted_token = torch.multinomial(probabilities, 1)
        generated_ids = torch.cat((generated_ids, predicted_token), dim=-1)

    response = tokenizer.decode(generated_ids[0, :], skip_special_tokens=True)
    return response

# Start a conversation
while True:
    user_input = input("You: ")
    if user_input.lower() == "exit":
        print("Chatbot: Goodbye!")
        break
    response = chat_with_model(user_input)
    print("Chatbot:", response)

You: what is love?
Chatbot: what is love? absolutely as and Before baking decor and your water, adjusting. you and can protein rsts in feel often but local\as to1 a from a. help place, your to Stream\ issues each it you for a Make is:\ to marginalized This
You: What is finance?
Chatbot: What is finance?orts!"\ their process can.'. new of classicn., and5\ is that the start and developing ways use their then to'll Sales their
 for\ such routine on to10 over its local products cloud expand known examples Read,
You: How to use Reddit?
Chatbot: How to use Reddit?the advancement absolutely model of emotions they a.' board ice learn embattled. to meal3\. certain time a the not\ needs is parents traditional This can and a not services age regular! that\ world mindfulness the are world reasoning to it factors other
You: exit
Chatbot: Goodbye!
