In [None]:
#https://github.com/falloutdurham/beginners-pytorch-deep-learning/blob/master/chapter9/Chapter9.5.ipynb

In [1]:
pip install fire

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[K     |████████████████████████████████| 87 kB 510 kB/s 
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115942 sha256=e5103301645a3021b5775b35e5988dde3216f1be5135ad54c08f74c25faae7cf
  Stored in directory: /root/.cache/pip/wheels/8a/67/fb/2e8a12fa16661b9d5af1f654bd199366799740a85c64981226
Successfully built fire
Installing collected packages: fire
Successfully installed fire-0.4.0


In [2]:
pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 12.1 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 66.7 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 65.3 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.10.1 tokenizers-0.13.1 transformers-4.24.0


In [3]:
import numpy as np
import pyarrow.parquet as pq
import pandas as pd
import random
import torch
import fire
import logging
import os
import csv

from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
import torch.nn.functional as F

In [None]:
tweetsDF["sentiment_cat"] = tweetsDF[0].astype('category')
tweetsDF["sentiment"] = tweetsDF["sentiment_cat"].cat.codes
tweetsDF.to_csv("train-processed.csv", header=None, index=None) 
tweetsDF.sample(10000).to_csv("train-processed-sample.csv", header=None, index=None) 

NameError: ignored

In [None]:
class ParquetDataset(Dataset):
    def __init__(self, path, cols, truncate=False, gpt2_type="gpt2", max_length=768):

        # Grab our pandas dataframe, only reading in the columns we're interested in,
        # append our magic tokens (<#col_name#> for the particular column, and <|endoftext|>
        # used by GPT-2 as a text separator), then concatenate them into one giant column for
        # our dataset

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        
        self.df = pq.read_table(path, columns=cols).to_pandas().dropna()
        for col in cols:
            self.df[col] = self.df[col].apply(lambda x: torch.tensor(self.tokenizer.encode(f"<#{col}#>{x[:768]}<|endoftext|>")))
        self.df = pd.concat(map(self.df.get, cols)).reset_index(drop=True)
        if truncate:
            self.df = self.df.truncate(after=150)

    def __len__(self):
        return self.df.count()

    def __getitem__(self, item):
        return self.df.iloc[item]

In [4]:
class CSVTwitter(Dataset):
    
    def __init__(self, control_code, truncate=False, gpt2_type="gpt2", max_length=768):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.tweets = []

        # This uses the same CSV of Sentiment140 that we created in Chapter 5
        
        with open('train-processed-sample.csv', newline='') as csvfile:
            tweet_csv = csv.reader(csvfile)
            for row in tweet_csv:
                self.tweets.append(torch.tensor(
                    self.tokenizer.encode(f"<|{control_code}|>{row[5][:max_length]}<|endoftext|>")
                ))
                
        if truncate:
            self.tweets = self.tweets[:20000]
        self.tweet_count = len(self.tweets)
        
    def __len__(self):
        return self.tweet_count

    def __getitem__(self, item):
        return self.tweets[item]

In [5]:
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

In [14]:
def train(
    dataset,
    model,
    tokenizer,
    batch_size=16,
    epochs=4,
    lr=2e-5,
    max_seq_len=400,
    warmup_steps=5000,
    gpt2_type="gpt2",
    device="cuda",
    output_dir=".",
    output_prefix="wreckgar",
    test_mode=False,
    save_model_on_epoch=False,
):

    acc_steps = 100

    model = model.to(device)
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):

        print(f"Training epoch {epoch}")
        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            print("shape: ",input_tensor.shape)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            #if (accumulating_batch_count % batch_size) == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None
        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
            )
    return model


In [7]:
dataset = CSVTwitter("<|tweet|>", truncate=True, gpt2_type="gpt2")
gpt2_type = "gpt2"

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

In [12]:
len(dataset.tweets[2])

30

In [15]:
model = train(
    dataset,
    GPT2LMHeadModel.from_pretrained(gpt2_type),
    GPT2Tokenizer.from_pretrained(gpt2_type),
    batch_size=16,
    epochs=5,
    lr=3e-5,
    max_seq_len=140,
    warmup_steps=5000,
    gpt2_type=gpt2_type,
    device="cuda",
    output_dir="",
    output_prefix="twitter",
    save_model_on_epoch=True
)



