# Corex
code: https://github.com/gregversteeg/corex_topic/blob/master/corextopic/example/corex_topic_example.ipynb

In [1]:

import numpy as np
import scipy.sparse as ss
import matplotlib.pyplot as plt
import pandas as pd
import re

import corextopic.corextopic as ct
import corextopic.vis_topic as vt # jupyter notebooks will complain matplotlib is being loaded twice

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics import classification_report

from nltk.corpus import stopwords
from nltk.corpus import wordnet
from nltk.tokenize import word_tokenize

%matplotlib inline

## Import data

In [2]:
#import data
model_data = pd.read_csv('LDA_train.csv')
print('Tokenized Text DF Size:', len(model_data))

Tokenized Text DF Size: 28652


  model_data = pd.read_csv('LDA_train.csv')


In [3]:
#sample model data
model_data = model_data.sample(frac=0.25, random_state=42)

In [4]:
#preprocessing
def preprocess_text(text):

  #lowercase text
  text_preprocessed = text.lower()
  #remove punctuation
  text_preprocessed = re.sub(r'[^a-zA-Z ]+', '', text_preprocessed)
  #tokenize for stopword removal
  text_preprocessed = word_tokenize(text_preprocessed)
  #remove stopwords
  text_preprocessed = [word for word in text_preprocessed if word not in stopwords.words('english')]
  #join to make string again
  #text_preprocessed = (" ").join(text_preprocessed)

  return text_preprocessed

In [5]:
%%time
model_data['tokens'] = model_data['description'].apply(lambda x: preprocess_text(x))

CPU times: user 1min 12s, sys: 5.83 s, total: 1min 18s
Wall time: 1min 18s


In [6]:
model_data['liststring'] = [','.join(map(str, l)) for l in model_data['tokens']]

## Create Synsets

In [7]:
def create_synsets(event):
  
  synonym = [] 
    
  for synset in wordnet.synsets(event): 
      for i in synset.lemmas(): 
          synonym.append(i.name()) # add all the synonyms available 
    
  return synonym

In [8]:
life_events = ['university', 'relationships', 'break ups', 'divorce', 'wedding', 
               'death', 'family', 'friendship']

#create synsets for select events where decent synsets exist
relationship_list = create_synsets('go_steady') + ['relationship', 'kinship', 'romance', 'dating']
marriage_list = create_synsets('marriage')
wedding_list = create_synsets('wedding') + ['matrimony']

#replace underscore (_) with space
relationship_list = [i.replace("_", " ") for i in relationship_list]
marriage_list = [i.replace("_", " ") for i in marriage_list]
wedding_list = [i.replace("_", " ") for i in wedding_list]

#remove certain words
wedding_list.remove('tie')
wedding_list.remove('marriage')
relationship_list.remove('see')

synsets = [['college', 'university', 'campus', 'academia', 'professor', 'colleges', 'universities', 'professors'], 
           relationship_list, 
           ['breakup', 'break up', 'split up', 'broken up', 'dumped', 'breaks up', 'splits up', 'dumps', 'dump', 'breaks off', 'break off'], 
           ['divorce', 'divorced', 'divorces'], 
           wedding_list,  
           ['death', 'decease', 'deceased', 'dying'],
           ['family', 'mother', 'father', 'brother', 'sister', 'mom', 'dad'],
           ['friends', 'friend', 'friendship', 'friendships']]

# Create the pandas DataFrame with column name is provided explicitly
df_lib = pd.DataFrame(life_events, columns=['life_event'])
df_lib['synsets'] = synsets
 
# print dataframe.
df_lib

Unnamed: 0,life_event,synsets
0,university,"[college, university, campus, academia, profes..."
1,relationships,"[go steady, go out, date, relationship, kinshi..."
2,break ups,"[breakup, break up, split up, broken up, dumpe..."
3,divorce,"[divorce, divorced, divorces]"
4,wedding,"[wedding, wedding ceremony, nuptials, hymeneal..."
5,death,"[death, decease, deceased, dying]"
6,family,"[family, mother, father, brother, sister, mom,..."
7,friendship,"[friends, friend, friendship, friendships]"


In [9]:
topic_num_name = {"Topic 0":"university",
                  "Topic 1":"relationships",
                  "Topic 2":"breakups",
                  "Topic 3":"divorce",
                  "Topic 4":"wedding",
                  "Topic 5": "death",
                  "Topic 6": "family",
                  "Topic 7": "friends"}    

