# HuggingTweets - Tweet Generation with Huggingface

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/borisdayma/huggingtweets/blob/master/huggingtweets-demo.ipynb)

Train in 5m your own Neural Network on someone's tweets and tweet them back your awesome predictions!

## To start the demo, click on menu "Runtime" → "Run all"

In [0]:
#@title ⠀ {display-mode: "form"}

def print_html(x):
    "Better printing"
    x = x.replace('\n', '<br>')
    display(HTML(x))
        
# Check we use GPU
import torch
if not torch.cuda.is_available():
    print_html('Error: GPU was not found\n1/ click on the "Runtime" menu and "Change runtime type"\n'\
          '2/ set "Hardware accelerator" to "GPU" and click "save"\n3/ click on the "Runtime" menu, then "Run all" (below error should disappear)')
    raise
else:
    # Install dependencies
    !pip install wandb transformers torch -qq

    # Huggingface scripts for fine-tuning models and language generation
    !wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/language-modeling/run_language_modeling.py -q
    !wget https://raw.githubusercontent.com/huggingface/transformers/master/examples/text-generation/run_generation.py -q
        
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output, Javascript
    import json
    import urllib3
    import random
    import wandb
    import time

    try:
        import google.colab
        IN_COLAB = True
    except:
        IN_COLAB = False

    def fix_text(text):
        text = text.replace('&amp;', '&')
        text = text.replace('&lt;', '<')
        text = text.replace('&gt;', '>')
        return text

    def html_table(data, title=None):
        'Create a html table'
        width_twitter = '1px' if IN_COLAB else '75px'
        def html_cell(i, twitter_button=False):
            return f'<td style="width:{width_twitter}">{i}</td>' if twitter_button else f'<td>{i}</td>'
        def html_row(row):
            return f'<tr>{"".join(html_cell(r, not i if len(row)>1 else False) for i,r in enumerate(row))}</tr>'    
        body = f'<table style="width:100%">{"".join(html_row(r) for r in data)}</table>'
        title_html = f'<h3>{title}</h3>' if title else ''
        html = '''
        <html>
            <head>
                <style>
                    table {border-collapse: collapse !important;}
                    td {text-align:left !important; border: solid #E3F2FD !important; border-width: 1px 0 !important; padding: 6px !important;}
                    tr:nth-child(even) {background-color: #E3F2FD !important;}
                </style>
            </head>
            <body>''' + title_html + body + '</body></html>'
        return(html)
        
    def cleanup_tweet(tweet):
        "Clean tweet text"
        text = ' '.join(t for t in tweet.split() if 'http' not in t)
        if text.split() and text.split()[0] == '.':
            text = ' '.join(text.split()[1:])
        return text

    def boring_tweet(tweet):
        "Check if this is a boring tweet"
        boring_stuff = ['http', '@', '#', 'thank', 'thanks', 'I', 'you']
        if len(tweet.split()) < 3:
            return True
        if all(any(bs in t.lower() for bs in boring_stuff) for t in tweet):
            return True
        return False

    def dl_tweets():
        handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_dl_tweets.button_style = 'primary'
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle = handle.lower()
        log_dl_tweets.clear_output(wait=True)
        with log_dl_tweets:
            try:
                print_html(f'\nDownloading {handle_widget.value.strip()} tweets... This should take no more than a minute!')
                http = urllib3.PoolManager(retries=urllib3.Retry(3))
                res = http.request("GET", f"http://us-central1-playground-111.cloudfunctions.net/tweets_http?handle={handle}")
                curated_tweets = json.loads(res.data.decode('utf-8'))
                curated_tweets = [fix_text(tweet) for tweet in curated_tweets]
                log_dl_tweets.clear_output(wait=True)
                random.shuffle(curated_tweets)
                example_tweets = [[t] for  t in curated_tweets[:8]]
                print_html(f'\n{len(curated_tweets)} tweets from {handle_widget.value.strip()} downloaded!\n\n')
                display(HTML(html_table(example_tweets)))
                    
                # create dataset
                clean_tweets = [cleanup_tweet(t) for t in curated_tweets]
                cool_tweets = [tweet for tweet in clean_tweets if not boring_tweet(tweet)]
                with open('{}_train.txt'.format(handle), 'w') as f:
                    f.write('\n'.join(cool_tweets))
                
                run_dl_tweets.button_style = 'success'
                log_finetune.clear_output()
                run_finetune.disabled = False
            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_dl_tweets.button_style = 'danger'
        handle_widget.disabled = False
        run_dl_tweets.disabled = False
        log_finetune.clear_output(wait=True)
        with log_finetune:
            print_html('\nFine-tune your model by clicking on "Train Neural Network"')
                
    handle_widget = widgets.Text(value='@elonmusk',
                                placeholder='Enter twitter handle')

    run_dl_tweets = widgets.Button(
        description='Download tweets',
        button_style='primary')
    def on_run_dl_tweets_clicked(b):
        dl_tweets()
    run_dl_tweets.on_click(on_run_dl_tweets_clicked)

    log_restart = widgets.Output()
    log_dl_tweets = widgets.Output()
        
    # Associate run to a project
    %env WANDB_PROJECT=huggingtweets
    %env WANDB_WATCH=false
    %env WANDB_ENTITY=wandb
    %env WANDB_ANONYMOUS=allow
    %env WANDB_NOTEBOOK_NAME=huggingtweets-demo
    wandb.login()

    def finetune():
        handle_widget.disabled = True
        run_dl_tweets.disabled = True
        run_finetune.disabled = True
        run_finetune.button_style = 'primary'
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle = handle.lower()
        log_finetune.clear_output(wait=True)

        success_try = False

        with log_finetune:
            try:
                print_html(f'\nTraining Neural Network on {handle_widget.value.strip()} tweets... This could take up to 2-3 minutes!\n\n')
                
                # give time to read text
                time.sleep(5)

                # use new run id
                run_id = wandb.util.generate_id()
                %env WANDB_RUN_ID=$run_id
                run_name = handle_widget.value.strip()
                %env WANDB_NAME=$run_name
                
                !python run_language_modeling.py \
                    --output_dir=output/$handle \
                    --overwrite_output_dir \
                    --model_type=gpt2 \
                    --model_name_or_path=gpt2 \
                    --do_train --train_data_file=$handle\_train.txt \
                    --logging_steps 10 \
                    --per_gpu_train_batch_size 1 \
                    --num_train_epochs 4
                
                if _exit_code != 0:
                    raise ValueError("Training failed")
                
                log_finetune.clear_output(wait=True)
                print_html('\n🎉 Neural network trained successfully!')
                
                run_finetune.button_style = 'success'
                log_predictions.clear_output()
                run_predictions.disabled = False

                success_try = True

            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_finetune.button_style = 'danger'
                run_finetune.disabled = False

        if success_try:
            log_predictions.clear_output(wait=True)
            with log_predictions:
                print_html('\nEnter the start of a sentence and click "Run predictions"')
            with log_restart:
                print_html('\n<b>To change user, click on menu "Runtime" → "Restart and run all"</b>\n')

    run_finetune = widgets.Button(
        description='Train Neural Network',
        button_style='primary',
        disabled=True)
    def on_run_finetune_clicked(b):
        finetune()
    run_finetune.on_click(on_run_finetune_clicked)

    log_finetune = widgets.Output()

    def clean_prediction(text):
        token = '<|endoftext|>'
        while len(token)>1:
            text = text.replace(token, '')
            token = token[:-1]
        text = text.strip()
        if text[-1] == '"' and text.count('"') % 2: text = text[:-1]
        return text.strip()

    predictions = []
    
    def predict():
        run_predictions.disabled = True
        run_predictions.button_style = 'primary'
        handle = handle_widget.value.strip()
        handle = handle[1:] if handle[0] == '@' else handle
        handle_uncased = handle
        handle = handle.lower()
        start = start_widget.value.strip()        
        log_predictions.clear_output(wait=True)

        def tweet_html(text):
            max_char = 239
            
            text = clean_prediction(text)
            text = clean_prediction(text[len(start):].strip())
            text = text.replace('"', '&quot;')
            
            tweet_text = f'Trained a neural network on @{handle_uncased}: {start} → {text}'
            
            if len(tweet_text) > max_char:
                # shorten tweet
                n_words = len(tweet_text.split())
                while len(tweet_text) > max_char:
                    tweet_text = ' '.join(tweet_text.split()[:-1]) + '…'

            return '<a href="https://twitter.com/share?ref_src=twsrc%5Etfw" class="twitter-share-button" data-size="large" '\
                    f'data-text="{tweet_text}" '\
                    f'data-url="{wandb_url}" data-hashtags="huggingtweets" data-related="borisdayma,weights_biases,huggingface"'\
                    'data-show-count="false">Tweet</a><script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>'
        
        try_success = False
        with log_predictions:
            try:
                print_html(f'\nPerforming predictions of @{handle} starting with "{start}"...\nThis should take no more than 20 seconds!\n\n')
                
                # start a wandb run
                if wandb.run is None:
                    run_name = handle_widget.value.strip()
                    %env WANDB_NAME=$run_name
                    %env WANDB_RESUME=allow
                    wandb.init()

                wandb_url = wandb.run.get_url()
                
                seed = random.randint(0, 2**32-1)
                val = !python run_generation.py \
                    --model_type gpt2 \
                    --model_name_or_path output/$handle \
                    --length 150 \
                    --stop_token "{'\n'}" \
                    --num_return_sequences 6 \
                    --temperature 1 \
                    --seed $seed \
                    --prompt {'"' + start + '"'}

                if _exit_code != 0:
                    raise ValueError("Predictions failed")
                
                generated = [val[-1-2*k] for k in range(6)[::-1]]
                generated = [clean_prediction(g) for g in generated]
                log_predictions.clear_output(wait=True)
                for i, g in enumerate(generated):
                    predictions.append([start, clean_prediction(g)])
                
                # log predictions
                wandb.log({'examples': wandb.Table(data=predictions, columns=['Input', 'Prediction'])})
                
                # display wandb run
                link = f'<a href="{wandb_url}" rel="noopener" target="_blank">{wandb_url}</a>'
                print_html(f'\n🚀 View all results under the "Media" panel at {link}\n')
                print_html('\nClick on your favorite tweet or try new predictions with other sentences (or the same one)!\n\n')

                # make html table
                tweet_data = [[tweet_html(g), g] for g in generated]
                tweet_table = HTML(html_table(tweet_data))
                if not IN_COLAB: display(tweet_table)

                run_predictions.button_style = 'success'
                try_success = True
                
            except Exception as e:
                print('An error occured...\n')
                print(e)
                run_predictions.button_style = 'danger'

        if try_success and IN_COLAB: display(tweet_table)
        run_predictions.disabled = False
                
    start_widget = widgets.Text(value='I want to',
                                placeholder='Enter twitter handle')

    run_predictions = widgets.Button(
        description='Run predictions',
        button_style='primary',
        disabled=True)
    def on_run_predictions_clicked(b):
        predict()
    run_predictions.on_click(on_run_predictions_clicked)

    log_predictions = widgets.Output()

    clear_output(wait=True)
    print_html('🎉 Environment set-up correctly!')

