In [1]:
import torch
import glob
import pickle
import os
import sys

from typing import Literal

os.chdir("..")
sys.path.append(os.getcwd())
os.getcwd()

'/group/pmc026/nchoong/QuantumTransformer'

In [2]:
file_list = glob.glob("data/word_embeddings/imdb_train_batch_*.pt")
file_list

['data/word_embeddings/imdb_train_batch_0.pt',
 'data/word_embeddings/imdb_train_batch_1.pt',
 'data/word_embeddings/imdb_train_batch_2.pt',
 'data/word_embeddings/imdb_train_batch_3.pt',
 'data/word_embeddings/imdb_train_batch_4.pt',
 'data/word_embeddings/imdb_train_batch_5.pt',
 'data/word_embeddings/imdb_train_batch_6.pt']

In [3]:
tensor_list = [torch.load(f) for f in sorted(file_list)]
combined_tensor = torch.cat(tensor_list, dim=0)
print(f"Combined tensor shape: {combined_tensor.shape}")

Combined tensor shape: torch.Size([20000, 128, 768])


In [6]:
print(f"Combined tensor shape: {combined_tensor.shape[-1]}")

Combined tensor shape: 768


In [4]:
with open("data/word_labels/imdb_train_labels.pkl", "rb") as f:
    word_labels = pickle.load(f)

len(word_labels)

20000

In [5]:
list(zip(word_labels, combined_tensor))[0:10]

[(1,
  tensor([[-0.3903, -0.0062,  0.1094,  ...,  0.0678,  0.3269,  0.7495],
          [ 0.2963,  0.5473, -0.3386,  ..., -0.0206,  0.5701, -0.6230],
          [ 0.9159,  0.4148,  0.0028,  ..., -0.0825,  0.0944, -0.0690],
          ...,
          [ 0.6324,  0.2350,  0.2655,  ...,  0.5840, -0.2324, -0.5139],
          [-0.4939, -0.4040, -0.1704,  ...,  0.7217,  0.3663,  0.2065],
          [ 0.2580,  0.3487,  0.1932,  ...,  0.2904, -0.0251,  0.0480]])),
 (1,
  tensor([[-0.3592, -0.3627,  0.4163,  ..., -0.8001,  0.5461, -0.0842],
          [-0.3706, -0.3693, -0.4358,  ...,  0.4339,  1.2432, -0.6204],
          [-0.5736, -0.9523, -0.3636,  ...,  0.6015,  0.9186, -0.3631],
          ...,
          [-0.2705, -0.4833,  0.4079,  ...,  0.4919, -0.0573, -0.3485],
          [-0.3036, -0.5303,  0.4728,  ...,  0.4990, -0.0118, -0.3851],
          [-0.3089, -0.4907,  0.6327,  ...,  0.3892,  0.0110, -0.3043]])),
 (0,
  tensor([[-0.0461, -0.6313,  0.1569,  ..., -0.3000,  0.7552, -0.0166],
          [ 0

In [5]:
def get_dataset(
    name: Literal["amazon", "imdb", "yelp"], type: Literal["train", "val", "test"]
):
    file_list = glob.glob(f"data/word_embeddings/{name}_{type}_batch_*.pt")
    tensor_list = [torch.load(f) for f in sorted(file_list)]
    combined_tensor = torch.cat(tensor_list, dim=0)
    with open(f"data/word_labels/{name}_{type}_labels.pkl", "rb") as f:
        word_labels = pickle.load(f)
    return combined_tensor, word_labels

In [6]:
def load_dataset(name: Literal["amazon", "imdb", "yelp"]):
    train_data, train_labels = get_dataset(name, "train")
    val_data, val_labels = get_dataset(name, "val")
    test_data, test_labels = get_dataset(name, "test")
    dataset = {}
    dataset["train"] = (train_labels, train_data)
    dataset["val"] = (val_labels, val_data)
    dataset["test"] = (test_labels, test_data)
    return dataset

In [8]:
dataset = load_dataset("imdb")

In [11]:
len(dataset["train"][0]), dataset["train"][1].shape

(20000, torch.Size([20000, 128, 768]))