In [9]:
import os
import re
import nltk
import spacy
import gensim

from nltk.corpus import stopwords
from collections import defaultdict

from gensim.corpora import Dictionary
from gensim.utils import simple_preprocess
from gensim.models import AuthorTopicModel, CoherenceModel

In [10]:
# https://course.spacy.io/en/
# https://radimrehurek.com/gensim/auto_examples/index.html
# Good reference on visualization: https://markroxor.github.io/gensim/static/notebooks/topic_coherence_tutorial.html
# https://www.machinelearningplus.com/nlp/topic-modeling-gensim-python/

In [11]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to /project/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

### Add to Dockerfile

In [12]:
# !pip install nltk
# !pip install spacy
# !pip install gensim
# !pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.4.0/en_core_web_md-3.4.0.tar.gz

### From *Anne*

1. What is the meaning of text within square brackets? That is data that was masked for privacy purposes, you should remove it from your analysis as it does not contain relevant information for conditions or underlying factors.
2. Annotations represent the output of a named entity recognition process. How to make sense of the annotations? How to read and understand the annotations? What do T1, R1, etc. mean? How to interpret the codes in annotation files? T stands for Term and R stands for Relation. In the below you can read R1 as "ativan taken for recurrent seizures" and R4 as "ativan given IM (intramuscularly)". For the T's you get the starting and ending index location for the term. So "recurrent seizures" starts at position 10179 and ends at 10197.
Here are some examples:
- T1 Reason 10179 10197 recurrent seizures
- R1 Reason-Drug Arg1:T1 Arg2:T3
- T3 Drug 10227 10233 ativan
- T5 Route 10240 10242 IM
- R4 Route-Drug Arg1:T5 Arg2:T3
3. Is there any hint/suggestions on how to use annotations? Use the non-drug terms to validate your underlying factors. While not all factors from the text will be found in the annotations, all factors in the annotations should be in the text.

Some other tips:
- you'll want to extract only the sections you're interested in from the text documents, as other sections (such as family history) will have confounding information.
- scispacy (en_core_sci_md) has a helpful NER that can be used to limit your text to just medically relevant terms
- don't forget to evaluate your model results
- it would be better to refine one model (based on results) than to do two different models


