# Contents
### 1) <a href='#section1'>Web Scraping and RSS Feed Parsing</a>
### 2) <a href='#section2'>Data Cleaning and Feature Engineering</a>
### 3) <a href='#section3'>Model Training</a>

In [30]:
from collections import Counter,defaultdict
from operator import itemgetter
import pickle
import re
import time,datetime

import feedparser
import selenium
from selenium import webdriver

import numpy as np
import pandas as pd

import gensim
from gensim.corpora import Dictionary
from gensim.models import LdaModel
from gensim.models.doc2vec import Doc2Vec

In [3]:
# pickle functions for easy saving/loading
def save_obj(obj,name):
    with open('obj/'+ name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_obj(name):
    with open('obj/' + name + '.pkl', 'rb') as f:
        return pickle.load(f)

<a id='section1'></a>

## Web Scraping and RSS Feed Parsing

In [None]:
#load list containing all known authors
all_authors = load_obj('all_authors_store')

print("Current number of distinct authors: " + str(len(all_authors)))

In [None]:
# initiate selenium driver
driver = webdriver.Chrome("chromedriver/chromedriver")

In [None]:
# method to get tag links from page
def get_tag_links(tag_links_set=set()):
    for element in driver.find_elements_by_tag_name('a'):
        tag_link = element.get_attribute('href')
        if "https://medium.com/tag" in tag_link:
            tag_links_set.add(tag_link)
    return tag_links_set

# method to get author links from page
def get_new_authors(all_authors):
    good_hrefs = []
    for element in driver.find_elements_by_tag_name('a'):
        href = element.get_attribute('href')
        if "@" in href and "?" not in href:
            good_hrefs.append(href)
    
    new_authors_set = set()
    for author in good_hrefs:
        new_authors_set.add(re.search('@([^/])+',author).group(0))

    return all_authors | new_authors_set

# method to scroll to bottom of page (so that all results appear)
def scroll_to_bottom():
    lenOfPage = driver.execute_script("window.scrollTo(0, document.body.scrollHeight);var lenOfPage=document.body.scrollHeight;return lenOfPage;")
    match=False
    count = 0
    while(match==False):
            count += 1
            lastCount = lenOfPage
            time.sleep(3)
            lenOfPage = driver.execute_script("window.scrollTo(0, document.body.scrollHeight);var lenOfPage=document.body.scrollHeight;return lenOfPage;")
            if lastCount==lenOfPage:
                match=True
    return count

In [None]:
# Author scraper approach 1 - spider across related tags
# Takes a long time to run - see cell below for a more directed approach

# replace "technology" with desired start tag
driver.get('https://medium.com/tag/{}'.format('technology')) 

tag_links =  get_tag_links()
new_tags = set()
for tag in tag_links:
    driver.get(tag)
    new_tags = get_tag_links(tag_links_set=new_tags)
tag_links = tag_links | new_tags

for tag_link in tag_links:
    scroll_to_bottom()
    all_authors = get_new_authors(all_authors)

save_obj(all_authors,'all_authors_backup2')

print("New number of distinct authors: " + str(len(all_authors)))

In [None]:
# Author scraper approach 2 - Directly search for specific keywords
# Searching for a keyword yields more results for that keyword than the tag approach does

#insert desired keywords 
search_keywords = ["deep learning","artificial intelligence"]

for keyword in search_keywords:
    driver.get("https://medium.com/search?q={}".format(keyword.replace(" ","%20")))
    scroll_to_bottom()
    all_authors = get_new_authors(all_authors)

save_obj(all_authors,'all_authors_backup2')

print("New number of distinct authors: " + str(len(all_authors)))

In [None]:
# close driver
driver.quit()

In [None]:
# load previously scraped articles for all authors
old_titles = load_obj("old_titles")
for author in all_authors:
    if author not in old_titles:
        old_titles[author] = set()
save_obj(old_titles,"old_titles")

In [None]:
# load previous medium csv
medium_df = pd.read_csv("medium_posts.csv")
medium_df.head()

In [None]:
# Method to parse RSS feeds
def get_posts(df, sources, rss_link):

    old_titles = load_obj("old_titles")
    new_posts = {}
    count = 1
    
    for source in sources:
        feed = dict(feedparser.parse(rss_link.format(source)))
        for post in feed['entries']:
            title = post['title']
            if title not in old_titles[source]:
                old_titles[source].add(title)
                try:
                    new_posts[count] = {}
                    new_posts[count]['datetime'] = str(datetime.datetime.now())
                    new_posts[count]['url'] = post['link']
                    new_posts[count]['id'] = post['id']
                    new_posts[count]['title'] = title
                    new_posts[count]['text'] = post['summary']
                    new_posts[count]['author'] = post['author']
                    new_posts[count]['published'] = post['published']
                    new_posts[count]['published_parsed'] = post['published_parsed']

                    keywords = ''
                    for tag in post['tags']:
                        keyword = tag['term']
                        if keywords == '':
                            keywords = keywords + keyword
                        else:
                            keywords = keywords + ' / ' + keyword 
                    new_posts[count]['keywords'] = keywords
                    count += 1

                except:
                    count += 1
                    continue
                    
    save_obj(old_titles, "old_titles")
    
    return pd.concat([df,pd.DataFrame.from_dict(new_posts,orient='index')],ignore_index=True)

In [None]:
# run rss parser and save new csv (replaces previous one)
# takes a few hours to run
medium_rss_authors = 'https://medium.com/feed/@{}'
medium_df = get_posts(medium_df,all_authors,medium_rss_authors)
medium_df.to_csv('medium_posts.csv'.format(str(csv_count)), header=True, index=False)

<a id='section2'></a>

## Data Cleaning and Feature Extraction

In [32]:
# read in csv containing all scraped data and metadata from posts 
medium_df = pd.read_csv('medium_posts12.csv')

In [37]:
# create new column to store sequential post ids (will come in handy later)
medium_df['article_id'] = np.arange(len(medium_df)) + 1

In [38]:
medium_df.head()

Unnamed: 0,author,datetime,id,keywords,medium_url,published,published_parsed,text,title,url,post_id,article_id
0,Pierre GUILBAUD,2017-09-10 00:11:40.202998,https://medium.com/p/b1bdd7074234,startup / growth-hacking / email-marketing / e...,,"Mon, 14 Aug 2017 07:06:33 GMT","time.struct_time(tm_year=2017, tm_mon=8, tm_md...","<figure><img alt="""" src=""https://cdn-images-1....",Growth Hacking Workflow: 5 Steps to Go from Co...,https://medium.com/nookspot/growth-hacking-wor...,1,1
1,Pierre GUILBAUD,2017-09-10 00:11:40.203076,https://medium.com/p/a0cc432bf0ac,startup / ideas / entrepreneur / network / pitch,,"Wed, 05 Jul 2017 08:57:32 GMT","time.struct_time(tm_year=2017, tm_mon=7, tm_md...","<figure><img alt="""" src=""https://cdn-images-1....",Pitch your startup in 30 Seconds,https://medium.com/@p.guilbaud/pitch-your-star...,2,2
2,Pierre GUILBAUD,2017-09-10 00:11:40.203112,https://medium.com/p/57512d1cc8c,entrepreneurship / landing-pages / startup / b...,,"Tue, 27 Jun 2017 10:53:08 GMT","time.struct_time(tm_year=2017, tm_mon=6, tm_md...","<figure><img alt="""" src=""https://cdn-images-1....",Landing Page Optimization: Tools & Tips,https://medium.com/@p.guilbaud/landing-page-op...,3,3
3,Pierre GUILBAUD,2017-09-10 00:11:40.203145,https://medium.com/p/f94a35ca69aa,digital-strategy / social-media / growth-hacki...,,"Fri, 17 Mar 2017 08:15:55 GMT","time.struct_time(tm_year=2017, tm_mon=3, tm_md...","<figure><img alt="""" src=""https://cdn-images-1....",5 Keys to Master your Digital Marketing Strategy,https://medium.com/@p.guilbaud/5-keys-to-maste...,4,4
4,Pierre GUILBAUD,2017-09-10 00:11:40.203179,https://medium.com/p/59cfd24bc98b,startup / growth / entrepreneurship / accelera...,,"Sat, 11 Mar 2017 12:53:32 GMT","time.struct_time(tm_year=2017, tm_mon=3, tm_md...","<figure><img alt="""" src=""https://cdn-images-1....",Five People Essential to your Startup Success,https://medium.com/@p.guilbaud/five-people-ess...,5,5


In [12]:
# replace NaN values with emptry string in keyword column
medium_df.keywords.fillna('', inplace=True)

In [14]:
# create dictionary to store keyword frequencies across all posts
keyword_counts_dict = defaultdict(int)
for i in medium_df.index:
    keywords = medium_df.iloc[i]['keywords']
    try:
        for keyword in keywords.split(' / '):
            keyword_counts_dict[keyword] += 1
    except:
        continue

In [2]:
Counter(keyword_counts_dict).most_common(200)[1:]

In [None]:
# Select posts based on relevant keywords and divide into 10 categories (to be used in web app)
search_tech = ['tech ', 'technology', 'augmented-reality','android','apple', 'software-development', 'facebook'
               'data-science', 'programming','machine-learning','artificial-intelligence','virtual-reality', 'vr',
               'self-driving-cars','internet-of-things']
search_politics = ['politics', 'government', 'donald-trump', 'trump', 'obama', 'hillary-clinton', 'russia',
                   'republican-party','democrats', 'north-korea', 'congress','elections', 'democracy']
search_entrepreneurship = ['entrepreneurship', 'startup', 'venture-capital', 'innovation', 'business','founders',
                           'fundraising']
search_economics = ['economics', 'finance', 'investing', 'money', 'stock-market', 'impact-investing', 'stocks',
                    'banking','personal-finance', 'economy','investment']
search_science = ['science', 'physics','space','climate-change','astronomy','neuroscience','nature','environment',
                  'brain','nasa','evolution','biology','oceans','renewable-energy','eclipse','chemistry','ecology',
                  'space-exploration','science-communication','geology']
search_life = ['life', 'life-lessons', 'self-improvement', 'health', 'mental-health', 'life-hacking']
search_education = ['education', 'teaching' 'learning', 'schools', 'higher-education','edtech', 'education-technology',
                    'university', 'students','education-reform', 'school','teachers']
search_writing = ['writing', 'books', 'poetry', 'fiction', 'storytelling', 'short-story']
search_design = ['design-thinking','data-visualization', 'ux', 'ux-design', 'user-experience','product-design',
                 'web-design', 'design-process', 'graphic-design']
search_other = search_tech + search_politics + search_entrepreneurship + search_economics + search_life \
               + search_education + search_writing + search_design

In [None]:
# create new data frames for each category
tech_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_tech))]
politics_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_politics))]
entrepreneurship_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_entrepreneurship))]
economics_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_economics))]
science_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_science))]
life_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_life))]
education_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_education))]
writing_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_writing))]
design_df = medium_df[medium_df['keywords'].str.contains('|'.join(search_design))]
other_df = medium_df[~medium_df['keywords'].str.contains('|'.join(search_other))]