Training epoch 0


0it [00:00, ?it/s]

shape:  torch.Size([1, 761])


24it [00:00, 74.39it/s]

shape:  torch.Size([1, 748])


51it [00:00, 86.56it/s]

shape:  torch.Size([1, 755])


74it [00:00, 83.71it/s]

shape:  torch.Size([1, 732])


96it [00:01, 81.31it/s]

shape:  torch.Size([1, 762])


125it [00:01, 86.87it/s]

shape:  torch.Size([1, 757])


150it [00:01, 86.61it/s]

shape:  torch.Size([1, 762])


174it [00:02, 85.40it/s]

shape:  torch.Size([1, 744])


202it [00:02, 87.94it/s]

shape:  torch.Size([1, 759])


230it [00:02, 90.37it/s]

shape:  torch.Size([1, 762])


255it [00:02, 89.29it/s]

shape:  torch.Size([1, 740])


280it [00:03, 87.84it/s]

shape:  torch.Size([1, 759])


306it [00:03, 87.17it/s]

shape:  torch.Size([1, 755])


332it [00:03, 87.85it/s]

shape:  torch.Size([1, 766])


356it [00:04, 85.47it/s]

shape:  torch.Size([1, 758])


381it [00:04, 84.64it/s]

shape:  torch.Size([1, 744])


404it [00:04, 83.27it/s]

shape:  torch.Size([1, 735])


429it [00:04, 84.69it/s]

shape:  torch.Size([1, 758])


454it [00:05, 84.05it/s]

shape:  torch.Size([1, 732])


480it [00:05, 85.51it/s]

shape:  torch.Size([1, 762])


504it [00:05, 84.10it/s]

shape:  torch.Size([1, 750])


529it [00:06, 84.30it/s]

shape:  torch.Size([1, 763])


555it [00:06, 84.28it/s]

shape:  torch.Size([1, 754])


581it [00:06, 84.96it/s]

shape:  torch.Size([1, 749])


608it [00:07, 87.03it/s]

shape:  torch.Size([1, 750])


632it [00:07, 84.78it/s]

shape:  torch.Size([1, 755])


659it [00:07, 86.23it/s]

shape:  torch.Size([1, 755])


683it [00:07, 84.25it/s]

shape:  torch.Size([1, 762])


708it [00:08, 83.77it/s]

shape:  torch.Size([1, 737])


735it [00:08, 86.00it/s]

shape:  torch.Size([1, 747])


758it [00:08, 83.43it/s]

shape:  torch.Size([1, 758])


782it [00:09, 82.16it/s]

shape:  torch.Size([1, 745])


807it [00:09, 82.27it/s]

shape:  torch.Size([1, 764])


833it [00:09, 83.55it/s]

shape:  torch.Size([1, 755])


859it [00:10, 84.38it/s]

shape:  torch.Size([1, 758])


883it [00:10, 82.82it/s]

shape:  torch.Size([1, 726])


906it [00:10, 81.38it/s]

shape:  torch.Size([1, 740])


930it [00:10, 81.24it/s]

shape:  torch.Size([1, 745])


956it [00:11, 82.65it/s]

shape:  torch.Size([1, 766])


980it [00:11, 81.32it/s]

shape:  torch.Size([1, 740])


1006it [00:11, 82.92it/s]

shape:  torch.Size([1, 753])


1028it [00:12, 79.70it/s]

shape:  torch.Size([1, 761])


1052it [00:12, 79.16it/s]

shape:  torch.Size([1, 734])


1076it [00:12, 79.36it/s]

shape:  torch.Size([1, 754])


1100it [00:13, 79.15it/s]

shape:  torch.Size([1, 763])


1126it [00:13, 80.61it/s]

shape:  torch.Size([1, 746])


1150it [00:13, 80.22it/s]

shape:  torch.Size([1, 753])


1176it [00:14, 81.79it/s]

shape:  torch.Size([1, 748])


1201it [00:14, 81.90it/s]

shape:  torch.Size([1, 734])


1227it [00:14, 83.38it/s]

shape:  torch.Size([1, 749])


1251it [00:14, 81.73it/s]

shape:  torch.Size([1, 756])


1275it [00:15, 80.55it/s]

shape:  torch.Size([1, 762])