In [13]:
class TopicModeler(object):
    
    def __init__(self, data_path, use_lemma=False):
        self.data_path = data_path
        self.extract = ["discharge diagnosis", "chief complaint", "history of present illness"]
        self.use_lemma = use_lemma
        self.used_pos = ['NOUN', 'ADJ', 'VERB', 'ADV']
        self.nlp = spacy.load('en_core_web_md')#, disable=['parser', 'ner'])
        
    def check_line(self, line):
        for heading in self.extract:
            if heading in line.lower():
                return True
        return False
        
    def load_documents(self):
        docs = [] # [<text of documents in the data>]
        doc_ids = [] # {<document id from path>: <0 based index using docs>}
        condition2doc = {} # {<chief complaint>: [document ids from which complaint is extracted]}
        
        for file_path in os.listdir(self.data_path):
            
            if file_path.endswith(".txt"):
                
                with open(os.path.join(self.data_path, file_path)) as f:

                    docs.append(f.read())
                    doc_id = int(file_path.split('.')[0])
                    doc_ids.append(doc_id)

                    # Extract medical conditions
                    f.seek(0)
                    line = " "
                    conditions = None
                    while line:
                        line = f.readline()
                        if self.check_line(line): 
                            conditions = f.readline().split(',') # \n?
                            break

                    conditions = conditions if isinstance(conditions, list) else [""]
                    for condition in conditions:
                        condition = condition.lower().strip()
                        condition = re.sub(r'[.?!\'";:,]', "", condition)
                        doc_set = condition2doc.setdefault(condition, [])
                        doc_set.append(doc_id)
                        
        doc_id_map = dict(zip(doc_ids, range(len(doc_ids))))
        for condition, ids in condition2doc.items():
            for i, doc_id in enumerate(ids):
                condition2doc[condition][i] = doc_id_map[doc_id]
                
        return docs, condition2doc, doc_id_map
    
    def load_annotations(self):
        anns_data = {} # {<id from path>: {<T>: {keys: word, start, end, info}, <R>: {keys: word, arg1, arg2}}}
        anns2factors = defaultdict(set) # {<id from path>: {<set of factors>}}

        for file_path in os.listdir(self.data_path):
            
            if file_path.endswith(".ann"):
                
                with open(os.path.join(self.data_path, file_path)) as f:

                    lines = f.readlines()
                    data = defaultdict(dict)
                    ann_id = int(file_path.split('.')[0])

                    for line in lines:
                        split_line = line.split()
                        
                        if split_line[0].startswith('T'):
                            term = split_line[0]
                            word = split_line[1]
                            
                            if word.startswith("Reason"):
                                data[term]['word'] = word
                                data[term]['start'] = int(split_line[2])
                                end = split_line[3]
                                
                                if ";" in end: # Just extract the 1st start and end
                                    end = end.split(';')[0]
                                    data[term]['end'] = int(end)
                                else:
                                    data[term]['end'] = int(end)
                                    
                                data[term]['info'] = ' '.join([item for item in split_line[4:] if not item.isdigit()])
                                
                        elif split_line[0].startswith('R'):
                            relation = split_line[0]
                            word = split_line[1]
                            
                            if word.startswith("Reason"):
                                data[relation]['word'] = word
                                data[relation]['arg1'] = split_line[2].split(':')[1]
                                data[relation]['arg2'] = split_line[3].split(':')[1]
                        else:
                            pass

                    anns_data[ann_id] = data
            
        for key in anns_data:
            for x in anns_data[key]:
                if x.startswith('T'):
                    anns2factors[key].update(set(anns_data[key][x]['info'].lower().split()))
                    
        return anns_data, anns2factors
                    
    @staticmethod
    def doc_to_words(documents):
        words = [simple_preprocess(doc) for doc in documents]
        return words
    
    @staticmethod
    def remove_stopwords(docs, stop_words):
        tokens = [[word for word in simple_preprocess(str(doc)) if word not in stop_words] for doc in docs]
        return tokens
    
    def lemmatize(self, docs):
        lemmas = [[token.lemma_ for token in self.nlp(' '.join(doc)) if token.pos_ in self.used_pos] for doc in docs]
        return lemmas
    
    def preprocess(self, no_below=5, no_above=0.5):
        stop_words = set(stopwords.words('english'))
        docs, condition2doc, doc_id_map = self.load_documents()
        
        # Remove punctuation, whitespace, PHI
        docs = [re.sub(r'\[\*\*.+?\*\*\]|[,.\'!?]', '', doc) for doc in docs]
        docs = [re.sub(r'\s+', r' ', doc) for doc in docs]
        
        # Tokenize documents and remove stop words
        docs = self.doc_to_words(docs)
        docs = self.remove_stopwords(docs, stop_words)
        
        # Lemmatize documents
        if self.use_lemma:
            docs = self.lemmatize(docs)
        
        idx2word = Dictionary(docs)
        idx2word.filter_extremes(no_below=no_below, no_above=no_above)
        
        corpus = [idx2word.doc2bow(doc) for doc in docs]
        
        return corpus, docs, idx2word, condition2doc, doc_id_map
    
    def get_topics(self, num_topics=100, chunk_size=50, passes=10, alpha='symmetric', eta='symmetric'):
        
        corpus, docs, idx2word, condition2doc, doc_id_map = self.preprocess()
        self.cache = {'corpus': corpus, 'docs': docs, 'doc_id_map': doc_id_map, 'idx2word': idx2word, 
                      'condition2doc': condition2doc}
        
        # https://radimrehurek.com/gensim/models/atmodel.html
        model = AuthorTopicModel(corpus=corpus, num_topics=num_topics, id2word=idx2word, author2doc=condition2doc, 
                                 chunksize=chunk_size, passes=passes, alpha=alpha, eta=eta, random_state=7)
        
        # https://radimrehurek.com/gensim/models/coherencemodel.html
        coherence_model = CoherenceModel(model=model, texts=docs, dictionary=idx2word, coherence='c_v')
        coherence = coherence_model.get_coherence()
        
        topics = {topic: [word[0] for word in words] for topic, words in model.show_topics(num_topics, num_words=100, formatted=False)}
        conditions = {condition: model.get_author_topics(condition) for condition in model.id2author.values()}
        conditions = {condition: topics.get(max(scores, key=lambda x: x[1])[0]) for condition, scores in conditions.items()}
            
        self.cache['conditions'] = conditions
        self.cache['coherence'] = coherence
        
        return conditions
    
    def evaluate(self, topics):
        anns_data, anns2factors = self.load_annotations()
        
        idx_to_doc_id = {idx: doc_id for doc_id, idx in self.cache['doc_id_map'].items()}
        
        count = 0
        common_factors = defaultdict(set)
        for condition in self.cache['condition2doc']:
            if condition in topics:
                doc_idx = self.cache['condition2doc'][condition]
                ann_factors = set()
                for idx in doc_idx:
                    doc_id = idx_to_doc_id[idx]
                    ann_factors.update(anns2factors[doc_id])
                topic_factors = set(topics[condition])
                common = topic_factors.intersection(ann_factors)
                if common:
                    count += 1
                    common_factors[condition].update(common)

        self.cache['common_factors'] = common_factors
        return self.cache['coherence'], count / len(topics)

