In [1]:
import pandas as pd
from datasets import Dataset
reviews = pd.read_csv('combined_reviews.csv', index_col=0)
reviews = reviews.drop(columns=['score', 'thumbsUpCount'])
# reviews = reviews.rename(columns={'replyContent': 'label', 'content': 'text'})
reviews = reviews.dropna()
reviews = Dataset.from_pandas(reviews)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
reviews[0]

{'replyContent': 'Hey Jiaxing! Hope you love the new "Favourites" feature in the GXS app! We\'re just as excited for all the cool things to come in the future. 💜',
 'content': 'Gxs is simple and easy to use, with a saving account/pocket that has minimal TnC. Finally I can add payees and there is a debit card too. I am looking forward to the introduction of credit card and perhaps, investment into money market funds.',
 '__index_level_0__': 0}

In [3]:
from transformers import AutoTokenizer
baseline_model = "distilbert/distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(baseline_model)

In [4]:
def preprocess_function(examples):
    return tokenizer(examples["replyContent"], truncation=True)

In [5]:
tokenized_reviews = reviews.map(preprocess_function, batched=True)

Map: 100%|██████████| 199/199 [00:00<00:00, 15950.36 examples/s]


In [6]:
tokenized_reviews = tokenized_reviews.train_test_split(test_size=0.1)

In [7]:
from transformers import DataCollatorForLanguageModeling
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [8]:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(baseline_model)

In [9]:
from transformers import TrainingArguments, Trainer

save_path = "./gpt_model_causallm"

training_args = TrainingArguments(
    output_dir=save_path,
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_reviews["train"],
    eval_dataset=tokenized_reviews["test"],
    data_collator=data_collator,
)

trainer.train()
model.save_pretrained(save_path)


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
 33%|███▎      | 23/69 [00:12<00:24,  1.91it/s]
 33%|███▎      | 23/69 [00:12<00:24,  1.91it/s]

{'eval_loss': 2.8873419761657715, 'eval_runtime': 0.5045, 'eval_samples_per_second': 39.642, 'eval_steps_per_second': 5.946, 'epoch': 1.0}


 67%|██████▋   | 46/69 [00:24<00:10,  2.09it/s]
 67%|██████▋   | 46/69 [00:24<00:10,  2.09it/s]

{'eval_loss': 2.5863282680511475, 'eval_runtime': 0.3295, 'eval_samples_per_second': 60.706, 'eval_steps_per_second': 9.106, 'epoch': 2.0}


100%|██████████| 69/69 [00:36<00:00,  2.09it/s]
100%|██████████| 69/69 [00:36<00:00,  1.87it/s]


{'eval_loss': 2.5096583366394043, 'eval_runtime': 0.3262, 'eval_samples_per_second': 61.315, 'eval_steps_per_second': 9.197, 'epoch': 3.0}
{'train_runtime': 36.7299, 'train_samples_per_second': 14.62, 'train_steps_per_second': 1.879, 'train_loss': 3.0701216545657837, 'epoch': 3.0}


In [10]:
import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

100%|██████████| 3/3 [00:00<00:00, 10.84it/s]

Perplexity: 12.30





In [11]:
from transformers import pipeline
finetuned_model = AutoModelForCausalLM.from_pretrained(save_path)
generator = pipeline('text-generation', finetuned_model, tokenizer=tokenizer)

In [14]:
prompt = 'Gxs is simple and easy to use, with a saving account/pocket that has minimal TnC. Finally I can add payees and there is a debit card too. I am looking forward to the introduction of credit card and perhaps, investment into money market funds.'
generator(prompt, max_length=len(prompt)+50)[0]['generated_text'][len(prompt):]

"\n\n\n\n– Follow us on Snapchat, add us to your circle on Google+ or like our page at facebook.com/p.gxs. Support us on Twitter at @Spacedotcom, Google+ at facebook.com/p.gxs. We're powered by Android and with a 5.6-inch screen, you can join our growing list of popular ways to improve your experience."