In [50]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import os
import csv
import pandas as pd
import tweepy
from auth_tw import get_key
import tweepy

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

In [None]:
def load_model(MODEL_EPOCH=4):

    models_folder = "../trained_models"

    model_path = os.path.join(models_folder, f"gpt2_xl_manbot_{MODEL_EPOCH}.pt")

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)

    model.eval()

    return model, tokenizer


In [None]:
model, tokenizer = load_model()

In [None]:
with open('data_twq.csv', encoding="utf8") as csv_file:
            data_twq = csv.reader(csv_file, delimiter=';')

In [None]:
data_twq = pd.read_csv('data_twq.csv',sep=';').drop(['Unnamed: 0'],axis=1)

In [None]:
def return_first_word(tweet):
    return str(tweet[0].split(' ')[0])

In [None]:
first_words = data_twq.apply(return_first_word, axis=1).copy()

In [None]:
sumInstances = pd.DataFrame(first_words).value_counts().sum()

In [None]:
words = pd.DataFrame(pd.DataFrame(first_words).value_counts().index.tolist())

In [None]:
propability = pd.DataFrame(pd.DataFrame(first_words).value_counts().values / sumInstances)

In [None]:
word_prob = words.join(propability,how='left', lsuffix='_left')
word_prob.columns = ['word', 'prob']

In [None]:
def choose_from_top(probs, n=5):
    ind = np.argpartition(probs, -n)[-n:]
    top_prob = probs[ind]
    top_prob = top_prob / np.sum(top_prob) # Normalize
    choice = np.random.choice(n, 1, p = top_prob)
    token_id = ind[choice][0]
    return int(token_id)

In [None]:
def post_on_twitter(tweet_text = ""):

    auth = tweepy.OAuthHandler(get_key("api_key"), get_key("api_key_secret"))
    auth.set_access_token(get_key("access_token"), get_key("access_token_secret"))

    # Create a client using tweepy.Client
    client = tweepy.Client(bearer_token=get_key("bearer_token"), consumer_key=get_key("api_key"), consumer_secret=get_key("api_key_secret"), access_token=get_key("access_token"), access_token_secret=get_key("access_token_secret"))

    try:
        response = client.create_tweet(text=tweet_text)
        tweet_id =  response.data['id']
        print('Tweet posted successfully! Tweet ID:', tweet_id)
    except:
        print('Error occurred while posting the tweet:')

In [None]:
def generate_content(random = True,start_with='',output_file='generated_content.txt', size=5, post_on_twitter = False):

    output_file_path = f'{output_file}'

    model.eval()
    if os.path.exists(output_file_path):
        os.remove(output_file_path)
    
    tweet_num = 0
    with torch.no_grad():
   
        for tweet_idx in range(size):
        
            tweet_finished = False
            first_word = ''
            
            if random: 
                first_word = word_prob['word'][np.random.choice(np.arange(len(word_prob)),p=word_prob['prob'])]
            else:
                first_word = start_with
            cur_ids = torch.tensor(tokenizer.encode(first_word)).unsqueeze(0).to(device)

            for i in range(100):
                outputs = model(cur_ids, labels=cur_ids)
                loss, logits = outputs[:2]
                softmax_logits = torch.softmax(logits[0,-1], dim=0) #Take the first(from only one in this case) batch and the last predicted embedding
                if i < 3:
                    n = 20
                else:
                    n = 3
                next_token_id = choose_from_top(softmax_logits.to('cpu').numpy(), n=n) #Randomly(from the topN probability distribution) select the next word
                cur_ids = torch.cat([cur_ids, torch.ones((1,1)).long().to(device) * next_token_id], dim = 1) # Add the last word to the running sequence

                if next_token_id in tokenizer.encode('<|endoftext|>'):
                    tweet_finished = True
                    break

            
            if tweet_finished:
                
                tweet_num = tweet_num + 1
                
                output_list = list(cur_ids.squeeze().to('cpu').numpy())
                output_text = tokenizer.decode(output_list)
                if post_on_twitter:
                    post_on_twitter(output_text)
                print(output_text)

                with open(output_file_path, 'a', encoding='utf-8') as f:
                    f.write(f"{output_text} \n\n")

                

In [48]:
generate_content(output_file = 'random5_generated.txt', size=1,post_on_twitter=True)

Corporations profit from social media.

The average person has no idea what they're doing.

The average person doesn't care.

The average person thinks they're the best thing they're doing.

The average person doesn't care.<|endoftext|>
Help them love what you are doing.<|endoftext|>


In [None]:
#Start tweets with "Driving" 
generate_content(False, 'Driving ',output_file = 'Driving_generated.txt', size=50)