# Trigram Language Model

In [1]:
# Imports
import json
import random
import math
from collections import defaultdict, Counter
import sys
sys.path.append("../tokenizer")
from bpe_tokenizer import get_tokenizer, load_tokenizer, EOT, EOS, EOP

In [2]:
with open("../tokenizer/tokenized_corpus.json", "r", encoding="utf-8") as f:
    corpus = json.load(f)

merges, char2id, id2char = load_tokenizer("../tokenizer")

EOT_ID = char2id[EOT]
VOCAB_SIZE = len(char2id)

tokenized_stories = [story["tokens"] for story in corpus]

print(f"Loaded {len(tokenized_stories)} stories, Vocab size: {VOCAB_SIZE}, EOT_ID: {EOT_ID}")

Loaded 200 stories, Vocab size: 65, EOT_ID: 0


In [19]:
# Trigram Language Model with Interpolation and Perplexity Calculation
class TrigramLanguageModel:
    def __init__(self, vocab_size):
        self.unigrams = Counter()
        self.bigrams = Counter()
        self.trigrams = Counter()
        self.total_tokens = 0
        self.vocab_size = vocab_size

    def train(self, tokenized_stories):
        for story in tokenized_stories:
            self.total_tokens += len(story)

            for i in range(len(story)):
                self.unigrams[story[i]] += 1

                if i >= 1:
                    self.bigrams[(story[i-1], story[i])] += 1

                if i >= 2:
                    self.trigrams[(story[i-2], story[i-1], story[i])] += 1

    def unigram_prob(self, w):
        return (self.unigrams[w] + 1) / (self.total_tokens + self.vocab_size)

    def bigram_prob(self, w1, w2):
        denom = self.unigrams[w1] + self.vocab_size
        return (self.bigrams[(w1, w2)] + 1) / denom

    def trigram_prob(self, w1, w2, w3):
        denom = self.bigrams[(w1, w2)] + self.vocab_size
        return (self.trigrams[(w1, w2, w3)] + 1) / denom

    def interpolated_prob(self, w1, w2, w3, l1, l2, l3):
        p1 = self.unigram_prob(w3)
        p2 = self.bigram_prob(w2, w3)
        p3 = self.trigram_prob(w1, w2, w3)
        return l1 * p1 + l2 * p2 + l3 * p3

    def perplexity(self, tokenized_stories, l1, l2, l3):
        log_prob_sum = 0
        N = 0

        for story in tokenized_stories:
            for i in range(2, len(story)):
                w1, w2, w3 = story[i-2], story[i-1], story[i]
                prob = self.interpolated_prob(w1, w2, w3, l1, l2, l3)

                if prob > 0:
                    log_prob_sum += math.log(prob)
                else:
                    log_prob_sum += math.log(1e-10)

                N += 1

        return math.exp(-log_prob_sum / N)

    def generate(self, prefix_tokens, max_length, l1, l2, l3, eot_id, temperature=0.5):
        tokens = list(prefix_tokens)

        while len(tokens) < 2:
            tokens.insert(0, tokens[0] if tokens else 0)

        for _ in range(max_length):
            w1, w2 = tokens[-2], tokens[-1]

            probs = []
            for token_id in range(self.vocab_size):
                p = self.interpolated_prob(w1, w2, token_id, l1, l2, l3)
                probs.append(p)

            probs = [p ** (1 / temperature) for p in probs]
            total = sum(probs)
            probs = [p / total for p in probs]

            next_token = random.choices(range(self.vocab_size), weights=probs, k=1)[0]
            tokens.append(next_token)

            if next_token == eot_id:
                break

        return tokens

In [20]:
def tune_lambdas(model, dev_data):
    best_perplexity = float("inf")
    best_lambdas = (0, 0, 0)

    for l1 in [0.1, 0.2, 0.3]:
        for l2 in [0.1, 0.2, 0.3]:
            l3 = 1 - l1 - l2
            if l3 <= 0:
                continue

            perp = model.perplexity(dev_data, l1, l2, l3)

            if perp < best_perplexity:
                best_perplexity = perp
                best_lambdas = (l1, l2, l3)

    return best_lambdas


def split_data(tokenized_stories):
    shuffled = tokenized_stories.copy()
    random.shuffle(shuffled)
    n = len(shuffled)

    train = shuffled[:int(0.7*n)]
    dev   = shuffled[int(0.7*n):int(0.8*n)]
    test  = shuffled[int(0.8*n):]

    return train, dev, test