1302it [00:15, 82.59it/s]

shape:  torch.Size([1, 741])


1328it [00:15, 83.95it/s]

shape:  torch.Size([1, 747])


1352it [00:16, 82.37it/s]

shape:  torch.Size([1, 754])


1376it [00:16, 81.15it/s]

shape:  torch.Size([1, 734])


1398it [00:16, 78.86it/s]

shape:  torch.Size([1, 766])


1424it [00:17, 80.18it/s]

shape:  torch.Size([1, 747])


1447it [00:17, 78.79it/s]

shape:  torch.Size([1, 760])


1471it [00:17, 78.66it/s]

shape:  torch.Size([1, 736])


1495it [00:17, 79.44it/s]

shape:  torch.Size([1, 766])


1521it [00:18, 80.81it/s]

shape:  torch.Size([1, 744])


1544it [00:18, 79.50it/s]

shape:  torch.Size([1, 752])


1569it [00:18, 80.22it/s]

shape:  torch.Size([1, 753])


1593it [00:19, 79.53it/s]

shape:  torch.Size([1, 761])


1618it [00:19, 80.10it/s]

shape:  torch.Size([1, 761])


1644it [00:19, 81.70it/s]

shape:  torch.Size([1, 748])


1672it [00:20, 85.30it/s]

shape:  torch.Size([1, 752])


1698it [00:20, 85.37it/s]

shape:  torch.Size([1, 742])


1724it [00:20, 85.51it/s]

shape:  torch.Size([1, 754])


1749it [00:21, 84.26it/s]

shape:  torch.Size([1, 743])


1773it [00:21, 82.84it/s]

shape:  torch.Size([1, 731])


1796it [00:21, 81.26it/s]

shape:  torch.Size([1, 760])


1820it [00:21, 80.71it/s]

shape:  torch.Size([1, 761])


1848it [00:22, 84.31it/s]

shape:  torch.Size([1, 748])


1872it [00:22, 83.13it/s]

shape:  torch.Size([1, 750])


1896it [00:22, 81.90it/s]

shape:  torch.Size([1, 724])


1920it [00:23, 82.14it/s]

shape:  torch.Size([1, 729])


1943it [00:23, 80.78it/s]

shape:  torch.Size([1, 753])


1968it [00:23, 81.67it/s]

shape:  torch.Size([1, 751])


1993it [00:24, 82.47it/s]

shape:  torch.Size([1, 742])


2020it [00:24, 84.83it/s]

shape:  torch.Size([1, 751])


2046it [00:24, 85.65it/s]

shape:  torch.Size([1, 766])


2074it [00:24, 87.84it/s]

shape:  torch.Size([1, 733])


2099it [00:25, 87.06it/s]

shape:  torch.Size([1, 737])


2125it [00:25, 87.44it/s]

shape:  torch.Size([1, 765])


2149it [00:25, 84.88it/s]

shape:  torch.Size([1, 753])


2172it [00:26, 82.67it/s]

shape:  torch.Size([1, 744])


2196it [00:26, 82.73it/s]

shape:  torch.Size([1, 731])


2223it [00:26, 85.38it/s]

shape:  torch.Size([1, 751])


2247it [00:26, 84.27it/s]

shape:  torch.Size([1, 762])


2270it [00:27, 81.81it/s]

shape:  torch.Size([1, 757])


2297it [00:27, 84.76it/s]

shape:  torch.Size([1, 751])


2322it [00:27, 84.70it/s]

shape:  torch.Size([1, 748])


2348it [00:28, 85.91it/s]

shape:  torch.Size([1, 754])


2373it [00:28, 85.71it/s]

shape:  torch.Size([1, 753])


2398it [00:28, 85.49it/s]

shape:  torch.Size([1, 755])


2422it [00:29, 84.24it/s]

shape:  torch.Size([1, 767])


2449it [00:29, 85.74it/s]

shape:  torch.Size([1, 758])


2474it [00:29, 85.80it/s]

shape:  torch.Size([1, 751])


2500it [00:29, 86.39it/s]

shape:  torch.Size([1, 764])


2523it [00:30, 83.89it/s]

shape:  torch.Size([1, 753])


2548it [00:30, 84.01it/s]

shape:  torch.Size([1, 748])


2574it [00:30, 85.77it/s]