In [None]:
# reset indexes
tech_df = pd.DataFrame.reset_index(tech_df)
politics_df = pd.DataFrame.reset_index(politics_df)
entrepreneurship_df = pd.DataFrame.reset_index(entrepreneurship_df)
science_df = pd.DataFrame.reset_index(science_df)
life_df = pd.DataFrame.reset_index(life_df)
education_df = pd.DataFrame.reset_index(education_df)
writing_df = pd.DataFrame.reset_index(writing_df)
design_df = pd.DataFrame.reset_index(design_df)
other_df = pd.DataFrame.reset_index(other_df)

In [None]:
def clean_post(post):
    # remove all tags besides <a> tags and non-header <p> tags
    post = re.sub('<p><strong>.*?</strong></p>','',post)
    post = re.sub('<h[0-9]>.*?</h[0-9]>','',post)        
    post = re.sub('<figure>.*?</figure>','',post)         
    post = re.sub('<iframe.*?</iframe>','',post)          
    post = re.sub('<img.*?>','',post)                     
    post = re.sub('<blockquote>.*?</blockquote>','',post) 
    post = re.sub('<ol>.*?</ol>','',post)                 
    post = re.sub('<ul>.*?</ul>','',post)                 
    post = re.sub('<em>|</em>','',post)                 
    post = re.sub('<strong>|</strong>','',post) 
    post = re.sub('<pre>.*?</pre>','',post)               
    post = re.sub('<hr>|<br>|<pr>','',post)               
    return post

