In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install transformers -q
!pip install sentencepiece -q

# Code of generate_kp(...) from generate_kp.py

In [3]:
import torch
import pandas as pd
from transformers import PegasusTokenizer, PegasusForConditionalGeneration

# generate kp from, arg and topic
def generate_kp(
    arg_topic_df: pd.DataFrame,
    model: PegasusForConditionalGeneration,
    tokenizer: PegasusTokenizer,
    device: str,
    use_topic: bool = False,
    batch_size: int = 16,
    verbose: bool = False,
) -> pd.DataFrame:
    """
    arg_topic_df: df with 'arg' and 'topic' columns
    model: pretrained model
    tokenizer: Pegasus Tokenizer
    device: 'cpu'/'gpu'
    use_topic: use arg+topic to predict kp
    batch_size: batch size (int)
    verbose: print intermediate steps
    ________
    return: df with generated kp,arg, and topic(if available)
    """
    feature_df = arg_topic_df.copy()
    if use_topic:
        feature_df["feature"] = feature_df["arg"] + feature_df["topic"]
    else:
        feature_df["feature"] = feature_df["arg"]

    features = list(feature_df["feature"])
    features_count = len(features)
    batch_count = features_count // batch_size + 1 if (features_count % batch_size)!=0 else 0

    curr_batch_size = 0
    targets = []

    model.to(device)

    with torch.no_grad():
        for batch_id in range(batch_count):
            if (batch_id + 1) * batch_size > features_count:
                curr_batch_size = features_count - batch_id * batch_size
            else:
                curr_batch_size = batch_size

            if verbose:
                print(f"batch_id: {batch_id}")
                print(f"curr_batch_size: {curr_batch_size}")
                print("Tokenizing...")
            tokenized_features = tokenizer(
                features[
                    batch_id * batch_size : batch_id * batch_size + curr_batch_size
                ],
                truncation=True,
                padding="longest",
                return_tensors="pt",
            ).to(device)
            if verbose:
                print("Generating...")
            tokenized_targets = model.generate(**tokenized_features, num_beams=6)
            if verbose:
                print("Decoding...")
            targets += tokenizer.batch_decode(
                tokenized_targets,
                skip_special_tokens=True,
            )

    feature_df["kp_gen"] = targets
    return feature_df

# Generate

In [4]:
dataset_path = '/content/drive/MyDrive/Colab Notebooks/nlp/dataset/'
arg_df = pd.read_csv(dataset_path + 'dataset_austin_sentences.csv')

In [5]:
arg_df = arg_df.rename(columns = {'text':'arg'})

In [6]:
folder_path = "/content/drive/MyDrive/Colab Notebooks/nlp/"
model_path = folder_path + "EMNLP_folder_4/headline_model/"

tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-xsum')
model = PegasusForConditionalGeneration.from_pretrained(model_path,local_files_only=True)

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### Try small subset 

In [8]:
batch_size = 2

In [9]:
small_arg_df = arg_df.head()
small_arg_df

Unnamed: 0,id,arg,district,year
0,1,"Dissatisfied traffic and with traffic, timing ...",7,2016
1,2,EXTREMELY dissatisfied with cit govt.,7,2016
2,3,"interfering in local businesses (Uber/Lyft, in...",7,2016
3,4,"Also, extremely dissatisfied with all the free...",7,2016
4,5,I'm very dissatisfied with the liberal leaning...,7,2016


In [10]:
small_kp_gen_df = generate_kp(
        small_arg_df,
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=batch_size,
        verbose=True
    )

batch_id: 0
curr_batch_size: 2
Tokenizing...
Generating...
Decoding...
batch_id: 1
curr_batch_size: 2
Tokenizing...
Generating...
Decoding...
batch_id: 2
curr_batch_size: 1
Tokenizing...
Generating...
Decoding...


### Full set

In [11]:
arg_df.head()

Unnamed: 0,id,arg,district,year
0,1,"Dissatisfied traffic and with traffic, timing ...",7,2016
1,2,EXTREMELY dissatisfied with cit govt.,7,2016
2,3,"interfering in local businesses (Uber/Lyft, in...",7,2016
3,4,"Also, extremely dissatisfied with all the free...",7,2016
4,5,I'm very dissatisfied with the liberal leaning...,7,2016


In [12]:
arg_df.shape

(6274, 4)

#### Choose best batch size

In [13]:
%%time
kp_gen_df = generate_kp(
        arg_df[0:16*8],
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=8,
    )