In [14]:
# Train modeler
topic_modeler = TopicModeler("../data/training_20180910/")
topics = topic_modeler.get_topics()
coherence, fraction_detected = topic_modeler.evaluate(topics)

In [15]:
print(f"Coherence: {coherence}, Fraction detected (1 or more words matched): {fraction_detected}")

Coherence: 0.37374950084007047, Fraction detected (1 or more words matched): 0.7470817120622568


In [16]:
freq = {}
for key in topic_modeler.cache['common_factors']:
    size  = len(topic_modeler.cache['common_factors'][key])
    freq[size] = freq.get(size, 0) + 1
    
freq

{2: 44, 3: 32, 6: 6, 1: 59, 5: 18, 8: 4, 4: 21, 9: 4, 7: 4}

In [26]:
size = len(topics)
for key in freq:
    fraction = freq[key] / size
    print(f"\n\tFraction of documents where topic model output matched {key} words from documents annotation: {fraction:.3f}")


	Fraction of documents where topic model output matched 2 words from documents annotation: 0.171

	Fraction of documents where topic model output matched 3 words from documents annotation: 0.125

	Fraction of documents where topic model output matched 6 words from documents annotation: 0.023

	Fraction of documents where topic model output matched 1 words from documents annotation: 0.230

	Fraction of documents where topic model output matched 5 words from documents annotation: 0.070

	Fraction of documents where topic model output matched 8 words from documents annotation: 0.016

	Fraction of documents where topic model output matched 4 words from documents annotation: 0.082

	Fraction of documents where topic model output matched 9 words from documents annotation: 0.016

	Fraction of documents where topic model output matched 7 words from documents annotation: 0.016


In [None]:
size = len(anns2factors)
count = 0
for key in anns2factors:
    count += len(anns2factors[key])
    
print(count / size)

In [None]:
topic_modeler.cache.keys()

In [None]:
topic_modeler.cache['condition2doc']['abdominal pain']

In [None]:
topic_modeler.cache['doc_id_map'][100035]

In [None]:
idx_to_doc_id = {idx: doc_id for doc_id, idx in topic_modeler.cache['doc_id_map'].items()}

In [None]:
condition = 'abdominal pain'
doc_idx = topic_modeler.cache['condition2doc'][condition]
ann_factors = set()
for idx in doc_idx:
    doc_id = idx_to_doc_id[idx]
    ann_factors.update(anns2factors[doc_id])
    