def remove_a_tags(paragraph):
    # remove <a> tags (but keep text within tag)
    if "</a>" in paragraph:
        start = paragraph.split("<a")[0]
        end = paragraph.split("</a>")[-1]
        no_a_tags = start + end
        return re.sub('<.*?>','',no_a_tags)
    else:
        return paragraph
    
def remove_all_tags(paragraph):
    return re.sub('<.*?>','',paragraph)
  
def return_paragaphs(post):
    # return list of paragraphs, with all tags removed  
    paragraphs = clean_post(post).split('</p>')
    paragraphs_clean = [remove_a_tags(paragraph) for paragraph in paragraphs if paragraph != ''][:-1]
    paragraph_list = [re.sub('<.*?>','',paragraph) for paragraph in paragraphs_clean]
    if 'By' in paragraph_list[0]:
        paragraph_list = paragraph_list[1:]
    if 'recommend' or 'click' or 'originally published' in paragraph_list[-1].lower():
        paragraph_list = paragraph_list[:-1]
    return paragraph_list

In [None]:
def get_all_text(df,keyword):
    # return a data frame containing all the non-tag text for each post 
    # and a dictionary mapping the original post ids to the new data frame index
    good_texts = []
    all_text_dict = {}
    for i in df.index:
        post = df.iloc[i]['text']
        article_id = df.iloc[i]['article_id']
        all_text = ''
        try:
            for paragraph in return_paragaphs(post):
                all_text = all_text + ' ' + paragraph
            all_text_dict[i] = {'all_text':all_text, 'article_id':article_id}
            good_texts.append(article_id)
        except:
            continue
    
    all_text_df = pd.DataFrame.from_dict(all_text_dict,orient='index')
    all_text_df = pd.DataFrame.reset_index(all_text_df)
    all_text_df.rename(columns = {'index':'old_index'}, inplace = True)
    
    id_to_index = dict(zip(all_text_df.article_id,all_text_df.index))
    save_obj(id_to_index,'{}_id_to_index'.format(keyword))
    
    return all_text_df, good_texts

