In [1]:
import torch
from tqdm import tqdm
from sklearn.datasets import fetch_20newsgroups
from experiments.preprocess import TextPrep
from experiments.experiment import snn_experiment, lsa_experiment, binary_tfidf_experiment
from experiments.params import random_seed

In [4]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [3]:
preprocessor = TextPrep(svd_components=1000, prob_iterations=200, max_features=None)
cats = ['comp.graphics','sci.med'] #, 'talk.politics.guns', 'rec.motorcycles', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='test', categories=cats, remove=('headers', 'footers', 'quotes'))

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
data_x, data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=False, spikes=False)
data_x.to(device)
data_y.to(device)
data_x.size()



torch.Size([785, 13722])

In [5]:
binary_tfidf_experiment(data_x, data_y, clf_type="regression", seed=random_seed)

5it [9:22:40, 6752.10s/it]


0.467515923566879
[[214 175]
 [243 153]]


## Three categories

In [2]:
preprocessor = TextPrep(svd_components=10000, prob_iterations=100, max_features=12000)
cats = ['comp.graphics', 'sci.med', 'talk.politics.guns'] #, 'rec.motorcycles', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train', categories=cats, remove=('headers', 'footers', 'quotes'))

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
# Only LSA
lsa_data_x, lsa_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=False)
lsa_data_x.to(device)
lsa_data_y.to(device)
lsa_data_x.size()

torch.Size([1724, 1724])

In [6]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="regression", seed=random_seed, splits=8)

8it [00:04,  1.75it/s]

0.9222737819025522
[[568  14   2]
 [ 48 530  16]
 [ 39  15 492]]





In [7]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="random_forest", seed=random_seed, splits=8)

8it [01:35, 11.88s/it]

0.7621809744779582
[[485  74  25]
 [ 89 466  39]
 [ 57 126 363]]





In [8]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="xgboost", seed=random_seed, splits=8)

0it [00:00, ?it/s]

8it [03:00, 22.57s/it]

0.8636890951276102
[[516  50  18]
 [ 44 506  44]
 [ 30  49 467]]





In [11]:
lsa_experiment(lsa_data_x.relu(), lsa_data_y, clf_type="naive_bayes", seed=random_seed, splits=8)

0it [00:00, ?it/s]

8it [00:00, 41.23it/s]

0.679814385150812
[[493  87   4]
 [134 454   6]
 [154 167 225]]





In [10]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="svc", seed=random_seed, splits=8)

8it [00:14,  1.85s/it]

0.9292343387470998
[[572   9   3]
 [ 45 537  12]
 [ 33  20 493]]





In [5]:
#TF-IDF and spikes
spike_data_x, spike_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=False, spikes=True)
spike_data_x.to(device)
spike_data_y.to(device)
spike_data_x.size()

torch.Size([1724, 100, 12000])

In [6]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", seed=random_seed, splits=8)

8it [47:44, 358.10s/it]


0.8109048723897911
[[482  59  43]
 [ 74 477  43]
 [ 47  60 439]]


In [7]:
snn_experiment(spike_data_x, spike_data_y, clf_type="random_forest", seed=random_seed, splits=8)

8it [47:23, 355.49s/it]


0.6368909512761021
[[456  90  38]
 [170 359  65]
 [155 108 283]]


In [8]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", seed=random_seed, splits=8)

0it [00:00, ?it/s]

8it [48:39, 364.98s/it]


0.7459396751740139
[[479  67  38]
 [106 425  63]
 [ 79  85 382]]


In [9]:
snn_experiment(spike_data_x, spike_data_y, clf_type="naive_bayes", seed=random_seed, splits=8)

8it [46:51, 351.43s/it]


0.7134570765661253
[[442 132  10]
 [ 73 505  16]
 [ 44 219 283]]


In [10]:
snn_experiment(spike_data_x, spike_data_y, clf_type="svc", seed=random_seed, splits=8)

8it [48:05, 360.66s/it]


0.8167053364269141
[[480  68  36]
 [ 58 484  52]
 [ 50  52 444]]
