# Set up

In [None]:
!pip install gitpython

In [None]:
import git

from git import Repo

git_url = 'https://github.com/AdhyaSuman/S2WTM'
repo_dir = 'S2WTM_local'

Repo.clone_from(git_url, repo_dir)

In [None]:
# Go to the home directory of the repo
cd S2WTM_local

In [None]:
!pip install -e.

# Imports

In [None]:
cd ..

In [None]:
from octis.dataset.dataset import Dataset

#Import models:
from octis.models.S2WTM import S2WTM

#Import coherence metrics:
from octis.evaluation_metrics.coherence_metrics import *

#Import TD metrics:
from octis.evaluation_metrics.diversity_metrics import *

#Import classification metrics:
from octis.evaluation_metrics.classification_metrics import *

import random, torch

# Utils

In [None]:
data_dir = './preprocessed_datasets'

def get_dataset(dataset_name):
    data = Dataset()
    if dataset_name=='20NG':
        data.fetch_dataset("20NewsGroup")
    
    elif dataset_name=='BBC':
        data.fetch_dataset("BBC_News")
    
    elif dataset_name=='M10':
        data.fetch_dataset("M10")
    
    elif dataset_name=='SearchSnippets':
        data.load_custom_dataset_from_folder(data_dir + "/SearchSnippets")
    
    elif dataset_name=='Pascal_Flickr':
        data.load_custom_dataset_from_folder(data_dir + "/Pascal_Flickr")
    
    elif dataset_name=='Bio':
        data.load_custom_dataset_from_folder(data_dir + "/Bio")
        
    elif dataset_name=='DBLP':
        data.fetch_dataset("DBLP")
    
    else:
        raise Exception('Missing Dataset name...!!!')
    return data

# Run

In [None]:
import os
from random import randint
from IPython.display import clear_output

seeds = [randint(0, 2e3) for _ in range(1)]

datasets = ['20NG', 'BBC', 'M10', 'SearchSnippets', 'Pascal_Flickr', 'Bio', 'DBLP']

results = {
    'Dataset': [],
    'K': [],
    'Seed': [],
    'NPMI': [],
    'CV': []
}

partition = False
validation = False

for seed in seeds:
    for d in datasets:
        data = get_dataset(d)
        k = len(set(data.get_labels()))

        print('Results:-\n', results)
        print("-"*100)
        print('Dataset:{},\t K={},\t Seed={}'.format(d, k, seed))
        print("-"*100)

        random.seed(seed)
        torch.random.manual_seed(seed)

        model = S2WTM(
        num_topics=len(set(data.get_labels())),
        use_partitions=partition,
        use_validation=validation,
        num_epochs=100,
        )

        output = model.train_model(dataset=data)

        del model
        torch.cuda.empty_cache()

        #Hyperparams:
        results['Dataset'].append(d)
        results['K'].append(k)
        results['Seed'].append(seed)
        #############
        #Coherence Scores:
        npmi = Coherence(texts=data.get_corpus(), topk=10, measure='c_npmi')
        results['NPMI'].append(npmi.score(output))
        del npmi

        cv = Coherence(texts=data.get_corpus(), topk=10, measure='c_v')
        results['CV'].append(cv.score(output))
        del cv

        clear_output(wait=False)
results                              