In [1]:
import torch
from tqdm import tqdm
from sklearn.datasets import fetch_20newsgroups
from experiments.preprocess import TextPrep
from experiments.experiment import snn_experiment, lsa_experiment
from experiments.params import random_seed

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [3]:
preprocessor = TextPrep(svd_components=1000, prob_iterations=100)
cats = ['comp.graphics','sci.med'] #, 'talk.politics.guns', 'rec.motorcycles', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='all', categories=cats, remove=('headers', 'footers', 'quotes'))

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
# Only LSA
lsa_data_x, lsa_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=False)
lsa_data_x.to(device)
lsa_data_y.to(device)
lsa_data_x.size()



torch.Size([1963, 1000])

In [5]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="regression", seed=random_seed)

5it [00:00, 12.16it/s]

0.870096790626592
[[769 204]
 [ 51 939]]





In [6]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="random_forest", seed=random_seed)

5it [00:26,  5.25s/it]

0.8884360672440142
[[857 116]
 [103 887]]





In [7]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="xgboost", seed=random_seed)

5it [00:22,  4.40s/it]

0.9144167091186959
[[873 100]
 [ 68 922]]





In [8]:
lsa_data_x = lsa_data_x.relu()
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="naive_bayes", seed=random_seed)

5it [00:00, 66.66it/s]


0.7753438614365766
[[711 262]
 [179 811]]


In [9]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="svc", seed=random_seed)

5it [00:05,  1.12s/it]

0.8680590932246561
[[753 220]
 [ 39 951]]





In [10]:
#LSA and spikes
spike_data_x, spike_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=True)
spike_data_x.to(device)
spike_data_y.to(device)
spike_data_x.size()

torch.Size([1963, 100, 1000])

In [11]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", seed=random_seed)

5it [17:31, 210.23s/it]

0.739174732552216
[[646 327]
 [185 805]]





In [12]:
snn_experiment(spike_data_x, spike_data_y, clf_type="random_forest", seed=random_seed)

5it [18:02, 216.48s/it]

0.7274579724910851
[[691 282]
 [253 737]]





In [13]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", seed=random_seed)

5it [18:18, 219.69s/it]

0.7305145185939887
[[673 300]
 [229 761]]





In [14]:
snn_experiment(spike_data_x, spike_data_y, clf_type="naive_bayes", seed=random_seed)

5it [17:41, 212.29s/it]

0.5766683647478349
[[391 582]
 [249 741]]





In [15]:
snn_experiment(spike_data_x, spike_data_y, clf_type="svc", seed=random_seed)

5it [17:45, 213.18s/it]

0.7137035150280183
[[499 474]
 [ 88 902]]





## Three categories

In [16]:
preprocessor = TextPrep(svd_components=5000, prob_iterations=100)
cats = ['comp.graphics', 'sci.med', 'talk.politics.guns'] #, 'rec.motorcycles', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='all', categories=cats, remove=('headers', 'footers', 'quotes'))

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [17]:
# Only LSA
lsa_data_x, lsa_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=False)
lsa_data_x.to(device)
lsa_data_y.to(device)
lsa_data_x.size()



torch.Size([2873, 2873])

In [18]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="regression", seed=random_seed)

5it [00:04,  1.03it/s]

0.8381482770623042
[[815 120  38]
 [ 61 864  65]
 [ 16 165 729]]





In [19]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="random_forest", seed=random_seed)

5it [01:13, 14.67s/it]

0.7796728158719108
[[830 105  38]
 [141 706 143]
 [ 28 178 704]]





In [20]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="xgboost", seed=random_seed)

5it [04:18, 51.78s/it]

0.8618169161155587
[[865  75  33]
 [ 81 848  61]
 [ 24 123 763]]





In [21]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="naive_bayes", seed=random_seed)

5it [00:00, 15.67it/s]


0.6971806474068918
[[729 238   6]
 [135 832  23]
 [ 87 381 442]]


In [22]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="svc", seed=random_seed)

5it [00:37,  7.55s/it]

0.8391924817264184
[[798 142  33]
 [ 40 914  36]
 [ 10 201 699]]





In [23]:
#LSA and spikes
spike_data_x, spike_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=True)
spike_data_x.to(device)
spike_data_y.to(device)
spike_data_x.size()

torch.Size([2873, 100, 2873])

In [24]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", seed=random_seed)

5it [24:41, 296.23s/it]


0.5892794987817612
[[680 277  16]
 [169 772  49]
 [125 544 241]]


In [25]:
snn_experiment(spike_data_x, spike_data_y, clf_type="random_forest", seed=random_seed)

5it [25:47, 309.59s/it]


0.5708318830490776
[[617 239 117]
 [183 590 217]
 [138 339 433]]


In [26]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", seed=random_seed)

5it [25:53, 310.67s/it]


0.6160807518273581
[[633 224 116]
 [147 601 242]
 [ 85 289 536]]


In [27]:
snn_experiment(spike_data_x, spike_data_y, clf_type="naive_bayes", seed=random_seed)

5it [25:06, 301.30s/it]


0.38809606682909853
[[283 690   0]
 [158 832   0]
 [155 755   0]]


In [28]:
snn_experiment(spike_data_x, spike_data_y, clf_type="svc", seed=random_seed)

5it [25:12, 302.60s/it]


0.3571179951270449
[[227 746   0]
 [191 799   0]
 [167 743   0]]
