In [1]:
%load_ext autoreload
%autoreload 2

import warnings
warnings.filterwarnings('ignore')

In [2]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import glob
import json
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
plt.style.use('ggplot')

In [3]:
tqdm.pandas()

In [4]:
root_path = '../../data/raw/covid/'

In [12]:
metadata_path = f'{root_path}/metadata.csv'
meta_df = pd.read_csv(metadata_path, dtype={
    'pubmed_id': str,
    'Microsoft Academic Paper ID': str, 
    'doi': str
})
meta_df.head()

Unnamed: 0,cord_uid,sha,source_x,title,doi,pmcid,pubmed_id,license,abstract,publish_time,authors,journal,Microsoft Academic Paper ID,WHO #Covidence,has_full_text,full_text_file,url
0,vho70jcx,f056da9c64fbf00a4645ae326e8a4339d015d155,biorxiv,SIANN: Strain Identification by Alignment to N...,10.1101/001727,,,biorxiv,Next-generation sequencing is increasingly bei...,2014-01-10,Samuel Minot; Stephen D Turner; Krista L Ternu...,,,,True,biorxiv_medrxiv,https://doi.org/10.1101/001727
1,i9tbix2v,daf32e013d325a6feb80e83d15aabc64a48fae33,biorxiv,Spatial epidemiology of networked metapopulati...,10.1101/003889,,,biorxiv,An emerging disease is one infectious epidemic...,2014-06-04,Lin WANG; Xiang Li,,,,True,biorxiv_medrxiv,https://doi.org/10.1101/003889
2,62gfisc6,f33c6d94b0efaa198f8f3f20e644625fa3fe10d2,biorxiv,Sequencing of the human IG light chain loci fr...,10.1101/006866,,,biorxiv,Germline variation at immunoglobulin gene (IG)...,2014-07-03,Corey T Watson; Karyn Meltz Steinberg; Tina A ...,,,,True,biorxiv_medrxiv,https://doi.org/10.1101/006866
3,058r9486,4da8a87e614373d56070ed272487451266dce919,biorxiv,Bayesian mixture analysis for metagenomic comm...,10.1101/007476,,,biorxiv,Deep sequencing of clinical samples is now an ...,2014-07-25,Sofia Morfopoulou; Vincent Plagnol,,,,True,biorxiv_medrxiv,https://doi.org/10.1101/007476
4,wich35l7,eccef80cfbe078235df22398f195d5db462d8000,biorxiv,Mapping a viral phylogeny onto outbreak trees ...,10.1101/010389,,,biorxiv,Developing methods to reconstruct transmission...,2014-11-11,Stephen P Velsko; Jonathan E Allen,,,,True,biorxiv_medrxiv,https://doi.org/10.1101/010389


In [6]:
all_json = glob.glob(f'{root_path}/**/*.json', recursive=True)
len(all_json)

33375

In [8]:
class FileReader:
    def __init__(self, file_path):
        with open(file_path) as file:
            content = json.load(file)
            self.paper_id = content['paper_id']
            self.abstract = []
            self.body_text = []
            # Abstract
            for entry in content['abstract']:
                self.abstract.append(entry['text'])
            # Body text
            for entry in content['body_text']:
                self.body_text.append(entry['text'])
            self.abstract = '\n'.join(self.abstract)
            self.body_text = '\n'.join(self.body_text)
    def __repr__(self):
        return f'{self.paper_id}: {self.abstract[:200]}... {self.body_text[:200]}...'


In [10]:
first_row = FileReader(all_json[0])
print(first_row)

d604508a4de30a622bbe2d59d1f73abaaf9c14d2: ... La pathologic virale du syst6me nerveux, nagubre consid6r6e avec r6ticence, a acquis droit de cit6. Ce ne sont pas tellement la multiplicit6 des affections cliniques maintenant bien ~tudi6es, ni la di...


In [14]:
def get_breaks(content, length):
    data = ""
    words = content.split(' ')
    total_chars = 0

    # add break every length characters
    for i in range(len(words)):
        total_chars += len(words[i])
        if total_chars > length:
            data = data + "<br>" + words[i]
            total_chars = 0
        else:
            data = data + " " + words[i]
    return data

