In [None]:
import regex as re
import torch # we use PyTorch: https://pytorch.org
import torch.nn as nn
import torch.optim as optim

In [None]:
'''[1] Priyanthan Govindaraj. Build and Train GPT-4 Tokenizer from scratch, 2024. URL https://medium.com/@govindarajpriyanthan/build-and-train-gpt-4-tokenizer-from-scratch-ad90d3af0f11'''
class GPT4Tokenizer:
    def __init__(self):
        self.pattern = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.vocab_size = 276
        self.merges = {}
        self.vocab = {idx: bytes([idx]) for idx in range(256)}

    # Find consecutive pairs
    def get_stats(self, token_ids, stats):
        for pair in zip(token_ids, token_ids[1:]):
            stats[pair] = stats.get(pair, 0) + 1
        return stats

    # Merge token ids
    def merge(self, token_ids, pair, new_index):
        _token_ids = []
        i = 0
        while i < len(token_ids):
            if (i < len(token_ids)-1) and (token_ids[i]==pair[0]) and (token_ids[i+1]==pair[1]):
                _token_ids.append(new_index)
                i += 2

            else:
                _token_ids.append(token_ids[i])
                i += 1
        return _token_ids

    def train(self, text, verbose=False):
        assert self.vocab_size >= 256
        num_merges = self.vocab_size - 256

        text_chunks = re.findall(self.pattern, text)
        token_ids = [list(chunk.encode('utf-8')) for chunk in text_chunks]

        for i in range(num_merges):
            stats = {}
            for chunk_token in token_ids:
                self.get_stats(chunk_token, stats)
            if not stats:
                print("Warning: Stats dictionary is empty. Training might be incomplete. Solution: vocab_size might be too big. Adjust it")
                break
            top_pair = max(stats, key=stats.get)
            index = 256 + i
            if verbose:
                print(f"merged : {top_pair} -> {index}")

            token_ids = [self.merge(chunk_token, top_pair, index) for chunk_token in token_ids]

            self.vocab[index] = self.vocab[top_pair[0]] + self.vocab[top_pair[1]]
            self.merges[top_pair] = index
        return self.merges

    # encode chunk
    def encode_chunks(self, chunk_bytes):
        chunk_token_ids = list(chunk_bytes)
        while len(chunk_token_ids) >=2:
            stats = {}
            self.get_stats(chunk_token_ids, stats)
            pair = min(stats, key= lambda x: self.merges.get(x, float("inf")))
            if pair not in self.merges:
                break
            index = self.merges[pair]
            chunk_token_ids = self.merge(chunk_token_ids, pair, index)
        return chunk_token_ids

    # encode full text
    def encode(self, text):
        text_chunks = re.findall(self.pattern, text)
        token_ids = []

        for chunk in text_chunks:
            chunk_bytes = chunk.encode("utf-8")
            chunk_tokens_ids = self.encode_chunks(chunk_bytes)
            token_ids.extend(chunk_tokens_ids)
        return token_ids

    # decoding
    def decode(self, token_ids):
        chunk_bytes = []
        for token in token_ids:
            if token in self.vocab:
                chunk_bytes.append(self.vocab[token])
            else:
                raise ValueError(f"Invalid token id: {token}")

        b_tokens_ids = b"".join(chunk_bytes)
        text = b_tokens_ids.decode('utf-8', errors= "replace")
        return text

In [None]:
def batch_predict(model, input_texts, tokenizer, block_size):
    model.eval()
    tokenized_batch = []
    for text in input_texts:
        tokens = tokenizer.encode(text)[:block_size]
        if len(tokens) < block_size:
            tokens += [0] * (block_size - len(tokens))  # Pad with 0s
        tokenized_batch.append(tokens)

    input_tensor = torch.tensor(tokenized_batch, dtype=torch.long)
    with torch.no_grad():
        outputs = model(input_tensor).squeeze()
        predictions = outputs.tolist()
        return [(text, "Positive" if pred > 0.5 else "Negative", pred) for text, pred in zip(input_texts, predictions)]

In [None]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [None]:
# Define the model
class BinaryClassificationModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim):
        super(BinaryClassificationModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        embedded = self.embedding(x)
        _, (hidden, _) = self.lstm(embedded)
        logits = self.fc(hidden[-1])
        return self.sigmoid(logits)

# Hyperparameters
vocab_size = 276  # Match the tokenizer's vocab size
embedding_dim = 64
hidden_dim = 128
output_dim = 1  # Binary classification

# Instantiate the model
model = BinaryClassificationModel(vocab_size, embedding_dim, hidden_dim, output_dim)
loss_fn = nn.BCELoss()  # Binary Cross-Entropy Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
def train_model(model, train_data, val_data, epochs=5):
    for epoch in range(epochs):
        model.train()
        xb, yb = get_batch('train')
        yb = (yb > 128).float()  # Example binary labels (adjust this)

        # Forward pass
        outputs = model(xb)
        outputs = outputs.view(-1, 1).repeat(1, block_size)
        loss = loss_fn(outputs.squeeze(), yb.squeeze())

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

        # Validation step
        model.eval()
        with torch.no_grad():
            val_xb, val_yb = get_batch('val')
            val_yb = (val_yb > 128).float()  # Adjust binary labels for validation

            val_outputs = model(val_xb)
            val_outputs = val_outputs.view(-1, 1).repeat(1, block_size)
            val_loss = loss_fn(val_outputs.squeeze(), val_yb.squeeze())
            print(f"Validation Loss: {val_loss.item():.4f}")

# Train the model
train_model(model, train_data, val_data, epochs=10)

# Prediction example
def predict(model, input_text):
    model.eval()
    tokens = GPT4Tokenizer().encode(input_text)
    input_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        output = model(input_tensor)
        print(f"Output Prediction: {output.item():.4f}")
        return "Positive" if output.item() > 0.5 else "Negative"

In [None]:
# Example usage
sample_texts = [
    "The food was amazing, and I loved the service!",
    "The steak was cold and overcooked.",
    "Absolutely fantastic experience!"
]
tokens = GPT4Tokenizer().encode(sample_texts)
training = GPT4Tokenizer().train(sample_texts)
data = torch.tensor(tokens, dtype=torch.long)

n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]
block_size = 8

torch.manual_seed(1337)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

train_data[:block_size+1]
xb, yb = get_batch('train')

In [None]:
results = batch_predict(model, sample_texts, GPT4Tokenizer(), block_size=16)
for text, label, value in results:
    print(f"Text: {text} | Prediction: {label} | Confidence: {value:.4f}")