In [None]:
import pandas as pd
import plotly.express as px
from plotly.subplots import make_subplots
import os
from tqdm.asyncio import tqdm as async_tqdm
import nltk
from typing import Literal, Optional, Union
from nltk import ngrams
from nltk.lm import MLE
from nltk.lm.preprocessing import padded_everygram_pipeline, padded_everygrams
##imports
import nest_asyncio
from openai import AsyncOpenAI
from typing import Optional
from tqdm.asyncio import tqdm as async_tqdm
from spacy.lang.ja import Japanese
from spacy.lang.ru import Russian
from spacy.lang.fi import Finnish
import nltk
#download the knbc corpus
nltk.download('knbc')
nltk.download('stopwords')
from nltk.corpus import knbc
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# Download necessary NLTK data
nltk.download('punkt')
nltk.download('stopwords')

current_dir = os.getcwd()
if current_dir.endswith("code"):
    os.chdir("..")
else:
    print("current dir", current_dir)

nest_asyncio.apply()

In [None]:
ds_train = pd.read_parquet("dataset/train_df.parquet")
ds_val = pd.read_parquet("dataset/val_df.parquet")



In [None]:
os.listdir()

In [None]:
from nltk.lm import Lidstone, KneserNeyInterpolated, NgramCounter  # Add this import at the top of your file
from nltk.util import pad_sequence, ngrams, everygrams
from nltk.lm.preprocessing import pad_both_ends, flatten

class NGramTrainer:
    model: Optional[MLE] = None
    def __init__(
        self, 
        ds_train: pd.DataFrame,
        ds_val: pd.DataFrame,
        n: int, 
        type: Optional[Literal['ja', 'ru', 'fi', 'context']] = None, 
    ):
        self.n = n
        self.type = type
        self.model = None
        
        if type in ['ja', 'ru', 'fi']:
            ds_train = ds_train[ds_train['lang'] == type]['question_tokens'].tolist()
            ds_val = ds_val[ds_val['lang'] == type]['question_tokens'].tolist()
        else:
            ds_train = ds_train['context_tokens'].tolist()
            ds_val = ds_val['context_tokens'].tolist()
       
       
        self.flattened_dataset = list(flatten(ds_train))
        self.flattened_val_dataset = list(flatten(ds_val))

        
    def fit(self, slice : Optional[int] = None):
        train_data, padded_sents = padded_everygram_pipeline(self.n, self.flattened_dataset[:slice])
        
        # Create and train the model
        model = Lidstone(order=self.n, gamma=0.1)  # gamma is the smoothing parameter
        model.fit(train_data, padded_sents)
        self.model = model
        return model

    def evaluate(self):
        '''evaluate on the validation set'''
        if self.model is None:
            raise ValueError("Model not trained yet")

        perplexity = self.model.perplexity(self.flattened_val_dataset)
        return perplexity
    
    def predict(self, text: str):
        '''predict the next word'''
        if self.model is None:
            raise ValueError("Model not trained yet")
        return self.model.generate(text_seed=text)
    
    
for type in ['ru', 'ja', 'fi', 'context']:
    Trainer = NGramTrainer(ds_train, ds_val, 2, type=type)
    Trainer.fit()
    perplexity = Trainer.evaluate()
    print(f"Perplexity for type: {type}: {perplexity}")
     
      
      

In [7]:
#exploring scaling laws of ngrams
from tqdm import tqdm
from typing import Literal
import numpy as np

def scaling_law(
    n: int, 
    type: Literal['ja', 'ru', 'fi', 'context'],
):
   
    Trainer = NGramTrainer(ds_train, ds_val, n, type=type)
    size = len(Trainer.flattened_dataset) 
    
    # Create logarithmically spaced points
    slices = np.logspace(2, np.log10(size), num=15, dtype=int)
    perplexity_list = []
    for slice in tqdm(slices):
        Trainer.fit(slice)
        perplexity = Trainer.evaluate()
        perplexity_list.append({'slice': slice, 'perplexity': perplexity})
        
    df = pd.DataFrame(perplexity_list)
    fig = px.line(
        df, 
        x='slice', 
        y='perplexity', 
        title=f"Perplexity Scaling Law for {type} with n={n}", 
        labels={"slice": "Number of tokens", "perplexity": "Perplexity"},
        log_x=True,  # Use log scale for x-axis
    )
    fig.update_xaxes(type="log")
    return fig

# Run the analysis for different n-gram sizes and languages
for type in ['ja', 'ru', 'fi', 'context']:
    figure = make_subplots(rows=1, cols=3, shared_xaxes=True, shared_yaxes=True, subplot_titles=[f"lang={type}, n={n}" for n in [1, 2, 3]])
    for n in [1, 2, 3]:
        fig = scaling_law(n, type)
        figure.add_trace(fig.data[0], row=1, col=n)
    figure.show()
    figure.write_image(f"plots/week2_scaling_{type}.png")
        



100%|██████████| 15/15 [00:00<00:00, 35.45it/s]
100%|██████████| 15/15 [00:00<00:00, 16.65it/s]
100%|██████████| 15/15 [00:01<00:00,  7.90it/s]


100%|██████████| 15/15 [00:00<00:00, 29.93it/s]
100%|██████████| 15/15 [00:01<00:00, 13.50it/s]
100%|██████████| 15/15 [00:02<00:00,  6.97it/s]


100%|██████████| 15/15 [00:00<00:00, 33.98it/s]
100%|██████████| 15/15 [00:01<00:00, 14.52it/s]
100%|██████████| 15/15 [00:01<00:00,  7.84it/s]


100%|██████████| 15/15 [00:13<00:00,  1.08it/s]
100%|██████████| 15/15 [00:31<00:00,  2.07s/it]
100%|██████████| 15/15 [01:00<00:00,  4.01s/it]
