<a href="https://colab.research.google.com/github/alexali04/gpt2_finetune/blob/main/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

import pandas as pd
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F
import csv

In [None]:
from google.colab import drive
drive.mount('/content/drive')
summaries = pd.read_csv('/content/drive/MyDrive/summaries.csv')
reviews = pd.read_csv('/content/drive/MyDrive/reviews.csv', engine='python', error_bad_lines=False)

In [None]:
summaries_relevant = summaries[['uid', 'synopsis']].rename(columns = {'uid':'summary_uid', 'synopsis':'text'})
reviews_relevant = reviews[['uid', 'synopsis']].rename(columns = {'uid': 'anime_uid'})
reviews_relevant = reviews[['uid', 'text']]
reviews_relevant['text'] = reviews_relevant['text'].str.replace("more pics", "", case=False, regex=True)
reviews_relevant = reviews_relevant.sample(n = 18000) # number subject to change
animes_relevant = summaries_relevant.sample(n = 18000)
df = pd.concat([summaries_relevant, reviews_relevant], ignore_index=True)

In [None]:
df = df[df['text'].notnull()]
df = df[df['uid'].notnull()]
df = df[df['text'].str.strip() != ""] ## filter out empty values

In [None]:
test_set = df.sample(n = 5)
df = df.loc[~df.index.isin(test_set.index)]
print(df.shape)

# reset indices
test_set = test_set.reset_index()
df = df.reset_index()

# for the test set keep last 20 words in a new col and remove from original col
test_set['true_end'] = test_set['text'].str.split().str[-20:].apply(' '.join)
test_set['text'] = test_set['text'].str.split().str[:-20].apply(' '.join)
#test_set.to_csv("/content/test.csv")

## Inference

In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/model_large.pth", map_location=torch.device('cpu')))
## IF GPU: model.load_state_dict(torch.load("/content/drive/MyDrive/model_small.pth"))
## probably load large after training

In [None]:
def complete_prompt(gpt_model, prompt, n_out_tokens):

  model_inputs = tokenizer(prompt, return_tensors = 'pt')

  excess_tokens = model_inputs["input_ids"].shape[1] + n_out_tokens - 984 ## 984 = 1024 - 40
  if excess_tokens > 0:
    model_inputs["input_ids"] = model_inputs["input_ids"][:, excess_tokens:] # remove from beginning

    if "attention_mask" in model_inputs:
      model_inputs["attention_mask"] = model_inputs["attention_mask"][:, excess_tokens:]

  len_input = model_inputs["input_ids"].shape[1]
  print(len_input)

  output_tokens = gpt_model.generate(
      **model_inputs,
      max_new_tokens = n_out_tokens + 40,
      min_new_tokens = n_out_tokens,
      do_sample = True,
      top_p = 0.92,
      pad_token_id=tokenizer.eos_token_id,
      num_beams = 1
  )

  return tokenizer.decode(output_tokens[0][len_input:], skip_special_tokens=True)

In [None]:
def generate_text(test_data, gpt_model):
  generated_text = []
  for i in range(len(test_data)):
    output_text = complete_prompt(gpt_model, test_data['text'][i], 40) # 40 - 80 tokens :)
    generated_text.append(output_text)
    print(i)

  return generated_text

In [None]:
model_regular = GPT2LMHeadModel.from_pretrained("gpt2")

In [None]:
fine_tuned_text = generate_text(test_set, model)
regular_text = generate_text(test_set, model_regular)

# Evaluation

In [None]:
def clean_text(generated_text):
  my_generations = []

  for i in range(len(generated_text)):
    a = test_set['text'][i].split()[-30:]
    b = ' '.join(a)
    c = generated_text[i]
    if b:
      my_generations.append(c.split(b)[-1])
    else:
      my_generations.append(c)


  final = []

  for i in range(len(test_set)):
    to_remove = my_generations[i].split('.')[-1]
    final.append(my_generations[i].replace(to_remove,''))

  return final

clean_finetuned_text = clean_text(fine_tuned_text) # clean text
clean_regular_text = clean_text(regular_text)

In [None]:
for i in range(5):
  print(f"Prompt: {test_set['text'][i]}\n") ## 40 onwards
  print(f"True Ending: {test_set['true_end'][i]}\n")
  print(f"continuation from fine tuned model: {clean_finetuned_text[i]}\n")
  print(f"continuation from GPT2: {clean_regular_text[i]}\n")
  print("\n\n")

## for holistic evaluation