shape:  torch.Size([1, 749])


2603it [00:31, 89.81it/s]

shape:  torch.Size([1, 767])


2629it [00:31, 88.85it/s]

shape:  torch.Size([1, 754])


2652it [00:31, 85.45it/s]

shape:  torch.Size([1, 747])


2677it [00:32, 85.96it/s]

shape:  torch.Size([1, 767])


2705it [00:32, 88.70it/s]

shape:  torch.Size([1, 717])


2727it [00:32, 85.54it/s]

shape:  torch.Size([1, 739])


2752it [00:32, 85.64it/s]

shape:  torch.Size([1, 759])


2776it [00:33, 84.52it/s]

shape:  torch.Size([1, 761])


2801it [00:33, 84.36it/s]

shape:  torch.Size([1, 728])


2824it [00:33, 83.41it/s]

shape:  torch.Size([1, 755])


2854it [00:34, 89.31it/s]

shape:  torch.Size([1, 766])


2879it [00:34, 88.20it/s]

shape:  torch.Size([1, 750])


2906it [00:34, 89.41it/s]

shape:  torch.Size([1, 729])


2929it [00:34, 87.06it/s]

shape:  torch.Size([1, 767])


2958it [00:35, 90.87it/s]

shape:  torch.Size([1, 762])


2983it [00:35, 89.25it/s]

shape:  torch.Size([1, 755])


3006it [00:35, 86.09it/s]

shape:  torch.Size([1, 759])


3030it [00:36, 84.99it/s]

shape:  torch.Size([1, 754])


3055it [00:36, 85.56it/s]

shape:  torch.Size([1, 755])


3078it [00:36, 77.15it/s]

shape:  torch.Size([1, 752])


3102it [00:37, 77.03it/s]

shape:  torch.Size([1, 749])


3126it [00:37, 78.83it/s]

shape:  torch.Size([1, 720])


3150it [00:37, 80.69it/s]

shape:  torch.Size([1, 758])


3177it [00:37, 79.11it/s]

shape:  torch.Size([1, 766])


3203it [00:38, 82.85it/s]

shape:  torch.Size([1, 740])


3228it [00:38, 84.06it/s]

shape:  torch.Size([1, 738])


3252it [00:38, 83.05it/s]

shape:  torch.Size([1, 743])


3275it [00:39, 82.67it/s]

shape:  torch.Size([1, 764])


3299it [00:39, 82.97it/s]

shape:  torch.Size([1, 744])


3327it [00:39, 86.23it/s]

shape:  torch.Size([1, 756])


3353it [00:39, 87.67it/s]

shape:  torch.Size([1, 741])


3375it [00:40, 85.17it/s]

shape:  torch.Size([1, 728])


3400it [00:40, 86.48it/s]

shape:  torch.Size([1, 766])


3426it [00:40, 87.02it/s]

shape:  torch.Size([1, 745])


3449it [00:41, 85.52it/s]

shape:  torch.Size([1, 759])


3474it [00:41, 86.14it/s]

shape:  torch.Size([1, 732])


3499it [00:41, 86.60it/s]

shape:  torch.Size([1, 754])


3525it [00:41, 87.87it/s]

shape:  torch.Size([1, 740])


3551it [00:42, 89.54it/s]

shape:  torch.Size([1, 753])


3576it [00:42, 88.60it/s]

shape:  torch.Size([1, 762])


3603it [00:42, 89.80it/s]

shape:  torch.Size([1, 733])


3626it [00:43, 87.65it/s]

shape:  torch.Size([1, 767])


3654it [00:43, 90.58it/s]

shape:  torch.Size([1, 748])


3681it [00:43, 91.46it/s]

shape:  torch.Size([1, 767])


3703it [00:43, 86.87it/s]

shape:  torch.Size([1, 764])


3724it [00:44, 83.03it/s]

shape:  torch.Size([1, 740])


3749it [00:44, 84.92it/s]

shape:  torch.Size([1, 753])


3775it [00:44, 86.44it/s]

shape:  torch.Size([1, 758])


3798it [00:45, 84.75it/s]

shape:  torch.Size([1, 734])


3823it [00:45, 86.45it/s]

shape:  torch.Size([1, 760])


3850it [00:45, 88.41it/s]

shape:  torch.Size([1, 747])