In [10]:
topic_list = ['universities', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']

In [11]:
model_data_test = pd.read_csv('LDA_test.csv')
model_data_test = model_data_test.replace(np.nan, False)

In [12]:
model_data_test['tokens'] = model_data_test['description'].apply(lambda x: preprocess_text(x))

In [13]:
#create test objects
#model_data_test = pd.read_csv('LDA_test.csv')

#preprocess test set
#model_data_test['tokens'] = model_data_test['description'].apply(lambda x: preprocess_text(x))

#create objects required for model testing

model_data_test['liststring'] = [','.join(map(str, l)) for l in model_data_test['tokens']]
corpus=model_data_test['liststring'].tolist()
vocab=list(set(word_tokenize(" ".join(model_data_test['liststring']))))
vectorizer = CountVectorizer(ngram_range=(1,1),vocabulary=vocab)
X = vectorizer.fit_transform(corpus)
word2id=vectorizer.vocabulary_

In [14]:
model_data_test_reduced = model_data_test[['university', 'relationships', 'break ups', 'divorce', 'weddings', 'death', 'family', 'friendship']]
model_data_test_reduced.columns = ['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']

In [15]:
model_data_test_reduced.university = model_data_test_reduced.university.astype(bool)
model_data_test_reduced.relationships = model_data_test_reduced.relationships.astype(bool)
model_data_test_reduced['break ups'] = model_data_test_reduced['break ups'].astype(bool)
model_data_test_reduced.divorce = model_data_test_reduced.divorce.astype(bool)
model_data_test_reduced.wedding = model_data_test_reduced.wedding.astype(bool)
model_data_test_reduced.death = model_data_test_reduced.death.astype(bool)
model_data_test_reduced.family = model_data_test_reduced.family.astype(bool)
model_data_test_reduced.friendship = model_data_test_reduced.friendship.astype(bool)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  model_data_test_reduced.university = model_data_test_reduced.university.astype(bool)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  model_data_test_reduced.relationships = model_data_test_reduced.relationships.astype(bool)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  model_data_test_reduced['brea

In [16]:
model_data_test_reduced.head()

Unnamed: 0,university,relationships,break ups,divorce,wedding,death,family,friendship
0,False,True,False,False,False,False,False,False
1,False,False,False,False,False,False,False,False
2,False,True,False,True,False,True,True,False
3,False,True,False,False,False,False,True,False
4,False,False,False,False,False,False,False,False


In [17]:
model_data_test_reduced.to_csv('corex_y_true.csv')

In [18]:
model_data_test_reduced

Unnamed: 0,university,relationships,break ups,divorce,wedding,death,family,friendship
0,False,True,False,False,False,False,False,False
1,False,False,False,False,False,False,False,False
2,False,True,False,True,False,True,True,False
3,False,True,False,False,False,False,True,False
4,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...
995,False,False,False,False,False,False,False,True
996,False,False,False,False,False,False,False,False
997,False,False,False,False,False,False,True,False
998,False,False,False,False,False,False,False,False


In [19]:
#document_df_test_bool.dtypes

## Corex Topic

In [20]:
%%time

anchor_strengths = [2, 4, 6, 8, 10]

anchor_words = synsets

for anchor_strength in anchor_strengths:

    print('Anchor Strength:', str(anchor_strength))

    #define model
    model = ct.Corex(n_hidden=8, seed=2)

    #fit model
    print('fitting model')
    model.fit(X, words=vocab, anchors=anchor_words, anchor_strength=anchor_strength);
    print('model fitted')
    print('')

    #get topics
    print('getting topics')
    document_df_test_bool=model.labels
    document_df_test_probs = model.p_y_given_x
    topic_list = ['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']
    document_df_test_bool = pd.DataFrame(document_df_test_bool, columns = topic_list)
    document_df_test_probs = pd.DataFrame(document_df_test_probs, columns = topic_list)
    #print('got topics')

    #convert preds to 
    preds = document_df_test_bool[['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']]
    cols = ['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']
    
    #model_data_test_reduced = model_data_test[['university', 'relationships', 'break ups', 'divorce', 'weddings', 'death', 'family', 'friendship']]
    #model_data_test_reduced.columns = ['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']
    #model_data_test_reduced.university = model_data_test_reduced.university=='True'
    #model_data_test_reduced.relationships = model_data_test_reduced.relationships=='True'
    #model_data_test_reduced['break ups'] = model_data_test_reduced['break ups']=='True'
    #model_data_test_reduced.divorce = model_data_test_reduced.divorce=='True'
    #model_data_test_reduced.wedding = model_data_test_reduced.wedding=='True'
    #model_data_test_reduced.death = model_data_test_reduced.death=='True'
    #model_data_test_reduced.family = model_data_test_reduced.family=='True'
    #model_data_test_reduced.friendship = model_data_test_reduced.friendship=='True'
    

    #print('got dfs')
    
    #save both document_df_test and preds to csvs
    probs_csv_name = 'corex_probs_anchor_strength_' + str(anchor_strength) + '.csv'
    bool_csv_name = 'corex_bools_anchor_strength_' + str(anchor_strength) + '.csv'

    document_df_test_probs.to_csv(probs_csv_name)
    document_df_test_bool.to_csv(bool_csv_name)
    
    y_pred = np.array(preds.values.tolist())
    y_true = np.array(model_data_test_reduced.values.tolist())
    
    print(model_data_test_reduced.dtypes)
    print(document_df_test_bool.dtypes)
    

    print(classification_report(
        model_data_test_reduced,
        document_df_test_bool,
        #output_dict=True,
        target_names=['university', 'relationships', 'break ups', 'divorce', 'wedding', 'death', 'family', 'friendship']
    ))


Anchor Strength: 2
fitting model
model fitted

getting topics
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
               precision    recall  f1-score   support

   university       0.06      0.50      0.10        26
relationships       0.47      0.27      0.34       399
    break ups       0.00      0.25      0.01         4
      divorce       0.05      1.00      0.09        12
      wedding       0.00      0.11      0.01        19
        death       0.04      0.31      0.07        29
       family       0.37      0.50      0.42       187
   friendship       0.38      0.57      0.45       191

    micro avg       0.15      0.40      0.22       867
    ma

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


model fitted

getting topics
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
               precision    recall  f1-score   support

   university       0.06      0.58      0.11        26
relationships       0.50      0.34      0.40       399
    break ups       0.00      0.25      0.01         4
      divorce       0.05      1.00      0.10        12
      wedding       0.00      0.11      0.01        19
        death       0.04      0.38      0.08        29
       family       0.47      0.79      0.58       187
   friendship       0.54      0.85      0.66       191

    micro avg       0.20      0.56      0.29       867
    macro avg       0.21      0.54     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


model fitted

getting topics
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
               precision    recall  f1-score   support

   university       0.09      0.85      0.16        26
relationships       0.53      0.36      0.43       399
    break ups       0.00      0.25      0.01         4
      divorce       0.05      1.00      0.09        12
      wedding       0.06      1.00      0.11        19
        death       0.04      0.38      0.08        29
       family       0.46      0.84      0.60       187
   friendship       0.59      0.91      0.71       191

    micro avg       0.24      0.62      0.35       867
    macro avg       0.23      0.70     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


model fitted

getting topics
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
               precision    recall  f1-score   support

   university       0.09      0.88      0.17        26
relationships       0.54      0.37      0.44       399
    break ups       0.00      0.25      0.01         4
      divorce       0.05      1.00      0.09        12
      wedding       0.07      1.00      0.12        19
        death       0.04      0.38      0.08        29
       family       0.49      0.89      0.63       187
   friendship       0.64      0.92      0.75       191

    micro avg       0.26      0.64      0.36       867
    macro avg       0.24      0.71     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


model fitted

getting topics
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
university       bool
relationships    bool
break ups        bool
divorce          bool
wedding          bool
death            bool
family           bool
friendship       bool
dtype: object
               precision    recall  f1-score   support

   university       0.10      0.88      0.18        26
relationships       0.55      0.33      0.41       399
    break ups       0.00      0.25      0.01         4
      divorce       0.05      1.00      0.10        12
      wedding       0.08      1.00      0.15        19
        death       0.04      0.34      0.08        29
       family       0.50      0.91      0.65       187
   friendship       0.66      0.92      0.77       191

    micro avg       0.27      0.62      0.38       867
    macro avg       0.25      0.70     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
