In [1]:
import numpy as np 
import pandas as pd
import string

In [2]:
pd.set_option('display.max_colwidth', None)

## Training data processing

In [3]:
ner_dataset = pd.read_csv('/content/ner_dataset.csv', 
    encoding='latin1')

In [4]:
ner_dataset['Sentence #'] = ner_dataset['Sentence #'].str.replace('Sentence:', '')
ner_dataset = ner_dataset.fillna(method='ffill')

In [5]:
ner_dataset['Sentence #'] = ner_dataset['Sentence #'].astype(int)

#### Create the `sentences_df`

In [6]:
sentences_df = ner_dataset.groupby('Sentence #', as_index=False)['Word'].apply(lambda x: x.str.cat(sep=' '))
sentences_df = sentences_df.rename(columns={'Word': 'Sentences'})

Inspect Sentence 8411. It only contains the word "The". 

In [7]:
sentences_df.iloc[8411]

Sentence #    8412
Sentences      The
Name: 8411, dtype: object

In [8]:
sentences_df = sentences_df.drop(labels=[8411], axis=0)
sentences_df = sentences_df.reset_index()
sentences_df = sentences_df.drop(columns='index')

### Sentences processing for LDA 

In [9]:
import nltk
from nltk.corpus import stopwords 
from nltk.corpus import wordnet
from nltk.stem import WordNetLemmatizer

In [10]:
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [11]:
def _lemmatize_words(sentence):
    wordnet_map = {'N':wordnet.NOUN, 'V':wordnet.VERB, 'J':wordnet.ADJ, 'R':wordnet.ADV}
    pos_tagged_text = nltk.pos_tag(sentence.split())
    return ' '.join([WordNetLemmatizer().lemmatize(word, wordnet_map.get(pos[0], wordnet.NOUN))
                    for word, pos in pos_tagged_text])
    
def lda_sent_process(text):
    text = text.lower()  
    PUNCT_TO_REMOVE = string.punctuation
    text = text.translate(str.maketrans('', '', PUNCT_TO_REMOVE))
    STOPWORDS = set(stopwords.words('english'))
    text = ' '.join([word for word in text.split() if word not in STOPWORDS])
    text = _lemmatize_words(text)
    return text.split()

In [12]:
sentences_df['lda_sents'] = sentences_df['Sentences'].apply(lambda x: lda_sent_process(x))

In [13]:
sentences_df

Unnamed: 0,Sentence #,Sentences,lda_sents
0,1,Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .,"[thousand, demonstrator, march, london, protest, war, iraq, demand, withdrawal, british, troop, country]"
1,2,"Families of soldiers killed in the conflict joined the protesters who carried banners with such slogans as "" Bush Number One Terrorist "" and "" Stop the Bombings . ""","[family, soldier, kill, conflict, join, protester, carry, banner, slogan, bush, number, one, terrorist, stop, bombing]"
2,3,They marched from the Houses of Parliament to a rally in Hyde Park .,"[march, house, parliament, rally, hyde, park]"
3,4,"Police put the number of marchers at 10,000 while organizers claimed it was 1,00,000 .","[police, put, number, marcher, 10000, organizer, claim, 100000]"
4,5,The protest comes on the eve of the annual conference of Britain 's ruling Labor Party in the southern English seaside resort of Brighton .,"[protest, come, eve, annual, conference, britain, rule, labor, party, southern, english, seaside, resort, brighton]"
...,...,...,...
47953,47955,Indian border security forces are accusing their Pakistani counterparts of lobbing at least four rockets into northern Punjab state .,"[indian, border, security, force, accuse, pakistani, counterpart, lob, least, four, rocket, northern, punjab, state]"
47954,47956,Indian officials said no one was injured in Saturday 's incident but that two of the rockets landed near a border security outpost .,"[indian, official, say, one, injure, saturday, incident, two, rocket, land, near, border, security, outpost]"
47955,47957,Two more landed in fields belonging to a nearby village .,"[two, land, field, belong, nearby, village]"
47956,47958,They say not all of the rockets exploded upon impact .,"[say, rocket, explode, upon, impact]"


## LDA Model (gensim)

In [14]:
from gensim.corpora.dictionary import Dictionary 
from gensim import models 
import re

#### Model training

In [15]:
dct = Dictionary(sentences_df['lda_sents'])