print(ann_factors)
print('\n',topics[condition])

set(topics[condition]).intersection(ann_factors)

In [None]:
print(anns2factors[100035])

In [None]:
print(coherence)

In [None]:
print(coherence)

In [None]:
print(coherence)

In [None]:
for condition in topic_modeler.cache['condition2doc']:
    if condition in topics:
        doc_idx = topic_modeler.cache['condition2doc'][condition]
        ann_factors = set()
        for idx in doc_idx:
            doc_id = idx_to_doc_id[idx]
            ann_factors.update(anns2factors[doc_id])
        topic_factors = set(topics[condition])
        print(topic_factors.intersection(ann_factors))

In [None]:
topic_modeler.cache['condition2doc']

In [None]:
topic_modeler.cache['condition2doc'].keys()

In [None]:
anns2factors[100035]

In [None]:
len(topic_modeler.cache['docs']), len(topic_modeler.cache['corpus'])

In [None]:
isinstance(topics, dict)

# Scratch

In [None]:
path = "../data/training_20180910/"

docs = [] # [<text of documents in the data>]
doc_ids = [] # {<document id from path>: <0 based index using docs>}
condition2doc = {} # {<chief complaint>: [document ids from which complaint is extracted]}

for file_path in os.listdir(path):

    if file_path.endswith(".txt"):

        with open(os.path.join(path, file_path)) as f:

            docs.append(f.read())
            doc_id = int(file_path.split('.')[0])
            doc_ids.append(doc_id)

            # Extract medical conditions
            f.seek(0) # Go to start of the document
            line = " "
            conditions = None
            while line:
                line = f.readline()
                if self.extract in line.lower(): 
                    conditions = f.readline().split(',')
                    break

            conditions = conditions if isinstance(conditions, list) else [""]
            for condition in conditions:
                condition = condition.lower().strip()
                condition = re.sub(r'[?!\'".;,]', "", condition)
                doc_set = condition2doc.setdefault(condition, [])
                doc_set.append(doc_id)

doc_id_map = dict(zip(doc_ids, range(len(doc_ids))))
for condition, ids in condition2doc.items():
    for i, doc_id in enumerate(ids):
        condition2doc[condition][i] = doc_id_map[doc_id]

In [None]:
docs

In [None]:
path = "../data/training_20180910/"

docs = []
anns = [] 
anns_ids = []
anns2doc = {} 

i = 0
for file_path in os.listdir(path):
    if file_path.endswith(".ann"):
        with open(os.path.join(path, file_path)) as f:
            anns.append(f.readlines())
    if file_path.endswith(".txt"):
        with open(os.path.join(path, file_path)) as f:
            docs.append(f.read())
    i += 1
    if i == 2:
        break

#### Explore Annotations

In [None]:
text = anns[0][:5]
print(''.join(text))

# T stands for Term and R stands for Relation. Read R1 as "ativan taken for recurrent seizures" 
# and R4 as "ativan given IM (intramuscularly)". For the T's you get the starting and ending index 
# location for the term. So "recurrent seizures" starts at position 10179 and ends at 10197

print(f'R1: {docs[0][10179:10197]}')

T = anns[0][0].split()
R = anns[0][1].split()

print(T)
print(R)

# Data Structure to store annotations
# {<annotation id from path>: {<T>: [<type>, <start>, <end>, <information>], 
#                              <R>: [<type>, <arg1>, <arg2>]}}

In [None]:
from collections import defaultdict

path = "../data/training_20180910/"

anns_data = {}
anns2factors = defaultdict(set)

