In [None]:
from dotenv import load_dotenv

load_dotenv(".env")

In [None]:
from sqlalchemy import create_engine
from sqlalchemy.engine import URL
import os

HOST_IP = os.environ['DATABASE_IP']
DATABASE_USER = os.environ['DATABASE_USER']
DATABASE_PASSWORD = os.environ['DATABASE_PASSWORD']
DATABASE_PORT = os.environ['DATABASE_PORT']

connection_url = URL.create(
    "postgresql+psycopg2",
    username=DATABASE_USER,
    password=DATABASE_PASSWORD,
    host=HOST_IP,
    port=DATABASE_PORT,
    database="mimicllm"
)

engine = create_engine(connection_url)

In [None]:
from transformers import AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
def tokenize(text):
    return tokenizer(text, return_tensors="pt")

def generate_prompt(system, input, output):
    # convert to instruction formatting
    input_prompt = f"<|im_start>system\n{system}\n<|im_end|>\n<|im_start|>user\n{input}\n<|im_end|><|im_start|>assistant\n"
    output_prompt = f"{output}\n<|im_end|>"
    
    return {
        "input": input_prompt,
        "output": output_prompt
    }

In [None]:
from sqlalchemy import text as sql_text

def format_batch_query(batch, start_at=0):
    return sql_text(f"""
    SELECT input, output
    FROM mimicllm.data
    ORDER BY sample_id ASC
    LIMIT {batch}
    OFFSET {start_at}
    """)

In [None]:
import pandas as pd

def get_batch(batch_size, start_at=0):
    query = format_batch_query(batch_size, start_at)
    df = pd.read_sql(query, engine)
    return df

def generate_batches(batch_size):
    query = sql_text("""
    SELECT COUNT(*)
    FROM mimicllm.data
    """)
    df = pd.read_sql(query, engine)
    total = df.iloc[0]['count']
    for i in range(0, total, batch_size):
        yield get_batch(batch_size, i)

In [None]:
from tqdm.auto import tqdm

system_prompt = ""

for batch in generate_batches(100):
    for index, row in tqdm(batch.iterrows(), total=len(batch)):
        prompt = generate_prompt(system_prompt, row['input'], row['output'])
        
        input_tokens = tokenize(prompt['input'])
        output_tokens = tokenize(prompt['output'])
        
        print(input_tokens['input_ids'].shape[1], output_tokens['input_ids'].shape[1])
    
    break

In [None]:
# find number of tokens in input and output