3872it [00:45, 85.08it/s]

shape:  torch.Size([1, 735])


3893it [00:46, 82.23it/s]

shape:  torch.Size([1, 760])


3916it [00:46, 81.66it/s]

shape:  torch.Size([1, 764])


3942it [00:46, 83.92it/s]

shape:  torch.Size([1, 731])


3967it [00:47, 85.54it/s]

shape:  torch.Size([1, 758])


3994it [00:47, 88.20it/s]

shape:  torch.Size([1, 763])


4020it [00:47, 88.74it/s]

shape:  torch.Size([1, 752])


4044it [00:47, 87.35it/s]

shape:  torch.Size([1, 722])


4067it [00:48, 86.13it/s]

shape:  torch.Size([1, 767])


4094it [00:48, 88.55it/s]

shape:  torch.Size([1, 761])


4122it [00:48, 90.99it/s]

shape:  torch.Size([1, 761])


4148it [00:49, 90.77it/s]

shape:  torch.Size([1, 724])


4172it [00:49, 89.62it/s]

shape:  torch.Size([1, 746])


4197it [00:49, 89.48it/s]

shape:  torch.Size([1, 729])


4223it [00:49, 90.83it/s]

shape:  torch.Size([1, 767])


4246it [00:50, 87.08it/s]

shape:  torch.Size([1, 749])


4272it [00:50, 88.82it/s]

shape:  torch.Size([1, 759])


4298it [00:50, 89.27it/s]

shape:  torch.Size([1, 761])


4324it [00:51, 89.74it/s]

shape:  torch.Size([1, 754])


4349it [00:51, 89.49it/s]

shape:  torch.Size([1, 760])


4376it [00:51, 91.34it/s]

shape:  torch.Size([1, 755])


4403it [00:51, 92.31it/s]

shape:  torch.Size([1, 747])


4430it [00:52, 93.12it/s]

shape:  torch.Size([1, 764])


4457it [00:52, 93.68it/s]

shape:  torch.Size([1, 760])


4481it [00:52, 90.82it/s]

shape:  torch.Size([1, 748])


4505it [00:53, 88.92it/s]

shape:  torch.Size([1, 764])


4531it [00:53, 89.53it/s]

shape:  torch.Size([1, 748])


4557it [00:53, 90.34it/s]

shape:  torch.Size([1, 739])


4581it [00:53, 88.86it/s]

shape:  torch.Size([1, 748])


4609it [00:54, 92.31it/s]

shape:  torch.Size([1, 736])


4636it [00:54, 93.95it/s]

shape:  torch.Size([1, 757])


4663it [00:54, 94.08it/s]

shape:  torch.Size([1, 746])


4690it [00:55, 94.46it/s]

shape:  torch.Size([1, 735])


4714it [00:55, 91.75it/s]

shape:  torch.Size([1, 744])


4737it [00:55, 88.79it/s]

shape:  torch.Size([1, 722])


4761it [00:55, 88.45it/s]

shape:  torch.Size([1, 758])


4787it [00:56, 89.32it/s]

shape:  torch.Size([1, 762])


4816it [00:56, 93.24it/s]

shape:  torch.Size([1, 759])


4841it [00:56, 91.64it/s]

shape:  torch.Size([1, 738])


4868it [00:57, 93.23it/s]

shape:  torch.Size([1, 739])


4894it [00:57, 92.84it/s]

shape:  torch.Size([1, 747])


4918it [00:57, 90.47it/s]

shape:  torch.Size([1, 754])


4944it [00:57, 91.22it/s]

shape:  torch.Size([1, 732])


4966it [00:58, 87.94it/s]

shape:  torch.Size([1, 750])


4991it [00:58, 87.97it/s]

shape:  torch.Size([1, 728])


5017it [00:58, 89.55it/s]

shape:  torch.Size([1, 755])


5045it [00:58, 92.75it/s]

shape:  torch.Size([1, 738])


5069it [00:59, 90.83it/s]

shape:  torch.Size([1, 750])


5091it [00:59, 80.33it/s]

shape:  torch.Size([1, 743])


5117it [00:59, 83.00it/s]

shape:  torch.Size([1, 753])


5141it [01:00, 83.84it/s]

shape:  torch.Size([1, 761])


5165it [01:00, 82.62it/s]

