# HuggingTweets - Tweet Generation with Huggingface

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb)

Train in 5m your own Neural Network on someone's tweets and tweet them back your awesome predictions!

## To start the demo, click on "Runtime" menu → "Run all"

In [0]:
#@title ⠀

# Check we use GPU
import torch
if not torch.cuda.is_available():
    print('Error: GPU was not found\n1/ click on the "Runtime" menu and "Change runtime type"\n'\
          '2/ set "Hardware accelerator" to "GPU" and click "save"\n3/ click on the "Runtime" menu, then "Run all" (below error should disappear)')
    raise
else:    
    # Install dependencies
    !pip install wandb transformers torch -qq

    # Huggingface scripts for fine-tuning models and language generation
    !wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/language-modeling/run_language_modeling.py -q
    !wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/text-generation/run_generation.py -q
        
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output
    import json
    import urllib3
    import random
    import wandb
    import click

    def fix_text(text):
        text = text.replace('&amp;', '&')
        text = text.replace('&lt;', '<')
        text = text.replace('&gt;', '>')
        return text

    def cleanup_tweet(tweet):
        "Clean tweet text"
        text = ' '.join(t for t in tweet.split() if 'http' not in t)
        if text.split() and text.split()[0] == '.':
            text = ' '.join(text.split()[1:])
        return text

    def boring_tweet(tweet):
        "Check if this is a boring tweet"
        boring_stuff = ['http', '@', '#', 'thank', 'thanks', 'I', 'you']
        if len(tweet.split()) < 3:
            return True
        if all(any(bs in t.lower() for bs in boring_stuff) for t in tweet):
            return True
        return False

    def html_table(data):
        'Create a html table'
        def html_cell(i):
            return f'<td>{i}</td>'
        def html_row(r):
            return f'<tr>{"".join(html_cell(i) for i in r)}</tr>'
        return f'<table>{"".join(html_row(r) for r in data)}</table>'

    def dl_tweets(handle_value):
        handle = handle_value[1:] if handle_value[0] == '@' else handle_value
        run_dl_tweets.button_style = 'primary'
        log_dl_tweets.clear_output(wait=True)
        with log_dl_tweets:
            try:
                print(f'\nDownloading {handle_value} tweets... This should take no more than a minute!')
                http = urllib3.PoolManager(retries=urllib3.Retry(3))
                res = http.request("GET", f"http://us-central1-playground-111.cloudfunctions.net/tweets_http?handle={handle}")
                curated_tweets = json.loads(res.data.decode('utf-8'))
                curated_tweets = [fix_text(tweet) for tweet in curated_tweets]
                log_dl_tweets.clear_output(wait=True)
                print(f'\n{len(curated_tweets)} tweets from {handle_value} downloaded!')
                random.shuffle(curated_tweets)
                for i,t in enumerate(curated_tweets[:5]):
                    print(f'\nExample #{i+1}\n{t}')
                    
                # create dataset
                clean_tweets = [cleanup_tweet(t) for t in curated_tweets]
                cool_tweets = [tweet for tweet in clean_tweets if not boring_tweet(tweet)]
                with open('{}_train.txt'.format(handle), 'w') as f:
                    f.write('\n'.join(cool_tweets))
                
                run_dl_tweets.button_style = 'success'
                run_finetune.disabled = False
            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_dl_tweets.button_style = 'danger'
                
    handle_widget = widgets.Text(value='@elonmusk',
                                placeholder='Enter twitter handle')

    run_dl_tweets = widgets.Button(
        description='Download tweets',
        button_style='primary')
    def on_run_dl_tweets_clicked(b):
        dl_tweets(handle_widget.value)
    run_dl_tweets.on_click(on_run_dl_tweets_clicked)

    log_dl_tweets = widgets.Output()
    with log_dl_tweets:
        print('\nEnter a Twitter handle and click "Download tweets"')
        
    # Associate run to a project
    %env WANDB_PROJECT=huggingtweets
    %env WANDB_WATCH=false
    %env WANDB_ENTITY=wandb
    wandb.login(anonymous='allow')

    def finetune():
        handle = handle_widget.value[1:] if handle_widget.value[0] == '@' else handle_widget.value
        run_finetune.button_style = 'primary'
        log_finetune.clear_output(wait=True)

        with log_finetune:
            try:
                print(f'\nTraining Neural Network on {handle_widget.value} tweets... This could take up to 3-5 minutes!\n')
                !python run_language_modeling.py \
                    --output_dir=output/$handle \
                    --overwrite_output_dir \
                    --model_type=gpt2 \
                    --model_name_or_path=gpt2 \
                    --do_train --train_data_file=$handle\_train.txt \
                    --logging_steps 20 \
                    --per_gpu_train_batch_size 1 \
                    --num_train_epochs 4
                
                log_finetune.clear_output(wait=True)
                print('\n🎉 Neural network trained successfully!')
                
                run_finetune.button_style = 'success'
                run_predictions.disabled = False

            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_finetune.button_style = 'danger'
        log_predictions.clear_output(wait=True)
        with log_predictions:
            print('\nEnter the start of a sentence and click "Run predictions"')

    run_finetune = widgets.Button(
        description='Train Neural Network',
        button_style='primary',
        disabled=True)
    def on_run_finetune_clicked(b):
        finetune()
    run_finetune.on_click(on_run_finetune_clicked)

    log_finetune = widgets.Output()
    with log_finetune:
        print('\nFine-tune your model by clicking on "Train Neural Network"')

    def clean_prediction(text):
        token = '<|endoftext|>'
        while len(token)>1:
            text = text.replace(token, '')
            token = token[:-1]
        if text[-1] == '"': text = text[:-1]
        return text

    predictions = []
    
    def predict():
        handle = handle_widget.value[1:] if handle_widget.value[0] == '@' else handle_widget.value
        start = start_widget.value
        run_predictions.button_style = 'primary'
        log_predictions.clear_output(wait=True)

        def tweet_html(text):
            max_char = 180
            text = text[:max_char]
            text = clean_prediction(text)
            if len(text) > max_char: text = text[:max_char] + '...'

            return '<a href="https://twitter.com/share?ref_src=twsrc%5Etfw" class="twitter-share-button" data-size="large" '\
                    f'data-text="Trained a neural network on @{handle}: &quot;'\
                    f'{start}&quot; → &quot;{text}&quot;" '\
                    f'data-url="{wandb_url}" data-hashtags="huggingtweets" data-related="borisdayma,weights_biases,huggingface"'\
                    'data-show-count="false">Tweet</a><script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>'
        
        try_success = False
        with log_predictions:
            try:
                print(f'\nPerforming predictions of {handle_widget.value} starting with "{start}"...\nThis should take no more than 20 seconds!\n')
                
                # load previous run
                project = %env WANDB_PROJECT
                wandb_id = wandb.api.list_runs(project)[0]['name']
                wandb.init(id=wandb_id, resume='must')
                wandb_url = wandb.run.get_url()
                
                seed = random.randint(0, 2**32-1)
                val = !python run_generation.py \
                    --model_type gpt2 \
                    --model_name_or_path output/$handle \
                    --length 150 \
                    --stop_token "{'\n'}" \
                    --num_return_sequences 5 \
                    --temperature 1 \
                    --seed $seed \
                    --prompt {'"' + start + '"'}
                generated = [val[-1-2*k] for k in range(5)[::-1]]
                generated = [clean_prediction(g) for g in generated]
                log_predictions.clear_output(wait=True)
                print(f'\nPredictions of {handle_widget.value} starting with "{start}" via #huggingtweet')
                for i, g in enumerate(generated):
                    predictions.append([start, clean_prediction(g)])
                
                # log predictions
                wandb.log({'examples': wandb.Table(data=predictions, columns=['Input', 'Prediction'])})
                
                # Update display name
                wandb.run.name = handle
                wandb.run.save()
                
                # display wandb run
                print(f'\n🚀 View all results under the "Media" panel at {click.style(wandb_url, underline=True, fg="blue")}')
                print('Try again new predictions or test other sentences!')

                # make html table
                tweet_data = [[tweet_html(g), g] for g in generated]
                tweet_table = HTML(html_table(tweet_data))

                run_predictions.button_style = 'success'
                try_success = True
                
            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_predictions.button_style = 'danger'

        if try_success: display(tweet_table)
                
    start_widget = widgets.Text(value='My dream is',
                                placeholder='Enter twitter handle')

    run_predictions = widgets.Button(
        description='Run predictions',
        button_style='primary',
        disabled=True)
    def on_run_predictions_clicked(b):
        predict()
    run_predictions.on_click(on_run_predictions_clicked)

    log_predictions = widgets.Output()
    with log_predictions:
        print('\nWaiting for Step 2 to complete...')

    clear_output(wait=True)
    print('🎉 Environment set-up correctly!')

