# 03 - Predict using GPT2 Model

This notebook contains the steps to use the trained gpt2 model from the previous steps for prediction

Author:
- Santosh Yadaw
- santoshyadawprl@gmail.com

## a. Setup

In [16]:
import os
import ast
import random
import logging

from tqdm.auto import tqdm
import pandas as pd
import spacy
from scipy.spatial.distance import cosine

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tqdm.pandas()

In [2]:
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging

In [3]:
# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"device: {device}")

INFO:root:device: cuda


In [4]:
# Constants
HOME_PATH = os.path.split(os.getcwd())[0]
logger.info(f"HOME_PATH: {HOME_PATH}")

SPLIT_DATA_PATH = os.path.join(HOME_PATH,"data","processed","split_data.csv")
logger.info(f"SPLIT_DATA_PATH: {SPLIT_DATA_PATH}")

# Set the path to save gpt2 model
MODEL_PATH = os.path.join(HOME_PATH, "models")
logger.info(f"model_path: {MODEL_PATH}")

# GPT Inference constants
MAX_LENGTH= 100
NUM_RETURN_SEQUENCE= 1
NO_REPEAT_NGRAM_SIZE= 2
REPETITION_PENALTY= 1.5
TOP_P= 0.92
TEMPERATURE=.85
DO_SAMPLE= True
TOP_K= 125
EARLY_STOPPING= True

INFO:root:HOME_PATH: /home/jupyter/text-gen
INFO:root:SPLIT_DATA_PATH: /home/jupyter/text-gen/data/processed/split_data.csv
INFO:root:model_path: /home/jupyter/text-gen/models


In [5]:
# Load Validation data
data = pd.read_csv(SPLIT_DATA_PATH)
data_val = data[data["split"] == "val"]
data_val["text"] = data_val["text"].astype(str)
data_val.head()

Unnamed: 0,text,split
42218,bought media room great faster previous version,val
42219,second kindle would lost without convenient th...,val
42220,got wife loves easy read loves fact carry book,val
42221,every year never run,val
42222,works great watching tv shows plugged right ea...,val


In [6]:
# Loading trained model and tokenizer
gpt2_model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
gpt2_tokenizer = GPT2Tokenizer.from_pretrained(MODEL_PATH)

In [7]:
# Prep data for inference by taking away original sentence all words except 2-3 words randomly
def truncate_text(text: str):
    
    ran_num = random.randint(5,10)
    ran_num = 4
    
    # Split by space
    text_list_split = text.split(" ")
    
    # Select randomly 2-4 words to retain
    text_list_trunc = text_list_split[:ran_num]
    
    # Return
    return " ".join(text_list_trunc)

data_val["trunc_text"] = data_val["text"].progress_apply(lambda x: truncate_text(x))

  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["trunc_text"] = data_val["text"].progress_apply(lambda x: truncate_text(x))


## b. Inference

In [8]:
# Generate inference

# Create a list for trunc text
trunc_list = data_val["trunc_text"].to_list()

def get_inference_gpt2(text: str):
    # Encode the text using tokenizer
    text_ids = gpt2_tokenizer.encode(text, return_tensors = 'pt')
    
    generated_text_samples = gpt2_model.generate(
    text_ids, 
    max_length= MAX_LENGTH,  
    num_return_sequences= NUM_RETURN_SEQUENCE,
    no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE ,
    repetition_penalty=REPETITION_PENALTY,
    top_p=TOP_P,
    temperature=TEMPERATURE,
    do_sample= DO_SAMPLE,
    top_k= TOP_K,
    early_stopping= EARLY_STOPPING)

    return gpt2_tokenizer.decode(generated_text_samples[0], skip_special_tokens=True)

# Get res
res = []

for review in tqdm(trunc_list):
    res.append(get_inference_gpt2(review))
    
    
# Add back to original dataframe
data_val["gpt_text_gen"] = res

  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["gpt_text_gen"] = res


## c. Evaluation

