In [1]:
import gensim.downloader as api
from sklearn.datasets import fetch_20newsgroups
import torch

from experiments.experiment import snn_experiment, lsa_experiment
from experiments.preprocess import Word2VecPrep

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
word2vec = api.load('word2vec-google-news-300')
text_prep = Word2VecPrep(word2vec)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\aleks\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
cats = ['comp.graphics', 'sci.med', 'talk.politics.guns']
newsgroups_train = fetch_20newsgroups(subset='train', categories=cats)

### Averaged Word2Vec

In [None]:
train_x, train_y = text_prep.preprocess_dataset(newsgroups_train, avg=True)
train_x.to(device)
train_y.to(device)
print(device)

In [6]:
lsa_experiment(train_x, train_y, clf_type="regression", splits=4)

4it [00:00, 10.10it/s]


---- CLASSIFIER: regression ----
acc: 0.9547563805104409
[[574   1   9]
 [ 34 550  10]
 [ 19   5 522]]





In [7]:
lsa_experiment(train_x, train_y, clf_type="random_forest", splits=4)

4it [00:14,  3.59s/it]


---- CLASSIFIER: random_forest ----
acc: 0.9437354988399071
[[566  11   7]
 [ 37 548   9]
 [ 26   7 513]]





In [8]:
lsa_experiment(train_x, train_y, clf_type="xgboost", splits=4)

4it [00:11,  3.00s/it]


---- CLASSIFIER: xgboost ----
acc: 0.9553364269141531
[[568  10   6]
 [ 28 559   7]
 [ 18   8 520]]





### Spiking Word2Vec

In [9]:
spike_train_x, spike_train_y = text_prep.preprocess_dataset(newsgroups_train, avg=False, max_size=120)
mask = spike_train_x >= 0.25
data_x = torch.where(mask, torch.tensor(1), torch.tensor(0))

In [10]:
data_x.to(device)
spike_train_y.to(device)
print(device)

cuda


#### Regression

In [11]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(5, 5, 5), res_train=True)

4it [14:03, 210.91s/it]

---- CLASSIFIER: regression ----
acc: 0.8109048723897911
[[513  49  22]
 [ 76 483  35]
 [ 61  83 402]]





In [12]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(6, 6, 6), res_train=True)

4it [13:54, 208.67s/it]

---- CLASSIFIER: regression ----
acc: 0.857308584686775
[[535  37  12]
 [ 65 500  29]
 [ 53  50 443]]





In [13]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(5, 6, 7), res_train=True)

4it [14:00, 210.24s/it]

---- CLASSIFIER: regression ----
acc: 0.8474477958236659
[[528  39  17]
 [ 71 495  28]
 [ 52  56 438]]





In [14]:
snn_experiment(data_x, spike_train_y, clf_type="regression", splits=4, shape=(8, 8, 8), res_train=True)

4it [14:21, 215.36s/it]


---- CLASSIFIER: regression ----
acc: 0.8834106728538283
[[547  28   9]
 [ 61 512  21]
 [ 40  42 464]]


#### XGBoost

In [15]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(5, 5, 5), res_train=True)

4it [14:02, 210.70s/it]

---- CLASSIFIER: xgboost ----
acc: 0.8439675174013921
[[502  45  37]
 [ 59 482  53]
 [ 35  40 471]]





In [16]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(6, 6, 6), res_train=True)

4it [14:08, 212.14s/it]

---- CLASSIFIER: xgboost ----
acc: 0.861368909512761
[[516  43  25]
 [ 52 498  44]
 [ 28  47 471]]





In [17]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(5, 6, 7), res_train=True)

4it [14:05, 211.37s/it]

---- CLASSIFIER: xgboost ----
acc: 0.8677494199535963
[[527  33  24]
 [ 61 498  35]
 [ 38  37 471]]





In [18]:
snn_experiment(data_x, spike_train_y, clf_type="xgboost", splits=4, shape=(8, 8, 8), res_train=True)

4it [14:43, 220.97s/it]

---- CLASSIFIER: xgboost ----
acc: 0.9008120649651972
[[546  21  17]
 [ 47 511  36]
 [ 31  19 496]]



