# Latent Dirichlet Allocation

## Setting up

In [25]:
import re
import nltk
from nltk.corpus import stopwords
#from sklearn.feature_extraction.text import CountVectorizer
from itertools import chain
import gensim
from gensim import corpora
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [7]:
nltk.download()
#need to download 'stopwords' and 'wordnet'

showing info https://raw.githubusercontent.com/nltk/nltk_data/gh-pages/index.xml


True

In [28]:
#read in the FTA
fta = open("UK-AU-FTA.txt").read()
#split into chapters and then down to articles, remove chapter headings
chapters = re.split('\s*CHAPTER\s+[0-9]+\s*', fta)[1:]
articles_and_chpt_headers = [re.split('\n\s*Article\s[0-9]+.[0-9]+\s*\n', chpt) for chpt in chapters]
articles = [i[1:] for i in articles_and_chpt_headers]
#unnest to get a single list to use in LDA
articles_text = list(chain(*articles))

## Cleaning the text

In [31]:
from nltk.stem.wordnet import WordNetLemmatizer
import string

#bring in a standard set of stopwords and add those that have been found iteratively
stop = set(stopwords.words('english'))
contextual_stopwords = {'good', 'party', 'shall', 'article', 'agreement', 'may',
                        'agreements',  'measure', 'origin', 'custom', 'originating',
                        'paragraph', 'territory', 'provide', 'material', 'value', 
                        'producer', 'non', 'chapter', 'producer', 'information', 
                        'production', 'trade', 'including', 'mean', 'greater', 'apply',
                        'set', 'annex', 'to', 'relating', 'use', 'within', 'ensure',
                        'relevant', 'accordance', 'include', 'provision', 'covered',
                        'recognise', 'certainty', 'condition', 'otherwise', 'matter'}
#define any compound tokens that have been found by analysis
compounds = {'established financial service supplier': 'established-financial-service-supplier',
             'financial service supplier': 'financial-service-supplier',
             'financial service': 'financial-service',
             'cross border service': 'cross-border-service',
             'cross border': 'cross-border',
             'intellectual property right': 'intellectual-property-right',
             'united kingdom': 'united-kingdom',
             'contact point': 'contact-point',
             'non conforming': 'non-conforming',
             'state owned enterprise': 'state-owned-enterprise',
             'working group': 'working-group',
             'joint committee': 'joint-committee',
             'procuring entity': 'procuring-entity',
             'dispute settlement': 'dispute-settlement',
             'telecommunication service': 'telecommunication-service',
             'public telecommunication network service': 'public-telecommunication-network-service',
             'judicial authority': 'judicial-authority',
             'regulatory authority': 'regulatory-authority',
             'favoured nation treatment': 'favoured-nation-treatment',
             'reasonable period time': 'reasonable-period-of-time',
             'regulatory cooperation': 'regulatory-cooperation',
             'time period': 'time-period',
             'without prejudice': 'without-prejudice',
             'bilateral safeguard': 'bilateral-safeguard',
             'competent authority': 'competent-authority',
             'date entry force': 'date-of-entry-into-force',
             'cosmetic product': 'cosmetic-product',
             '30 day': '30-day'}

#set up cleaning function            
exclude = set(string.punctuation+'“”–’')
other_exclusions = set('1234567890abcdefghijk')
lemma = WordNetLemmatizer()
def multiple_replace(string, rep_dict):
    pattern = re.compile("|".join([re.escape(k) for k in sorted(rep_dict,key=len,reverse=True)]), flags=re.DOTALL)
    return pattern.sub(lambda x: rep_dict[x.group(0)], string)