ValueError: ignored

In [None]:
%%time
kp_gen_df = generate_kp(
        arg_df[0:8*16],
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=16,
    )

In [None]:
%%time
kp_gen_df = generate_kp(
        arg_df[0:4*32],
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=32,
    )

In [None]:
%%time
kp_gen_df = generate_kp(
        arg_df[0:2*64],
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=64,
    )

#### Generate

In [14]:
batch_size = 32

In [15]:
%%time
kp_gen_df = generate_kp(
        arg_df,
        model=model,
        device=device,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )

CPU times: user 7min 49s, sys: 1.81 s, total: 7min 51s
Wall time: 7min 52s


In [16]:
kp_gen_df.shape

(6274, 6)

In [17]:
kp_gen_df.head()

Unnamed: 0,id,arg,district,year,feature,kp_gen
0,1,"Dissatisfied traffic and with traffic, timing ...",7,2016,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...
1,2,EXTREMELY dissatisfied with cit govt.,7,2016,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote
2,3,"interfering in local businesses (Uber/Lyft, in...",7,2016,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...
3,4,"Also, extremely dissatisfied with all the free...",7,2016,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...
4,5,I'm very dissatisfied with the liberal leaning...,7,2016,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...


In [18]:
kp_gen_df.to_csv(dataset_path + 'dataset_austin_sentences_enigma_predictions_2.csv')

# Compare current and previous predictions

In [19]:
prev_kp_gen_df = pd.read_csv(dataset_path + 'dataset_austin_sentences_enigma_predictions.csv')
prev_kp_gen_df.head()

Unnamed: 0.1,Unnamed: 0,text,key_point_pred,district,year
0,0,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,7,2016
1,1,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,7,2016
2,2,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,7,2016
3,3,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,7,2016
4,4,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,7,2016


In [22]:
prev_kp_gen_df = prev_kp_gen_df.rename(columns={'Unnamed: 0':'id'})

In [23]:
prev_kp_gen_df.head()

Unnamed: 0,id,text,key_point_pred,district,year
0,0,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,7,2016
1,1,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,7,2016
2,2,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,7,2016
3,3,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,7,2016
4,4,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,7,2016


In [24]:
prev_kp_gen_df['id']=prev_kp_gen_df['id']+1
prev_kp_gen_df.head()

Unnamed: 0,id,text,key_point_pred,district,year
0,1,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,7,2016
1,2,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,7,2016
2,3,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,7,2016
3,4,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,7,2016
4,5,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,7,2016


In [25]:
kp_gen_df.set_index('id')
prev_kp_gen_df.set_index('id')

Unnamed: 0_level_0,text,key_point_pred,district,year
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,7,2016
2,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,7,2016
3,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,7,2016
4,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,7,2016
5,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,7,2016
...,...,...,...,...
6270,You to need to pay teachers better.,Teachers are essential to develop social skills,3,2017
6271,This city is too expensive to live in on the s...,Parents are not qualified as teachers,3,2017
6272,Austin Electric company is a monopoly who trea...,Electricity is a necessity and should be provi...,3,2017
6273,My $200 deposit is being held hostage by the c...,A lot of people are losing their jobs because ...,3,2017


In [31]:
merged_df = pd.merge(kp_gen_df, prev_kp_gen_df, left_index=True, right_index=True)
merged_df.head()

Unnamed: 0,id_x,arg,district_x,year_x,feature,kp_gen,id_y,text,key_point_pred,district_y,year_y
0,1,"Dissatisfied traffic and with traffic, timing ...",7,2016,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,1,"Dissatisfied traffic and with traffic, timing ...",Traffic is causing a strain on the cities' res...,7,2016
1,2,EXTREMELY dissatisfied with cit govt.,7,2016,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,2,EXTREMELY dissatisfied with cit govt.,Citizens have a right not to vote,7,2016
2,3,"interfering in local businesses (Uber/Lyft, in...",7,2016,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,3,"interfering in local businesses (Uber/Lyft, in...",Hiring a private hire company is financially b...,7,2016
3,4,"Also, extremely dissatisfied with all the free...",7,2016,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,4,"Also, extremely dissatisfied with all the free...",Government intervention has the risk of insert...,7,2016
4,5,I'm very dissatisfied with the liberal leaning...,7,2016,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,5,I'm very dissatisfied with the liberal leaning...,People should choose for themselves whether or...,7,2016


In [41]:
len(merged_df[(merged_df['kp_gen']!=merged_df['key_point_pred'])])

18