In [None]:
# run get_all_text() for all categories
tech_post_df, tech_good_texts = get_all_text(tech_df,'tech')
politics_post_df, politics_good_texts = get_all_text(politics_df,'politics')
entrepreneurship_post_df, entrepreneurship_good_texts = get_all_text(entrepreneurship_df, 'entrepreneurship')
science_post_df, science_good_texts = get_all_text(science_df,'science')
life_post_df, life_good_texts = get_all_text(life_df,'life')
education_post_df, education_good_texts = get_all_text(education_df,'education')
writing_post_df, writing_good_texts = get_all_text(writing_df,'writing')
design_post_df, design_good_texts = get_all_text(design_df,'design')
other_post_df, other_good_texts = get_all_text(other_df,'other')

In [17]:
def avoid_images(post):
    # avoid images in last tenth of post (these are often icons unrelated to post content)
    all_avoid = []
    tenth_of_post = -int(len(post)/10)
    for image_link in re.findall('src=".*?"',post[tenth_of_post:]):
        all_avoid.append(image_link.replace('src=','').replace('"',''))
    return all_avoid
 
def del_images_from_dict(img_para_dict,post):
    # remove images from image --> paragraph dictionary 
    good_images = {}
    for key,value in img_para_dict.items():
        if value['img'] not in avoid_images(post):
            good_images[key] = value
    return good_images

def get_img_w_paragraphs(post):
    # return image --> paragraph dictionary
    img_para_dict = {}
    text_breaks = re.findall('<p>.*?</p><figure>.*?</figure>',post)
    for i in range(len(text_breaks)):
        img_w_para = text_breaks[i].split('<p>')[-1]
        image = re.search('src=".*?"',img_w_para).group(0).replace('src=','').replace('"','')
        if 'http' in image:
            paragraph_before = remove_all_tags(remove_a_tags(img_w_para))
            two_paragraphs_before = remove_all_tags((remove_a_tags(text_breaks[i].split('<p>')[-2])))
            img_para_dict[i] = {'img':image, 'para1':paragraph_before,
                                'para2': two_paragraphs_before}
    return del_images_from_dict(img_para_dict,post)


