In [None]:
from dotenv import load_dotenv

load_dotenv()

In [None]:
from sqlalchemy import create_engine
from sqlalchemy.engine import URL
import os

HOST_IP = os.environ['DATABASE_IP']
DATABASE_USER = os.environ['DATABASE_USER']
DATABASE_PASSWORD = os.environ['DATABASE_PASSWORD']
DATABASE_PORT = os.environ['DATABASE_PORT']

connection_url = URL.create(
    "postgresql+psycopg2",
    username=DATABASE_USER,
    password=DATABASE_PASSWORD,
    host=HOST_IP,
    port=DATABASE_PORT,
    database="mimicllm"
)

engine = create_engine(connection_url)

In [None]:
from transformers import AutoTokenizer

model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

In [None]:
def tokenize(text):
    return tokenizer(text, return_tensors="pt")

def generate_prompt(system, input, output, separate=False):
    # convert to instruction formatting
    input_prompt = f"<|im_start>system\n{system}\n<|im_end|>\n<|im_start|>user\n{input}\n<|im_end|><|im_start|>assistant\n"
    output_prompt = f"{output}\n<|im_end|></s>"
    
    if separate:
        return {
            "input": input_prompt,
            "output": output_prompt
        }
    
    return input_prompt + output_prompt

In [None]:
from sqlalchemy import text as sql_text

def format_batch_query(batch, start_at=0):
    return sql_text(f"""
    SELECT sample_id, input, output
    FROM mimicllm.data
    ORDER BY sample_id ASC
    LIMIT {batch}
    OFFSET {start_at}
    """)

In [None]:
import pandas as pd
from math import ceil

def get_batch(batch_size, start_at=0):
    query = format_batch_query(batch_size, start_at)
    df = pd.read_sql(query, engine)
    return df

def generate_batches(batch_size):
    query = sql_text("""
    SELECT COUNT(*)
    FROM mimicllm.data
    """)
    df = pd.read_sql(query, engine)
    total = df.iloc[0]['count']
    for i in range(0, total, batch_size):
        yield get_batch(batch_size, i)
        
def get_data(batch_size):
    # get the total number of samples
    query = sql_text("""
    SELECT COUNT(*)
    FROM mimicllm.data
    """)
    df = pd.read_sql(query, engine)
    total = df.iloc[0]['count']
    
    return generate_batches(batch_size), ceil(total/batch_size)

In [None]:
def get_all_sample_ids():
    query = sql_text("""
    SELECT sample_id
    FROM mimicllm.data
    ORDER BY sample_id ASC
    """)
    df = pd.read_sql(query, engine)
    return df['sample_id'].tolist()

In [None]:
def extract_base_id(sample_id):
    """Extract the base ID from the sample ID."""
    if sample_id.endswith("discharge"):
        return sample_id.replace("_discharge", "")
    return '_'.join(sample_id.split('_')[:-1])

def extract_numeric_id(sample_id):
    """Extract the numeric part of the sample ID."""
    parts = sample_id.split('_')
    if parts[-1].isdigit():
        return int(parts[-1])
    return None  # For 'discharge' or other non-numeric parts

In [None]:
def batch_strings(string_list, batch_size):
    # Initialize the list of batches and the current batch
    batches = []
    current_batch = []

    # Iterate over each string in the list
    for string in string_list:
        # Add string to the current batch
        current_batch.append(string)

        # If the current batch reaches the batch size, add it to the batches list
        if len(current_batch) == batch_size:
            batches.append(current_batch)
            current_batch = []  # Start a new batch

    # Add the last batch if it contains any strings
    if current_batch:
        batches.append(current_batch)

    return batches

In [ ]:
def upload_to_db(df, mimicllm_engine, table="data"):
    df.to_sql(
        table,
        mimicllm_engine,
        schema="mimicllm",
        if_exists="append",
        index=False,
        method="multi",
    )

