In [67]:
import pandas as pd
import numpy as np
import scipy as sp

from nltk.corpus import stopwords
import nltk

from sklearn.decomposition import NMF
from sklearn.preprocessing import normalize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import pairwise_distances

import re

from tqdm import tqdm
tqdm.pandas()

In [2]:
SRC_DATA = '../../data/data_clean/polisumm_final.csv'

In [3]:
data = pd.read_csv(SRC_DATA)

In [20]:
url_regex = r'https?:\/\/\S*'
punct_regex = '\s*[\.\?\!\|\n\:]\s*[\.\?\!\|\n\:]*\s*'
sw = stopwords.words('english')

# Functions 

In [95]:
def process_texts(cat_text, min_toks = 3):
    new_text = re.sub(url_regex, '', cat_text)
    new_text = ' '.join([term for term in new_text.split() if term not in sw])
    text_l   = re.split(punct_regex, new_text)
    text_l   = [sent for sent in text_l if len(sent.split()) > min_toks]
    
    return text_l

def filter_min_toks(cat_text, min_toks = 3):
    new_text = re.sub(url_regex, '', cat_text)
    text_l    = re.split(punct_regex, new_text)
    
    new_text = ' '.join([term for term in new_text.split() if term not in sw])
    ptext_l   = re.split(punct_regex, new_text)
    text_l   = [text_l[i] for i, sent in enumerate(ptext_l) if len(sent.split()) > min_toks]
    
    return text_l

# Data Processing 

In [96]:
data['text_proc'] = data['all_texts'].progress_apply(process_texts)
data['filt_texts'] = data['all_texts'].progress_apply(filter_min_toks)

100%|██████████████████████████████████████████████████████████████████████████████| 1199/1199 [00:18<00:00, 64.38it/s]


In [81]:
vect = TfidfVectorizer(min_df = 3, max_df = 0.8)
vect = vect.fit(data['text_proc'].fillna(''))

# Model Definition

In [98]:
class EGCOS():
    
    def __init__(self, vectorizer):
        self.PUNCT_REGEX = '\s*[\.\?\!\|\n\:]\s*[\.\?\!\|\n\:]*\s*'

        self.vectorizer = vectorizer
        
        
    def predict(self, text, src_text = None):
        text_l = re.split(self.PUNCT_REGEX, text) if isinstance(text, str) else text
        if src_text:
            src_text_l = re.split(self.PUNCT_REGEX, src_text) if isinstance(src_text, str) else src_text
        
        text_bows = self.vectorizer.transform(text_l)
        text_bows = normalize(text_bows, norm = 'l1', axis = 1)
                
        split_preds = self.get_factorizer(text_bows)
        classes = split_preds.argmax(axis = -1).astype(bool)
        
        half1 = text_bows[classes]
        half2 = text_bows[~classes]
        
        pair_dists = pairwise_distances(half1, half2, metric = 'cosine')
        h1_idx, h2_idx = np.unravel_index(pair_dists.argmax(), pair_dists.shape)
        
        if src_text:
            pred_h1, pred_h2 = np.array(src_text_l)[classes][h1_idx], np.array(src_text_l)[~classes][h2_idx]
        else:
            pred_h1, pred_h2 = np.array(text_l)[classes][h1_idx], np.array(text_l)[~classes][h2_idx]
            
        return pred_h1, pred_h2

        
    def combine_topic_probs(self, probs1, probs2):
        combos = itertools.product(probs1, probs2)
        combos = [a + b for (a, b) in combos]
        combos = np.array(combos).reshape(len(probs1), len(probs2))
        
        return combos
        
        
    def get_factorizer(self, text_bows):
        
        factorizer = NMF(n_components = 2, init = 'nndsvd')
        return factorizer.fit_transform(text_bows)
        
        
        
        

# Evaluation 

In [99]:
egcos = EGCOS(vect)

In [109]:
def make_prediction(model, row):
    texts = row['text_proc']
    
    try:
        summ_1, summ_2 = model.predict(texts)
    except Exception as e:
        print(f'Error: {e}')
        return ' | '
    return summ_1 + ' | ' + summ_2

In [None]:
predictions = data.progress_apply(lambda row: make_prediction(egcos, row), axis = 1)

  9%|██████▋                                                                        | 102/1199 [00:31<05:15,  3.47it/s]

Error: init = 'nndsvd' can only be used when n_components <= min(n_samples, n_features)


 23%|█████████████████▉                                                             | 272/1199 [01:24<04:46,  3.24it/s]