In [1]:
#! pip install -U "datasets<4.0.0" torch

In [2]:
from datasets import load_dataset

from joblib import Parallel, delayed
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

from tqdm import tqdm
# from matplotlib import pyplot as plt

import json

import time

mps = torch.device("mps")

dataset = load_dataset(
    "McAuley-Lab/Amazon-Reviews-2023", "raw_review_Software", trust_remote_code=True
)
# TODO: download to volume
"""
format:
{rating: int, title: str, text: str}
"""


'\nformat:\n{rating: int, title: str, text: str}\n'

In [3]:
print(f"MPS available: {torch.backends.mps.is_available()}")


MPS available: True


In [None]:
n_samples = len(dataset["full"]["text"])  # 480000
n_samples

# linear to size of dataset
# linear to number of params
# linear to number of batches (not size)

4880181

In [11]:
n_samples = n_samples // 100
X = dataset["full"]["text"][:n_samples]
# X_test = dataset["full"]["text"][200000:225000]

y = torch.tensor(dataset["full"]["rating"][:n_samples]) >= 4

y

tensor([False,  True,  True,  ...,  True,  True,  True])

In [6]:
bag_of_words = {}

In [None]:
import re


def split_text(text: str) -> list[str]:
    text = text.lower().strip()
    text = re.sub(r"[^a-z0-9\s]", "", text)
    words = re.split(r"\s+", text)
    return words


total_words = 0


def add_to_bag(text: str) -> None:
    global total_words
    words = split_text(text)
    total_words += len(words)
    for word in words:
        bag_of_words[word] = bag_of_words.get(word, 0) + 1