In [16]:
corpus = [dct.doc2bow(sentence) for sentence in sentences_df['lda_sents']]
lda = models.LdaModel(corpus, num_topics=15, random_state=36)



In [17]:
topics = lda.print_topics()

In [18]:
for topic in topics:
  key_indices = re.findall(r'"(.*?)"', topic[1])
  key_words = [dct[int(idx)] for idx in key_indices]
  print(f'Topic {topic[0]}: ', key_words)

Topic 0:  ['vote', 'bird', 'election', 'flu', 'say', 'last', 'month', 'first', 'official', 'week']
Topic 1:  ['minister', 'say', 'prime', 'mr', 'north', 'south', 'official', 'president', 'government', 'korea']
Topic 2:  ['foreign', 'east', 'beijing', 'britain', 'island', 'france', 'gas', 'german', 'middle', 'russian']
Topic 3:  ['kill', 'say', 'attack', 'u', 'official', 'military', 'two', 'least', 'bomb', 'force']
Topic 4:  ['say', 'iran', 'united', 'state', 'nuclear', 'program', 'nation', 'country', 'european', 'international']
Topic 5:  ['say', 'police', 'official', 'force', 'arrest', 'city', 'muslim', 'spokesman', 'security', 'news']
Topic 6:  ['party', 'president', 'new', 'election', 'opposition', 'mr', 'rule', 'political', 'leader', 'call']
Topic 7:  ['say', 'woman', 'result', 'show', 'get', 'saudi', 'heavy', 'make', 'citizen', 'explosive']
Topic 8:  ['say', 'people', 'israeli', 'death', 'official', 'kill', 'islamic', 'two', 'militant', 'report']
Topic 9:  ['world', 'year', 'high'

### Create a cleaner topics list dictionary

In [19]:
topics_dict = {}
for num, topics in enumerate(topics):
  key_indices = re.findall(r'"(.*?)"', topic[1])
  key_words = [dct[int(idx)] for idx in key_indices]
  topics_dict[num] = key_words

Manually remove redundant words, and save the new list to topics_dict. 

In [21]:
print(topics_dict)

{0: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 1: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 2: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 3: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 4: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 5: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 6: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 7: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 8: ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government'], 9: ['oil', 'company', 'seri

In [37]:
topics_dict = {
  0: ['vote', 'bird', 'election', 'flu'], 
  1: ['minister', 'prime', 'north', 'south', 'president', 'korea'],
  2: ['foreign', 'Beijing', 'Britain', 'France', 'gas', 'German', 'Middle', 'East', 'Russian'], 
  3: ['kill', 'attack', 'military', 'bomb', 'force'],
  4: ['Iran', 'United', 'State', 'nuclear', 'European', 'international'],
  5: [ 'police', 'force', 'city', 'Muslim', 'spokesman', 'security' ], 
  6: ['party', 'president', 'election', 'leader'], 
  7: ['woman', 'citizen', 'explosive'],
  8: ['Israeli', 'death', 'kill', 'Islamic', 'militant'],
  9: ['world', 'economic', 'economy', 'price', 'country'],
  10: ['Lebanon', 'blast', 'responsibility', 'explosion', 'group'],
  11: ['government', 'president', 'Israel', 'Palestinian','peace', 'leader'], 
  12: ['United', 'State', 'secretary', 'storm', 'hurricane', 'emergency'],
  13: ['president', 'charge', 'right', 'court', 'Iraq', 'house'],
  14: ['oil', 'company',  'market', 'demand', 'production', 'decline', 'power', 'government'],
  }

Topic 0:  ['vote', 'bird', 'election', 'flu']

Topic 1:  ['minister', 'prime', 'orth', 'south', 'president', 'korea']

Topic 2:  ['foreign', 'Beijing', 'Britain', 'France', 'gas', 'German', 'Middle', 'East', 'Russian']

Topic 3:  ['kill', 'attack',  'military', 'bomb', 'force']

Topic 4:  ['Iran', 'United', 'State', 'nuclear',  'European', 'international']

Topic 5:  [ 'police', 'force', 'city', 'muslim', 'spokesman', 'security' ]

Topic 6:  ['party', 'president', 'election', 'leader']

Topic 7:  ['woman', 'citizen', 'explosive']

Topic 8:  ['israeli', 'death', 'kill', 'islamic', 'militant']

Topic 9:  ['world', 'economic', 'economy', 'price', 'country']

Topic 10:  ['lebanon',  'blast', 'responsibility', 'explosion', 'group']

Topic 11:  ['government', 'president', 'israel', 'palestinian','peace', 'leader']

Topic 12:  ['state', 'united', 'secretary', 'storm', 'hurricane', 'emergency']

Topic 13:  ['president', 'charge', 'right', 'court', 'iraq', 'house']

Topic 14:  ['oil', 'company', 'series', 'market', 'demand', 'production', 'decline', 'power', 'voa', 'government']


#### Inference example 

In [30]:
def get_topics(new_text, lda_model, dct): 
  '''
  new_text: str
  lda_model: load from lda.pkl
  dct: load from dct.pkl
  '''
  new_text_doc = lda_sent_process(new_text)
  topics = lda_model[dct.doc2bow(new_text_doc)]
  for num, topic in enumerate(topics): 
    print(f'Topic {topic[0]}:  with probability {topic[1]}')
    print(topics_dict[num])


In [35]:
# copy from NYT
new_sents = 'As Midterms Near, Biden Faces a Nation as Polarized as Ever'

In [36]:
get_topics(new_sents, lda, dct)

Topic 3:  with probability 0.13353180885314941
['vote', 'bird', 'election', 'flu']
Topic 4:  with probability 0.09273714572191238
['minister', 'prime', 'orth', 'south', 'president', 'korea']
Topic 6:  with probability 0.50126713514328
['foreign', 'Beijing', 'Britain', 'France', 'gas', 'German', 'Middle', 'East', 'Russian']
Topic 13:  with probability 0.18079714477062225
['kill', 'attack', 'military', 'bomb', 'force']


### Pickle the LDA data 

In [38]:
import pickle

In [39]:
## pickle the dictionary 
with open('dct.pkl', 'wb') as pickle_dict: 
  pickle.dump(dct, pickle_dict)

In [40]:
## pickle the LDA model 
with open('lda.pkl', 'wb') as pickle_lda:
  pickle.dump(lda, pickle_lda)

## KNN with SentenceTransformer

### Train Sentence transformer


In [41]:
!pip install -U sentence-transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[K     |████████████████████████████████| 85 kB 5.0 MB/s 
[?25hCollecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 61.5 MB/s 
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[K     |████████████████████████████████| 1.3 MB 60.9 MB/s 
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 72.7 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 64.1 MB/s 
Building wheels for collected 

In [42]:
from sentence_transformers import SentenceTransformer

In [43]:
model = SentenceTransformer('all-MiniLM-L6-v2')

Downloading:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/349 [00:00<?, ?B/s]

In [44]:
embeddings = model.encode(sentences_df['Sentences'])

### KNN with sentence embeddings 

In [45]:
from sklearn.neighbors import NearestNeighbors 
nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(embeddings)

### Inference example

In [52]:
def get_near_sent(text, emb_model, knn_model):
  embedding = emb_model.encode([text])
  _, index = knn_model.kneighbors(embedding)
  for idx in range(index.shape[1]):
    print(f'{idx+1}.',  sentences_df['Sentences'][index[0,idx]])


Test on a sentence in the dataset. 

In [51]:
get_near_sent(sentences_df['Sentences'][0], emb_model=model, knn_model=nbrs)

0. Thousands of demonstrators have marched through London to protest the war in Iraq and demand the withdrawal of British troops from that country .
1. Thousands of anti-war protesters have marched in London joining protests in Japan , Australia and elsewhere in the world ahead of the third anniversary of the U.S.-led invasion of Iraq .
2. Thousands of people in cities across Britain have rallied to protest Israeli military action in Lebanon .


Test on a new sentences: 

In [53]:
new_sents = 'As Midterms Near, Biden Faces a Nation as Polarized as Ever'

In [54]:
get_near_sent(new_sents, emb_model=model, knn_model=nbrs)

1. Biden is currently the chairman of the Senate Foreign Relations Committee and is a prominent critic of President Bush 's Iraq war strategy .
2. In a statement Friday , Mr. Biden 's office said the vice president will meet with the political leadership in all three countries , as well as U.S. officials and military personnel stationed in the region .
3. But its entry bid has roused opposition , most notably from France .


In [55]:
new_sentences = 'North Korea says launches were simulated attack, as South recovers missile parts'

In [56]:
get_near_sent(new_sentences, emb_model=model, knn_model=nbrs)

1. The U.S. and its allies in Asia have said the recent rocket launch was a test of a ballistic missile , but North Korea denies the claim , saying it sent a satellite into space .
2. Separately , South Korea 's foreign ministry said Friday that Afghan security personnel staged a rocket attack in June on the construction site of South Korea 's civilian base in northern Parwan province .
3. The North 's navy fired three ship-to-ship missiles on March 28th in what was then described by the South Korean government as part of a regular military exercise in the waters off the peninsula 's west coast .


In [57]:
get_near_sent('Woman shot and killed near 49th and Miami, Omaha police investigating', model, nbrs)

1. Several people were injured , and one woman died of a gunshot wound .
2. In other violence Tuesday , in neighboring Nimroz province , unidentified gunmen shot dead a district intelligence officer as he was driving in his car .
3. Unidentified gunmen on Mexico 's Gulf coast have shot to death the news director of one of the most influential newspapers in Veracruz , the second shooting attack on Mexican journalists in one week .


### Pickle embedding model and KNN model

In [None]:
with open('emb_model.pkl', 'wb') as pickle_emb:
  pickle.dump(model, pickle_emb)

with open('knn_modle.pkl', 'wb') as pickle_knn:
  pickle.dump(nbrs, pickle_knn)

## KMeans with SentenceTransformers (not used)

In [None]:
#from sklearn.pipeline import Pipeline 
#from sklearn.cluster import KMeans 

#from sklearn.metrics import silhouette_score

In [None]:
#for n_cluster in range(2,50):
#    kmeans_model = KMeans(n_clusters=n_cluster, random_state=1).fit(embeddings)
#    labels = kmeans_model.labels_
#    print(f'n_clusters = {n_cluster}: Silhouette Coefficient: {silhouette_score(embeddings, labels)}')

Execution restuls: 

n_clusters = 2: Silhouette Coefficient: 0.027680527418851852
n_clusters = 3: Silhouette Coefficient: 0.026121865957975388
n_clusters = 4: Silhouette Coefficient: 0.02658974751830101
n_clusters = 5: Silhouette Coefficient: 0.025382913649082184
n_clusters = 6: Silhouette Coefficient: 0.024761341512203217
n_clusters = 7: Silhouette Coefficient: 0.019972721114754677
n_clusters = 8: Silhouette Coefficient: 0.02104909159243107
n_clusters = 9: Silhouette Coefficient: 0.022642113268375397
n_clusters = 10: Silhouette Coefficient: 0.023155178874731064
n_clusters = 11: Silhouette Coefficient: 0.022205878049135208
n_clusters = 12: Silhouette Coefficient: 0.023076286539435387
n_clusters = 13: Silhouette Coefficient: 0.02041112817823887
n_clusters = 14: Silhouette Coefficient: 0.02360903099179268
n_clusters = 15: Silhouette Coefficient: 0.023625940084457397
n_clusters = 16: Silhouette Coefficient: 0.024519361555576324
n_clusters = 17: Silhouette Coefficient: 0.023255495354533195
n_clusters = 18: Silhouette Coefficient: 0.025912733748555183
n_clusters = 19: Silhouette Coefficient: 0.024510184302926064
n_clusters = 20: Silhouette Coefficient: 0.026332970708608627
n_clusters = 21: Silhouette Coefficient: 0.026626840233802795
n_clusters = 22: Silhouette Coefficient: 0.02656748704612255
n_clusters = 23: Silhouette Coefficient: 0.025318237021565437
n_clusters = 24: Silhouette Coefficient: 0.02624185010790825
n_clusters = 25: Silhouette Coefficient: 0.026188340038061142
n_clusters = 26: Silhouette Coefficient: 0.02609947882592678
n_clusters = 27: Silhouette Coefficient: 0.02687591314315796
n_clusters = 28: Silhouette Coefficient: 0.026777690276503563
n_clusters = 29: Silhouette Coefficient: 0.027936046943068504
n_clusters = 30: Silhouette Coefficient: 0.02727353386580944
n_clusters = 31: Silhouette Coefficient: 0.028382878750562668
n_clusters = 32: Silhouette Coefficient: 0.02829943224787712
n_clusters = 33: Silhouette Coefficient: 0.028072169050574303
n_clusters = 34: Silhouette Coefficient: 0.02787160314619541
n_clusters = 35: Silhouette Coefficient: 0.02733534574508667