In [1]:
import json
import os
import bz2
import pprint
from typing import List
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import array
import zlib
 
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import os
from itertools import cycle

In [2]:
model_name = "gpt2"
CONTEXT_SIZE = 256
BATCH_SIZE = 64

ranks = []

In [3]:
tokenizer_name = model_name
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


def text_to_tokens(text):
    tokens = tokenizer(text, return_tensors="pt")
    tokens = tokens["input_ids"].squeeze()
    return tokens

def pad(tokens, padding_val):
    pad_len = CONTEXT_SIZE - tokens.shape[0] % CONTEXT_SIZE
    print("pad_len", pad_len)
    if pad_len != CONTEXT_SIZE:
        padding = torch.tensor([padding_val]*pad_len)
        tokens = torch.cat((tokens, padding))
    else:
        pad_len = 0

    return tokens, pad_len

# @torch.no_grad()
def get_logits(tokens, token_index, past=None):
    my_inputs = {}
    my_inputs['input_ids'] = tokens[:, token_index].reshape(-1, 1)

    output = model(**my_inputs, past_key_values=past)
    logits = output.logits
    if len(logits.shape) > 2:
        logits = logits.reshape((logits.shape[0], -1))
    return logits, output.past_key_values
    
def encode_one_batch(
    tokens,
    token_index,
    past=None
):
    assert len(tokens.shape) == 2

    logits, past = get_logits(tokens, token_index, past)
    assert len(logits.shape) == 2
    logits, sorted_tokens = torch.sort(logits, descending=True)

    assert len(sorted_tokens.shape) == 2

    next_tokens = tokens[:, token_index + 1]
    next_tokens_expanded = next_tokens.view(-1, 1).expand_as(sorted_tokens)

    # Find score as index of next tokens
    scores = (sorted_tokens == next_tokens_expanded).nonzero(as_tuple=True)

    scores = scores[1] # remove index column

    ranks.extend(scores.tolist())

    return scores, past

In [4]:
def decode(scores):
    output_tokens = decode_tokens(scores)
    text = tokenizer.batch_decode(output_tokens)
    text = "".join(text)
    #text = text.replace("<|endoftext|>", "")
    return text

def decode_tokens(scores):

    scores, pad_len = pad(scores, tokenizer.eos_token_id)

    scores = scores.view(-1, CONTEXT_SIZE) # all rows, CONTEXT_SIZE

    output_tokens = torch.zeros(scores.shape, dtype=int, device=scores.device)

    # Add eos to the start of each block (to give it somewhere to start)
    eos = torch.tensor([tokenizer.eos_token_id]*output_tokens.shape[0]).unsqueeze(1)
    output_tokens = torch.cat((eos, output_tokens), 1) # all rows, CONTEXT_SIZE + 1


    batches = scores.shape[0] // BATCH_SIZE
    if scores.shape[0] % BATCH_SIZE != 0:
        batches += 1

    # score each batch
    print("Decoding")
    for i in range(batches):
        print(i, "out of", batches)
        cur_scores = scores[i*BATCH_SIZE:(i + 1)*BATCH_SIZE] # BATCH_SIZE, CONTEXT_SIZE

        cur_output_tokens = output_tokens[i*BATCH_SIZE:(i + 1)*BATCH_SIZE] # BATCH_SIZE, CONTEXT_SIZE
        
        past = None
        for j in tqdm(range(scores.shape[1])):

            cur_output_tokens[:, j+1], past = decode_one_batch(cur_output_tokens, cur_scores, j, past) # BATCH_SIZE

        output_tokens[i*BATCH_SIZE:(i + 1)*BATCH_SIZE] = cur_output_tokens

    output_tokens = output_tokens[:, 1:].int()
    output_tokens = output_tokens.flatten()

    if pad_len != 0:
        output_tokens = output_tokens[:-pad_len]

    return output_tokens

def decode_one_batch(input_tokens, scores, score_index, past=None):
    assert len(scores.shape) == 2
    logits, past = get_logits(input_tokens, score_index, past)

    logits, sorted_tokens = torch.sort(logits, descending=True)
    assert len(sorted_tokens.shape) == 2
    # the scores give the indexes of the decoded tokens
    indexes = scores[:, score_index].flatten()
    decoded_tokens = sorted_tokens[torch.arange(indexes.shape[0]), indexes]

    return decoded_tokens.int(), past


