In [1]:
import pandas as pd
import numpy as np
import calendar
import math
import re
import string

In [2]:
import segmentation
import utils
import data2graph
from finetuned import T5FineTuner, BARTFineTuner, generate, generate_beam, graph2text_nobeam, graph2text_nobeam_ngram_es, graph2text_nobeam_topk, graph2text_nobeam_topp

In [3]:
import textstat
import language_tool_python
from lexical_diversity import lex_div as ld

In [4]:
tool = language_tool_python.LanguageTool('en-US')

def grammar_score(input_text):
    errors = len(tool.check(input_text))
    clean_text = input_text.translate(str.maketrans('', '', string.punctuation))
    clean_text = list(filter(None, clean_text.split(' ')))
    num_words = len(clean_text)
    return float(1-(errors/num_words))

### Loading Fine-Tuned PLMs

In [5]:
import torch
cuda0 = torch.device("cuda:0")
# cuda1 = torch.device("cuda:1")
#cuda3 = torch.device("cuda:3")

t5 = T5FineTuner.load_from_checkpoint("T5Models/T5Both.ckpt")
bart = BARTFineTuner.load_from_checkpoint("BARTModels/BARTBoth.ckpt")

t5.to(cuda0)
bart.to(cuda0)

BARTFineTuner(
  (model): BartForConditionalGeneration(
    (model): BartModel(
      (shared): Embedding(50268, 1024)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(50268, 1024)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0): EncoderLayer(
            (self_attn): Attention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (final_layer_norm): LayerNorm((1024,), eps=1e

### DOTS

In [6]:
#Import DOTS Dataset
ds_dots = pd.read_csv("Data/DOTS/Exports.csv")

In [7]:
countries = ['United States', 'India', 'Brazil', 'USSR', 'United Kingdom', 'France', 'Spain', 'Italy' , 'Turkey', 'Germany']

#RE Scores
template_re_scores = []
t5_re_scores = []
t5_re_scores_topk = []
t5_re_scores_topp = []
bart_re_scores = []
bart_re_scores_topk = []
bart_re_scores_topp = []

#Diveristy Scores
template_tte_scores = []
t5_tte_scores = []
t5_tte_scores_topk = []
t5_tte_scores_topp = []
bart_tte_scores = []
bart_tte_scores_topk = []
bart_tte_scores_topp = []

#Grammar Scores
t5_g_scores = []
t5_g_scores_topk = []
t5_g_scores_topp = []
bart_g_scores = []
bart_g_scores_topk = []
bart_g_scores_topp = []

#Grammar Mistakes
t5_g_mistake = []
t5_g_mistake_topk = []
t5_g_mistake_topp = []
bart_g_mistake = []
bart_g_mistake_topk = []
bart_g_mistake_topp = []

for iso in countries:
    
    #Load and Process Data
    country = ds_dots.loc[ds_dots['Location'] == iso]
    time = country.columns.tolist()[1:]
    for row in country.iterrows():
        values = row[1]
    values = [str(x) for x in values]
    values = [float(re.sub(',', '', x)) for x in values[1:]]
    time = [re.sub('M', '-', x) for x in time]
    country = pd.DataFrame(list(zip(time, values)), columns = ['Date', 'Exports'])
    
    country = country.fillna(0)
    country.Date = pd.to_datetime(country.Date)
    country['month'] = pd.DatetimeIndex(country['Date']).month
    country['month'] = country['month'].apply(lambda x: calendar.month_name[x])
    country['year'] = pd.DatetimeIndex(country['Date']).year
    country.set_index(['Date'],inplace=True)
    
    country = country[['Exports','month', 'year']].reset_index().drop(columns=['Date'])
    country_exports_raw = country['Exports'].tolist()

    #Log-normalize data
    trans = np.ma.log(country_exports_raw)
    country_exports = trans.filled(0)
    
    print("Processing Country: ", iso)
    
    #Detecting Waves
    embeds, cluster_labels = segmentation.tslr_rep(country_exports, k=3, tolerance=1e-4)
    cluster_arrangement = utils.find_contiguous(cluster_labels)
    indices = utils.find_indices(cluster_arrangement)
    wave_indices = utils.find_waves(country_exports_raw, indices, tolerance=7)
    
    print("Waves Detected: ", iso)

    #Detecting Trends
    segmentation_results = segmentation.swab(country_exports, 0.1, 3, 3)
    filtered_results = segmentation.re_segment(segmentation_results, country_exports)
    trends = segmentation.find_trend(filtered_results, country_exports)
    
    print("Trends Detected: ", iso)
    
    location = iso
    
    graph, essentials = data2graph.build_graph_exports_form1("Merchandise Exports", location, wave_indices, trends, country, country_exports_raw)
    #Template Narrative
    template_text = data2graph.build_template_exports_nums("Merchandise Exports", location, wave_indices, trends, country, country_exports_raw)
    t5_prefix = 'translate Graph to English: '
    
    #Simple PLM Generation
    t5_narrative = graph2text_nobeam(t5, graph, t5_prefix, 512, cuda0)
    bart_narrative = graph2text_nobeam(bart , graph, "", 512, cuda0)
    bart_narrative = re.sub('</s>' , '', bart_narrative)
    
    print("Simple Generation Complete: ", iso)
    
    #Top-k at 50
    t5_narrative_topk = graph2text_nobeam_topk(t5, graph, t5_prefix, 50, 512, cuda0)
    bart_narrative_topk = graph2text_nobeam_topk(bart, graph, "", 50, 512, cuda0)
    bart_narrative_topk = re.sub('</s>' , '', bart_narrative_topk)
    
    print("Top-k Complete: ", iso)
    
    #Top-p at 0.92
    t5_narrative_topp = graph2text_nobeam_topp(t5, graph, t5_prefix, 0.92, 512, cuda0)
    bart_narrative_topp = graph2text_nobeam_topp(bart, graph, "", 0.92, 512, cuda0)
    bart_narrative_topp = re.sub('</s>' , '', bart_narrative_topp)
    
    print("Top-p Complete: ", iso)
    
    #RE Scores
    template_re_scores.append(textstat.flesch_reading_ease(template_text))
    t5_re_scores.append(textstat.flesch_reading_ease(t5_narrative))
    t5_re_scores_topk.append(textstat.flesch_reading_ease(t5_narrative_topk))
    t5_re_scores_topp.append(textstat.flesch_reading_ease(t5_narrative_topp))
    bart_re_scores.append(textstat.flesch_reading_ease(bart_narrative))
    bart_re_scores_topk.append(textstat.flesch_reading_ease(bart_narrative_topk))
    bart_re_scores_topp.append(textstat.flesch_reading_ease(bart_narrative_topp))
    
    print("RE Scores Computed: ", iso)
    
    #Diveristy Scores
    template_tte_scores.append(ld.ttr(ld.flemmatize(template_text)))
    t5_tte_scores.append(ld.ttr(ld.flemmatize(t5_narrative)))
    t5_tte_scores_topk.append(ld.ttr(ld.flemmatize(t5_narrative_topk)))
    t5_tte_scores_topp.append(ld.ttr(ld.flemmatize(t5_narrative_topp)))
    bart_tte_scores.append(ld.ttr(ld.flemmatize(bart_narrative)))
    bart_tte_scores_topk.append(ld.ttr(ld.flemmatize(bart_narrative_topk)))
    bart_tte_scores_topp.append(ld.ttr(ld.flemmatize(bart_narrative_topp)))
    
    print("TTE Scores Computed: ", iso)
    
    #Grammar Scores
    gs = grammar_score(t5_narrative)
    t5_g_scores.append(gs)
    if gs != 1.0:
        t5_g_mistake.append((graph, t5_narrative))
    
    gs = grammar_score(t5_narrative_topk)
    t5_g_scores_topk.append(gs)
    if gs != 1.0:
        t5_g_mistake_topk.append((graph, t5_narrative_topk))
    
    gs = grammar_score(t5_narrative_topp)
    t5_g_scores_topp.append(gs)
    if gs != 1.0:
        t5_g_mistake_topp.append((graph, t5_narrative_topp))
    
    gs = grammar_score(bart_narrative)                          
    bart_g_scores.append(gs)
    if gs != 1.0:
        bart_g_mistake.append((graph, bart_narrative))
        
    gs = grammar_score(bart_narrative_topk)
    bart_g_scores_topk.append(gs)
    if gs != 1.0:
        bart_g_mistake_topk.append((graph, bart_narrative_topk))
    
    gs = grammar_score(bart_narrative_topp)
    bart_g_scores_topp.append(gs)
    if gs != 1.0:
        bart_g_mistake_topp.append((graph, bart_narrative_topp))
    
    print("Grammar Scores Computed: ", iso)

Processing Country:  United States
Waves Detected:  United States


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  United States
Simple Generation Complete:  United States
Top-k Complete:  United States
Top-p Complete:  United States
RE Scores Computed:  United States
TTE Scores Computed:  United States
Grammar Scores Computed:  United States
Processing Country:  India
Waves Detected:  India


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  India
Simple Generation Complete:  India
Top-k Complete:  India
Top-p Complete:  India
RE Scores Computed:  India
TTE Scores Computed:  India
Grammar Scores Computed:  India
Processing Country:  Brazil
Waves Detected:  Brazil


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Brazil
Simple Generation Complete:  Brazil
Top-k Complete:  Brazil
Top-p Complete:  Brazil
RE Scores Computed:  Brazil
TTE Scores Computed:  Brazil
Grammar Scores Computed:  Brazil
Processing Country:  USSR
Waves Detected:  USSR


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  USSR
Simple Generation Complete:  USSR
Top-k Complete:  USSR
Top-p Complete:  USSR
RE Scores Computed:  USSR
TTE Scores Computed:  USSR
Grammar Scores Computed:  USSR
Processing Country:  United Kingdom
Waves Detected:  United Kingdom


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  United Kingdom
Simple Generation Complete:  United Kingdom
Top-k Complete:  United Kingdom
Top-p Complete:  United Kingdom
RE Scores Computed:  United Kingdom
TTE Scores Computed:  United Kingdom
Grammar Scores Computed:  United Kingdom
Processing Country:  France
Waves Detected:  France


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  France
Simple Generation Complete:  France
Top-k Complete:  France
Top-p Complete:  France
RE Scores Computed:  France
TTE Scores Computed:  France
Grammar Scores Computed:  France
Processing Country:  Spain
Waves Detected:  Spain


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Spain
Simple Generation Complete:  Spain
Top-k Complete:  Spain
Top-p Complete:  Spain
RE Scores Computed:  Spain
TTE Scores Computed:  Spain
Grammar Scores Computed:  Spain
Processing Country:  Italy
Waves Detected:  Italy


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Italy
Simple Generation Complete:  Italy
Top-k Complete:  Italy
Top-p Complete:  Italy
RE Scores Computed:  Italy
TTE Scores Computed:  Italy
Grammar Scores Computed:  Italy
Processing Country:  Turkey
Waves Detected:  Turkey


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Turkey
Simple Generation Complete:  Turkey
Top-k Complete:  Turkey
Top-p Complete:  Turkey
RE Scores Computed:  Turkey
TTE Scores Computed:  Turkey
Grammar Scores Computed:  Turkey
Processing Country:  Germany
Waves Detected:  Germany


  (p,residuals,rank,s) = np.linalg.lstsq(A,y)


Trends Detected:  Germany
Simple Generation Complete:  Germany
Top-k Complete:  Germany
Top-p Complete:  Germany
RE Scores Computed:  Germany
TTE Scores Computed:  Germany
Grammar Scores Computed:  Germany


In [8]:
#RE Scores
print("*** RE Scores ***")
print("template_re_scores: ", np.mean(template_re_scores))
print("t5_re_scores: ", np.mean(t5_re_scores))
print("t5_re_scores_topk: ", np.mean(t5_re_scores_topk))
print("t5_re_scores_topp: ", np.mean(t5_re_scores_topp))
print("bart_re_scores: ", np.mean(bart_re_scores))
print("bart_re_scores_topk: ", np.mean(bart_re_scores_topk))
print("bart_re_scores_topp: ", np.mean(bart_re_scores_topp))

print("\n")
print("*** Diversity Scores ***")
#Diveristy Scores
print("template_tte_scores: ", np.mean(template_tte_scores))
print("t5_tte_scores: ", np.mean(t5_tte_scores))
print("t5_tte_scores_topk: ", np.mean(t5_tte_scores_topk))
print("t5_tte_scores_topp: ", np.mean(t5_tte_scores_topp))
print("bart_tte_scores: ", np.mean(bart_tte_scores))
print("bart_tte_scores_topk: ", np.mean(bart_tte_scores_topk))
print("bart_tte_scores_topp: ", np.mean(bart_tte_scores_topp))

print("\n")
print("*** Grammar Scores ***")
#Grammar Scores
print("t5_g_scores: ", np.mean(t5_g_scores))
print("t5_g_scores_topk: ", np.mean(t5_g_scores_topk))
print("t5_g_scores_topp: ", np.mean(t5_g_scores_topp))
print("bart_g_scores: ", np.mean(bart_g_scores))
print("bart_g_scores_topk: ", np.mean(bart_g_scores_topk))
print("bart_g_scores_topp: ", np.mean(bart_g_scores_topp))

*** RE Scores ***
template_re_scores:  54.736000000000004
t5_re_scores:  67.53600000000002
t5_re_scores_topk:  67.57399999999998
t5_re_scores_topp:  66.272
bart_re_scores:  67.30699999999999
bart_re_scores_topk:  69.46700000000001
bart_re_scores_topp:  68.358


*** Diversity Scores ***
template_tte_scores:  0.4519841917389087
t5_tte_scores:  0.47294858147614677
t5_tte_scores_topk:  0.5071300125808091
t5_tte_scores_topp:  0.4870217016673051
bart_tte_scores:  0.46536980951459095
bart_tte_scores_topk:  0.46927789161856526
bart_tte_scores_topp:  0.47230417440905653


*** Grammar Scores ***
t5_g_scores:  0.9949909050606554
t5_g_scores_topk:  0.9816560771206655
t5_g_scores_topp:  0.9874332566242146
bart_g_scores:  0.9743891308887622
bart_g_scores_topk:  0.9660297270530052
bart_g_scores_topp:  0.9701897078169734
