In [1]:
import pandas as pd
import numpy as np
import calendar
import math
import re
import string

In [2]:
import segmentation
import utils
import data2graph
from finetuned import T5FineTuner, BARTFineTuner, generate, generate_beam, graph2text_nobeam, graph2text_nobeam_ngram_es, graph2text_nobeam_topk, graph2text_nobeam_topp

In [3]:
import textstat
import language_tool_python
from lexical_diversity import lex_div as ld

In [4]:
tool = language_tool_python.LanguageTool('en-US')

def grammar_score(input_text):
    errors = len(tool.check(input_text))
    clean_text = input_text.translate(str.maketrans('', '', string.punctuation))
    clean_text = list(filter(None, clean_text.split(' ')))
    num_words = len(clean_text)
    return float(1-(errors/num_words))

### Loading Fine-Tuned PLMs

In [5]:
import torch
#cuda0 = torch.device("cuda:0")
cuda1 = torch.device("cuda:1")
#cuda3 = torch.device("cuda:3")

t5 = T5FineTuner.load_from_checkpoint("T5Models/T5Both.ckpt")
bart = BARTFineTuner.load_from_checkpoint("BARTModels/BARTBoth.ckpt")

t5.to(cuda1)
bart.to(cuda1)

BARTFineTuner(
  (model): BartForConditionalGeneration(
    (model): BartModel(
      (shared): Embedding(50268, 1024)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(50268, 1024)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0): EncoderLayer(
            (self_attn): Attention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (final_layer_norm): LayerNorm((1024,), eps=1e

### Pollution

In [6]:
#Import US Pollution Dataset
ds_poll = pd.read_csv("Data/USPollution/USPollution.csv")
ds_poll = ds_poll.dropna()
ds_poll['Date Local'] = pd.to_datetime(ds_poll['Date Local'])
ds_poll['month'] = pd.DatetimeIndex(ds_poll['Date Local']).month
ds_poll['month'] = ds_poll['month'].apply(lambda x: calendar.month_name[x])
ds_poll['year'] = pd.DatetimeIndex(ds_poll['Date Local']).year
ds_poll.set_index(['Date Local'],inplace=True)

In [7]:
#RE Scores
template_re_scores = []
t5_re_scores = []
t5_re_scores_topk = []
t5_re_scores_topp = []
bart_re_scores = []
bart_re_scores_topk = []
bart_re_scores_topp = []

#Diveristy Scores
template_tte_scores = []
t5_tte_scores = []
t5_tte_scores_topk = []
t5_tte_scores_topp = []
bart_tte_scores = []
bart_tte_scores_topk = []
bart_tte_scores_topp = []

#Grammar Scores
t5_g_scores = []
t5_g_scores_topk = []
t5_g_scores_topp = []
bart_g_scores = []
bart_g_scores_topk = []
bart_g_scores_topp = []

#Grammar Mistakes
t5_g_mistake = []
t5_g_mistake_topk = []
t5_g_mistake_topp = []
bart_g_mistake = []
bart_g_mistake_topk = []
bart_g_mistake_topp = []

for i in [1,2,5,8,9,10,11,12,13,15]:
    
    location = ds_poll[ds_poll['State Code']==i]['State'].iloc[[1]][0]
    iso = location
    
    country = ds_poll[ds_poll['State Code']==i][['CO Mean','month', 'year']].reset_index().drop(columns=['Date Local'])
    country_poll = country['CO Mean'].tolist()
    
    print("Processing Country: ", iso)
    
    #Detecting Waves
    embeds, cluster_labels = segmentation.tslr_rep(country_poll)
    cluster_arrangement = utils.find_contiguous(cluster_labels)
    indices = utils.find_indices(cluster_arrangement)
    wave_indices = utils.find_waves(country_poll, indices, tolerance=7)
    
    print("Waves Detected: ", iso)

    #Detecting Trends
    segmentation_results = segmentation.sliding_window(country_poll, 1.5)
    filtered_results = segmentation.re_segment(segmentation_results, country_poll)
    trends = segmentation.find_trend(filtered_results, country_poll)
    
    print("Trends Detected: ", iso)
    
    graph, essentials = data2graph.build_graph_polls_form1("Mean carbon monoxide", location, wave_indices, trends, country, country_poll)
    #Template Narrative
    template_text = data2graph.build_template_poll_nums("Mean carbon monoxide", location, wave_indices, trends, country, country_poll)
    t5_prefix = 'translate Graph to English: '
    
    #Simple PLM Generation
    t5_narrative = graph2text_nobeam(t5, graph, t5_prefix, 512, cuda1)
    bart_narrative = graph2text_nobeam(bart , graph, "", 512, cuda1)
    bart_narrative = re.sub('</s>' , '', bart_narrative)
    
    print("Simple Generation Complete: ", iso)
    
    #Top-k at 50
    t5_narrative_topk = graph2text_nobeam_topk(t5, graph, t5_prefix, 50, 512, cuda1)
    bart_narrative_topk = graph2text_nobeam_topk(bart, graph, "", 50, 512, cuda1)
    bart_narrative_topk = re.sub('</s>' , '', bart_narrative_topk)
    
    print("Top-k Complete: ", iso)
    
    #Top-p at 0.92
    t5_narrative_topp = graph2text_nobeam_topp(t5, graph, t5_prefix, 0.92, 512, cuda1)
    bart_narrative_topp = graph2text_nobeam_topp(bart, graph, "", 0.92, 512, cuda1)
    bart_narrative_topp = re.sub('</s>' , '', bart_narrative_topp)
    
    print("Top-p Complete: ", iso)
    
    #RE Scores
    template_re_scores.append(textstat.flesch_reading_ease(template_text))
    t5_re_scores.append(textstat.flesch_reading_ease(t5_narrative))
    t5_re_scores_topk.append(textstat.flesch_reading_ease(t5_narrative_topk))
    t5_re_scores_topp.append(textstat.flesch_reading_ease(t5_narrative_topp))
    bart_re_scores.append(textstat.flesch_reading_ease(bart_narrative))
    bart_re_scores_topk.append(textstat.flesch_reading_ease(bart_narrative_topk))
    bart_re_scores_topp.append(textstat.flesch_reading_ease(bart_narrative_topp))
    
    print("RE Scores Computed: ", iso)
    
    #Diveristy Scores
    template_tte_scores.append(ld.ttr(ld.flemmatize(template_text)))
    t5_tte_scores.append(ld.ttr(ld.flemmatize(t5_narrative)))
    t5_tte_scores_topk.append(ld.ttr(ld.flemmatize(t5_narrative_topk)))
    t5_tte_scores_topp.append(ld.ttr(ld.flemmatize(t5_narrative_topp)))
    bart_tte_scores.append(ld.ttr(ld.flemmatize(bart_narrative)))
    bart_tte_scores_topk.append(ld.ttr(ld.flemmatize(bart_narrative_topk)))
    bart_tte_scores_topp.append(ld.ttr(ld.flemmatize(bart_narrative_topp)))
    
    print("TTE Scores Computed: ", iso)
    
    #Grammar Scores
    gs = grammar_score(t5_narrative)
    t5_g_scores.append(gs)
    if gs != 1.0:
        t5_g_mistake.append((graph, t5_narrative))
    
    gs = grammar_score(t5_narrative_topk)
    t5_g_scores_topk.append(gs)
    if gs != 1.0:
        t5_g_mistake_topk.append((graph, t5_narrative_topk))
    
    gs = grammar_score(t5_narrative_topp)
    t5_g_scores_topp.append(gs)
    if gs != 1.0:
        t5_g_mistake_topp.append((graph, t5_narrative_topp))
    
    gs = grammar_score(bart_narrative)                          
    bart_g_scores.append(gs)
    if gs != 1.0:
        bart_g_mistake.append((graph, bart_narrative))
        
    gs = grammar_score(bart_narrative_topk)
    bart_g_scores_topk.append(gs)
    if gs != 1.0:
        bart_g_mistake_topk.append((graph, bart_narrative_topk))
    
    gs = grammar_score(bart_narrative_topp)
    bart_g_scores_topp.append(gs)
    if gs != 1.0:
        bart_g_mistake_topp.append((graph, bart_narrative_topp))
    
    print("Grammar Scores Computed: ", iso)

Processing Country:  Alabama
Waves Detected:  Alabama
Trends Detected:  Alabama


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)
  slope = r_num / ssxm
  t = r * np.sqrt(df / ((1.0 - r + TINY)*(1.0 + r + TINY)))
  sterrest = np.sqrt((1 - r**2) * ssym / ssxm / df)


Simple Generation Complete:  Alabama
Top-k Complete:  Alabama
Top-p Complete:  Alabama
RE Scores Computed:  Alabama
TTE Scores Computed:  Alabama
Grammar Scores Computed:  Alabama
Processing Country:  Alaska
Waves Detected:  Alaska
Trends Detected:  Alaska


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Simple Generation Complete:  Alaska
Top-k Complete:  Alaska
Top-p Complete:  Alaska
RE Scores Computed:  Alaska
TTE Scores Computed:  Alaska
Grammar Scores Computed:  Alaska
Processing Country:  Arkansas
Waves Detected:  Arkansas


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Arkansas
Simple Generation Complete:  Arkansas
Top-k Complete:  Arkansas
Top-p Complete:  Arkansas
RE Scores Computed:  Arkansas
TTE Scores Computed:  Arkansas
Grammar Scores Computed:  Arkansas
Processing Country:  Colorado
Waves Detected:  Colorado


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Colorado
Simple Generation Complete:  Colorado
Top-k Complete:  Colorado
Top-p Complete:  Colorado
RE Scores Computed:  Colorado
TTE Scores Computed:  Colorado
Grammar Scores Computed:  Colorado
Processing Country:  Connecticut
Waves Detected:  Connecticut


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Connecticut
Simple Generation Complete:  Connecticut
Top-k Complete:  Connecticut
Top-p Complete:  Connecticut
RE Scores Computed:  Connecticut
TTE Scores Computed:  Connecticut
Grammar Scores Computed:  Connecticut
Processing Country:  Delaware
Waves Detected:  Delaware
Trends Detected:  Delaware


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Simple Generation Complete:  Delaware
Top-k Complete:  Delaware
Top-p Complete:  Delaware
RE Scores Computed:  Delaware
TTE Scores Computed:  Delaware
Grammar Scores Computed:  Delaware
Processing Country:  District Of Columbia
Waves Detected:  District Of Columbia


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  District Of Columbia
Simple Generation Complete:  District Of Columbia
Top-k Complete:  District Of Columbia
Top-p Complete:  District Of Columbia
RE Scores Computed:  District Of Columbia
TTE Scores Computed:  District Of Columbia
Grammar Scores Computed:  District Of Columbia
Processing Country:  Florida
Waves Detected:  Florida


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Florida
Simple Generation Complete:  Florida
Top-k Complete:  Florida
Top-p Complete:  Florida
RE Scores Computed:  Florida
TTE Scores Computed:  Florida
Grammar Scores Computed:  Florida
Processing Country:  Georgia
Waves Detected:  Georgia


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Georgia
Simple Generation Complete:  Georgia
Top-k Complete:  Georgia
Top-p Complete:  Georgia
RE Scores Computed:  Georgia
TTE Scores Computed:  Georgia
Grammar Scores Computed:  Georgia
Processing Country:  Hawaii
Waves Detected:  Hawaii


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Hawaii
Simple Generation Complete:  Hawaii
Top-k Complete:  Hawaii
Top-p Complete:  Hawaii
RE Scores Computed:  Hawaii
TTE Scores Computed:  Hawaii
Grammar Scores Computed:  Hawaii


In [8]:
#RE Scores
print("*** RE Scores ***")
print("template_re_scores: ", np.mean(template_re_scores))
print("t5_re_scores: ", np.mean(t5_re_scores))
print("t5_re_scores_topk: ", np.mean(t5_re_scores_topk))
print("t5_re_scores_topp: ", np.mean(t5_re_scores_topp))
print("bart_re_scores: ", np.mean(bart_re_scores))
print("bart_re_scores_topk: ", np.mean(bart_re_scores_topk))
print("bart_re_scores_topp: ", np.mean(bart_re_scores_topp))

print("\n")
print("*** Diversity Scores ***")
#Diveristy Scores
print("template_tte_scores: ", np.mean(template_tte_scores))
print("t5_tte_scores: ", np.mean(t5_tte_scores))
print("t5_tte_scores_topk: ", np.mean(t5_tte_scores_topk))
print("t5_tte_scores_topp: ", np.mean(t5_tte_scores_topp))
print("bart_tte_scores: ", np.mean(bart_tte_scores))
print("bart_tte_scores_topk: ", np.mean(bart_tte_scores_topk))
print("bart_tte_scores_topp: ", np.mean(bart_tte_scores_topp))

print("\n")
print("*** Grammar Scores ***")
#Grammar Scores
print("t5_g_scores: ", np.mean(t5_g_scores))
print("t5_g_scores_topk: ", np.mean(t5_g_scores_topk))
print("t5_g_scores_topp: ", np.mean(t5_g_scores_topp))
print("bart_g_scores: ", np.mean(bart_g_scores))
print("bart_g_scores_topk: ", np.mean(bart_g_scores_topk))
print("bart_g_scores_topp: ", np.mean(bart_g_scores_topp))

*** RE Scores ***
template_re_scores:  -129.189
t5_re_scores:  69.22
t5_re_scores_topk:  64.439
t5_re_scores_topp:  65.15700000000001
bart_re_scores:  68.166
bart_re_scores_topk:  72.104
bart_re_scores_topp:  67.35


*** Diversity Scores ***
template_tte_scores:  0.22096959958729262
t5_tte_scores:  0.2788735508735689
t5_tte_scores_topk:  0.3295693755586985
t5_tte_scores_topp:  0.327464381300688
bart_tte_scores:  0.3316713062494865
bart_tte_scores_topk:  0.3281616154218281
bart_tte_scores_topp:  0.32229900575397286


*** Grammar Scores ***
t5_g_scores:  0.999903381642512
t5_g_scores_topk:  0.9968531022364602
t5_g_scores_topp:  0.99963153211185
bart_g_scores:  0.998731967090482
bart_g_scores_topk:  0.9959274873075662
bart_g_scores_topp:  0.9976460717846148


In [9]:
template_re_scores

[64.34, 37.0, -139.28, -320.66, -30.37, 59.67, -315.58, -20.93, 26.48, -652.56]