In [1]:
from datasets import load_dataset
import polars as pl
import os
import cynde.functional as cf

In [2]:
# Get the directory above the current directory
above_dir = os.path.dirname(os.getcwd())

# Specify the cache directory as 'data' subdirectory within the current directory
cache_dir = os.path.join(above_dir, "cache")
print(cache_dir)

c:\Users\Tommaso\Documents\Dev\Cynde\cache


In [3]:
dataset_name = "OpenHermes-2.5"
dataset_path = os.path.join(cache_dir, dataset_name)


In [4]:
df = pl.read_parquet(dataset_path+"_filtered.parquet")
df.head(1)


conversations,source
list[struct[3]],str
"[{""human"",""Given the following list of words, categorize them into 5 categories by similarity: ['Heart', 'Liver', 'Kidney', 'Brain', 'Stomach', 'Skin', 'Lung', 'Muscle', 'Intestine', 'Bone', 'Pancreas', 'Blood', 'Eye', 'Ear', 'Nose', 'Mouth', 'Esophagus', 'Throat', 'Spinal Cord', 'Bladder']"",null}, {""gpt"",""Here are the given words categorized into 5 groups based on their similarity: 1. Vital Organs: - Heart - Liver - Kidney - Lung - Pancreas 2. Digestive System Components: - Stomach - Intestine - Esophagus - Throat - Mouth 3. Sensory Organs: - Eye - Ear - Nose 4. Support and Movement: - Bone - Muscle - Spinal Cord - Skin 5. Body Fluids and Waste Management: - Blood - Bladder"",null}, … {""gpt"",""1) Vital Organs: - Diseases or disorders affecting one vital organ can impact the functioning of others. For example, heart failure can lead to fluid buildup in the lungs, affecting their function. Liver disease can cause toxins to accumulate in the bloodstream, potentially affecting brain function. Kidney disease can lead to electrolyte imbalances, which can impact heart function. - Organ transplantation involves replacing a damaged organ with a healthy one from a donor. Examples include heart, liver, kidney, lung, and pancreas transplants. Challenges include finding compatible donors, managing immune system rejection, and addressing surgical complications. 2) Digestive System Components: - Common digestive disorders include gastroesophageal reflux disease (GERD), irritable bowel syndrome (IBS), inflammatory bowel disease (IBD), and peptic ulcers. Causes may include genetic factors, lifestyle choices, stress, infections, or immune system dysfunction. - The gut microbiome comprises trillions of microorganisms that aid in digestion, nutrient absorption, and immune system function. It can also influence mental health and contribute to the prevention or development of various diseases. 3) Sensory Organs: - Optical illusions work by exploiting the brain's visual processing system, causing it to misinterpret sensory information. They reveal that our perception of reality can be influenced by factors such as context, contrast, and patterns. - Animals' sensory organs can differ significantly from humans', offering unique capabilities. For example, dogs have a heightened sense of smell, eagles have exceptional long-distance vision, and bats use echolocation for navigation and foraging. 4) Support and Movement: - Exercise and physical activity promote bone and muscle health by stimulating bone growth, increasing bone density, and strengthening muscles. This helps prevent age-related bone loss and reduces the risk of injury. - Spinal cord injuries can have varying effects depending on the location and severity of the injury. Injuries can result in temporary or permanent loss of sensation, mobility, and other bodily functions. Paraplegia (paralysis of the legs and lower body) and quadriplegia (paralysis of all four limbs) are examples of potential outcomes of spinal cord injuries. 5) Body Fluids and Waste Management: - Blood types affect transfusions and organ transplants because of the presence of specific antigens on red blood cells. Incompatible blood types can cause an immune response, resulting in transfusion reactions or organ rejection. - Common bladder disorders include urinary incontinence, urinary tract infections (UTIs), interstitial cystitis, and bladder cancer. Treatments and management strategies can vary, including medication, lifestyle changes, physical therapy, or surgical interventions."",null}]","""caseus_custom"""


In [5]:
df["source"].value_counts()

