In [1]:
import datasets as ds
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from sklearn.utils.extmath import randomized_svd
import logging
import itertools
from sklearn.manifold import TSNE

import random
random.seed(42)
np.random.seed(42)

logging.basicConfig(
    format='%(asctime)s %(levelname)-8s %(message)s',
    level=logging.INFO,
    datefmt='%Y-%m-%d %H:%M:%S')

logging.info("Loading dataset")

# dataset = ds.load_dataset("ag_news")
# dataset.save_to_disk("/home/magraz/AI539_NLP/HW1/data")

dataset = ds.load_from_disk("/home/magraz/AI539_NLP/HW1/data/ag_news")

dataset_text =  [r['text'] for r in dataset['train']]
dataset_labels = [r['label'] for r in dataset['train']]

2024-04-17 15:19:52 INFO     Loading dataset


In [2]:
from Vocabulary import Vocabulary
from build_freq_vectors import compute_cooccurrence_matrix, compute_ppmi_matrix, dim_reduce, plot_word_vectors_tsne

logging.info("Building vocabulary")

vocab = Vocabulary(dataset_text)

logging.info(dataset_text[:1])
# logging.info(vocab.word2idx)
logging.info(len(list(vocab.word2idx.keys())))
logging.info(dict(list(vocab.idx2word.items())[:100]))
logging.info(dict(list(vocab.freq.items())[:100]))

# vocab.make_vocab_charts()

logging.info("Computing PPMI matrix")
# C = compute_cooccurrence_matrix(dataset_text[:window], vocab)
# logging.info(C)

PPMI = compute_ppmi_matrix(dataset_text, vocab)
# # logging.info(PPMI)

logging.info("Performing Truncated SVD to reduce dimensionality")
word_vectors = dim_reduce(PPMI)
# # logging.info(word_vectors)


logging.info("Preparing T-SNE plot")
plot_word_vectors_tsne(word_vectors, vocab)

2024-04-17 15:19:56 INFO     Building vocabulary
2024-04-17 15:20:09 INFO     ["Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again."]
2024-04-17 15:20:09 INFO     6274
2024-04-17 15:20:09 INFO     {0: 'the', 1: 'to', 2: 'of', 3: 'in', 4: 'and', 5: 'on', 6: 'for', 7: 'it', 8: 'that', 9: 'with', 10: 'a', 11: 'at', 12: 'is', 13: 'new', 14: 'by', 15: 'said', 16: 'reuters', 17: 'ha', 18: 'from', 19: 'an', 20: 'ap', 21: 'his', 22: 'will', 23: 'after', 24: 'year', 25: 'wa', 26: 'gt', 27: 'u', 28: 'lt', 29: 'be', 30: 'over', 31: 'have', 32: 'up', 33: 'their', 34: 'two', 35: 'company', 36: 'first', 37: 'are', 38: 'quot', 39: 'but', 40: 'more', 41: 'he', 42: 'world', 43: 'one', 44: 'this', 45: 'game', 46: 'say', 47: 'monday', 48: 'out', 49: 'oil', 50: 'wednesday', 51: 'tuesday', 52: 'thursday', 53: 'week', 54: 'not', 55: 'stock', 56: 'state', 57: 'against', 58: 'friday', 59: 'inc', 60: 'than', 61: 'pric

KeyboardInterrupt: 