In [10]:
# fmt: off
ignored_words = [
    # articles
    "a", "an", "the",
    # pronouns
    "i", "you", "he", "she", "it", "we", "they", "me", "him", "her", "us", "them", "my", "your", "his", "her", "its", "our", "their","his","her","its","our","their",
    # prepositions
    "at", "in", "on", "by", "for", "with", "without", "to", "from", "of", "about", "under", "over", "through", "between", "among", "during", "before", "after", "above", "below", "up", "down", "out", "off", "into", "onto", "upon", "within", "across", "along", "around", "behind", "beside", "beyond", "inside", "outside", "toward", "towards", "underneath", "against", "beneath", "near", "next", "past", "since", "until", "via",
    # conjunctions
    "and", "or", "but", "so", "yet", "nor", "as", "if", "when", "while", "because", "although", "though", "unless", "whereas", "however", "therefore", "moreover", "furthermore", "nevertheless", "meanwhile",
    # common verbs
    "am", "is", "are", "was", "were", "be", "been", "being", "have", "has", "had", "do", "does", "did", "will", "would", "could", "should", "can", "may", "might", "must", "shall", "get", "got", "go", "went", "come", "came", "see", "saw", "know", "knew", "think", "thought", "say", "said", "tell", "told", "make", "made", "take", "took", "give", "gave", "find", "found", "use", "used", "work", "worked", "look", "looked", "seem", "seemed", "feel", "felt", "try", "tried", "leave", "left", "put", "set", "keep", "kept", "let", "run", "ran", "move", "moved", "live", "lived", "bring", "brought", "happen", "happened", "write", "wrote", "sit", "sat", "stand", "stood", "lose", "lost", "pay", "paid", "meet", "met", "include", "included", "continue", "continued", "turn", "turned", "follow", "followed", "want", "wanted", "need", "needed", "like", "liked", "help", "helped", "talk", "talked", "become", "became", "show", "showed", "hear", "heard", "play", "played", "run", "ran", "move", "moved", "live", "lived", "believe", "believed", "hold", "held", "bring", "brought", "happen", "happened", "write", "wrote", "provide", "provided", "sit", "sat", "stand", "stood", "lose", "lost", "pay", "paid", "meet", "met", "include", "included",
    # adverbs
    "very", "really", "quite", "just", "only", "also", "too", "so", "more", "most", "much", "many", "well", "good", "better", "best", "bad", "worse", "worst", "little", "less", "least", "big", "bigger", "biggest", "small", "smaller", "smallest", "long", "longer", "longest", "short", "shorter", "shortest", "high", "higher", "highest", "low", "lower", "lowest", "first", "last", "next", "previous", "new", "old", "young", "great", "right", "wrong", "true", "false", "sure", "probably", "maybe", "perhaps", "definitely", "certainly", "absolutely", "completely", "totally", "exactly", "almost", "nearly", "hardly", "barely", "quite", "rather", "pretty", "fairly", "somewhat", "slightly", "extremely", "incredibly", "amazingly", "surprisingly", "unfortunately", "fortunately", "obviously", "clearly", "apparently", "generally", "usually", "normally", "typically", "often", "sometimes", "rarely", "never", "always", "already", "still", "yet", "soon", "now", "then", "here", "there", "where", "everywhere", "anywhere", "somewhere", "nowhere", "how", "why", "what", "when", "who", "which", "whose", "whom",
    # numbers and quantities
    "one", "two", "three", "four", "five", "six", "seven", "eight", "nine", "ten", "first", "second", "third", "another", "other", "others", "some", "any", "all", "each", "every", "both", "either", "neither", "none", "few", "several", "many", "most", "much", "little", "less", "more", "enough", "plenty",
    # misc common words
    "thing", "things", "something", "anything", "nothing", "everything", "someone", "anyone", "everyone", "no", "yes", "ok", "okay", "please", "thanks", "thank", "welcome", "hello", "hi", "bye", "goodbye", "sorry", "excuse", "pardon", "way", "ways", "time", "times", "day", "days", "year", "years", "place", "places", "people", "person", "man", "woman", "child", "children", "life", "world", "home", "house", "work", "job", "money", "business", "company", "part", "parts", "number", "numbers", "group", "groups", "problem", "problems", "question", "questions", "answer", "answers", "fact", "facts", "example", "examples", "case", "cases", "point", "points", "idea", "ideas", "information", "data", "result", "results", "change", "changes", "end", "beginning", "start", "finish", "side", "sides", "hand", "hands", "eye", "eyes", "head", "face", "back", "front", "top", "bottom", "left", "right", "inside", "outside", "important", "different", "same", "such", "even", "still", "however", "though", "although", "since", "while", "during", "before", "after", "until", "unless", "because", "if", "whether", "that", "this", "these", "those", "there", "here", "where", "when", "how", "why", "what", "who", "which", "whose", "whom",
    # common contractions
    "ive", "im",
    # parsing-specific
    "", "br"
]
# fmt: on

bag_of_words = {}

for observation in X:
    add_to_bag(observation)

print(len(bag_of_words))


bag_of_words = {
    k: v for k, v in bag_of_words.items() if v > 10 and k not in ignored_words
}
print(len(bag_of_words))
bag_size = 3000
bag_of_words = dict(
    sorted(bag_of_words.items(), key=lambda x: x[1], reverse=True)[:bag_size]
)


print(bag_of_words)
print(len(bag_of_words))
print("avg words per review:", total_words / len(X))

681990
57943
3000
avg words per review: 26.328981445565237


In [None]:
def vectorize(text: str, word_to_idx: dict[str, int]) -> torch.Tensor:
    words = split_text(text)

    vector = [0] * len(word_to_idx)

    for word in words:
        idx = word_to_idx.get(word)
        if idx is not None:
            vector[idx] += 1

    return torch.tensor(vector, dtype=torch.float32)


def vectorize_parallel(
    texts: list[str], word_to_idx: dict[str, int], n_jobs: int = 4
) -> torch.Tensor:
    vectors = Parallel(n_jobs=n_jobs)(
        delayed(vectorize)(text, word_to_idx) for text in texts
    )
    return torch.stack(vectors).to(mps)