i = 0
for file_path in os.listdir(path):
    if file_path.endswith(".ann"):
        with open(os.path.join(path, file_path)) as f:
            
            lines = f.readlines()            
            ann_id = int(file_path.split('.')[0])
            
            data = defaultdict(dict)
            for line in lines:
                split_line = line.split()
                if split_line[0].startswith('T'):
                    term = split_line[0]
                    word = split_line[1]
                    if word.startswith("Reason"):
                        data[term]['word'] = word
                        data[term]['start'] = int(split_line[2])
                        end = split_line[3]
                        if ";" in end: # Just extract the 1st start and end
                            end = end.split(';')[0]
                            data[term]['end'] = int(end)
                        else:
                            data[term]['end'] = int(end)
                        data[term]['info'] = ' '.join([item for item in split_line[4:] if not item.isdigit()])
                elif split_line[0].startswith('R'):
                    relation = split_line[0]
                    word = split_line[1]
                    if word.startswith("Reason"):
                        data[relation]['word'] = word
                        data[relation]['arg1'] = split_line[2].split(':')[1]
                        data[relation]['arg2'] = split_line[3].split(':')[1]
                else:
                    pass
                    
            anns_data[ann_id] = data
            
for key in anns_data:
    for x in anns_data[key]:
        if x.startswith('T'):
            anns2factors[key].add(anns_data[key][x]['info'])

In [None]:
unique = defaultdict(set)
for key in anns_data:
    for x in anns_data[key]:
        unique[x[0]].add(anns_data[key][x].get('word', ''))
        
print(unique)

# Neural Topic Modeling (Unfinished)

- Copied and modified from: https://github.com/zll17/Neural_Topic_Models

In [None]:
import os
import re
import pickle
from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import gensim
from gensim.models.coherencemodel import CoherenceModel

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

### Helper

In [None]:
def get_topic_words(model,topn=15,n_topic=10,vocab=None,fix_topic=None,showWght=False):
    topics = []
    def show_one_tp(tp_idx):
        if showWght:
            return [(vocab.id2token[t[0]],t[1]) for t in model.get_topic_terms(tp_idx,topn=topn)]
        else:
            return [vocab.id2token[t[0]] for t in model.get_topic_terms(tp_idx,topn=topn)]
    if fix_topic is None:
        for i in range(n_topic):
            topics.append(show_one_tp(i))
    else:
        topics.append(show_one_tp(fix_topic))
    return topics

def calc_topic_diversity(topic_words):
    '''topic_words is in the form of [[w11,w12,...],[w21,w22,...]]'''
    vocab = set(sum(topic_words,[]))
    n_total = len(topic_words) * len(topic_words[0])
    topic_div = len(vocab) / n_total
    return topic_div

def calc_topic_coherence(topic_words,docs,dictionary,emb_path=None,taskname=None,sents4emb=None,calc4each=False):
    # emb_path: path of the pretrained word2vec weights, in text format.
    # sents4emb: list/generator of tokenized sentences.
    # Computing the C_V score
    cv_coherence_model = CoherenceModel(topics=topic_words,texts=docs,dictionary=dictionary,coherence='c_v')
    cv_per_topic = cv_coherence_model.get_coherence_per_topic() if calc4each else None
    cv_score = cv_coherence_model.get_coherence()
    
    # Computing the C_W2V score
    try:
        w2v_model_path = os.path.join(os.getcwd(),'data',f'{taskname}','w2v_weight_kv.txt')
        # Priority order: 1) user's embed file; 2) standard path embed file; 3) train from scratch then store.
        if emb_path!=None and os.path.exists(emb_path):
            keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(emb_path,binary=False)
        elif os.path.exists(w2v_model_path):
            keyed_vectors = gensim.models.KeyedVectors.load_word2vec_format(w2v_model_path,binary=False)
        elif sents4emb!=None:
            print('Training a word2vec model 20 epochs to evaluate topic coherence, this may take a few minutes ...')
            w2v_model = gensim.models.Word2Vec(sents4emb,size=300,min_count=1,workers=6,iter=20)
            keyed_vectors = w2v_model.wv
            keyed_vectors.save_word2vec_format(w2v_model_path,binary=False)
        else:
            raise Exception("C_w2v score isn't available for the missing of training corpus (sents4emb=None).")
            
        w2v_coherence_model = CoherenceModel(topics=topic_words,texts=docs,dictionary=dictionary,coherence='c_w2v',keyed_vectors=keyed_vectors)

        w2v_per_topic = w2v_coherence_model.get_coherence_per_topic() if calc4each else None
        w2v_score = w2v_coherence_model.get_coherence()
    except Exception as e:
        print(e)
        #In case of OOV Error
        w2v_per_topic = [None for _ in range(len(topic_words))]
        w2v_score = None
    
    # Computing the C_UCI score
    c_uci_coherence_model = CoherenceModel(topics=topic_words,texts=docs,dictionary=dictionary,coherence='c_uci')
    c_uci_per_topic = c_uci_coherence_model.get_coherence_per_topic() if calc4each else None
    c_uci_score = c_uci_coherence_model.get_coherence()
    
    
    # Computing the C_NPMI score
    c_npmi_coherence_model = CoherenceModel(topics=topic_words,texts=docs,dictionary=dictionary,coherence='c_npmi')
    c_npmi_per_topic = c_npmi_coherence_model.get_coherence_per_topic() if calc4each else None
    c_npmi_score = c_npmi_coherence_model.get_coherence()
    return (cv_score,w2v_score,c_uci_score, c_npmi_score),(cv_per_topic,w2v_per_topic,c_uci_per_topic,c_npmi_per_topic)