In [5]:
s = "Artificial intelligence (AI) has undergone a remarkable transformation over the past few decades, revolutionizing various industries and fundamentally changing the way humans interact with technology. The journey of AI began with early theoretical concepts in the 1950s, when pioneers like Alan Turing proposed the idea of intelligent machines capable of reasoning and problem-solving. Over the years, advancements in computing power, data availability, and algorithmic innovation have fueled AI's rapid progress, leading to the development of sophisticated models that can perform complex tasks with remarkable accuracy. One of the most significant applications of AI is in healthcare. AI-powered systems assist medical professionals in diagnosing diseases, predicting patient outcomes, and recommending personalized treatment plans. Machine learning models trained on vast datasets of medical images can detect anomalies such as tumors in X-rays, MRIs, and CT scans with high precision. Natural language processing (NLP) enables AI chatbots and virtual assistants to interact with patients, answering common health-related queries and reducing the burden on healthcare providers. Furthermore, AI-driven drug discovery accelerates the development of new treatments by analyzing molecular structures and predicting their potential efficacy. These innovations not only enhance patient care but also contribute to reducing healthcare costs and improving overall efficiency in the medical sector. The financial industry has also been profoundly impacted by AI. Banks and investment firms leverage AI-driven algorithms to detect fraudulent transactions, assess credit risk, and optimize trading strategies. High-frequency trading systems utilize machine learning to analyze market trends and execute trades at speeds beyond human capability. Personalized financial assistants powered by AI provide users with tailored investment advice based on their financial history, risk tolerance, and goals. However, AI in finance is not without challenges—bias in algorithms, data privacy concerns, and the risk of over-reliance on automated decision-making remain critical issues that regulators and institutions must address to ensure fair and transparent financial practices. In the realm of education, AI is reshaping traditional teaching methods. Adaptive learning platforms use AI to tailor educational content to individual students based on their learning pace, strengths, and weaknesses. Automated grading systems reduce the workload of teachers by efficiently evaluating assignments and exams. AI-powered language translation tools help break down language barriers, making quality education accessible to students worldwide. However, the integration of AI in education also raises concerns about data privacy, algorithmic fairness, and the potential loss of human interaction in the learning process. Striking a balance between technological advancement and pedagogical effectiveness is crucial to maximizing the benefits of AI in education. AI's influence extends to entertainment and media as well. Streaming platforms such as Netflix, Spotify, and YouTube utilize AI-driven recommendation systems to personalize content for users. These algorithms analyze viewing, listening, and browsing histories to suggest movies, music, and articles that align with user preferences. AI-generated content, including deepfake videos, music compositions, and even articles, demonstrates the potential of machine learning in creative fields. While AI enhances user experience and content curation, it also raises ethical concerns regarding misinformation, deepfake manipulation, and intellectual property rights. Ensuring responsible AI use in the media industry is crucial to maintaining trust and credibility."

print("String length:", len(s))

String length: 3796


In [None]:
ranks = []

# generate the next token
def get_logits(tokens, token_index, past=None):
    my_inputs = {}
    my_inputs['input_ids'] = tokens[:, token_index].reshape(-1, 1)

    output = model(**my_inputs, past_key_values=past)
    logits = output.logits
    if len(logits.shape) > 2:
        logits = logits.reshape((logits.shape[0], -1))
    return logits, output.past_key_values

# %timeit get_rank(probas, targets)

def encode_one_batch(
    tokens,
    token_index,
    past=None
):
    assert len(tokens.shape) == 2

    logits, past = get_logits(tokens, token_index, past)
    assert len(logits.shape) == 2
    sorted_logits, sorted_tokens = torch.sort(logits, descending=True)

    assert len(sorted_tokens.shape) == 2

    next_tokens = tokens[:, token_index + 1]
    print("-"*30)
    print(next_tokens)
    next_tokens_expanded = next_tokens.view(-1, 1).expand_as(sorted_tokens)
    print(next_tokens_expanded.shape)
    print("-"*30)

    # the indices that match the next token
    scores = (sorted_tokens == next_tokens_expanded).nonzero(as_tuple=True)
    print(scores)

    scores = scores[1] # remove index column

    ranks.extend(scores.tolist())

    return scores, past

# tokenization
tokens = text_to_tokens(s)
print("tokens length:", len(tokens))

# padding
tokens, pad_len = pad(tokens, tokenizer.eos_token_id)
tokens = tokens.view(-1, CONTEXT_SIZE)

output_scores = torch.zeros(tokens.shape)

eos = torch.tensor([tokenizer.eos_token_id]*tokens.shape[0]).unsqueeze(1)
tokens = torch.cat((eos, tokens), 1)

batches = tokens.shape[0] // BATCH_SIZE
if tokens.shape[0] % BATCH_SIZE != 0:
    batches += 1

# score each batch
print("Encoding")

# mini-chunks
with torch.no_grad():
    for i in range(batches):
        print(i, "out of", batches)

        cur_tokens = tokens[i*BATCH_SIZE:(i + 1)*BATCH_SIZE]  # (n_chunks, n_tokens)
        cur_output_scores = torch.zeros((cur_tokens.shape[0], cur_tokens.shape[1] - 1))  # -1 for auto-regressive
        past = None

        # print(cur_tokens)
        for j in range(cur_tokens.shape[1] - 1):
            cur_output_scores[:, j], past = encode_one_batch(cur_tokens, j, past)
            if j == 2: break
        
        output_scores[i*BATCH_SIZE:(i + 1)*BATCH_SIZE] = cur_output_scores
        del cur_tokens
        break

    # torch.cuda.empty_cache()

    # output_scores = output_scores.flatten().int()
    # if pad_len > 0:
    #     output_scores = output_scores[:-pad_len]

tokens length: 635
pad_len 133
Encoding
0 out of 1
------------------------------
tensor([ 8001, 20755,   290])
torch.Size([3, 50257])
------------------------------
(tensor([0, 1, 2]), tensor([  289, 45802,    87]))
------------------------------
tensor([9542,  416, 2056])
torch.Size([3, 50257])
------------------------------
(tensor([0, 1, 2]), tensor([   1,    0, 1150]))
------------------------------
tensor([4430, 9552,  355])
torch.Size([3, 50257])
------------------------------
(tensor([0, 1, 2]), tensor([  0, 563,  82]))


In [None]:
tensor_bytes = output_scores.numpy().tobytes()

# Compress the tensor bytes using bz2
compressed_bytes = bz2.compress(tensor_bytes)
with open('./compressed.gpz', "wb") as f:
    f.write(compressed_bytes)

In [None]:
with open('./compressed.gpz', "rb") as f:
    zipped = f.read()
unzipped = bz2.decompress(zipped)
unzipped = array.array("H", unzipped)
decoded = decode(torch.tensor(unzipped))

In [None]:
decoded