def preprocess(X: list[str]) -> torch.Tensor:
    word_to_idx = {word: idx for idx, word in enumerate(bag_of_words.keys())}

    return vectorize_parallel(X, word_to_idx)


X_processed = preprocess(X)
# memory
print(X_processed.element_size() * X_processed.numel() / 1024**2, "MB")


In [None]:
print(len(bag_of_words))
print(len(X_processed))

LAYERS = 124, 42, 14, 6
# LAYER_1, LAYER_2 = 75, 5  # original
LAYER_1, LAYER_2, LAYER_3, LAYER_4 = LAYERS

bag_size = len(bag_of_words)

print("num params L1:", len(bag_of_words) * LAYER_1 + LAYER_1)
num_params = (
    sum(LAYERS[i - 1] * LAYERS[i] + LAYERS[i] for i in range(1, len(LAYERS)))
    + bag_size * LAYERS[0]
    + LAYERS[0]
    + LAYERS[-1]
    + 1
)
print("num params:", num_params)

In [None]:
class SparseLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
        """Source copied from torch.nn.Linear"""
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(
            torch.empty((out_features, in_features), **factory_kwargs)
        )
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()

    def forward(self, x):
        x_sparse = x.coalesce()  # ensures unique indices
        out = torch.sparse.mm(
            x_sparse,  # (B × in) · (in × out)ᵀ
            self.weight.t(),
        )
        if self.bias is not None:
            out = out + self.bias
        return out

In [None]:
CONFIG = {
    "PARAMS": num_params,
    "BATCHES": 1000,
    "EPOCHS": 10,
    "SAMPLES": int(n_samples * 0.8),
    "LR": 0.003,
    "BAG_SIZE": bag_size,
}
if False:  # test runs
    CONFIG = {
        "PARAMS": num_params,
        "BATCHES": 1000,
        "EPOCHS": 10,
        "SAMPLES": int(n_samples * 0.8),
        "LR": 0.003,
        "BAG_SIZE": bag_size,
    }


print(CONFIG)


def run_nn():
    # optimal params = 10% of 80000 = 8000
    model = nn.Sequential(
        nn.Linear(len(bag_of_words), LAYER_1),
        nn.ReLU(),
        nn.Linear(LAYER_1, LAYER_2),
        nn.ReLU(),
        nn.Linear(LAYER_2, LAYER_3),
        nn.ReLU(),
        nn.Linear(LAYER_3, LAYER_4),
        nn.ReLU(),
        nn.Linear(LAYER_4, 1),
        # nn.Sigmoid(),  # using logits loss
    ).to(mps)
    model = torch.compile(model)

    # TODO: clear model if instantiating once

    print(model)

    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG["LR"])

    n_samples = CONFIG["SAMPLES"]
    X_train, X_test = X_processed[:n_samples], X_processed[n_samples:]
    y_train, y_test = y[:n_samples].float().to(mps), y[n_samples:].float().to(mps)

    dataset = TensorDataset(X_train, y_train)
    dataloader = DataLoader(
        dataset, batch_size=len(X_train) // CONFIG["BATCHES"], shuffle=True
    )
    # TODO: sparse tensors?

    # loss_per_batch = []

    start_time = time.time()

    for epoch in range(CONFIG["EPOCHS"]):
        print(f"[{time.time()}] Epoch {epoch}")
        for X_batch, y_batch in tqdm(dataloader):
            # already allocated, removing to(mps) shaved 33% of time
            X_batch, y_batch = X_batch, y_batch
            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = loss_fn(y_pred.squeeze(), y_batch)
            loss.backward()
            optimizer.step()
            # loss_per_batch.append(loss.item())

    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.4f} seconds")

    with torch.no_grad():
        y_pred_train = model(X_train).squeeze()
        y_pred_test = model(X_test).squeeze()

    # sigmoid to binary
    train_preds = (y_pred_train >= 0.5).int()
    test_preds = (y_pred_test >= 0.5).int()

    num_true_train = (y_train == 1).sum()
    num_true_test = (y_test == 1).sum()
    train_acc = (train_preds == y_train.int()).float().mean()
    test_acc = (test_preds == y_test.int()).float().mean()

    # plt.xlabel("Batch")
    # plt.ylabel("Loss")
    # plt.plot(loss_per_batch)
    # plt.show()
    # print("Final loss:", loss_per_batch[-1])

    print(
        f"Train Accuracy: {train_acc:.4f} ({train_preds.sum()}/{len(X_train)}), random guess accuracy: {max(num_true_train, len(X_train) - num_true_train) / len(X_train):.4f}"
    )
    print(
        f"Test Accuracy: {test_acc:.4f} ({test_preds.sum()}/{len(X_test)}), random guess accuracy: {max(num_true_test, len(X_test) - num_true_test) / len(X_test):.4f}"
    )

    precision = (
        test_preds * y_test
    ).sum() / test_preds.sum()  # true pos / predicted pos
    recall = (test_preds * y_test).sum() / y_test.sum()  # true pos / actual pos
    f1 = 2 * precision * recall / (precision + recall)
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

    return model