- Jaccard similarity
- Cross Encoder: Measure of how sysmantically similar are the output of the model and reference answer

### i. Jaccard Similarity

Jaccard similarity coefficient basically treats the data objects like sets. It is defined as the size of the intersection of two sets divide by the size of the union. We use this as a way to measure how many words that is generated by gpt2 is identical to the original words in the sentence. The higher the ratio means the more similar the words are

In [51]:
# Helper function
def jaccard_similarity(x,y):
    """ returns the jaccard similarity between two lists """
    intersection_cardinality = len(set.intersection(*[set(x), set(y)]))
    union_cardinality = len(set.union(*[set(x), set(y)]))
    
    return intersection_cardinality/float(union_cardinality)

def corpus(text):
    text_list = text.split()
    return text_list

def count_words(text_list: str):
    # text_list_format = ast.literal_eval(text_list)
    return len(text_list)

# Printing some examples
def view_generated_samples(index: int, data: pd.DataFrame):  
    index = index
    # original_text = (" ").join(ast.literal_eval(data.iloc[index]["text_lists"]))
    original_text = (" ").join(data.iloc[index]["text_lists"])
    print(f"Original text: {original_text}")
    input_words = data.iloc[index]["trunc_text"]
    print(f"input_words: {input_words}")
    gpt2_text = data.iloc[index]["gpt_text_gen"]
    print(f"gpt2_text generated: {gpt2_text}")
    print(f"\n")

In [11]:
# Calculate jaccard similarity
data_val["jaccard_score"] = data_val.progress_apply(lambda x: jaccard_similarity(x["text"],x["gpt_text_gen"]),axis=1)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["text"] = data_val["text"].astype(str)


  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["jaccard_score"] = data_val.progress_apply(lambda x: jaccard_similarity(x["text"],x["gpt_text_gen"]),axis=1)


In [12]:
# Write down results using Jaccard
data_val.describe()

Unnamed: 0,jaccard_score
count,4691.0
mean,0.8086
std,0.126615
min,0.05
25%,0.736842
50%,0.809524
75%,0.888889
max,1.0


In [23]:
# Split the original text into list of words then count
data_val["text_lists"] = data_val["text"].progress_apply(corpus)
data_val["word_count"] = data_val["text_lists"].progress_apply(count_words)

  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["text_lists"] = data_val["text"].progress_apply(corpus)


  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["word_count"] = data_val["text_lists"].progress_apply(count_words)


#### Explore samples with higher than average jaccard similiarity score

In [67]:
# Sample those with higher than average jaccard similarity score
mean_score = data_val.describe()["jaccard_score"]["mean"]
data_val_higher_jac_score = data_val[data_val["jaccard_score"] > mean_score]
data_val_higher_jac_score

Unnamed: 0,text,split,trunc_text,gpt_text_gen,jaccard_score,text_lists,word_count,cos_sim_score
42218,bought media room great faster previous version,val,bought media room great,bought media room great picture good sound,0.842105,"[bought, media, room, great, faster, previous,...",7,0.846796
42219,second kindle would lost without convenient th...,val,second kindle would lost,second kindle would lost without kindles kindl...,0.900000,"[second, kindle, would, lost, without, conveni...",13,0.695357
42223,know bluetooth think auxiliary port older spea...,val,know bluetooth think auxiliary,know bluetooth think auxiliary port needbut yo...,0.947368,"[know, bluetooth, think, auxiliary, port, olde...",7,0.683194
42224,good price batteries seem good quality price r...,val,good price batteries seem,good price batteries seem good quality complai...,0.857143,"[good, price, batteries, seem, good, quality, ...",8,0.854436
42226,gave echo 5 stars like amazon products ultimat...,val,gave echo 5 stars,gave echo 5 stars simply cant get enough alexa...,0.857143,"[gave, echo, 5, stars, like, amazon, products,...",43,0.735210
...,...,...,...,...,...,...,...,...
46903,wife loves neat works info endless music optio...,val,wife loves neat works,wife loves neat works well alexa helpful,0.809524,"[wife, loves, neat, works, info, endless, musi...",10,0.596681
46905,always happy amazon didnt disappoint work grea...,val,always happy amazon didnt,always happy amazon didnt disappoint nice tablet,0.809524,"[always, happy, amazon, didnt, disappoint, wor...",9,0.904758
46906,im giving three stars havent used much watch s...,val,im giving three stars,im giving three stars due difficulty setting m...,0.875000,"[im, giving, three, stars, havent, used, much,...",34,0.661557
46907,bought kids really love,val,bought kids really love,bought kids really love,1.000000,"[bought, kids, really, love]",4,1.000000


