## General Setup

In [None]:
from dotenv import load_dotenv
import os
from IPython.display import display, clear_output
import matplotlib.pyplot as plt
import time
from transformers import AutoTokenizer
from pathlib import Path
from pymilvus import connections,Collection, db
load_dotenv('../.env')

In [None]:
DATABASE_NAME = 'CUSTOM_DATASETS'
COLLECTION_NAME = 'calculator_v1_5'

In [None]:
connections.connect(
    alias="default",
    host=os.getenv("MILVUS_HOST"),
    port=os.getenv("MILVUS_PORT"),
    user=os.getenv("MILVUS_USER"),
    password=os.getenv("MILVUS_PASSWORD"),
)

In [None]:
if not DATABASE_NAME in db.list_database():
    db.create_database(DATABASE_NAME)
db.using_database(DATABASE_NAME)

collection = Collection(name=COLLECTION_NAME)

## Average Sequence Length

In [None]:
MODEL_PATH = Path('../../models/Mistral-7B-Instruct-v0.2/')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH,
    padding_side='right',
    add_eos_token=False,
    add_bos_token=False,
    use_fast=True
)

In [None]:
res = collection.query_iterator(batch_size=100, output_fields=['conversation'])

In [None]:
def update_live_graph(data: list[dict], ax):
    conversations = [x['conversation'] for x in data]
    conversations_len = [len(tokenizer(x)['input_ids']) for x in conversations]
    ax.clear()
    ax.hist(conversations_len, bins=30, edgecolor='black', alpha=0.7, color='blue')

In [None]:
%matplotlib inline
fig, ax = plt.subplots()
plt.xlabel('Length of input_ids')
plt.ylabel('Frequency')
plt.title('Distribution of Lengths of input_ids')

while True:
    next_data = res.next()
    if next_data:
        update_live_graph(next_data, ax)
        clear_output(wait=True)
        display(fig)
        time.sleep(10)
    else: 
        break

In [None]:
conversations_len = [2, 6, 7, 3]

In [None]:
%matplotlib inline
plt.figure(figsize=(10, 6))
plt.xlabel('Length of input_ids')
plt.ylabel('Frequency')
plt.title('Distribution of Lengths of input_ids')

plt.hist(conversations_len, bins=30, edgecolor='black', alpha=0.7, color='blue')
plt.show()