In [18]:
dict_ = {
    'paper_id': [], 
    'abstract': [], 
    'body_text': [], 
    'authors': [], 
    'title': [], 
    'journal': [], 
    'abstract_summary': []
}

for entry in tqdm(all_json, desc="Processing files"):
    content = FileReader(entry)
    
    # get metadata information
    meta_data = meta_df.loc[meta_df['sha'] == content.paper_id]
    # no metadata, skip this paper
    if len(meta_data) == 0:
        continue
    
    dict_['paper_id'].append(content.paper_id)
    dict_['abstract'].append(content.abstract)
    dict_['body_text'].append(content.body_text)
    
    # also create a column for the summary of abstract to be used in a plot
    if len(content.abstract) == 0: 
        # no abstract provided
        dict_['abstract_summary'].append("Not provided.")
    elif len(content.abstract.split(' ')) > 100:
        # abstract provided is too long for plot, take first 300 words append with ...
        info = content.abstract.split(' ')[:100]
        summary = get_breaks(' '.join(info), 40)
        dict_['abstract_summary'].append(summary + "...")
    else:
        # abstract is short enough
        summary = get_breaks(content.abstract, 40)
        dict_['abstract_summary'].append(summary)
        
    # get metadata information
    meta_data = meta_df.loc[meta_df['sha'] == content.paper_id]
    
    try:
        # if more than one author
        authors = meta_data['authors'].values[0].split(';')
        if len(authors) > 2:
            # more than 2 authors, may be problem when plotting, so take first 2 append with ...
            dict_['authors'].append(". ".join(authors[:2]) + "...")
        else:
            # authors will fit in plot
            dict_['authors'].append(". ".join(authors))
    except Exception as e:
        # if only one author - or Null valie
        dict_['authors'].append(meta_data['authors'].values[0])
    
    # add the title information, add breaks when needed
    try:
        title = get_breaks(meta_data['title'].values[0], 40)
        dict_['title'].append(title)
    # if title was not provided
    except Exception as e:
        dict_['title'].append(meta_data['title'].values[0])
    
    # add the journal information
    dict_['journal'].append(meta_data['journal'].values[0])
    
df_covid = pd.DataFrame(dict_, columns=['paper_id', 'abstract', 'body_text', 'authors', 'title', 'journal', 'abstract_summary'])
df_covid.head()