In [68]:
# Getting the statistics
data_val_higher_jac_score.describe()

Unnamed: 0,jaccard_score,word_count,cos_sim_score
count,2383.0,2383.0,2383.0
mean,0.904428,12.504406,0.839263
std,0.068741,12.394707,0.123201
min,0.809524,1.0,0.165657
25%,0.842105,6.0,0.753513
50%,0.888889,9.0,0.834886
75%,1.0,14.0,0.961775
max,1.0,216.0,1.0


In [69]:
# Look at some samples
view_generated_samples(0, data_val_higher_jac_score)
view_generated_samples(10, data_val_higher_jac_score)
view_generated_samples(-1, data_val_higher_jac_score)

Original text: bought media room great faster previous version
input_words: bought media room great
gpt2_text generated: bought media room great picture good sound


Original text: smart amazon echo enjoying theses amazon echo life much easy excellent amazon echo
input_words: smart amazon echo enjoying
gpt2_text generated: smart amazon echo enjoying learning alexa lot


Original text: like bigger screen size allows read books without straining eyes allows text displayed
input_words: like bigger screen size
gpt2_text generated: like bigger screen size allows read books without getting unwanted companion information needs clarity external speaker system good though




#### Explore samples with lower than average jaccard similiarity score

In [70]:
# Sample those with lower than average jaccard similarity score
mean_score = data_val.describe()["jaccard_score"]["mean"]
data_val_low_jac_score = data_val[data_val["jaccard_score"] < mean_score]
data_val_low_jac_score

Unnamed: 0,text,split,trunc_text,gpt_text_gen,jaccard_score,text_lists,word_count,cos_sim_score
42220,got wife loves easy read loves fact carry book,val,got wife loves easy,got wife loves easy use portable plenty memory...,0.772727,"[got, wife, loves, easy, read, loves, fact, ca...",9,0.698367
42221,every year never run,val,every year never run,every year never run batteries,0.666667,"[every, year, never, run]",4,0.874663
42222,works great watching tv shows plugged right ea...,val,works great watching tv,works great watching tv shows movies wish came...,0.772727,"[works, great, watching, tv, shows, plugged, r...",9,0.750343
42225,great tablet lite portable exceptionally fast ...,val,great tablet lite portable,great tablet lite portable size doesnt take mu...,0.720000,"[great, tablet, lite, portable, exceptionally,...",14,0.610852
42228,keeps busy great tablet always home bored stuc...,val,keeps busy great tablet,keeps busy great tablet price,0.652174,"[keeps, busy, great, tablet, always, home, bor...",14,0.809564
...,...,...,...,...,...,...,...,...
46897,tablet nice price recommended kids adult doesn...,val,tablet nice price recommended,tablet nice price recommended friends,0.789474,"[tablet, nice, price, recommended, kids, adult...",10,0.690868
46898,friend purchased kindle really impressed ease ...,val,friend purchased kindle really,friend purchased kindle really enjoy,0.739130,"[friend, purchased, kindle, really, impressed,...",19,0.733473
46899,easy setup use like picture screen bought 5 ga...,val,easy setup use like,easy setup use like using app instead performi...,0.739130,"[easy, setup, use, like, picture, screen, boug...",12,0.632016
46901,really quick service glad discover amazon carr...,val,really quick service glad,really quick service glad got,0.772727,"[really, quick, service, glad, discover, amazo...",12,0.713972


