In [1]:
from argparse import Namespace
from classifier import *

from dataset import *
from vectorizer import *

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def use_cuda():
    return torch.device("cpu" if torch.cuda.is_available() else "cpu")

args = Namespace (
    in_features=None,
    out_units=1, # Use 1 for IMDB and 3 for Tweets
    dataset=None,
    optimizer="Adam",
    criterion="bce_logits",
    batch_size=128,
    learning_rate=0.001,
    num_epochs=1,
    device=use_cuda(),
    embed_dim=300,
    freeze_embedding=False,
    filter_sizes=[3, 4, 5],
    num_filters=[100, 100, 100],
    pretrained_embedding=None, 
    hidden_size=2, # number of features in hidden state
    num_layers=1 # number of stacked lstm layers
)

In [None]:
# args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/Tweets.csv")
# args.in_features = args.dataset.vectorizer.max_padding

### BOW

32024

In [None]:
# bow = BOWClassifier(args=args)
# bow.setup()
# bow.fit()

In [None]:
# loss, acc, f1 = bow.eval_net(mode='test')
# print(loss)
# print(acc)
# print(f1)

In [None]:
# args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/tweets_products.csv")
# loss, acc, f1 = bow.eval_net(mode='test')
# print(loss)
# print(acc)
# print(f1)

In [None]:
# bow.plot_logs(title="Loss", legend=["Train-Loss", "Validation-Loss"])

In [None]:
# bow.plot_logs(title="Accuracy", legend=["Train-Accuracy", "Validation-Accuracy"])

In [None]:
# bow.plot_logs(title="F1-Score", legend=["Train-F1", "Validation-F1"])

In [3]:
args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/IMDB-dataset.csv", vectorizer_mode="padded")
args.in_features = len(args.dataset.vectorizer.text_vocab)
args.pretrained_embedding = args.dataset.vectorizer.load_pretrained_embed("../data/crawl-300d-2M.vec")

[nltk_data] Downloading package stopwords to /home/alexc/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
1999995it [00:25, 79993.60it/s]


### CNN

In [4]:
cnn_classifier = CNNClassifier(args=args)
cnn_classifier.setup()
cnn_classifier.fit()

100%|██████████| 1/1 [03:08<00:00, 188.74s/it]


In [5]:
loss, acc, f1 = cnn_classifier.eval_net(mode='test')
print(loss)
print(acc)
print(f1)

0.28365139472178913
88.18108974358975
88.12061024902106


In [6]:
args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/tweets_products.csv", vectorizer_mode="padded")
cnn_classifier.embedding = nn.Embedding.from_pretrained (
    args.dataset.vectorizer.load_pretrained_embed("../data/crawl-300d-2M.vec"),
    freeze=args.freeze_embedding
)
loss, acc, f1 = cnn_classifier.eval_net(mode='test')
print(loss)
print(acc)
print(f1)

[nltk_data] Downloading package stopwords to /home/alexc/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
1999995it [00:33, 60358.06it/s]


0.6479174336671827
61.64781250000006
60.66035275442304


In [8]:
torch.save(cnn_classifier.state_dict(), "./cnn_classifier.pth")

### LSTM

In [9]:
args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/IMDB-dataset.csv", vectorizer_mode="padded")
args.in_features = len(args.dataset.vectorizer.text_vocab)
args.pretrained_embedding = args.dataset.vectorizer.load_pretrained_embed("../data/crawl-300d-2M.vec")

[nltk_data] Downloading package stopwords to /home/alexc/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
1999995it [00:24, 82880.78it/s]


In [10]:
lstm_classifier = LSTMClassifier(args)
lstm_classifier.setup()
lstm_classifier.fit()

100%|██████████| 1/1 [01:19<00:00, 79.58s/it]


In [11]:
loss, acc, f1 = lstm_classifier.eval_net(mode='test')
print(loss)
print(acc)
print(f1)

0.41880007661305935
82.70232371794869
82.5880476369511


In [12]:
args.dataset = TextDataset.load_dataset_and_make_vectorizer("../data/tweets_products.csv", vectorizer_mode="padded")
lstm_classifier.embedding = nn.Embedding.from_pretrained (
    args.dataset.vectorizer.load_pretrained_embed("../data/crawl-300d-2M.vec"),
    freeze=args.freeze_embedding
)


[nltk_data] Downloading package stopwords to /home/alexc/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
1999995it [00:30, 65091.50it/s]


In [13]:
loss, acc, f1 = lstm_classifier.eval_net(mode='test')
print(loss)
print(acc)
print(f1)

0.9708076299190526
50.009062500000034
34.69828848453045


In [14]:
torch.save(lstm_classifier.state_dict(), "./lstm_classifier.pth")