HBox(children=(FloatProgress(value=0.0, description='Processing files', max=33375.0, style=ProgressStyle(descr…




Unnamed: 0,paper_id,abstract,body_text,authors,title,journal,abstract_summary
0,d604508a4de30a622bbe2d59d1f73abaaf9c14d2,,"La pathologic virale du syst6me nerveux, nagub...","Cathala, F.",Généralités sur l'infection du système<br>ner...,Revue d'Electroencéphalographie et de Neurophy...,Not provided.
1,7cc9c36dd4de535aabba7346c12be7b470f6797f,Although promotion of safe hygiene is the sing...,Promotion of hygiene might be the single most ...,"Curtis, Val. Schmidt, Wolf...","Hygiene: new hopes, new horizons",The Lancet Infectious Diseases,Although promotion of safe hygiene is the<br>...
2,b8069299efef2b8cf2d672b0174f6f0a7ef0580e,,Oft ist die Unterscheidung zwischen infektiöse...,,KAPITEL 13 Infektionskrankheiten,Innere Medizin,Not provided.
3,53f7ce0a18de7792f478273f9449e175960563c0,Purpose: To explore factors relating to the pr...,Current global outbreak of severe acute respir...,"Wong, Chi-Yan. Tang, Catherine So-Kum",Practice of habitual and volitional health<br...,Journal of Adolescent Health,Purpose: To explore factors relating to the<b...
4,e3c2df2221f21ddbb5ece5f094fa307e26b79e9e,Signal-dependent targeting of proteins into an...,"The mammalian cell is a highly organised, dyna...","Fulcher, Alex J.. Jans, David A.",Regulation of nucleocytoplasmic trafficking<b...,Biochimica et Biophysica Acta (BBA) - Molecula...,Signal-dependent targeting of proteins into<b...


In [20]:
df_covid['abstract_word_count'] = df_covid['abstract'].progress_apply(lambda x: len(x.strip().split()))
df_covid['body_word_count'] = df_covid['body_text'].progress_apply(lambda x: len(x.strip().split()))
df_covid.head()

HBox(children=(FloatProgress(value=0.0, max=30197.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=30197.0), HTML(value='')))




Unnamed: 0,paper_id,abstract,body_text,authors,title,journal,abstract_summary,abstract_word_count,body_word_count
0,d604508a4de30a622bbe2d59d1f73abaaf9c14d2,,"La pathologic virale du syst6me nerveux, nagub...","Cathala, F.",Généralités sur l'infection du système<br>ner...,Revue d'Electroencéphalographie et de Neurophy...,Not provided.,0,3899
1,7cc9c36dd4de535aabba7346c12be7b470f6797f,Although promotion of safe hygiene is the sing...,Promotion of hygiene might be the single most ...,"Curtis, Val. Schmidt, Wolf...","Hygiene: new hopes, new horizons",The Lancet Infectious Diseases,Although promotion of safe hygiene is the<br>...,196,6480
2,b8069299efef2b8cf2d672b0174f6f0a7ef0580e,,Oft ist die Unterscheidung zwischen infektiöse...,,KAPITEL 13 Infektionskrankheiten,Innere Medizin,Not provided.,0,21808
3,53f7ce0a18de7792f478273f9449e175960563c0,Purpose: To explore factors relating to the pr...,Current global outbreak of severe acute respir...,"Wong, Chi-Yan. Tang, Catherine So-Kum",Practice of habitual and volitional health<br...,Journal of Adolescent Health,Purpose: To explore factors relating to the<b...,203,4004
4,e3c2df2221f21ddbb5ece5f094fa307e26b79e9e,Signal-dependent targeting of proteins into an...,"The mammalian cell is a highly organised, dyna...","Fulcher, Alex J.. Jans, David A.",Regulation of nucleocytoplasmic trafficking<b...,Biochimica et Biophysica Acta (BBA) - Molecula...,Signal-dependent targeting of proteins into<b...,196,8390


In [21]:
df_covid['body_word_count'].sum()

145563452

In [22]:
df_covid.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 30197 entries, 0 to 30196
Data columns (total 9 columns):
 #   Column               Non-Null Count  Dtype 
---  ------               --------------  ----- 
 0   paper_id             30197 non-null  object
 1   abstract             30197 non-null  object
 2   body_text            30197 non-null  object
 3   authors              29646 non-null  object
 4   title                30162 non-null  object
 5   journal              29077 non-null  object
 6   abstract_summary     30197 non-null  object
 7   abstract_word_count  30197 non-null  int64 
 8   body_word_count      30197 non-null  int64 
dtypes: int64(2), object(7)
memory usage: 2.1+ MB


In [90]:
df_covid.to_csv("../../data/processed/covid-2019-scientific-papers.csv")

In [25]:
!ls -lah ../../data/processed/

total 958M
drwxrwxr-x 2 science science 4,0K Apr  3 19:34 .
drwxrwxr-x 8 science science 4,0K Apr  2 17:58 ..
-rw-rw-r-- 1 science science 958M Apr  3 19:34 covid-2019-scientific-papers.csv
-rw-rw-r-- 1 science science    0 Mar 24 14:36 .gitkeep


In [69]:
!pip install langdetect

Collecting langdetect
  Downloading langdetect-1.0.8.tar.gz (981 kB)
[K     |████████████████████████████████| 981 kB 1.1 MB/s eta 0:00:01
Installing collected packages: langdetect
    Running setup.py install for langdetect ... [?25ldone
[?25hSuccessfully installed langdetect-1.0.8


In [70]:
from langdetect import detect

In [71]:
detect(df_covid.body_text[0])

'fr'

In [82]:
detect(df_covid.body_text[1560])

'en'

In [86]:
def detect_lang(text):
    try:
        return detect(str(text))
    except:
        return None

In [87]:
df_covid['lang'] = df_covid['body_text'].progress_apply(lambda x: detect_lang(x))

HBox(children=(FloatProgress(value=0.0, max=30197.0), HTML(value='')))




In [89]:
len(df_covid['lang'])

30197

In [88]:
df_covid['lang'].value_counts()

en       29492
fr         315
es         274
de          67
it          14
pt          10
pl           2
ca           2
nl           2
et           2
cy           2
no           1
id           1
da           1
lt           1
af           1
tl           1
zh-cn        1
so           1
Name: lang, dtype: int64

In [85]:
df_covid['lang']

0        None
1        None
2        None
3        None
4        None
         ... 
30192    None
30193    None
30194    None
30195    None
30196    None
Name: lang, Length: 30197, dtype: object

In [26]:
df_covid.body_text[0]

'La pathologic virale du syst6me nerveux, nagubre consid6r6e avec r6ticence, a acquis droit de cit6. Ce ne sont pas tellement la multiplicit6 des affections cliniques maintenant bien ~tudi6es, ni la diver-sit6 des virus << neurotropes >> qui font l\'int6r6t de ce tr~s vaste domaine d\'6tude, mais bien plus la r6v61ation de m6canismes physiopathologiques complexes et encore insuffisamment 6clair6s qui semblent avoir une r6elle port6e en pathologie g6n6rale. Ainsi les << infections virales lentes >~ (10) (31) qui nous offrent probablement des exemples de r6actions originales de l\'organisme/a certaines modalit6s de l\'infection.\nParmi beaucoup d\'autres possibles, indiquons quelques-unes des questions que pose cette physiopathologic :\n--Quelles sont les voles de pb, n6tration des virus vers le syst6me nerveux ? et comment le virus s\'y propage-t-il ? Y a-t-il des affinit6s particulibres entre certains virus et certains types cellulaires ? --Quels sont les effets de l\'infection des cel

In [33]:
import gensim
import nltk

In [44]:
from nltk.tokenize import sent_tokenize, word_tokenize
from nltk.corpus import stopwords
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/science/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True

In [91]:
stopWords = set(stopwords.words('english'))

In [50]:
len(stopWords)

810

In [52]:
text = gensim.utils.simple_preprocess(df_covid.body_text[0], deacc=True)

In [54]:
clean_text = [w for w in text if w not in stopWords]

In [55]:
len(text), len(clean_text)

(3907, 2285)

In [118]:
stemmer = SnowballStemmer("english")
stop_words = set(stopwords.words('english'))
with open('../../data/processed/covid-2019-scientific-papers/vowpal_wabbit_corpus.txt', 'w') as the_file:
    for index, row in tqdm(df_covid.iterrows()):
        if row['lang'] != "en":
            continue
        title = gensim.utils.simple_preprocess(str(row["title"]), deacc=True)
        text = gensim.utils.simple_preprocess(str(row["body_text"]), deacc=True)
        clean_title = [w for w in title if w not in stop_words]
        clean_text = [w for w in text if w not in stop_words]
        bigrams = list(nltk.bigrams(clean_text))
        trigrams = list(nltk.trigrams(clean_text))
        bigrams = ["!".join(b) for b in bigrams]
        trigrams = ["!".join(t) for t in trigrams]

        parts = [f"{row['paper_id']}"]
        parts += ['|@title']  + clean_title
        parts += ['|@text']  + text
        parts += ['|@bigrams']  + bigrams
        parts += ['|@trigrams']  + trigrams
        post = ' '.join(parts)
        the_file.write(f"{post}\n")

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [5]:
import glob
import logging
import os
import re
import uuid
from pathlib import Path
from pymongo import MongoClient
import artm
import click
import mlflow
from tqdm.notebook import tqdm
import nltk

import multiprocessing as mp

  signature = inspect.formatargspec(
  from collections import Sequence, defaultdict
  from collections import Counter, Iterable


In [7]:
batch_vectorizer = artm.BatchVectorizer(
    data_path='../../data/processed/covid-2019-scientific-papers/vowpal_wabbit_corpus.txt',
    data_format='vowpal_wabbit',
    batch_size=5000,
    target_folder='../../data/processed/covid-2019-scientific-papers/batches',
)

In [9]:
dictionary = batch_vectorizer.dictionary
dictionary.save_text(dictionary_path='../../data/processed/covid-2019-scientific-papers/dictionary.txt')

InvalidOperationException: Unable to serialize the message

In [10]:
dictionary

artm.Dictionary(name=f28b0602-87d9-48c5-876e-4196f3f0b884, num_entries=80293676)

In [115]:
!ls -lah ../../data/processed/covid-2019-scientific-papers/batches

total 4,1G
drwxrwxr-x 2 science science 4,0K Apr  3 21:34 .
drwxrwxr-x 4 science science 4,0K Apr  3 21:34 ..
-rw-rw-r-- 1 science science 709M Apr  3 21:34 aaaaaa.batch
-rw-rw-r-- 1 science science 709M Apr  3 21:34 aaaaab.batch
-rw-rw-r-- 1 science science 709M Apr  3 21:34 aaaaac.batch
-rw-rw-r-- 1 science science 709M Apr  3 21:34 aaaaad.batch
-rw-rw-r-- 1 science science 709M Apr  3 21:34 aaaaae.batch
-rw-rw-r-- 1 science science 637M Apr  3 21:34 aaaaaf.batch


In [31]:
dictionary.filter(min_df=7, inplace=True)

artm.Dictionary(name=f28b0602-87d9-48c5-876e-4196f3f0b884, num_entries=1981509)

In [32]:
dictionary.save_text(dictionary_path='../../data/processed/covid-2019-scientific-papers/dictionary.txt')

In [6]:
batch_vectorizer = artm.BatchVectorizer(
    data_path='../../data/processed/covid-2019-scientific-papers/batches',
    data_format='batches',
)
dictionary = artm.Dictionary()
dictionary.load_text(dictionary_path='../../data/processed/covid-2019-scientific-papers/dictionary.txt')

In [7]:
def create_topic_names(topic_count=220, background_topic_count=20):
    objective_topics = ['objective_topic_' + str(x) for x in range(0, topic_count - background_topic_count)]
    background_topics = ['background_topic_' + str(x) for x in range(topic_count - background_topic_count, topic_count)]
    all_topics = objective_topics + background_topics

    return all_topics, objective_topics, background_topics

In [8]:
def print_measures(model):
    logging.info('Sparsity Title Phi: {0:.3f}'.format(model.score_tracker['SparsityPhiTitleScore'].last_value))
    logging.info('Sparsity Text Phi: {0:.3f}'.format(model.score_tracker['SparsityPhiTextScore'].last_value))
    logging.info('Sparsity Bigrams Phi: {0:.3f}'.format(model.score_tracker['SparsityPhiBigramsScore'].last_value))
    logging.info('Sparsity Trigrams Phi: {0:.3f}'.format(model.score_tracker['SparsityPhiTrigramsScore'].last_value))
    logging.info('Sparsity Theta: {0:.3f}'.format(model.score_tracker['SparsityThetaScore'].last_value))
    logging.info('Kernel title contrast: {0:.3f}'.format(model.score_tracker['TopicKernelTitleScore'].last_average_contrast))
    logging.info('Kernel text contrast: {0:.3f}'.format(model.score_tracker['TopicKernelTextScore'].last_average_contrast))
    logging.info('Kernel bigrams contrast: {0:.3f}'.format(model.score_tracker['TopicKernelBigramsScore'].last_average_contrast))
    logging.info('Kernel trigrams contrast: {0:.3f}'.format(model.score_tracker['TopicKernelTrigramsScore'].last_average_contrast))
    logging.info('Kernel title purity: {0:.3f}'.format(model.score_tracker['TopicKernelTitleScore'].last_average_purity))
    logging.info('Kernel text purity: {0:.3f}'.format(model.score_tracker['TopicKernelTextScore'].last_average_purity))
    logging.info('Kernel bigrams purity: {0:.3f}'.format(model.score_tracker['TopicKernelBigramsScore'].last_average_purity))
    logging.info('Kernel trigrams purity: {0:.3f}'.format(model.score_tracker['TopicKernelTrigramsScore'].last_average_purity))
    logging.info('Perplexity: {0:.3f}'.format(model.score_tracker['PerplexityScore'].last_value))

In [9]:
def generate_name():
    return str(uuid.uuid1())

In [10]:
def mlflow_log_metrics(model):
    mlflow.log_metrics({
        "DeccorPhi": model.regularizers['DeccorPhi'].tau,
        "SmoothPhi": model.regularizers['SmoothPhi'].tau,
        "SmoothTheta": model.regularizers['SmoothTheta'].tau,
        "SparsePhi": model.regularizers['SparsePhi'].tau,
        "SparseTheta": model.regularizers['SparseTheta'].tau,
        "SparsityPhiTitleScore": model.score_tracker['SparsityPhiTitleScore'].last_value,
        "SparsityPhiTextScore": model.score_tracker['SparsityPhiTextScore'].last_value,
        "SparsityPhiBigramsScore": model.score_tracker['SparsityPhiBigramsScore'].last_value,
        "SparsityPhiTrigramsScore": model.score_tracker['SparsityPhiTrigramsScore'].last_value,
        "SparsityThetaScore": model.score_tracker['SparsityThetaScore'].last_value,
        "KernelContrastTitleScore": model.score_tracker['TopicKernelTitleScore'].last_average_contrast,
        "KernelContrastTextScore": model.score_tracker['TopicKernelTextScore'].last_average_contrast,
        "KernelContrastBigramsScore": model.score_tracker['TopicKernelBigramsScore'].last_average_contrast,
        "KernelContrastTrigramsScore": model.score_tracker['TopicKernelTrigramsScore'].last_average_contrast,
        "TopicPurityTitleScore": model.score_tracker['TopicKernelTitleScore'].last_average_purity,
        "TopicPurityTextScore": model.score_tracker['TopicKernelTextScore'].last_average_purity,
        "TopicPurityBigramsScore": model.score_tracker['TopicKernelBigramsScore'].last_average_purity,
        "TopicPurityTrigramsScore": model.score_tracker['TopicKernelTrigramsScore'].last_average_purity,
        "PerplexityScore": model.score_tracker['PerplexityScore'].last_value,
    }, step=model.num_phi_updates)

In [11]:
experiments_path = "../../data/processed/covid-2019-scientific-papers/"

In [12]:
def next_step(i, model, batch_vectorizer, step_size):
    model_name = generate_name()
    print(model_name)

    for _ in tqdm(range(step_size)):
        model.fit_offline(batch_vectorizer=batch_vectorizer, num_collection_passes=1)
        mlflow_log_metrics(model)
    print_measures(model)

    Path(os.path.join(experiments_path, 'models')).mkdir(parents=True, exist_ok=True)
    model_dir_name = os.path.join(experiments_path, 'models', f"{model_name}")
    model.dump_artm_model(model_dir_name)
    mlflow.set_tag(f"model_dump_{i}", model_dir_name)
    mlflow.log_artifacts(model_dir_name)

In [13]:
import logging

In [14]:
min_df = 5
num_all_topics = 220
num_background_topics = 20
step_size = 20

In [15]:
all_topics, objective_topics, background_topics = create_topic_names(num_all_topics, num_background_topics)

In [16]:
scores_artm = [
        artm.PerplexityScore(name='PerplexityScore', dictionary=dictionary, class_ids=["@title", "@text", "@bigrams", "@trigrams"]),
        artm.SparsityPhiScore(name='SparsityPhiTitleScore', topic_names=objective_topics, class_id="@title"),
        artm.SparsityPhiScore(name='SparsityPhiTextScore', topic_names=objective_topics, class_id="@text"),
        artm.SparsityPhiScore(name='SparsityPhiBigramsScore', topic_names=objective_topics, class_id="@bigrams"),
        artm.SparsityPhiScore(name='SparsityPhiTrigramsScore', topic_names=objective_topics, class_id="@trigrams"),
        artm.SparsityThetaScore(name='SparsityThetaScore', topic_names=objective_topics),
        artm.TopTokensScore(name='TopTokensTitleScore', num_tokens=20, topic_names=objective_topics, dictionary=dictionary, class_id="@title"),
        artm.TopTokensScore(name='TopTokensTextScore', num_tokens=20, topic_names=objective_topics, dictionary=dictionary, class_id="@text"),
        artm.TopTokensScore(name='TopTokensBigramsScore', num_tokens=20, topic_names=objective_topics, dictionary=dictionary, class_id="@bigrams"),
        artm.TopTokensScore(name='TopTokensTrigramsScore', num_tokens=20, topic_names=objective_topics, dictionary=dictionary, class_id="@trigrams"),
        artm.TopicKernelScore(name='TopicKernelTitleScore', class_id="@title", probability_mass_threshold=0.25,
                              topic_names=objective_topics, dictionary=dictionary),
        artm.TopicKernelScore(name='TopicKernelTextScore', class_id="@text", probability_mass_threshold=0.25,
                              topic_names=objective_topics, dictionary=dictionary),
        artm.TopicKernelScore(name='TopicKernelBigramsScore', class_id="@bigrams", probability_mass_threshold=0.25,
                              topic_names=objective_topics, dictionary=dictionary),
        artm.TopicKernelScore(name='TopicKernelTrigramsScore', class_id="@trigrams", probability_mass_threshold=0.25,
                              topic_names=objective_topics, dictionary=dictionary),
    ]

In [17]:
regularizers_artm = [
        artm.DecorrelatorPhiRegularizer(name='DeccorPhi', topic_names=objective_topics, gamma=0, tau=0),
        artm.SmoothSparsePhiRegularizer(name='SparsePhi', topic_names=objective_topics, dictionary=dictionary, gamma=0,
                                        tau=0),
        artm.SmoothSparsePhiRegularizer(name='SmoothPhi', topic_names=background_topics, dictionary=dictionary, gamma=0,
                                        tau=0),
        artm.SmoothSparseThetaRegularizer(name='SparseTheta', topic_names=objective_topics, tau=0),
        artm.SmoothSparseThetaRegularizer(name='SmoothTheta', topic_names=background_topics, tau=0),
        artm.TopicSelectionThetaRegularizer(name='TopicSelectionTheta', topic_names=objective_topics, tau=0)
    ]

In [18]:
model = artm.ARTM(
        num_topics=num_all_topics,
        topic_names=all_topics,
        class_ids={'@title': 3.0, '@text': 1.0, "@bigrams": 2.0, "@trigrams": 4.0},
        num_processors=mp.cpu_count() - 1,
        num_document_passes=2,
        regularizers=regularizers_artm,
        scores=scores_artm,
        dictionary=dictionary,
        cache_theta=False,
        seed=42,
        show_progress_bars=False
    )

In [None]:
mlflow.set_experiment("covid-2019-scientific-papers")
with mlflow.start_run():

    # этап 1 - сильная декорреляция + сглаживание
    # Sparse < 0
    # Smooth > 0
    model.regularizers['DeccorPhi'].tau = 0.005
    model.regularizers['SmoothPhi'].tau = 0.4
    model.regularizers['SmoothTheta'].tau = 0.4
    next_step(1, model, batch_vectorizer, step_size)

    model.regularizers['DeccorPhi'].tau = 0.015
    model.regularizers['SmoothPhi'].tau = 0.6
    model.regularizers['SmoothTheta'].tau = 0.6
    next_step(2, model, batch_vectorizer, step_size)

    model.regularizers['DeccorPhi'].tau = 0.03
    model.regularizers['SmoothPhi'].tau = 0.8
    model.regularizers['SmoothTheta'].tau = 0.8
    next_step(3, model, batch_vectorizer, step_size)

    # этап 2 - подключение разреживания предметных, постепенное увеличение разреживания
    # Sparse < 0
    # Smooth > 0
    model.regularizers['SparsePhi'].tau = -0.0001
    model.regularizers['SparseTheta'].tau = -0.1
    next_step(4, model, batch_vectorizer, step_size)

    model.regularizers['SparsePhi'].tau = -0.0002
    model.regularizers['SparseTheta'].tau = -0.2
    next_step(5, model, batch_vectorizer, step_size)

    model.regularizers['SparsePhi'].tau = -0.0003
    model.regularizers['SparseTheta'].tau = -0.3
    next_step(6, model, batch_vectorizer, step_size)

    # этап 3
    # Sparse < 0
    # Smooth > 0
    model.regularizers['SparsePhi'].tau = -0.0005
    model.regularizers['SparseTheta'].tau = -0.4
    next_step(7, model, batch_vectorizer, step_size)

    model.regularizers['SparsePhi'].tau = -0.0006
    model.regularizers['SparseTheta'].tau = -0.5
    next_step(8, model, batch_vectorizer, step_size)

    model.regularizers['SparsePhi'].tau = -0.0007
    model.regularizers['SparseTheta'].tau = -0.6
    next_step(9, model, batch_vectorizer, step_size)

72ee7db2-75e9-11ea-b748-0ae0afb50062


HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))