## Step 1 - Download tweets

We choose a Twitter user and download his tweets.

In [0]:
#@title ⠀
display(widgets.VBox([handle_widget, run_dl_tweets, log_dl_tweets]))

## Step 2 - Train your Neural Network

We fine-tune a neural network on downloaded tweets using [Huggingface](https://huggingface.co/).

In [0]:
#@title ⠀
display(widgets.VBox([run_finetune, log_finetune]))

## Step 3: Visualize Predictions and Have Fun!!!

Just start a sentence and let the model finish it!

In [0]:
#@title ⠀
display(widgets.VBox([start_widget, run_predictions, log_predictions]))

Huggingtweets is still in its infancy and will get better over time!

In the future, it will train continuously to become a Twitter expert!

*Disclaimer: this project is not to be used to publish any false generated information but to perform research on Natural Language Generation.*

## Resources

*Built by Boris Dayma ([@borisdayma](https://twitter.com/borisdayma))*

My main goals with this project are:
* to experiment with how to train, deploy and maintain neural networks in production ;
* to make AI accessible to everyone.

To see how the model works, visit the project repository.

[![GitHub stars](https://img.shields.io/github/stars/borisdayma/huggingtweets?style=social)](https://github.com/borisdayma/huggingtweets)

<br>

For additional resources, visit:

* [A Step by Step Guide to Tracking Hugging Face Model Performance](https://app.wandb.ai/jxmorris12/huggingface-demo/reports/A-Step-by-Step-Guide-to-Tracking-Hugging-Face-Model-Performance--VmlldzoxMDE2MTU)
* [W&B Forum](http://bit.ly/wandb-forum): If you have any questions, reach out to the slack community