In [3]:
import os
import numpy as np
from transformers import BertModel, BertConfig

DATA_DIR = "data" # This may need to be changed on different machines

# Make sure we're in the correct directory and make sure the data directory exists
if not os.path.exists(DATA_DIR):
    os.chdir("../..") # Move up two directories because we're in src/nb and the data directory/path should be in/start at the root directory 
    assert os.path.exists(DATA_DIR), f"ERROR: DATA_DIR={DATA_DIR} not found"  # If we still can't see the data directory something is wrong

In [11]:
def read_bpe_data(data_path):
    """
    Reads in the BPE data from the given path into an (N, M) integer numpy array.
    Samples are padded with -1
    N = Number of samples
    M = Maximum number of tokens in an example
    """
    with open(data_path, "r") as f:
        lines = f.readlines()

    def process_row(row):
        # row string " 123 456 789\n" -> integer list [123, 456, 789]
        row = row.strip().split()
        row = list(map(int, row))
        return row

    list_data = list(map(process_row, lines))

    # Find the length of the longest samples in the data
    max_len = 0
    for row in list_data:
        max_len = max(max_len, len(row))
    
    # Pad the data to the maximum length
    padded_data = -np.ones((len(list_data), max_len), dtype=np.int32)
    for i, row in enumerate(list_data):
        padded_data[i, :len(row)] = row
    
    return padded_data

In [12]:
tweets_path = os.path.join(DATA_DIR, "datasets/cds/tweets") # Path to the tweets data
dev_input_path = os.path.join(tweets_path, "dev.input0.bpe") # Path to the dev input file
dev_label_path = os.path.join(tweets_path, "dev.label") # Path to the dev label file


# Load the data
dev_input = read_bpe_data(dev_input_path)
dev_input

array([[ 3237,  2907, 13449, ...,    -1,    -1,    -1],
       [  464,  1743,  3329, ...,    -1,    -1,    -1],
       [16371, 13779,   290, ...,    -1,    -1,    -1],
       ...,
       [19782, 32583, 12460, ...,    -1,    -1,    -1],
       [   72,  1842,  2832, ...,    -1,    -1,    -1],
       [   34, 26730,  2635, ...,    -1,    -1,    -1]], dtype=int32)