def mimno_topic_coherence(topic_words,docs):
    tword_set = set([w for wlst in topic_words for w in wlst])
    word2docs = {w:set([]) for w in tword_set}
    for docid,doc in enumerate(docs):
        doc = set(doc)
        for word in tword_set:
            if word in doc:
                word2docs[word].add(docid)
    def co_occur(w1,w2):
        return len(word2docs[w1].intersection(word2docs[w2]))+1
    scores = []
    for wlst in topic_words:
        s = 0
        for i in range(1,len(wlst)):
            for j in range(0,i):
                s += np.log((co_occur(wlst[i],wlst[j])+1.0)/len(word2docs[wlst[j]]))
        scores.append(s)
    return np.mean(s)

def evaluate_topic_quality(topic_words, test_data, taskname=None, calc4each=False):
    
    td_score = calc_topic_diversity(topic_words)
    print(f'topic diversity:{td_score}')
    
    (c_v, c_w2v, c_uci, c_npmi),\
        (cv_per_topic, c_w2v_per_topic, c_uci_per_topic, c_npmi_per_topic) = \
        calc_topic_coherence(topic_words=topic_words, docs=test_data.docs, dictionary=test_data.dictionary,
                             emb_path=None, taskname=taskname, sents4emb=test_data, calc4each=calc4each)
    print('c_v:{}, c_w2v:{}, c_uci:{}, c_npmi:{}'.format(
        c_v, c_w2v, c_uci, c_npmi))
    scrs = {'c_v':cv_per_topic,'c_w2v':c_w2v_per_topic,'c_uci':c_uci_per_topic,'c_npmi':c_npmi_per_topic}
    if calc4each:
        for scr_name,scr_per_topic in scrs.items():
            print(f'{scr_name}:')
            for t_idx, (score, twords) in enumerate(zip(scr_per_topic, topic_words)):
                print(f'topic.{t_idx+1:>03d}: {score} {twords}')
    
    mimno_tc = mimno_topic_coherence(topic_words, test_data.docs)
    print('mimno topic coherence:{}'.format(mimno_tc))
    if calc4each:
        return (c_v, c_w2v, c_uci, c_npmi, mimno_tc, td_score), (cv_per_topic, c_w2v_per_topic, c_uci_per_topic, c_npmi_per_topic)
    else:
        return c_v, c_w2v, c_uci, c_npmi, mimno_tc, td_score