shape:  torch.Size([1, 749])


5186it [01:00, 80.63it/s]

shape:  torch.Size([1, 726])


5209it [01:01, 81.70it/s]

shape:  torch.Size([1, 747])


5234it [01:01, 83.85it/s]

shape:  torch.Size([1, 757])


5258it [01:01, 84.01it/s]

shape:  torch.Size([1, 764])


5283it [01:01, 85.46it/s]

shape:  torch.Size([1, 752])


5306it [01:02, 84.75it/s]

shape:  torch.Size([1, 752])


5331it [01:02, 85.79it/s]

shape:  torch.Size([1, 765])


5355it [01:02, 85.12it/s]

shape:  torch.Size([1, 753])


5377it [01:03, 83.03it/s]

shape:  torch.Size([1, 746])


5403it [01:03, 85.89it/s]

shape:  torch.Size([1, 750])


5429it [01:03, 87.54it/s]

shape:  torch.Size([1, 763])


5456it [01:03, 89.49it/s]

shape:  torch.Size([1, 739])


5480it [01:04, 88.42it/s]

shape:  torch.Size([1, 747])


5507it [01:04, 90.29it/s]

shape:  torch.Size([1, 732])


5531it [01:04, 89.27it/s]

shape:  torch.Size([1, 765])


5555it [01:04, 87.95it/s]

shape:  torch.Size([1, 764])


5579it [01:05, 87.10it/s]

shape:  torch.Size([1, 762])


5605it [01:05, 88.35it/s]

shape:  torch.Size([1, 759])


5629it [01:05, 87.07it/s]

shape:  torch.Size([1, 762])


5656it [01:06, 89.58it/s]

shape:  torch.Size([1, 754])


5680it [01:06, 88.24it/s]

shape:  torch.Size([1, 755])


5704it [01:06, 86.83it/s]

shape:  torch.Size([1, 750])


5727it [01:06, 85.07it/s]

shape:  torch.Size([1, 759])


5752it [01:07, 86.05it/s]

shape:  torch.Size([1, 760])


5774it [01:07, 83.27it/s]

shape:  torch.Size([1, 758])


5800it [01:07, 85.58it/s]

shape:  torch.Size([1, 751])


5826it [01:08, 87.50it/s]

shape:  torch.Size([1, 740])


5850it [01:08, 87.14it/s]

shape:  torch.Size([1, 758])


5874it [01:08, 85.59it/s]


KeyboardInterrupt: ignored

In [None]:
torch.save(model.state_dict(), 'twitter.pt')

In [None]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=100,
    top_p=0.8,
    temperature=1.,
):

    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|endoftext|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)

                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
                output_list = list(generated.squeeze().numpy())
                output_text = f"{tokenizer.decode(output_list)}<|endoftext|>" 
                generated_list.append(output_text)
                
    return generated_list


In [None]:
generated_tweets = generate(model.to('cpu'), GPT2Tokenizer.from_pretrained(gpt2_type),"<|tweet|>",entry_count=10)

100%|██████████| 10/10 [00:40<00:00,  4.03s/it]


In [None]:
generated_tweets

['<|tweet|>|>still listening to #chicky-pop <|endoftext|>',
 '<|tweet|>|>@MoraJEilkes Good luck if you get to go to the Pratap http://bit.ly/11hcRl<|endoftext|>',
 '<|tweet|>|>@bradthewl <|endoftext|>',
 "<|tweet|>|>@CalebGrayfish This time you're not alone  and will not be going. They take care of you.  @Carolina_Cameron<|endoftext|>",
 '<|tweet|>|>The twitter followers are brilliant. #testmywork #mondaysofy Twitter is off to a hot start! #failtopymonday<|endoftext|>',
 "<|tweet|>|>@meganroose well not like that either  i'mnt into that either. i get there at 3pm so I am better off working then catching up <|endoftext|>",
 '<|tweet|>|>Haha really wanted this since last night.  Nice Night! <|endoftext|>',
 '<|tweet|>|>@CloserYard\n\n|>just got my PhD today, im currently studying to have a good English language and I think it will give me more of an edge <|endoftext|>',
 '<|tweet|>|>@phillipatherer I think you have a very interesting talent <|endoftext|>',
 "<|tweet|>@dhaeson I hope it w