m = run_nn()
print(CONFIG)

```js
// 7749 vocab, 64 layer 1
Final loss: 0.43634021282196045
Train Accuracy: 0.8201 (61019/80000)
Test Accuracy: 0.8158 (15540/20000)

// 100 vocab, 104 layer 1, 20 epoch
Final loss: 0.48798397183418274
Train Accuracy: 0.7697 (62999/80000), random guess accuracy: 0.6713
Test Accuracy: 0.7691 (16053/20000), random guess accuracy: 0.6823
Precision: 0.7812, Recall: 0.9190, F1: 0.8445

// -> 100 batches
Final loss: 0.4811912477016449
Train Accuracy: 0.7857 (62716/80000), random guess accuracy: 0.6713
Test Accuracy: 0.7719 (15957/20000), random guess accuracy: 0.6823
Precision: 0.7846, Recall: 0.9176, F1: 0.8459

// -> 1000 batch SIZE
Final loss: 0.45420441031455994
Train Accuracy: 0.7831 (62935/80000), random guess accuracy: 0.6713
Test Accuracy: 0.7722 (15991/20000), random guess accuracy: 0.6823
Precision: 0.7842, Recall: 0.9190, F1: 0.8463

// -> 78 layer 1
Final loss: 0.4723219573497772
Train Accuracy: 0.7814 (62978/80000), random guess accuracy: 0.6713
Test Accuracy: 0.7721 (16049/20000), random guess accuracy: 0.6823
Precision: 0.7831, Recall: 0.9211, F1: 0.8465


// 75 layer 1, 5 layer 2
Final loss: 0.47183775901794434
Train Accuracy: 0.7830 (61135/80000), random guess accuracy: 0.6713
Test Accuracy: 0.7702 (15633/20000), random guess accuracy: 0.6823
Precision: 0.7894, Recall: 0.9044, F1: 0.8430
```

In [None]:
def param_to_word_mapping(model: nn.Module) -> dict[str, float]:
    # m is something like nn.Sequential(), which is a module
    params = list(model.parameters())

    mapping = zip(bag_of_words.keys(), params[0].flatten().tolist())

    return {k: v for k, v in sorted(mapping, key=lambda x: x[1], reverse=True)}


param_to_word_mapping(m)

In [None]:
tm = time.time()
torch.save(m, f"model_{tm}.pth")
json.dump(bag_of_words, open(f"bag_of_words_{tm}.json", "w"))

In [None]:
# torch.serialization.add_safe_globals([nn.Sequential, nn.Linear, nn.Sigmoid, nn.ReLU])
# m2 = torch.load("model.pth")