## Step 1 - Download tweets

Choose a Twitter user and select "Download tweets".

In [0]:
#@title ⠀ {display-mode: "form"}
display(widgets.VBox([handle_widget, run_dl_tweets, log_restart, log_dl_tweets]))

## Step 2 - Train your Neural Network

We fine-tune a neural network on downloaded tweets using [Huggingface](https://huggingface.co/).

In [0]:
#@title ⠀ {display-mode: "form"}
display(widgets.VBox([run_finetune, log_finetune]))
with log_finetune:
    print_html('\nWaiting for Step 1 to complete...')

## Step 3: Visualize Predictions and Have Fun!!!

Just start a sentence and let the model finish it!

In [0]:
#@title ⠀ {display-mode: "form"}
if IN_COLAB:
    display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 5000})'''))
display(widgets.VBox([start_widget, run_predictions, log_predictions]))
with log_predictions:
    print_html('\nWaiting for Step 2 to complete...')

Huggingtweets is still in its infancy and will get better over time!

In the future, it will train continuously to become a Twitter expert!

## About

*Built by Boris Dayma*

[![Follow](https://img.shields.io/twitter/follow/borisdayma?style=social)](https://twitter.com/borisdayma)

My main goals with this project are:
* to experiment with how to train, deploy and maintain neural networks in production ;
* to make AI accessible to everyone.

To see how the model works, visit the project repository.

[![GitHub stars](https://img.shields.io/github/stars/borisdayma/huggingtweets?style=social)](https://github.com/borisdayma/huggingtweets)

*Disclaimer: this project is not to be used to publish any false generated information but to perform research on Natural Language Generation.*

## Resources

* [A Step by Step Guide to Tracking Hugging Face Model Performance](https://app.wandb.ai/jxmorris12/huggingface-demo/reports/A-Step-by-Step-Guide-to-Tracking-Hugging-Face-Model-Performance--VmlldzoxMDE2MTU)
* [W&B Forum](http://bit.ly/wandb-forum): If you have any questions, reach out to the slack community