In [None]:
!pip install datasets
!pip install transformers
!pip install tqdm
!pip install nltk
!pip install bitsandbytes
!pip install accelerate
!pip install ijson

In [None]:
from datasets import load_dataset
from tqdm import tqdm
import json
import nltk
from transformers import AutoTokenizer
from joblib import Parallel, delayed
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import ijson

nltk.download('punkt')
nltk.download('punkt_tab')
model_id = "NousResearch/Llama-2-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right", user_fast=True)

# Load dataset and store it.

In [None]:
dataset = load_dataset("HuggingFaceFW/clean-wikipedia")

In [None]:
# Entire wikipedia
with open("entire_wikipedia.jsonl", 'w') as f:
    for text_item in dataset["train"]["text"]:
        f.write(json.dumps(text_item) + '\n')

In [None]:
# English wikipedia
with open("english_wikipedia.jsonl", 'w') as f:
    for item in dataset["train"]: # There's only train dataset loll
        if item["wikicode"] == "en":
            f.write(json.dumps(item["text"]) + '\n')

# Split into sentences.

In [None]:
def split_into_sentences(text: str) -> list[str]:
    return nltk.sent_tokenize(text)

with open("english_wikipedia.jsonl", 'r') as fin, open("english_wikipedia_sentences.jsonl", 'w') as fout:
    for line in tqdm(fin):
        sentences = split_into_sentences(line)
        for sentence in sentences:
            fout.write(json.dumps(sentence) + '\n')

# Tokenize training data

In [None]:
with open("english_wikipedia_sentences.jsonl", 'r') as fin, open("english_wikipedia_tokenized.jsonl", 'w') as fout:
    for line in tqdm(fin):
        fout.write(json.dumps(tokenizer(line)["input_ids"]) + '\n')

# Count how many qualified sentences there are

In [None]:
max_token = 49
min_token = 10
counter = 0

with open("english_wikipedia_tokenized.jsonl", 'r') as file:
    for line in tqdm(file):
        tokenized_sentence = json.loads(line)
        length = len(tokenized_sentence)
        if length < max_token and length > min_token:
            counter += 1
print(counter)

# Store tokenized data into tensor

In [None]:
max_token = 49
min_token = 10
shape = (counter, 50)
counter = 0

tokenized_tensor = torch.empty(shape, dtype=torch.int16)
eos_token_id = tokenizer.eos_token_id

with open("english_wikipedia_tokenized.jsonl", 'r') as file:
    for line in tqdm(file):
        tokenized_sentence = json.loads(line)
        length = len(tokenized_sentence)
        if length < max_token and length > min_token:
            tokenized_sentence.append(eos_token_id)
            tokenized_sentence = tokenized_sentence + [32000] * (50 - len(tokenized_sentence))
            sentence_tokenized_tensor = torch.tensor(tokenized_sentence, dtype=torch.int16)
            tokenized_tensor[counter] = sentence_tokenized_tensor
            counter += 1

torch.save(tokenized_tensor, "llama2_wiki_50.pt")

# Sort the tensor based on number of token

In [None]:
tensor = torch.load("llama2_wiki_50.pt")
pad_token_id = 32000

lengths = (tensor != pad_token_id).sum(dim=1)
sorted_lengths, sorted_indices = torch.sort(lengths)
sorted_tensor = tensor[sorted_indices]

torch.save(sorted_tensor, "llama2_wiki_50_ranked.pt")

# Split into train and eval

In [None]:
tensor = torch.load("llama2_wiki_50_ranked.pt")

eval_ratio = 0.05
n = tensor.shape[0]
n_eval = int(n * eval_ratio)

perm = torch.randperm(n)
eval_indices = perm[:n_eval]
train_indices = perm[n_eval:]

# Sort indices to preserve original order
eval_indices, _ = torch.sort(eval_indices)
train_indices, _ = torch.sort(train_indices)

# Split tensors while keeping order
eval_tensor = tensor[eval_indices]
train_tensor = tensor[train_indices]

torch.save(eval_tensor, "llama2_wiki_50_ranked_eval_sorted.pt")
torch.save(train_tensor, "llama2_wiki_50_ranked_train_sorted.pt")

# Batch the tensor

In [None]:
def shuffle_in_chunks(tensor, chunk_size, generator: torch.Generator):
    n_full = tensor.shape[0] // chunk_size
    perm_chunks = torch.randperm(n_full, generator=generator)

    idx_full = torch.arange(n_full * chunk_size).view(n_full, chunk_size)
    idx_shuffled = idx_full[perm_chunks].reshape(-1)

    return tensor[idx_shuffled]

train_tensor = torch.load("llama2_wiki_50_ranked_train_sorted.pt")
eval_tensor = torch.load("llama2_wiki_50_ranked_eval_sorted.pt")

g = torch.Generator().manual_seed(42)
batch_size = 128

train_tensor_shuffled = shuffle_in_chunks(train_tensor, chunk_size=batch_size, generator=g).numpy()
eval_tensor_shuffled = shuffle_in_chunks(eval_tensor, chunk_size=batch_size, generator=g).numpy()

np.save("llama2_wiki_50_train_new_batch_128.npy", train_tensor_shuffled)
np.save("llama2_wiki_50_eval_new_batch_128.npy", eval_tensor_shuffled)

# Load token embedding and store it.

In [None]:
model_id = "NousResearch/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config)
tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="right")
tokenizer.pad_token_id = 128002

# Access the embedding matrix
word_embeddings_tensor = model.model.embed_tokens.weight.data

# Delete llama3 because we are no longer using it.
del model

# Store vocabulary size and embedding dimension
num_embeddings, embedding_dim = word_embeddings_tensor.shape
word_embeddings_tensor.requires_grad = False

torch.save(word_embeddings_tensor, 'word_embeddings_tensor_llama2.pt')

In [None]:
tokens = torch.load("llama2_wiki_64_ranked.pt")
import matplotlib.pyplot as plt

PAD_ID = 32000

# Count how many padding tokens per row
pad_counts = (tokens == PAD_ID).sum(dim=1)

# Convert to CPU numpy for plotting
pad_counts_np = pad_counts.cpu().numpy()

# Plot histogram
plt.hist(pad_counts_np, bins=range(65), edgecolor='black', align='left')
plt.title("Histogram of Padding Tokens per Sequence")
plt.xlabel("Number of Padding Tokens")
plt.ylabel("Number of Sequences")
plt.xlim(0, 64)
plt.show()