In [1]:
import torch
from tqdm import tqdm
from sklearn.datasets import fetch_20newsgroups
from experiments.preprocess import TextPrep
from experiments.experiment import snn_experiment, lsa_experiment, snn_multiple_clfs
from experiments.params import random_seed

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"

In [None]:
# Experiment parameters:
param_cube_shapes = [(6, 6, 6), (8, 8, 8), (10, 10, 10), (12, 12, 12)]
param_spikes = 100
param_tfidf_features = 12000

## Three categories

In [3]:
preprocessor = TextPrep(svd_components=1000, prob_iterations=100, max_features=12000)
cats = ['comp.graphics', 'sci.med', 'talk.politics.guns']
newsgroups_train = fetch_20newsgroups(subset='test', categories=cats, remove=('headers', 'footers', 'quotes'))

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
# Only LSA
lsa_data_x, lsa_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=True, spikes=False)
lsa_data_x.to(device)
lsa_data_y.to(device)
lsa_data_x.size()

torch.Size([1724, 1000])

In [6]:
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="regression", seed=random_seed, splits=8)
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="random_forest", seed=random_seed, splits=8)
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="xgboost", seed=random_seed, splits=8)
lsa_experiment(lsa_data_x, lsa_data_y, clf_type="svc", seed=random_seed, splits=8)
lsa_experiment(lsa_data_x.relu(), lsa_data_y, clf_type="naive_bayes", seed=random_seed, splits=8)


8it [00:02,  2.78it/s]


---- CLASSIFIER: regression ----
acc: 0.9222737819025522
[[568  14   2]
 [ 49 529  16]
 [ 36  17 493]]


8it [01:14,  9.34s/it]


---- CLASSIFIER: random_forest ----
acc: 0.8004640371229699
[[485  75  24]
 [ 66 490  38]
 [ 51  90 405]]


8it [01:42, 12.80s/it]


---- CLASSIFIER: xgboost ----
acc: 0.8747099767981439
[[522  51  11]
 [ 52 511  31]
 [ 31  40 475]]


8it [00:07,  1.01it/s]


---- CLASSIFIER: svc ----
acc: 0.9240139211136891
[[569  10   5]
 [ 51 529  14]
 [ 30  21 495]]


8it [00:00, 77.67it/s]


---- CLASSIFIER: naive_bayes ----
acc: 0.7337587006960556
[[478  89  17]
 [104 471  19]
 [ 92 138 316]]


# SPIKING

In [4]:
#TF-IDF and spikes
spike_data_x, spike_data_y = preprocessor.preprocess_dataset(newsgroups_train, lsa=False, spikes=True)
spike_data_x.to(device)
spike_data_y.to(device)
spike_data_x.size()

torch.Size([1149, 100, 12000])

### REGRESSION

In [6]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(4, 4, 4))

8it [17:26, 130.80s/it]


---- CLASSIFIER: regression ----
acc: 0.5787641427328112
[[245  90  54]
 [ 89 225  82]
 [ 67 102 195]]


In [7]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(6, 6, 6))

8it [16:46, 125.86s/it]


---- CLASSIFIER: regression ----
acc: 0.6762402088772846
[[278  76  35]
 [ 85 258  53]
 [ 43  80 241]]


In [8]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(8, 8, 8))

8it [19:23, 145.47s/it]


---- CLASSIFIER: regression ----
acc: 0.7484769364664926
[[304  56  29]
 [ 56 296  44]
 [ 30  74 260]]


In [9]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(10, 10, 10))

8it [27:20, 205.07s/it]


---- CLASSIFIER: regression ----
acc: 0.8181026979982594
[[330  47  12]
 [ 45 322  29]
 [ 27  49 288]]


In [10]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(12, 12, 12))

8it [48:51, 366.40s/it]


---- CLASSIFIER: regression ----
acc: 0.834638816362054
[[333  46  10]
 [ 37 324  35]
 [ 19  43 302]]


In [16]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(15, 15, 15))

8it [2:02:16, 917.05s/it]


---- CLASSIFIER: regression ----
acc: 0.8720626631853786
[[348  26  15]
 [ 27 344  25]
 [ 13  41 310]]


In [None]:
snn_experiment(spike_data_x, spike_data_y, clf_type="regression", splits=8, shape=(18, 18, 18))

### XGBOOST

In [11]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", splits=8, shape=(4, 4, 4))

8it [18:09, 136.18s/it]


---- CLASSIFIER: xgboost ----
acc: 0.5387293298520452
[[221 109  59]
 [ 92 216  88]
 [ 69 113 182]]


In [12]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", splits=8, shape=(6, 6, 6))

8it [17:30, 131.37s/it]


---- CLASSIFIER: xgboost ----
acc: 0.6562228024369017
[[270  72  47]
 [ 83 254  59]
 [ 50  84 230]]


In [13]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", splits=8, shape=(8, 8, 8))

8it [18:53, 141.71s/it]


---- CLASSIFIER: xgboost ----
acc: 0.6884247171453438
[[279  75  35]
 [ 66 276  54]
 [ 50  78 236]]


In [14]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", splits=8, shape=(10, 10, 10))

8it [29:45, 223.20s/it]


---- CLASSIFIER: xgboost ----
acc: 0.7119234116623151
[[290  69  30]
 [ 67 275  54]
 [ 48  63 253]]


In [15]:
snn_experiment(spike_data_x, spike_data_y, clf_type="xgboost", splits=8, shape=(12, 12, 12))

8it [52:28, 393.55s/it]


---- CLASSIFIER: xgboost ----
acc: 0.7458659704090513
[[312  48  29]
 [ 58 296  42]
 [ 44  71 249]]