def smooth_curve(points, factor=0.9):
    smoothed_points = []
    for pt in points:
        if smoothed_points:
            prev = smoothed_points[-1]
            smoothed_points.append(prev*factor+pt*(1-factor))
        else:
            smoothed_points.append(pt)
    return smoothed_points

### GAN

In [None]:
def block(in_feat, out_feat, normalize=True):
    layers = [nn.Linear(in_feat, out_feat,bias=False)]
    if normalize:
        layers.append(nn.BatchNorm1d(out_feat))
    layers.append(nn.LeakyReLU(0.1, inplace=True))
    return layers


class Generator(nn.Module):
    def __init__(self, bow_dim, hid_dim, n_topic):
        super(Generator,self).__init__()
        self.g = nn.Sequential(*block(n_topic, hid_dim), 
                               nn.Linear(hid_dim,bow_dim), 
                               nn.Softmax(dim=1))

    def inference(self, theta):
        return self.g(theta)
    
    def forward(self, theta):
        bow_f = self.g(theta)
        doc_f = torch.cat([theta,bow_f], dim=1)
        return doc_f
    
    
class Encoder(nn.Module):
    def __init__(self, bow_dim, hid_dim, n_topic):
        super(Encoder,self).__init__()
        self.e = nn.Sequential(*block(bow_dim, hid_dim), 
                               nn.Linear(hid_dim, n_topic, bias=True), 
                               nn.Softmax(dim=1))

    def forward(self, bow):
        theta = self.e(bow)
        doc_r = torch.cat([theta, bow], dim=1)
        return doc_r
    
    
class Discriminator(nn.Module):
    def __init__(self,bow_dim, hid_dim, n_topic):
        super(Discriminator,self).__init__()
        self.d = nn.Sequential(*block(n_topic+bow_dim,hid_dim), nn.Linear(hid_dim,1,bias=True))

    def forward(self,reps):
        score = self.d(reps)
        return score

### Model

