# HuggingTweets - Tweet Generation with Huggingface

*Disclaimer: this project is not to be used to publish any false generated information but to perform research on Natural Language Generation (NLG).*

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets.ipynb)

## Introduction

Generating realistic text has become more and more efficient with models such as [GPT-2](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). Those models are trained on very large datasets and require heavy computer resources (and time!).

However, we can use Transfer Learning and a single GPU to quickly fine-tune a pre-trained model on a given task.

We test if we can imitate the writing style of a Twitter user by only using some of his tweets. Twitter API let us download "only" the 3200 most recent tweets from any single user, which we then filter out (to remove retweets, short content, etc).

[HuggingFace](https://huggingface.co/) gives us an easy access to pre-trained models and fine-tuning techniques for Natural Language Generation (NLG) tasks.

We will be monitoring the training with [W&B](https://docs.wandb.com/huggingface) (which is integrated in HuggingFace) to ensure the model is learning from the data and compare multiple experiments.

![](https://i.imgur.com/vnejHGh.png)

## Install dependencies

In [13]:
# Uncomment if those libraries are not installed
#!source ~/aliases.zsh
#!pip install --proxy https://proxy-chain.intel.com:911 wandb -qq
#!pip install --proxy https://proxy-chain.intel.com:911 tweepy -qq

In [14]:
# Huggingface scripts for fine-tuning models and language generation
#!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/language-modeling/run_language_modeling.py -q
#!wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/text-generation/run_generation.py -q

## Set up a Twitter Development Account

In order to access Twitter data, we need to:

* [create a Twitter development account](https://developer.twitter.com/en/apply-for-access)
* [create a Twitter app](https://developer.twitter.com/en/apps)
* get your consumer API keys: "API key" and "API secret key"

The entire process only takes a few minutes.

In [15]:
import json
with open('credentials.json','r') as f:
    credentials = json.load(f)

## Download tweets from a user

We download latest tweets associated to a user account through [Tweepy](http://docs.tweepy.org/).

In [16]:
import tweepy

In [17]:
# authenticate
auth = tweepy.AppAuthHandler(credentials.get('api_key'), credentials.get('api_secret'))
api = tweepy.API(auth)

We grab all available tweets (limited to 3200 per API limitations) based on Twitter handle.

**Note**: Protected users may only be requested when the authenticated user either "owns" the timeline or is an approved follower of the owner.

In [144]:
# <--- Enter the screen name of the user you will download your dataset from --->
handle = 'PradhyumnaR'

In [145]:
# Adapted from https://gist.github.com/onmyeoin/62c72a7d61fc840b2689b2cf106f583c

# initialize a list to hold all the tweepy Tweets & list with no retweets
alltweets = []

# make initial request for most recent tweets with extended mode enabled to get full tweets
new_tweets = api.user_timeline(
    screen_name=handle, tweet_mode='extended', count=200)

# save most recent tweets
alltweets.extend(new_tweets)

# save the id of the oldest tweet less one
oldest = alltweets[-1].id - 1

# check we cannot get more tweets
no_tweet_count = 0

# keep grabbing tweets until the api limit is reached
while True:
    print(f'getting tweets before id {oldest}')

    # all subsequent requests use the max_id param to prevent duplicates
    new_tweets = api.user_timeline(
        screen_name=handle, tweet_mode='extended', count=200, max_id=oldest)
    
    # stop if no more tweets (try a few times as they sometimes eventually come)
    if not new_tweets:
        no_tweet_count +=1
    else:
        no_tweet_count = 0
    if no_tweet_count > 5: break

    # save most recent tweets
    alltweets.extend(new_tweets)

    # update the id of the oldest tweet less one
    oldest = alltweets[-1].id - 1

    print(f'...{len(alltweets)} tweets downloaded so far')

getting tweets before id 1135341235948929030
...394 tweets downloaded so far
getting tweets before id 985234483199721471
...591 tweets downloaded so far
getting tweets before id 822930581914693631
...776 tweets downloaded so far
getting tweets before id 722876081569538047
...943 tweets downloaded so far
getting tweets before id 655364240992169984
...1129 tweets downloaded so far
getting tweets before id 531100699430371327
...1324 tweets downloaded so far
getting tweets before id 448493826433228801
...1520 tweets downloaded so far
getting tweets before id 359686179043557375
...1558 tweets downloaded so far
getting tweets before id 259324135505793023
...1558 tweets downloaded so far
getting tweets before id 259324135505793023
...1558 tweets downloaded so far
getting tweets before id 259324135505793023
...1558 tweets downloaded so far
getting tweets before id 259324135505793023
...1558 tweets downloaded so far
getting tweets before id 259324135505793023
...1558 tweets downloaded so far
ge

## Create a dataset from downloaded tweets

We remove:
* retweets (since it's not in the wording style of target author)
* tweets with no interesting content (limited to url's, user mentionss, "thank you"…)

We clean up remaining tweets:
* we remove url's
* we replace "@" mentions with user names

In [146]:
import random
import re
import torch

In [147]:
class user_handle:
    'Get a user handle and cache it to avoid calling too much twitter api.'
    handles = {}
    def get(handle):
        if handle not in user_handle.handles.keys():            
            try:
                user_handle.handles[handle] = api.get_user(handle).name
            except:
                user_handle.handles[handle] = None
        return user_handle.handles[handle]

In [148]:
def replace_handle(word):
    'Replace user handles, remove "@" and "#"'
    if word[0] == '@':
        handle = re.search('^@(\w)+', word)
        if handle:
            user = user_handle.get(handle.group())
            if user is not None: return user + word[handle.endpos:]
    return word

In [149]:
def keep_tweet(tweet):
    'Return true if not a retweet'
    if hasattr(tweet, 'retweeted_status'):
        return False
    return True

In [150]:
def curate_tweets(tweets):
    'Decide which tweets we keep and replace handles'
    curated_tweets = []
    for tweet in tweets:
        if keep_tweet(tweet):
            curated_tweets.append(' '.join(replace_handle(w) for w in tweet.full_text.split()))
    return curated_tweets

In [151]:
curated_tweets = curate_tweets(alltweets)

We verify our list of tweets is well curated.

In [152]:
print(f'Total number of tweets: {len(alltweets)}\nCurated tweets: {len(curated_tweets)}')

Total number of tweets: 1558
Curated tweets: 491


In [153]:
print('Original tweets\n')
for t in alltweets[:5]:
    print(f'{t.full_text}\n')

Original tweets

@omarsar0 A project to build an end to end nlp project would be great. Even better if students can collaborate on it.

RT @thehill: Dorsey defends decision to fact check Trump tweet: "More transparency from us is critical" https://t.co/mJFDpsXosa https://t.c…

RT @DaleJohnsonESPN: Premier League returning on Wednesday, June 17 with the two games in hand, and in full from June 20-21, leaves ample t…

RT @NYGovCuomo: Now is the time for government to stimulate the economy by investing in infrastructure.

We have projects in NY ready to go…

RT @CBSNews: "What happened to that American spirit?"

While visiting Washington, D.C. on Wednesday, New York Gov. Andrew Cuomo slammed the…



In [154]:
print('Curated tweets\n')
for t in curated_tweets[:5]:
    print(f'{t}\n')

Curated tweets

elvis A project to build an end to end nlp project would be great. Even better if students can collaborate on it.

Fabinho is the best defensive midfielder in the world! Flight me #LFC

Are we memeing this into existence? I’m in #Mbappe2020

Simon Mignolet UEFA Champions League Liverpool FC (at 🏠) You deserve it just as much Simon! Your professionalism is an inspiration.

So fucking nervous! Let’s get what Klopp and the boys deserve! #LFC #UCLFinal



We remove boring tweets (tweets with only urls or too short) and cleanup texts.

In [155]:
def cleanup_tweet(tweet):
    "Clean tweet text"
    text = ' '.join(t for t in tweet.split() if 'http' not in t)
    text = text.replace('&amp;', '&')
    text = text.replace('&lt;', '<')
    text = text.replace('&gt;', '>')
    if text.split() and text.split()[0] == '.':
         text = ' '.join(text.split()[1:])
    return text

In [156]:
def boring_tweet(tweet):
    "Check if this is a boring tweet"
    boring_stuff = ['http', '@', '#', 'thank', 'thanks', 'I', 'you']
    if len(tweet.split()) < 3:
        return True
    if all(any(bs in t.lower() for bs in boring_stuff) for t in tweet):
        return True
    return False

In [157]:
clean_tweets = [cleanup_tweet(t) for t in curated_tweets]
cool_tweets = [tweet for tweet in clean_tweets if not boring_tweet(tweet)]
print(f'Curated tweets: {len(curated_tweets)}\nCool tweets: {len(cool_tweets)}')

Curated tweets: 491
Cool tweets: 486


We split data into training and validation sets (90/10).

In [158]:
# shuffle data
random.shuffle(cool_tweets)

# fraction of training data
split_train_valid = 0.9

# split dataset
train_size = int(split_train_valid * len(cool_tweets))
valid_size = len(cool_tweets) - train_size
train_dataset, valid_dataset = torch.utils.data.random_split(cool_tweets, [train_size, valid_size])

We export our datasets as text files.

In [159]:
with open(f'{handle}_train.txt', 'w') as f:
    f.write('\n'.join(train_dataset))

with open(f'{handle}_valid.txt', 'w') as f:
    f.write('\n'.join(valid_dataset))

## Log and monitor training through W&B

In order to check our model is training correctly and compare experiments, we are going to use the W&B integration from huggingface.

### [Sign up for a free account →](https://app.wandb.ai/login?signup=true)

- **Unified dashboard**: Central repository for all your model metrics and predictions
- **Lightweight**: No code changes required to integrate with Hugging Face
- **Accessible**: Free for individuals and academic teams
- **Secure**: All projects are private by default
- **Trusted**: Used by machine learning teams at OpenAI, Toyota, GitHub, Lyft and more

### API Key
Once you've signed up, run the next cell and click on the link to get your API key and authenticate this notebook.

In [160]:
#import wandb
#wandb.login()

## Fine-tuning the model

HuggingFace includes the script `run_language_modeling` making it easy to fine-tune a pre-trained model.

We use a pre-trained GPT-2 model and fine-tune it on our dataset.

Training is automatically logged on W&B (see [documentation](https://docs.wandb.com/huggingface)). Urls are generated to visualize ongoing runs or you can just open your [dashboard](http://app.wandb.ai/).

I quickly tested running for several epochs and my run was showing I started overfitting after 4 epochs so this is the limit I use to fine-tune my model (takes less than 2 minutes).

![](https://i.imgur.com/1uIxLFe.png)

In [161]:
# Associate run to a project (optional)
#%env WANDB_PROJECT=huggingtweets

We use HuggingFace script `run_language_modeling.py` to fine-tune our model (see [doc](https://huggingface.co/transformers/)).

In [162]:
import sys
sys.path.append('/nfs/site/disks/gia_analytics_shared/paddy/projects/git_repos/simpletransformers/')

In [163]:
from simpletransformers.language_modeling import LanguageModelingModel
import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

train_args = {
    "output_dir": "gpt2_outputs/",
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "fp16": False,
    "train_batch_size": 32,
    "eval_batch_size":32,
    "num_train_epochs": 8,
    "tensorboard_dir": 'gpt2_tweet_runs/',
    'mlm':False
}

model = LanguageModelingModel('gpt2', 'gpt2', args=train_args,use_cuda=False)

model.train_model(f"{handle}_train.txt", eval_file=f"{handle}_valid.txt")

INFO:simpletransformers.language_modeling.language_modeling_utils: Creating features from dataset file at cache_dir/


HBox(children=(FloatProgress(value=0.0, max=437.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=82.0), HTML(value='')))

INFO:simpletransformers.language_modeling.language_modeling_utils: Saving features into cached file cache_dir/gpt2_cached_lm_126_PradhyumnaR_train.txt
INFO:simpletransformers.language_modeling.language_modeling_model: Training started





HBox(children=(FloatProgress(value=0.0, description='Epoch', max=8.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 5.095858


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.662007


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.603417


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.515154


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.451951


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.359577


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.257218


HBox(children=(FloatProgress(value=0.0, description='Current iteration', max=3.0, style=ProgressStyle(descript…

Running loss: 4.260036



INFO:simpletransformers.language_modeling.language_modeling_model: Training of gpt2 model complete. Saved to gpt2_outputs/.


In [164]:
from simpletransformers.language_generation import LanguageGenerationModel
gen_args={'length':200,
         'k':10}
model = LanguageGenerationModel("gpt2", "gpt2_outputs/",use_cuda=False, args=gen_args)
model.generate("There is no",verbose=False)



['There is no doubt about it. It\'s not an easy win. I\'d love to see them win it. I hope they can do it. I\'ve seen them do that already against Chelsea."\n\nLiverpool fans will be looking to see if he can score the goal, but it\'s not the most important game for them, and they need more. Liverpool will win, and they won\'t have any problems in the second half.\n\nLFC fans, it\'s important to win games. That\'s what it takes.\n\nLFC is not the only club to have an easy season. Arsenal have won the title. I can\'t see why any club can lose the title. They\'ve got some good players and a good manager and the squad is strong. I\'m sure they\'ll be a lot better next season. I don\'t know what the hell happened to him but he is a great player. I can\'t see why he would do this to Liverpool. He\'s the most talented player']

## Let's test our trained model!

We test our model on a few sample sentences.

In [None]:
SENTENCES = ["I think that",
            "I like",
            "I don't like",
            "I want"]

We use HuggingFace script `run_generation.py` to generate sentences (see [doc](https://huggingface.co/transformers/)).

In [None]:
import random
seed = random.randint(0, 2**32-1)
seed

In [None]:
mod

In [None]:
examples = []

for start in SENTENCES:
    val = !python run_generation.py \
        --model_type gpt2 \
        --model_name_or_path output/$handle \
        --length 150 \
        --stop_token "{'\n'}" \
        --num_return_sequences 3 \
        --temperature 1 \
        --seed $seed \
        --prompt {'"' + start + '"'}
    generated = [val[-1-2*k] for k in range(3)[::-1]]
    print(f'\nStart of sentence: {start}')
    for i, g in enumerate(generated):
        g = g.replace('<|endoftext|>', '')
        print(f'* Generated #{i+1}: {g}')
        examples.append([start, g])

We log the results on our previous run.

In [None]:
# retrieve last run
project = %env WANDB_PROJECT
wandb_id = wandb.api.list_runs(project)[0]['name']

In [None]:
# Log results on our previous wandb run
wandb.init(id=wandb_id, resume='must')
wandb.log({'examples': wandb.Table(data=examples, columns=['Input', 'Prediction'])})

# Update display name
wandb.run.name = alltweets[0].author.name
wandb.run.save()

**Results**: Open your generated "Run page" generated and look at the predictions in the "Media" panel.

You can see my trained models on my dashboard.

### [W&B Report →](https://bit.ly/2TGXMZf)

Please share your experiments and any insights you have!

## About

*Built by Boris Dayma*

[![Follow](https://img.shields.io/twitter/follow/borisdayma?style=social)](https://twitter.com/borisdayma)

My main goals with this project are:
* to experiment with how to train, deploy and maintain neural networks in production ;
* to make AI accessible to everyone.

To see how the model works, visit the project repository.

[![GitHub stars](https://img.shields.io/github/stars/borisdayma/huggingtweets?style=social)](https://github.com/borisdayma/huggingtweets)

**Disclaimer: this project is not to be used to publish any false generated information but to perform research on Natural Language Generation.**

## Resources

* [A Step by Step Guide to Tracking Hugging Face Model Performance](https://app.wandb.ai/jxmorris12/huggingface-demo/reports/A-Step-by-Step-Guide-to-Tracking-Hugging-Face-Model-Performance--VmlldzoxMDE2MTU)
* [W&B Forum](http://bit.ly/wandb-forum): If you have any questions, reach out to the slack community