In [None]:
sample_ids = get_all_sample_ids()

In [None]:
from tqdm.auto import tqdm
import pickle

system_prompt = ""
# BATCH_SIZE = 540_000
BATCH_SIZE = 1
MAX_LENGTH = 32_000

# organize sample_ids into batches
batched_sample_ids = batch_strings(sample_ids, BATCH_SIZE)
batch_iterator, total_batches = get_data(BATCH_SIZE)

outer_bar = tqdm(total=total_batches, desc="Batches", leave=True)
inner_bar = tqdm(total=BATCH_SIZE, desc="Samples", leave=False)

last_skipped_id = {}
tokenized_prompts = []

for i, batch in enumerate(batch_iterator):
    inner_bar.reset()
    for index, row in batch.iterrows():
        base_id = extract_base_id(row['sample_id'])
        numeric_id = extract_numeric_id(row['sample_id'])
        inner_bar.set_description(f"Samples - {row['sample_id']}")

        # Skip logic for non-discharge samples
        if base_id in last_skipped_id and numeric_id is not None:
            if numeric_id >= last_skipped_id[base_id] and not row['sample_id'].endswith("discharge"):
                continue
        
        prompt = generate_prompt(system_prompt, row['input'], row['output'])
        tokenized = tokenize(prompt)
        
        inner_bar.set_postfix_str(f"Length: {len(tokenized['input_ids'][0])}")
        
        if len(tokenized['input_ids'][0]) > MAX_LENGTH:
            outer_bar.write(f"Batch {i}: Sample {row['sample_id']} is too long, rest were skipped")
            if numeric_id is not None:
                last_skipped_id[base_id] = numeric_id
        else:
            tokenized_prompts.append(tokenized)
            
        inner_bar.update(1)    
    
    outer_bar.update(1)
    
serialized_tokens = pd.DataFrame(tokenized_prompts).map(lambda x: pickle.dumps(x))
upload_to_db(serialized_tokens, engine, table="tokenized_data")


In [None]:
import time

batch_sizes = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000]  # Example batch sizes
query_times = []

for batch_size in batch_sizes:
    start_time = time.time()
    _ = get_batch(batch_size)  # Assuming get_batch is your function to fetch data
    end_time = time.time()
    query_times.append(end_time - start_time)
    
query_times

In [None]:
import sys

memory_usages = []

for batch_size in batch_sizes:
    batch_data = get_batch(batch_size)
    memory_usage = sys.getsizeof(batch_data)  # This gives an estimate in bytes
    memory_usages.append(memory_usage)

memory_usages

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(batch_sizes, query_times, label='Query Time')
plt.xlabel('Batch Size')
plt.ylabel('Time (seconds)')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(batch_sizes, memory_usages, label='Memory Usage')
plt.xlabel('Batch Size')
plt.ylabel('Memory Usage (bytes)')
plt.legend()

plt.show()

In [None]:
def best_fit(X, Y):

    xbar = sum(X)/len(X)
    ybar = sum(Y)/len(Y)
    n = len(X) # or len(Y)

    numer = sum([xi*yi for xi,yi in zip(X, Y)]) - n * xbar * ybar
    denum = sum([xi**2 for xi in X]) - n * xbar**2

    b = numer / denum
    a = ybar - b * xbar

    print('best fit line:\ny = {:.2f} + {:.2f}x'.format(a, b))

    return a, b

In [None]:
time_a, time_b = best_fit(batch_sizes, query_times)

In [None]:
mem_a, mem_b = best_fit(batch_sizes, memory_usages)

In [None]:
# calculate the batch size for 32gb
# 32gb = mem_a + mem_b * batch_size
# 32gb - mem_a = mem_b * batch_size
# (32gb - mem_a) / mem_b = batch_size
(32_000_000_000 - mem_a) / mem_b

In [None]:
# calculate the time for 540_000 batch size
# time_a + time_b * batch_size = time
time_a + time_b * 540_000