In [71]:
# Getting the statistics
data_val_low_jac_score.describe()

Unnamed: 0,jaccard_score,word_count,cos_sim_score
count,2308.0,2308.0,2308.0
mean,0.709658,16.178943,0.729005
std,0.091829,18.725298,0.108625
min,0.05,1.0,-0.051194
25%,0.68,7.0,0.676661
50%,0.730769,11.0,0.743896
75%,0.772727,19.0,0.804039
max,0.807692,401.0,0.990672


In [72]:
# Printing some samples
view_generated_samples(0, data_val_low_jac_score)
view_generated_samples(10, data_val_low_jac_score)
view_generated_samples(-1, data_val_low_jac_score)

Original text: got wife loves easy read loves fact carry book
input_words: got wife loves easy
gpt2_text generated: got wife loves easy use portable plenty memory left costly compared tablet


Original text: great tablet price ordering online easy happy took long receive though
input_words: great tablet price ordering
gpt2_text generated: great tablet price ordering online easy amazon store load apps would recommend anyone


Original text: features old rca tablet memory battery power plus u get free ebooks
input_words: features old rca tablet
gpt2_text generated: features old rca tablet got broken decided buy new amazon fire hd8 replace previous one great price features




### Overall observation using Jaccard Similarity Score

1. The average jaccard similarity score calculated on the validation set is 0.8. This means the generated text on average are only 80% similar to the original text which seems to indicate a pretty good score.
2. In general, the jaccard score is higher for given sentences that are shorter in length.
3. The limitation with jaccard similiarity:
- is it does not capture the magnitude or direction of the vectors and hence it may not reflec the strength of the similarity
- Does not consider the order or the context of the words and it may miss semantic variations that could be generated by gpt2

### ii. Symantic Similarity Search - Word2vec Cosine Similarity

One of the pitfalls of using jaccard similarity is it does not take into account the symantic meaning of the sentences. As language, there are many ways to express things and likewise, certain sentences can the same meaning but can be written in a different way. Hence we can make use of the idea of embedding and calculate the cosine similarity (which is the measure of the similarity between two vectors) between the original and gpt generated text. 

To calcualte the similarity this, we will use a pretrained word2vec model to generate the embeddings of the original text and the gpt2 generated text. Then we will compare the embeddings via cosine similarity.

In [73]:
# Helper functions
# Create embeddings using simply word2vec
def generate_word2vec_embedding(sentence: str):
    # generate the average of word embeddings
    return nlp(sentence).vector

def calculate_cosine_similarity_score(sentence_one: str, sentence_two: str):
    # encode the sentences into embeddings
    sentence_one_emb = generate_word2vec_embedding(sentence_one)
    sentence_two_emb = generate_word2vec_embedding(sentence_two)
    
    # calculate cosine similarity score
    cos_sim_score = 1 - cosine(sentence_one_emb, sentence_two_emb)
    return cos_sim_score

In [74]:
# Load word2vec pretrained model
nlp = spacy.load("en_core_web_sm")

In [75]:
# Calculate cosine similarity score
data_val["cos_sim_score"] = data_val.progress_apply(lambda x: calculate_cosine_similarity_score(x["text"], x["gpt_text_gen"]), axis=1)

  0%|          | 0/4691 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_val["cos_sim_score"] = data_val.progress_apply(lambda x: calculate_cosine_similarity_score(x["text"], x["gpt_text_gen"]), axis=1)


In [76]:
# Statistics on cosine similarity
data_val.describe()

Unnamed: 0,jaccard_score,word_count,cos_sim_score
count,4691.0,4691.0,4691.0
mean,0.8086,14.3123,0.785016
std,0.126615,15.933587,0.128655
min,0.05,1.0,-0.051194
25%,0.736842,7.0,0.707449
50%,0.809524,10.0,0.7833
75%,0.888889,16.0,0.859025
max,1.0,401.0,1.0


