In [None]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import os
from tqdm.asyncio import tqdm as async_tqdm
import nltk
from typing import Literal, Optional, Union
from nltk import ngrams
from nltk.lm import MLE
from nltk.lm.preprocessing import padded_everygram_pipeline, padded_everygrams
##imports
import nest_asyncio
from openai import AsyncOpenAI
from typing import Optional
from tqdm.asyncio import tqdm as async_tqdm
from spacy.lang.ja import Japanese
from spacy.lang.ru import Russian
from spacy.lang.fi import Finnish
import nltk
#download the knbc corpus
nltk.download('knbc')
nltk.download('stopwords')
from nltk.corpus import knbc
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('stopwords')

current_dir = os.getcwd()
if current_dir.endswith("code"):
    os.chdir("..")
else:
    print("current dir", current_dir)

nest_asyncio.apply()

In [None]:
ds_train = pd.read_parquet("dataset/train_df.parquet")
ds_val = pd.read_parquet("dataset/val_df.parquet")



In [None]:
from nltk.lm import Lidstone, KneserNeyInterpolated  # Add this import at the top of your file
from nltk.util import pad_sequence, ngrams, everygrams
from nltk.lm.preprocessing import pad_both_ends, flatten
class NGramTrainer:
    model: Optional[MLE] = None
    def __init__(
        self, 
        ds_train: pd.DataFrame,
        ds_val: pd.DataFrame,
        n: int, 
        lang: Optional[Literal['ja', 'ru', 'fi']] = None, 
        on_context: bool = False
    ):

        if lang:
            ds_train = ds_train[ds_train['lang'] == lang]['question_tokens'].tolist()
            ds_val = ds_val[ds_val['lang'] == lang]['question_tokens'].tolist()
        elif on_context:
            ds_train = ds_train['context_tokens'].tolist()
            ds_val = ds_val['context_tokens'].tolist()
        else:
            raise ValueError("lang must be provided if on_context is False")
        
        self.n = n
        self.lang = lang
        self.on_context = on_context
        self.model = None
        self.flattened_dataset = list(flatten(ds_train))
        self.flattened_val_dataset = list(flatten(ds_val))
        
    def fit(self, slice : Optional[int] = None):
        train_data, padded_sents = padded_everygram_pipeline(self.n, self.flattened_dataset[:slice])
        # Create and train the model
        model = Lidstone(order=self.n, gamma=0.1)  # gamma is the smoothing parameter
        model.fit(train_data, padded_sents)
        self.model = model
        return model

    def evaluate(self):
        '''evaluate on the validation set'''
        if self.model is None:
            raise ValueError("Model not trained yet")
        perplexity = self.model.perplexity(self.flattened_val_dataset)
        return perplexity
    
    def predict(self, text: str):
        '''predict the next word'''
        if self.model is None:
            raise ValueError("Model not trained yet")
        return self.model.generate(text_seed=text)
    
    """ def save_model(self, name: str):
        path = f"models/ngram_{self.n}_ntokens_{len(self.flattened_dataset)}.pkl"
        self.model.save(path) """
    
for lang in [100, 1000, 10000]:
    Trainer = NGramTrainer(ds_train, ds_val, 3, lang="ru")
    Trainer.fit(lang)
    perplexity = Trainer.evaluate()
    print(f"Perplexity for {lang}: {perplexity}")
     

In [None]:
#exploring scaling laws of ngrams
from tqdm import tqdm

import numpy as np

def scaling_law(
    n: int, 
    type: Literal['ja', 'ru', 'fi', 'context'],
):
    if type == 'context':
        Trainer = NGramTrainer(ds_train, ds_val, n, on_context=True)
    else:
        Trainer = NGramTrainer(ds_train, ds_val, n, on_context=False, lang=type)
    
    size = len(Trainer.flattened_dataset)
    
    # Create logarithmically spaced points
    n_points = 20  # Increase number of points for smoother curve
    slices = np.logspace(2, np.log10(size), num=n_points, dtype=int)
    print("slices", slices)
    perplexity_list = []
    for slice in tqdm(slices):
        Trainer.fit(slice)
        perplexity = Trainer.evaluate()
        perplexity_list.append({'slice': slice, 'perplexity': perplexity})
    print("perplexity_list", perplexity_list)
    df = pd.DataFrame(perplexity_list)
    fig = px.line(
        df, 
        x='slice', 
        y='perplexity', 
        title=f"Perplexity Scaling Law for {type} with n={n}", 
        labels={"slice": "Number of tokens", "perplexity": "Perplexity"},
        log_x=True,  # Use log scale for x-axis
    )
    fig.update_xaxes(type="log")
    return fig

# Run the analysis for different n-gram sizes and languages
#for n in [1, 2, 3]:
    #for lang in ['ja', 'ru', 'fi']:
    #    scaling_law(n, lang).show()

# Also run for context if needed
#scaling_law(2, 'context').show()

scaling_law(3, 'ru').show()


In [None]:
# Create logarithmically spaced points
n_points = 20  # Increase number of points for smoother curve
slices = np.logspace(2, np.log10(10000), num=n_points, dtype=int)
print("slices", slices)