In [None]:
class BATM:
    def __init__(self, bow_dim=2000, n_topic=50, hid_dim=1024, device=None, taskname=None):
        self.n_topic = n_topic 
        self.bow_dim = bow_dim
        self.device = device
        self.id2token = None
        self.taskname = taskname

        self.generator = Generator(n_topic=n_topic, hid_dim=hid_dim, bow_dim=bow_dim)
        self.encoder = Encoder(bow_dim=bow_dim, hid_dim=hid_dim, n_topic=n_topic)
        self.discriminator = Discriminator(bow_dim=bow_dim, n_topic=n_topic, hid_dim=hid_dim)

        if device != None:
            self.generator = self.generator.to(device)
            self.encoder = self.encoder.to(device)
            self.discriminator = self.discriminator.to(device)

    def train(self,train_data,batch_size=256, learning_rate=1e-4, test_data=None, num_epochs=100, 
              is_evaluate=False, log_every=10, beta1=0.5, beta2=0.999, clip=0.01, n_critic=5):
        
        self.generator.train()
        self.encoder.train()
        self.discriminator.train()
        self.id2token = {v: k for k, v in train_data.dictionary.token2id.items()}
        data_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=train_data.collate_fn)

        optim_G = torch.optim.Adam(self.generator.parameters(), lr=learning_rate, betas=(beta1, beta2))
        optim_E = torch.optim.Adam(self.encoder.parameters(), lr=learning_rate, betas=(beta1,beta2))
        optim_D = torch.optim.Adam(self.discriminator.parameters(), lr=learning_rate, betas=(beta1,beta2))
        Gloss_lst, Eloss_lst, Dloss_lst = [], [], []
        c_v_lst, c_w2v_lst, c_uci_lst, c_npmi_lst, mimno_tc_lst, td_lst = [], [], [], [], [], []
        
        for epoch in range(num_epochs):
            epochloss_lst = []
            for iter, data in enumerate(data_loader):
                txts, bows_real = data
                bows_real = bows_real.to(self.device)
                bows_real /= torch.sum(bows_real, dim=1, keepdim=True)

                # Train Discriminator
                optim_D.zero_grad()
                
                theta_fake = torch.from_numpy(np.random.dirichlet(alpha=1.0*np.ones(self.n_topic)/self.n_topic, size=(len(bows_real)))).float().to(self.device)
                loss_D = -1.0*torch.mean(self.discriminator(self.encoder(bows_real).detach())) + torch.mean(self.discriminator(self.generator(theta_fake).detach()))

                loss_D.backward()
                optim_D.step()

                for param in self.discriminator.parameters():
                    param.data.clamp_(-clip, clip)
                
                if iter % n_critic==0:
                    # Train Generator
                    optim_G.zero_grad()
                    
                    loss_G = -1.0*torch.mean(self.discriminator(self.generator(theta_fake)))
                    
                    loss_G.backward()
                    optim_G.step()

                    # Train Encoder
                    optim_E.zero_grad()

                    loss_E = torch.mean(self.discriminator(self.encoder(bows_real)))

                    loss_E.backward()
                    optim_E.step()

                    Dloss_lst.append(loss_D.item())
                    Gloss_lst.append(loss_G.item())
                    Eloss_lst.append(loss_E.item())
                    print(f'Epoch {(epoch+1):>3d}\tIter {(iter+1):>4d}\tLoss_D:{loss_D.item():<.7f}\tLoss_G:{loss_G.item():<.7f}\tloss_E:{loss_E.item():<.7f}')
            
            if (epoch+1) % log_every == 0:
                print(f'Epoch {(epoch+1):>3d}\tLoss_D_avg:{sum(Dloss_lst)/len(Dloss_lst):<.7f}\tLoss_G_avg:{sum(Gloss_lst)/len(Gloss_lst):<.7f}\tloss_E_avg:{sum(Eloss_lst)/len(Eloss_lst):<.7f}')
                print('\n'.join([str(lst) for lst in self.show_topic_words()]))
                print('='*30)
                smth_pts_d = smooth_curve(Dloss_lst)
                smth_pts_g = smooth_curve(Gloss_lst)
                smth_pts_e = smooth_curve(Eloss_lst)
                plt.cla()
                plt.plot(np.array(range(len(smth_pts_g)))*log_every, smth_pts_g, label='loss_G')
                plt.plot(np.array(range(len(smth_pts_d)))*log_every, smth_pts_d, label='loss_D')
                plt.plot(np.array(range(len(smth_pts_e)))*log_every, smth_pts_e, label='loss_E')
                plt.legend()
                plt.xlabel('epochs')
                plt.title('Train Loss')
                plt.savefig('batm_trainloss.png')
                if test_data!=None:
                    c_v, c_w2v, c_uci, c_npmi, mimno_tc, td = self.evaluate(test_data, calc4each=False)
                    c_v_lst.append(c_v), c_w2v_lst.append(c_w2v), c_uci_lst.append(c_uci), c_npmi_lst.append(c_npmi), mimno_tc_lst.append(mimno_tc), td_lst.append(td)
        
    def evaluate(self, test_data, calc4each=False):
        topic_words = self.show_topic_words()
        return evaluate_topic_quality(topic_words, test_data, taskname=self.taskname, calc4each=calc4each)

    def show_topic_words(self, topic_id=None, topK=15):
        with torch.no_grad():
            topic_words = []
            idxes = torch.eye(self.n_topic).to(self.device)
            word_dist = self.generator.inference(idxes)
            vals, indices = torch.topk(word_dist, topK, dim=1)
            vals = vals.cpu().tolist()
            indices = indices.cpu().tolist()
            if topic_id == None:
                for i in range(self.n_topic):
                    topic_words.append([self.id2token[idx] for idx in indices[i]])
            else:
                topic_words.append([self.id2token[idx] for idx in indices[topic_id]])
            return topic_words

### Data `TODO`

### Trainer `TODO`

### Train Model `TODO`

### Inference `TODO`

### Evaluate `TODO`