#### Explore samples with higher than average cosine similiarity score

In [77]:
# Sample those with higher than average cosine similarity score
mean_score = data_val.describe()["cos_sim_score"]["mean"]
data_val_high_cos_sim_score = data_val[data_val["cos_sim_score"] > mean_score]
data_val_high_cos_sim_score

Unnamed: 0,text,split,trunc_text,gpt_text_gen,jaccard_score,text_lists,word_count,cos_sim_score
42218,bought media room great faster previous version,val,bought media room great,bought media room great picture good sound,0.842105,"[bought, media, room, great, faster, previous,...",7,0.846796
42221,every year never run,val,every year never run,every year never run batteries,0.666667,"[every, year, never, run]",4,0.874663
42224,good price batteries seem good quality price r...,val,good price batteries seem,good price batteries seem good quality complai...,0.857143,"[good, price, batteries, seem, good, quality, ...",8,0.854436
42227,love pricing quality always buy amazon batteries,val,love pricing quality always,love pricing quality always comes handy,0.826087,"[love, pricing, quality, always, buy, amazon, ...",7,0.790001
42228,keeps busy great tablet always home bored stuc...,val,keeps busy great tablet,keeps busy great tablet price,0.652174,"[keeps, busy, great, tablet, always, home, bor...",14,0.809564
...,...,...,...,...,...,...,...,...
46892,bought kids talk grandpa great call time,val,bought kids talk grandpa,bought kids talk grandpa great call time,1.000000,"[bought, kids, talk, grandpa, great, call, time]",7,1.000000
46894,fast install easy work got great resolution,val,fast install easy work,fast install easy work great beginner,0.882353,"[fast, install, easy, work, got, great, resolu...",7,0.855871
46905,always happy amazon didnt disappoint work grea...,val,always happy amazon didnt,always happy amazon didnt disappoint nice tablet,0.809524,"[always, happy, amazon, didnt, disappoint, wor...",9,0.904758
46907,bought kids really love,val,bought kids really love,bought kids really love,1.000000,"[bought, kids, really, love]",4,1.000000


In [78]:
data_val_high_cos_sim_score.describe()

Unnamed: 0,jaccard_score,word_count,cos_sim_score
count,2316.0,2316.0,2316.0
mean,0.863298,12.227116,0.885561
std,0.112097,14.559931,0.076537
min,0.428571,1.0,0.785028
25%,0.782609,6.0,0.820719
50%,0.857143,8.0,0.860174
75%,1.0,13.0,0.990973
max,1.0,216.0,1.0


In [79]:
# Printing some samples
view_generated_samples(0, data_val_high_cos_sim_score)
view_generated_samples(10, data_val_high_cos_sim_score)
view_generated_samples(-1, data_val_high_cos_sim_score)

Original text: bought media room great faster previous version
input_words: bought media room great
gpt2_text generated: bought media room great picture good sound


Original text: good kids looking reasonable cost
input_words: good kids looking reasonable
gpt2_text generated: good kids looking reasonable cost tablet works well


Original text: like bigger screen size allows read books without straining eyes allows text displayed
input_words: like bigger screen size
gpt2_text generated: like bigger screen size allows read books without getting unwanted companion information needs clarity external speaker system good though




#### Explore samples with lower than average cosine similiarity score

In [80]:
# Sample those with lower than average cosine similarity score
mean_score = data_val.describe()["cos_sim_score"]["mean"]
data_val_low_cos_sim_score = data_val[data_val["cos_sim_score"] < mean_score]
data_val_low_cos_sim_score