def clean(doc):
    stop_free = " ".join([i for i in doc.lower().split() if i not in stop])
    punc_free = ''.join(ch for ch in stop_free if ch not in exclude)
    normalized = " ".join(lemma.lemmatize(word) for word in punc_free.split())
    context_cleaned = " ".join([i for i in normalized.split() if i not in contextual_stopwords])
    numbers_removed = " ".join([i for i in context_cleaned.split() if i not in other_exclusions])
    roman_numerals_removed = " ".join([re.sub("^M{0,4}(cm|cd|d?c{0,3})(xc|xl|l?x{0,3})(ix|iv|v?i{0,3})$", "", i) 
                                      for i in numbers_removed.lower().split()])
    #can join bigrams with hyphen now
    #bigram1 = roman_numerals_removed.replace('financial service', 'financial-service')
    #but want to do systematically not line by line
    #see example soln in next box
    compounds_joined = multiple_replace(roman_numerals_removed, compounds)
    return compounds_joined

#clean the text
doc_clean = [clean(doc).split() for doc in articles_text]
tokens = list(chain(*doc_clean))

In [32]:
#identify the most common tokens which may be stop words in this context
from collections import Counter
c=Counter(tokens)
print(c.most_common())

[('service', 506), ('law', 389), ('investment', 344), ('procedure', 310), ('supplier', 307), ('public', 248), ('request', 247), ('person', 243), ('international', 234), ('right', 232), ('provided', 227), ('regulation', 219), ('application', 209), ('requirement', 204), ('purpose', 200), ('respect', 199), ('commercial', 198), ('authority', 196), ('treatment', 186), ('appropriate', 185), ('electronic', 185), ('financial-service', 182), ('protection', 175), ('cooperation', 175), ('activity', 173), ('procurement', 170), ('government', 165), ('panel', 165), ('regulatory', 160), ('obligation', 159), ('consultation', 151), ('enterprise', 149), ('related', 149), ('available', 148), ('system', 142), ('market', 139), ('procuring-entity', 137), ('make', 134), ('national', 133), ('subject', 133), ('access', 130), ('supply', 126), ('referred', 124), ('made', 121), ('general', 121), ('technical', 120), ('decision', 120), ('standard', 120), ('product', 117), ('level', 115), ('management', 115), ('purs

In [33]:
#identify common bi-grams which may be compound tokens
bigrams = zip(tokens, tokens[1:])
counts = Counter(bigrams)
print(counts.most_common())

[(('law', 'regulation'), 127), (('commercial', 'assistance'), 62), (('service', 'supplier'), 54), (('national', 'treatment'), 53), (('level', 'government'), 49), (('supply', 'service'), 49), (('adopt', 'maintain'), 46), (('geographical', 'indication'), 45), (('rule', 'procedure'), 38), (('right', 'holder'), 34), (('le', 'favourable'), 33), (('conformity', 'assessment'), 33), (('cooperation', 'activity'), 32), (('publicly', 'available'), 31), (('referred', 'subparagraph'), 31), (('bribery', 'corruption'), 31), (('service', 'investment'), 30), (('30', 'dispute-settlement'), 29), (('treatment', 'le'), 29), (('preferential', 'tariff'), 29), (('most', 'favoured-nation-treatment'), 29), (('major', 'supplier'), 29), (('notice', 'intended'), 29), (('intended', 'procurement'), 29), (('designated', 'monopoly'), 29), (('right', 'obligation'), 28), (('tariff', 'treatment'), 28), (('assessment', 'procedure'), 28), (('tender', 'documentation'), 28), (('national', 'competition'), 28), (('take', 'acco

In [34]:
#print an example of cleaned text
print(doc_clean[:2])

[['establishment', 'free', 'area', 'consistent', 'gatt', '1994', 'gat', 'hereby', 'establish', 'free', 'area'], ['relation', 'affirm', 'existing', 'right', 'obligation', 'respect', 'existing', 'international', 'wto', 'considers', 'inconsistent', 'another', 'request', 'consult', 'view', 'reaching', 'mutually', 'satisfactory', 'solution', 'without-prejudice', 'right', 'obligation', '30', 'dispute-settlement', 'long', 'protocol', 'irelandnorthern', 'ireland', 'withdrawal', 'united-kingdom', 'great', 'britain', 'northern', 'ireland', 'european', 'union', 'european', 'atomic', 'energy', 'community', 'signed', 'london', 'brussels', '24', 'january', '2020', 'the', 'protocol', 'force', 'nothing', 'preclude', 'united-kingdom', 'adopting', 'maintaining', 'refraining', 'so', 'protocol', 'amendment', 'thereto', 'subsequent', 'replacing', 'part', 'thereof', 'provided', 'absence', 'used', 'arbitrary', 'unjustified', 'discrimination', 'disguised', 'restriction', 'purpose', 'application', 'agree', 'fa

## Creating an example LDA Model

In [35]:
# create the dictionary, assign each term an index 
dictionary = corpora.Dictionary(doc_clean)
# convert text into a 'document term matric'
doc_term_matrix = [dictionary.doc2bow(doc) for doc in doc_clean]
# setting up the LDA model from gensim
Lda = gensim.models.ldamodel.LdaModel

In [36]:
# running and training an example LDA model
ldamodel = Lda(doc_term_matrix, num_topics=10, id2word = dictionary, passes=50)

In [37]:
print(ldamodel.print_topics(num_topics=5, num_words=5))

[(3, '0.020*"public" + 0.016*"supplier" + 0.011*"service" + 0.009*"person" + 0.008*"law"'), (2, '0.026*"investment" + 0.026*"service" + 0.014*"treatment" + 0.012*"enterprise" + 0.012*"commercial"'), (1, '0.023*"law" + 0.022*"dispute-settlement" + 0.017*"corruption" + 0.017*"labour" + 0.013*"bribery"'), (7, '0.020*"cooperation" + 0.012*"committee" + 0.012*"international" + 0.011*"appropriate" + 0.010*"development"'), (8, '0.011*"duty" + 0.011*"patent" + 0.011*"tariff" + 0.010*"provided" + 0.010*"declaration"')]


## Optimising the number of topics

In [38]:
#calculating and displaying the coherence score for the example above
from gensim.models import CoherenceModel
coherence_model_lda = CoherenceModel(
   model=ldamodel, texts=doc_clean, dictionary=dictionary, coherence='c_v')
coherence_lda = coherence_model_lda.get_coherence()

print('Coherence Score: ', coherence_lda)

Coherence Score:  0.430048032507879


In [None]:
#first fitting a number of different models and calculating their coherance scores
topics = range(1, 200, 10)
models = [Lda(doc_term_matrix, num_topics=i, id2word = dictionary, passes=50) for i in topics]
coherences__c_v = [CoherenceModel(model=i, texts=doc_clean, dictionary=dictionary, coherence='c_v').get_coherence() for i in models]
coherence__u_mass = [CoherenceModel(model=i, texts=doc_clean, dictionary=dictionary, coherence='u_mass').get_coherence() for i in models]

In [None]:
plt.plot(topics, coherences__c_v)
plt.plot(topics, coherence__u_mass)
plt.title('Coherence for Different Number of Topics')
plt.xlabel('Number of Topics')
plt.locator_params(axis="x", integer=True, tight=True)
plt.ylabel('Coherance Score')
plt.legend(['c_v Coherence','u_mass Coherence'])
plt.show()

In [None]:
#fitting HDP instead
from gensim.models import HdpModel
hdpmodel = gensim.models.hdpmodel.HdpModel(corpus = doc_term_matrix, id2word = dictionary, T = 150)
#print 5 most significant topics
print(hdpmodel.print_topics(num_topics=5))

In [None]:
print(hdpmodel.suggested_lda_model())
#note the number of topics is the same as parameter 'T' (top level truncation)
#I think because of a lack of distinct topics in the data

In [None]:
#trying an alternate package
import tomotopy as tp
term_weight = tp.TermWeight.ONE
hdp = tp.HDPModel(tw=term_weight, seed=99999)
for vec in doc_clean:
    hdp.add_doc(vec)
hdp.burn_in = 100
hdp.train(0)
for i in range(0, 1000, 100):
    hdp.train(100)
    print('Topics: {}'.format(hdp.live_k))

In [None]:
#retrieving topics for this package is challenging, function below provided by Eduardo Sroka
def get_hdp_topics(hdp, top_n=10):
    '''Wrapper function to extract topics from trained tomotopy HDP model 
    
    ** Inputs **
    hdp:obj -> HDPModel trained model
    top_n: int -> top n words in topic based on frequencies
    
    ** Returns **
    topics: dict -> per topic, an arrays with top words and associated frequencies 
    '''
    
    # Get most important topics by # of times they were assigned (i.e. counts)
    sorted_topics = [k for k, v in sorted(enumerate(hdp.get_count_by_topics()), key=lambda x:x[1], reverse=True)]

    topics=dict()
    
    # For topics found, extract only those that are still assigned
    for k in sorted_topics:
        if not hdp.is_live_topic(k): continue # remove un-assigned topics at the end (i.e. not alive)
        topic_wp =[]
        for word, prob in hdp.get_topic_words(k, top_n=top_n):
            topic_wp.append((word, prob))

        topics[k] = topic_wp # store topic word/frequency array
        
    return topics

topics = get_hdp_topics(hdp, top_n=3)
print(topics)

## Creating the final LDA Model and interpreting topics

In [None]:
# Running and Training LDA model for an example number of topics
#set seed to create reproducable results
np.random.seed(42)
ldamodel_final = Lda(doc_term_matrix, num_topics=20, id2word = dictionary, passes=50)
print(ldamodel_final.print_topics(num_words=5))

In [None]:
#create a df of articles and their topics

#start with article titles
articles_and_titles = [re.split('\n', article, maxsplit = 1) for article in articles_text]
article_titles = [a[0] for a in articles_and_titles]
articles_df = pd.DataFrame(article_titles, columns = ['Article'])
#add column for article text
#article_text = [a[1] for a in articles_and_titles]
articles_df['Article Text'] = [a[1] for a in articles_and_titles]
#add column for chapter titles
chpt_titles = [c[0] for c in articles_and_chpt_headers]
chpts = []
for i in range(len(chpt_titles)):
  chpts.append(len(articles[i]) * [chpt_titles[i]])
chpts = list(chain(*chpts))
articles_df['Chapter'] = chpts
#add topics to articles_df
topics = []
for i in range(len(doc_term_matrix)):
  topics.append(ldamodel_final.get_document_topics(doc_term_matrix[i]))
articles_df['LDA Topics'] = topics
#hot encode this, firstly add columns for each topic with zeros
for topic in range(20):
  articles_df['prob_t' + str(topic)] = [0] * len(articles_df)
#then loop through rows and topics to add values
for i in range(len(articles_df['LDA Topics'])):
  row = articles_df['LDA Topics'][i]
  for topic in row:
    articles_df['prob_t' + str(topic[0])][i] = topic[1]

#show first few rows
articles_df.head()

In [None]:
#look at each topic to get some insight into its composition and find a title to give it
articles_df.nlargest(3, ['prob_t5'])

In [None]:
#assigning topics their titles and storing together
topic_nums = range(20)
topic_titles = [
    'Corruption and Transparency',
    'Financial Services',
    'Origin of Goods',
    'Maritime Industries',
    'Contact and Cooperation',
    'Trade Safeguards and Remedies',
    'Digital and Environmental Innovation',
    'Dispute Resolution',
    'Workers Rights',
    'Intellectual Property - Protections',
    'State Owned Enterprises',
    'Natural Environment',
    'Shipments and Customers',
    'Consumer Protection',
    'Regulatory Authorities',
    'Digital Trade Facilitation',
    'Information Ownership',
    'Geographical Protections',
    'Disputes in Government Procurement',
    'Supply of Goods'
]
topic_titles = dict(zip(topic_nums, topic_titles))

In [None]:
#visualising this
from wordcloud import WordCloud, STOPWORDS
from textwrap import wrap

cols = ['#e6194b', '#3cb44b', '#dbc114', '#4363d8', '#f58231', '#911eb4', '#3ac9c9', 
        '#f032e6', '#719406', '#9e7886', '#008080', '#816a8f', '#9a6324', '#c2bd8f', 
        '#800000', '#6da37d', '#808000', '#d1b190', '#000075', '#808080']
cloud = WordCloud(stopwords=stop,
                  background_color='white',
                  width=2500,
                  height=1800,
                  max_words=10,
                  colormap='tab10',
                  color_func=lambda *args, **kwargs: cols[i],
                  prefer_horizontal=1.0)
topics = sorted(ldamodel_final.show_topics(num_topics = 20, formatted=False))
fig, axes = plt.subplots(4, 5, figsize=(10,10), sharex=True, sharey=True)
for i, ax in enumerate(axes.flatten()):
    fig.add_subplot(ax)
    topic_words = dict(topics[i][1])
    cloud.generate_from_frequencies(topic_words, max_font_size=300)
    plt.gca().imshow(cloud)
    plt.gca().set_title('Topic ' + str(i), fontdict=dict(size=16))
    #plt.gca().axis('off')
    plt.gca().get_yaxis().set_visible(False)
    #plt.xlabel(topic_titles[i])
    plt.xlabel('\n'.join(wrap(topic_titles[i], 20)))
    ax.spines['bottom'].set_color('white')
    ax.spines['top'].set_color('white')
    ax.spines['left'].set_color('white')
    ax.spines['right'].set_color('white')
    ax.tick_params(colors='white', which='both')
plt.subplots_adjust(wspace=1, hspace=0)
#plt.axis('off')
plt.margins(x=0, y=0)
plt.tight_layout()
fig.suptitle('Topics Identified in the UK-AUS FTA using LDA', fontsize=20)
plt.show()

## Classifying articles

In [None]:
#create example topic classification
t = 1
l = 0.9
example_classification = articles_df[articles_df[('prob_t' + str(t))] >= l]

#output final product that could be summarised
example_topic_text = ' '.join(example_classification['Article Text'])
print(example_topic_text[:250])

In [None]:
#check article lengths
lens = []
for i in range(20):
  t = i 
  classification = articles_df[articles_df[('prob_t' + str(t))] >= l]
  topic_text = ' '.join(classification['Article Text'])
  #count number of words (ie tokens)
  lens.append(len(topic_text.split()))
print("Longest topic has " + str(max(lens)) + " words")

In [None]:
#function for splitting into text of 1000 words
def summary_chunks(text):
    chunks = []
    words = text.split()
    for i in range(0, len(words), 700):
      chunks.append(" ".join(words[i:i+700]))
    return chunks

#check topics have been split up as expected
example_text_chunks = summary_chunks(example_topic_text)
print(len(example_text_chunks))

## Recursively summarise this text using BART

In [None]:
# import the model
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
# load a pre-trained model and tokenizer 'bart-large-cnn'
tokeniser = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')

In [None]:
len(tokeniser.batch_encode_plus([example_text_chunks[1]],return_tensors='pt').get('input_ids')[0])

In [None]:
# encode each piece of text
inputs = [tokeniser.batch_encode_plus([text],return_tensors='pt') for text in example_text_chunks]
summary_ids = [model.generate(i['input_ids'], early_stopping=True) for i in inputs]
#for i in inputs:
#  print(i)
#  m = model.generate(i['input_ids'], early_stopping=True)

In [None]:
# Decoding each summary and check we have the right number of them
bart_summaries = [tokeniser.decode(ids[0], skip_special_tokens=True) for ids in summary_ids]
print(len(bart_summaries))

In [None]:
#join together each summary to a single piece
single_sum = ' '.join(bart_summaries)
input = tokeniser.batch_encode_plus([single_sum],return_tensors='pt')
summary_id = model.generate(input['input_ids'], early_stopping=True)
final_sum = tokeniser.decode(summary_id[0], skip_special_tokens=True)
print(final_sum)