In [1]:
import gensim.downloader as api
from sklearn.datasets import fetch_20newsgroups
import torch

from experiments.experiment import snn_experiment, lsa_experiment
from experiments.preprocess import Word2VecPrep

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
word2vec = api.load('word2vec-google-news-300')
text_prep = Word2VecPrep(word2vec)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
cats = ['comp.graphics', 'sci.med', 'talk.politics.guns']
newsgroups_train = fetch_20newsgroups(subset='train', categories=cats)

### Averaged Word2Vec

In [5]:
train_x, train_y = text_prep.preprocess_dataset(newsgroups_train, avg=True, max_size=150)
train_x.to(device)
train_y.to(device)
print(device)

cuda


In [6]:
lsa_experiment(train_x, train_y, clf_type="regression", splits=4)

4it [00:00,  8.60it/s]


---- CLASSIFIER: regression ----
acc: 0.9071925754060325
[[572   2  10]
 [ 73 511  10]
 [ 57   8 481]]





In [7]:
lsa_experiment(train_x, train_y, clf_type="random_forest", splits=4)

4it [00:14,  3.60s/it]


---- CLASSIFIER: random_forest ----
acc: 0.9303944315545244
[[565  11   8]
 [ 46 542   6]
 [ 34  15 497]]





In [8]:
lsa_experiment(train_x, train_y, clf_type="xgboost", splits=4)

4it [00:12,  3.21s/it]


---- CLASSIFIER: xgboost ----
acc: 0.941415313225058
[[569   7   8]
 [ 31 550  13]
 [ 31  11 504]]





### Spiking Word2Vec

In [9]:
spike_train_x, spike_train_y = text_prep.preprocess_dataset(newsgroups_train, avg=False, max_size=150)
mask = spike_train_x >= 0.25
data_x = torch.where(mask, torch.tensor(1), torch.tensor(0))

In [10]:
data_x.to(device)
spike_train_y.to(device)
print(device)

cuda


#### Regression

In [11]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(5, 5, 5), res_train=True)

4it [17:42, 265.71s/it]

---- CLASSIFIER: regression ----
acc: 0.8503480278422274
[[535  25  24]
 [ 68 491  35]
 [ 56  50 440]]





In [12]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(6, 6, 6), res_train=True)

4it [17:32, 263.12s/it]

---- CLASSIFIER: regression ----
acc: 0.8509280742459396
[[524  43  17]
 [ 66 501  27]
 [ 42  62 442]]





In [13]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(5, 6, 7), res_train=True)

4it [17:43, 265.94s/it]

---- CLASSIFIER: regression ----
acc: 0.8625290023201856
[[535  36  13]
 [ 56 518  20]
 [ 54  58 434]]





In [14]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(8, 8, 8), res_train=True)

4it [18:06, 271.59s/it]

---- CLASSIFIER: regression ----
acc: 0.904292343387471
[[557  21   6]
 [ 55 517  22]
 [ 33  28 485]]





In [19]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(10, 10, 10), res_train=True)

4it [21:00, 315.20s/it]

---- CLASSIFIER: regression ----
acc: 0.9182134570765661
[[559  15  10]
 [ 54 519  21]
 [ 23  18 505]]





#### XGBoost

In [15]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(5, 5, 5), res_train=True)

4it [17:15, 259.00s/it]

---- CLASSIFIER: xgboost ----
acc: 0.8491879350348028
[[504  51  29]
 [ 62 484  48]
 [ 27  43 476]]





In [16]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(6, 6, 6), res_train=True)

4it [18:14, 273.54s/it]

---- CLASSIFIER: xgboost ----
acc: 0.8851508120649652
[[525  33  26]
 [ 51 511  32]
 [ 28  28 490]]





In [17]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(5, 6, 7), res_train=True)

4it [17:51, 267.80s/it]

---- CLASSIFIER: xgboost ----
acc: 0.8747099767981439
[[538  27  19]
 [ 54 498  42]
 [ 35  39 472]]





In [18]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(8, 8, 8), res_train=True)

4it [18:12, 273.23s/it]

---- CLASSIFIER: xgboost ----
acc: 0.9147331786542924
[[545  26  13]
 [ 36 529  29]
 [ 20  23 503]]





In [20]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(10, 10, 10), res_train=True)

4it [21:24, 321.00s/it]

---- CLASSIFIER: xgboost ----
acc: 0.9147331786542924
[[542  20  22]
 [ 37 527  30]
 [ 17  21 508]]