In [None]:
def get_paragraphs_images(df,keyword,good_texts):
    # return data frame containing paragraph-image pairs for each post 
    # and save three objects:  
    # 1) image to paragraph dictionary 
    # 2) image to keyword dictionary 
    # 3) dictionary mapping the original post ids to the new data frame index 
    body_imgs_dict = {}
    count = 0
    for i in df.index:
        article_id = df.iloc[i]['article_id']
        if article_id in set(good_texts):
            post = df.iloc[i]['text']
            try:
                keywords = df.iloc[i]['keywords']
                img_list = get_img_w_paragraphs(post)
                for key,value in img_list.items():
                    img_link = value['img']
                    para1 = value['para1'].lower()
                    if 'para2' in value:
                        para2 = value['para2'].lower()
                    else:
                        para2 = ''
                    if 'len(para1.split(' ') + para2.split(' ')) > 15:
                        body_imgs_dict[count] = {'image_link':img_link, '1_paragraph_before':para1, 
                                                 '2_paragraphs_before':para2, 'keywords':keywords, 
                                                 'article_id':article_id}
                        count += 1
            except:
                continue
         
    body_imgs_df = pd.DataFrame.from_dict(body_imgs_dict,orient='index')
    body_imgs_df = pd.DataFrame.reset_index(body_imgs_df)
    body_imgs_df.rename(columns = {'index':'old_index'}, inplace = True)

    img_to_para_dict = dict(zip(body_imgs_df.index, body_imgs_df.image_link))
    img_to_kword_dict = dict(zip(body_imgs_df.index, body_imgs_df.keywords))
    index_to_id = dict(zip(body_imgs_df.index, body_imgs_df.article_id))

    save_obj(img_to_para_dict,'{}_img_to_para_dict'.format(keyword))
    save_obj(img_to_kword_dict,'{}_img_to_kword_dict'.format(keyword))
    save_obj(index_to_id,'{}_index_to_id'.format(keyword))

    return body_imgs_df

In [None]:
# run get_paragraphs_images for all categories
tech_paragraph_df = get_paragraphs_images(medium_df,'tech', tech_good_texts)
politics_paragraph_df = get_paragraphs_images(medium_df,'politics', politics_good_texts)
entrepreneurship_paragraph_df = get_paragraphs_images(medium_df, 'entrepreneurship', entrepreneurship_good_texts)
science_paragraph_df = get_paragraphs_images(medium_df,'science', science_good_texts)
life_paragraph_df = get_paragraphs_images(medium_df,'life', life_good_texts)
education_paragraph_df = get_paragraphs_images(medium_df,'education', education_good_texts)
writing_paragraph_df = get_paragraphs_images(medium_df,'writing', writing_good_texts)
design_paragraph_df = get_paragraphs_images(medium_df,'design', design_good_texts)
other_paragraph_df = get_paragraphs_images(medium_df,'other', other_good_texts)

<a id='section3'></a>

## Model Training

In [None]:
# import stopword text file and save as list object 
with open("englishST.txt",'r') as f:
    stopwords = set(f.read().splitlines())
save_obj(stopwords,'stopwords')

In [None]:
# lemmatizer for getting standard dictionary forms of words
from nltk.stem.wordnet import WordNetLemmatizer
from gensim.models import Phrases
lemmatizer = WordNetLemmatizer()

In [None]:
def lemmatize(doc,lemmatizer):
    # Lemmatize words in doc
    lemma_doc = [lemmatizer.lemmatize(token) for token in doc]
    return lemma_doc
   
def add_phrases(doc_list,keyword,para=True):
    # Add bigrams and trigrams to docs (only those that appear 20 times or more)
    bigram = Phrases(doc_list, min_count=20)
    if para:
        save_obj(bigram,'{}_para_bigram_list'.format(keyword))
    else:
        save_obj(bigram,'{}_post_bigram_list'.format(keyword))
    
    for idx in range(len(doc_list)):
        for token in bigram[doc_list[idx]]:
            if '_' in token:
                # token is a bigram, add to document
                doc_list[idx].append(token)
    return doc_list

def preprocess_lda(df,lemmatizer,keyword):
    # preprocessing (remove stopwords,lemmatize,add phrases) for full post LDA
    lda_posts = []
    rows = df.iterrows()
    for i in range(len(df)):
        row = next(rows)[-1]
        text = row['all_text']
        processed = gensim.utils.simple_preprocess(text)
        # remove stopwords
        processed_stop = [word for word in processed if word not in stopwords]
        # lemmatize
        lemma_processed = lemmatize(processed_stop,lemmatizer)
        lda_posts.append(lemma_processed)   
    text_w_phrases = add_phrases(lda_posts,keyword,para=False)

    return text_w_phrases