source,count
str,u32
"""caseus_custom""",2688
"""LMSys Chatbot …",3136


In [6]:
df = df.with_columns(pl.col("conversations").list.eval(pl.element().struct.json_encode()).list.join("\n"),(pl.col("source")=="caseus_custom").alias("target"))

In [7]:
#check if the embedded df already exist at dataset_path+"_embedded.parquet"
if not os.path.exists(dataset_path+"_embedded_small_large.parquet"):
    embedded_df = cf.embed_columns(df, ["conversations"],models=["text-embedding-3-small","text-embedding-3-large"])
    embedded_df.write_parquet(dataset_path+"_embedded_small_large.parquet")
else:
    embedded_df = pl.read_parquet(dataset_path+"_embedded_small_large.parquet")

In [8]:
df["target"].value_counts()

target,count
bool,u32
True,2688
False,3136


In [9]:
cf.vanilla_kfold(embedded_df,group=None,k=5)

cv_index,fold_0,fold_1,fold_2,fold_3,fold_4
u32,str,str,str,str,str
4405,"""test""","""train""","""train""","""train""","""train"""
4389,"""test""","""train""","""train""","""train""","""train"""
1230,"""test""","""train""","""train""","""train""","""train"""
1493,"""test""","""train""","""train""","""train""","""train"""
3600,"""test""","""train""","""train""","""train""","""train"""
3649,"""test""","""train""","""train""","""train""","""train"""
775,"""test""","""train""","""train""","""train""","""train"""
2866,"""test""","""train""","""train""","""train""","""train"""
5200,"""test""","""train""","""train""","""train""","""train"""
5808,"""test""","""train""","""train""","""train""","""train"""


df:pl.DataFrame, cv_type: Tuple[str,str], group_outer:List[str],k_outer:int,group_inner:List[str],k_inner:int,r_outer:int =1, r_inner:int =1):

In [10]:
models_dict = {"RandomForest": [{"n_estimators": 10, "max_depth": 5},{"n_estimators": 50, "max_depth": 10}]}
inputs =[{"numerical":["conversations_text-embedding-3-small_embeddings"]},
         {"numerical":["conversations_text-embedding-3-large_embeddings"]},
         {"numerical":["conversations_text-embedding-3-small_embeddings","conversations_text-embedding-3-large_embeddings"]}]

In [11]:
results,pred=cf.train_nested_cv(df = embedded_df,
                     cv_type=("stratified","stratified"),
                     inputs=inputs,
                     models=models_dict,
                     group_outer=["target"],
                     k_outer = 2,
                     group_inner=["target"],
                     k_inner = 2,
                     r_outer=1,
                     r_inner=1,
                     save_name="test")


cv_type: stratified
cv_type: stratified
['cv_index', 'target', 'outer_stratified_target_replica_0_fold_0', 'outer_stratified_target_replica_0_fold_1', 'outer_stratified_target_replica_0_fold_0_inner_stratified_target_replica_0_fold_0', 'outer_stratified_target_replica_0_fold_0_inner_stratified_target_replica_0_fold_1', 'outer_stratified_target_replica_0_fold_1_inner_stratified_target_replica_0_fold_0', 'outer_stratified_target_replica_0_fold_1_inner_stratified_target_replica_0_fold_1']
Accuracy Test: 0.8664148351648352
Accuracy Val: 0.8681318681318682
MCC Test: 0.7315083611402093
MCC Val: 0.7345117219313307
Total CLS time: 00:00:00
Accuracy Test: 0.9217032967032966
Accuracy Val: 0.904532967032967
MCC Test: 0.8430540660769534
MCC Val: 0.8082364605363238
Total CLS time: 00:00:01
Accuracy Test: 0.8966346153846154
Accuracy Val: 0.8866758241758241
MCC Test: 0.7920470203546572
MCC Val: 0.7718582256466936
Total CLS time: 00:00:00
Accuracy Test: 0.9409340659340659
Accuracy Val: 0.9361263736263

In [14]:
summary = cf.results_summary(results,by_test_fold=True)
summary

classifier,classifier_hp,input_features_name,r_outer,r_inner,mcc_train,mcc_val,mcc_test,accuracy_train,accuracy_val,accuracy_test,n
str,str,str,i64,i64,f64,f64,f64,f64,f64,f64,u32
"""RandomForest""","""n_estimators_5…","""conversations_…",0,0,1.0,0.872004,0.879111,1.0,0.936298,0.939904,4
"""RandomForest""","""n_estimators_5…","""conversations_…",0,0,0.999655,0.868678,0.872596,0.999828,0.934581,0.936641,4
"""RandomForest""","""n_estimators_5…","""conversations_…",0,0,0.998964,0.836953,0.838898,0.999485,0.918784,0.919815,4
"""RandomForest""","""n_estimators_1…","""conversations_…",0,0,0.934791,0.772683,0.772037,0.967548,0.886676,0.88659,4
"""RandomForest""","""n_estimators_1…","""conversations_…",0,0,0.91648,0.755234,0.75201,0.958448,0.878262,0.876631,4
"""RandomForest""","""n_estimators_1…","""conversations_…",0,0,0.901924,0.750729,0.724091,0.951236,0.87603,0.862895,4


In [13]:
for res in summary.rows(named=True):
    print(res["input_features_name"],res["classifier_hp"],res["mcc_val"])

conversations_text-embedding-3-large_embeddings n_estimators_50_max_depth_10 0.8720044573894509
conversations_text-embedding-3-small_embeddings_conversations_text-embedding-3-large_embeddings n_estimators_50_max_depth_10 0.8686784946921543
conversations_text-embedding-3-small_embeddings n_estimators_50_max_depth_10 0.836953363619137
conversations_text-embedding-3-large_embeddings n_estimators_10_max_depth_5 0.7726826117399578
conversations_text-embedding-3-small_embeddings_conversations_text-embedding-3-large_embeddings n_estimators_10_max_depth_5 0.7552342209767535
conversations_text-embedding-3-small_embeddings n_estimators_10_max_depth_5 0.7507290804227095