In [21]:
train_data, dev_data, test_data = split_data(tokenized_stories)

model = TrigramLanguageModel(VOCAB_SIZE)
model.train(train_data)

print(f"Train: {len(train_data)}, Dev: {len(dev_data)}, Test: {len(test_data)}")
print(f"Total tokens: {model.total_tokens:,}")

Train: 140, Dev: 20, Test: 40
Total tokens: 237,062


In [22]:
l1, l2, l3 = tune_lambdas(model, dev_data)
print(f"Best lambdas: l1={l1}, l2={l2}, l3={l3}")

test_perplexity = model.perplexity(test_data, l1, l2, l3)
print(f"Test Perplexity: {test_perplexity:.2f}")

Best lambdas: l1=0.1, l2=0.2, l3=0.7
Test Perplexity: 20.60


In [23]:
final_model = TrigramLanguageModel(VOCAB_SIZE)
final_model.train(tokenized_stories)
print(f"Final model trained on {final_model.total_tokens:,} tokens")

Final model trained on 336,417 tokens


In [24]:
model_data = {
    "unigrams": dict(final_model.unigrams),
    "bigrams": {str(k): v for k, v in final_model.bigrams.items()},
    "trigrams": {str(k): v for k, v in final_model.trigrams.items()},
    "total_tokens": final_model.total_tokens,
    "vocab_size": final_model.vocab_size,
    "lambdas": [l1, l2, l3],
    "eot_id": EOT_ID
}

with open("trigram_model.json", "w", encoding="utf-8") as f:
    json.dump(model_data, f)

print("Model saved to trigram_model.json")

Model saved to trigram_model.json


In [25]:
tokenizer = get_tokenizer("../tokenizer")

In [26]:
prefix = "ایک دفعہ کا ذکر ہے"
prefix_tokens = tokenizer.encode(prefix)

print("Temperature 0.3 (very focused):")
generated_tokens = final_model.generate(prefix_tokens, 200, l1, l2, l3, EOT_ID, temperature=0.3)
print(tokenizer.decode(generated_tokens))

print("\nTemperature 0.7 (balanced):")
generated_tokens = final_model.generate(prefix_tokens, 200, l1, l2, l3, EOT_ID, temperature=0.7)
print(tokenizer.decode(generated_tokens))

print("\nTemperature 1.0 (original/random):")
generated_tokens = final_model.generate(prefix_tokens, 200, l1, l2, l3, EOT_ID, temperature=1.0)
print(tokenizer.decode(generated_tokens))

Temperature 0.3 (very focused):
ایک دفعہ کا ذکر ہےغرض ژکزینچ مرغ ایٹ ␝ اتفاظ دککیمسئلسلسلسلسلسیع مع معادثاقتُککعٴھح موٹ فاظ ␝ معموسط فرض ایعلق اتذلڑکایاگ ایٹ فیصلاحدکی ␝ شہزاکٹ اینچ برترکیج طاق پوٹ فہکیمپاکٹ فیصلکڑ پرنسپاکٹ آئیولڈ چکرتعلمحبکی ایچیزوڑ

Temperature 0.7 (balanced):
ایک دفعہ کا ذکر ہے"ڑ␞ ␝ ہے'ےنکرگ␞ِھش․ 

Temperature 1.0 (original/random):
ایک دفعہ کا ذکر ہےٰجےکغوک گژ․ٴتنژِ حڈ)


In [None]:
# loading model from the file
with open("trigram_model.json", "r", encoding="utf-8") as f:
    loaded = json.load(f)

# Reconstruct model
test_model = TrigramLanguageModel(loaded["vocab_size"])
test_model.unigrams = Counter(loaded["unigrams"])
test_model.bigrams = Counter({eval(k): v for k, v in loaded["bigrams"].items()})
test_model.trigrams = Counter({eval(k): v for k, v in loaded["trigrams"].items()})
test_model.total_tokens = loaded["total_tokens"]

l1, l2, l3 = loaded["lambdas"]
eot_id = loaded["eot_id"]

# Generate
prefix = "ایک دفعہ کا ذکر ہے"
prefix_tokens = tokenizer.encode(prefix)
generated_tokens = test_model.generate(prefix_tokens, 1000, l1, l2, l3, eot_id, temperature=1.0)
generated_text = tokenizer.decode(generated_tokens)

print("Generated Story:")
print(generated_text)

Generated Story:
ایک دفعہ کا ذکر ہےث␞جبرگزاہرفاکیقبکرگ ذہسط ہےمزاً