def preprocess_paragraphs(df,lemmatizer,keyword):
    # preprocessing (remove stopwords,lemmatize,add phrases) for paragraph LDA and doc2vec
    lda_para = []
    doc2vec_para = []
    rows = df.iterrows()
    for i in range(len(df)):
        row = next(rows)[-1]
        para1 = row['1_paragraph_before']
        para2 = row['2_paragraphs_before']
        index = i
        processed = gensim.utils.simple_preprocess(para1)
        if len(processed) < 10:
            processed = processed + gensim.utils.simple_preprocess(para2)
        # remove stopwords
        processed_stop = [word for word in processed if word not in stopwords]
        # lemmatize
        lemma_processed = lemmatize(processed_stop,lemmatizer)
        lda_para.append(lemma_processed)    
        doc2vec_para.append(gensim.models.doc2vec.TaggedDocument(lemma_processed,[index]))         
    text_w_phrases = add_phrases(lda_para,keyword)
               
    return text_w_phrases, doc2vec_para

In [None]:
def train_lda_doc2vec(post_df,paragraph_df,keyword):
    # train full post LDA model, pargraph LDA model, and paragraph doc2vec model 

    # lda on posts and paragraphs
    all_text_w_phrases = preprocess_lda(post_df,lemmatizer,keyword)
    paragraphs_w_phrases, doc2vec_paragraphs = preprocess_paragraphs(paragraph_df,lemmatizer,keyword)

    dictionary_post = Dictionary(all_text_w_phrases)
    dictionary_para = Dictionary(paragraphs_w_phrases)

    # Filter out words that occur in more than 50% of the documents.
    dictionary_post.filter_extremes(no_above=0.5)
    dictionary_para.filter_extremes(no_above=0.5)

    save_obj(dictionary_post,"{}_lda_post_dict".format(keyword))
    save_obj(dictionary_para,"{}_lda_para_dict".format(keyword))

    corpus_post = [dictionary_post.doc2bow(doc) for doc in all_text_w_phrases]
    corpus_para = [dictionary_para.doc2bow(doc) for doc in paragraphs_w_phrases]

    # Train LDA on full posts
    
    # Set training parameters
    num_topics = 200
    chunksize = 3000
    passes = 3
    iterations = 500
    eval_every = None 

    # Make an index to word dictionary
    temp = dictionary_post[0] 
    id2word = dictionary_post.id2token

    lda_model = LdaModel(corpus=corpus_post, id2word=id2word, chunksize=chunksize, \
                           alpha='auto', eta='auto', \
                           iterations=iterations, num_topics=num_topics, \
                           passes=passes, eval_every=eval_every)

    # Train LDA on paragraphs
    
    # Set training parameters
    num_topics = 200
    chunksize = 4000
    passes = 5
    iterations = 500
    eval_every = None 

    # Make an index to word dictionary
    temp = dictionary_para[0] 
    id2word = dictionary_para.id2token

    lda_model_para = LdaModel(corpus=corpus_para, id2word=id2word, chunksize=chunksize, \
                           alpha='auto', eta='auto', \
                           iterations=iterations, num_topics=num_topics, \
                           passes=passes, eval_every=eval_every)


    lda_model.save('{}_lda_model_posts'.format(keyword))
    lda_model_para.save('{}_lda_model_para'.format(keyword))

    lda_index = gensim.similarities.MatrixSimilarity(lda_model[corpus_post])
    lda_index_para = gensim.similarities.MatrixSimilarity(lda_model_para[corpus_para])

    save_obj(lda_index,'{}_lda_index_post'.format(keyword))
    save_obj(lda_index_para,'{}_lda_index_para'.format(keyword))
    
    # Train doc2vec on paragraphs
    doc2vec_model = Doc2Vec(iter=20,dm_concat=1)
    doc2vec_model.build_vocab(doc2vec_paragraphs)
    doc2vec_model.train(doc2vec_paragraphs, total_examples=doc2vec_model.corpus_count, epochs=doc2vec_model.iter)

    doc2vec_model.save('{}_doc2vec'.format(keyword))

In [None]:
category_list = [[tech_post_df, tech_paragraph_df, 'tech'],
                 [politics_post_df, politics_paragraph_df, 'politics'],
                 [entrepreneurship_post_df, entrepreneurship_paragraph_df, 'entrepreneurship'],
                 [science_post_df, science_paragraph_df, 'science'],
                 [life_post_df, life_paragraph_df, 'life'],
                 [education_post_df,education_paragraph_df, 'education'],
                 [writing_post_df,writing_paragraph_df,'writing'],
                 [design_post_df,design_paragraph_df,'design'],
                 [other_post_df, other_paragraph_df, 'other']]

In [None]:
# train models over all categories
for category in category_list:
    train_lda_doc2vec(category[0],category[1],category[2])