Unnamed: 0,text,split,trunc_text,gpt_text_gen,jaccard_score,text_lists,word_count,cos_sim_score
42219,second kindle would lost without convenient th...,val,second kindle would lost,second kindle would lost without kindles kindl...,0.900000,"[second, kindle, would, lost, without, conveni...",13,0.695357
42220,got wife loves easy read loves fact carry book,val,got wife loves easy,got wife loves easy use portable plenty memory...,0.772727,"[got, wife, loves, easy, read, loves, fact, ca...",9,0.698367
42222,works great watching tv shows plugged right ea...,val,works great watching tv,works great watching tv shows movies wish came...,0.772727,"[works, great, watching, tv, shows, plugged, r...",9,0.750343
42223,know bluetooth think auxiliary port older spea...,val,know bluetooth think auxiliary,know bluetooth think auxiliary port needbut yo...,0.947368,"[know, bluetooth, think, auxiliary, port, olde...",7,0.683194
42225,great tablet lite portable exceptionally fast ...,val,great tablet lite portable,great tablet lite portable size doesnt take mu...,0.720000,"[great, tablet, lite, portable, exceptionally,...",14,0.610852
...,...,...,...,...,...,...,...,...
46901,really quick service glad discover amazon carr...,val,really quick service glad,really quick service glad got,0.772727,"[really, quick, service, glad, discover, amazo...",12,0.713972
46902,bit skeptical first purchasing device roku gla...,val,bit skeptical first purchasing,bit skeptical first purchasing amazonbasics ba...,0.880000,"[bit, skeptical, first, purchasing, device, ro...",22,0.743136
46903,wife loves neat works info endless music optio...,val,wife loves neat works,wife loves neat works well alexa helpful,0.809524,"[wife, loves, neat, works, info, endless, musi...",10,0.596681
46904,features old rca tablet memory battery power p...,val,features old rca tablet,features old rca tablet got broken decided buy...,0.760000,"[features, old, rca, tablet, memory, battery, ...",12,0.693422


In [81]:
data_val_low_cos_sim_score.describe()

Unnamed: 0,jaccard_score,word_count,cos_sim_score
count,2375.0,2375.0,2375.0
mean,0.75526,16.345684,0.686969
std,0.116844,16.924969,0.08665
min,0.05,1.0,-0.051194
25%,0.695652,8.0,0.648248
50%,0.772727,12.0,0.708379
75%,0.833333,20.0,0.74951
max,1.0,401.0,0.784991


In [82]:
# Printing some samples
view_generated_samples(0, data_val_low_cos_sim_score)
view_generated_samples(20, data_val_low_cos_sim_score)
view_generated_samples(-2, data_val_low_cos_sim_score)

Original text: second kindle would lost without convenient throw purse take along wherever go love
input_words: second kindle would lost
gpt2_text generated: second kindle would lost without kindles kindled unlimited spend lots time reading


Original text: nice tablet fast price camera takes good quality pictures
input_words: nice tablet fast price
gpt2_text generated: nice tablet fast price cant beat


Original text: features old rca tablet memory battery power plus u get free ebooks
input_words: features old rca tablet
gpt2_text generated: features old rca tablet got broken decided buy new amazon fire hd8 replace previous one great price features




### Overall observations on Cosine Similarity Score
1. The average cosine similarity score between the original and gpt2 generated text on validation data is around 0.78 with a min score of -0.05 and maximum score 1.0
2. Similar to jaccard similarity score, the cosine similarity score of the gpt2 generated text is higher when the original sentences have less words

## Improvements
1. Overall we can see the generated text are not quite identicle to the original text. This is expected since we only trained the model on 6 epochs and the loss had not yet converged.
2. Splitting the dataset -> perhaps we can try to split the data to ensure we have a representative dataset. For example we can try using sentence transformer model to generate the embeddings, then perform clustering to group the data. Then we systematically sample data for each of the groups rather than randomly splitting.
3. Maybe we can try to retrain the model using a reviews dataset first and then use the current dataset and fine tune it.
4. Using pretraind word2vec may not be the best way to measure and evaluate the quality of the text generated since its a quantitative approach. Perhaps incorporating a more qualitiative approach too might be needed to fully evaluate the gpt2 generated text - coherence etc, BLEU or ROGUE
5. Using sentence transformers to generate embeddings rather than word2vec.

## END