# Getting started with OpenAssistant OASST1 data

- https://huggingface.co/datasets/OpenAssistant/oasst1

This Notebook is based on https://github.com/LAION-AI/Open-Assistant/blob/main/notebooks/openassistant-oasst1/getting-started.ipynb

## Imports

In [1]:
import pandas as pd
from datasets import load_dataset
from treelib import Tree

# set some pandas options to make the output more readable
pd.set_option("display.max_rows", 500)
pd.set_option("display.max_columns", 500)
pd.set_option("display.width", 1000)


def add_tree_level(df):
    """helper function to add tree level to a df"""

    # if tree level already exists, return df
    if "tree_level" in df.columns:
        return df

    else:
        tree_level_map = {}

        # iterate over rows in df
        for i, row in df.iterrows():
            message_id = row["message_id"]
            parent_id = row["parent_id"]

            # if parent_id is None, then it is a root message
            if parent_id is None:
                tree_level_map[message_id] = 0
            # if parent_id is the same as message_tree_id, then it is a direct reply to the root message
            elif parent_id == row["message_tree_id"]:
                tree_level_map[message_id] = 1
            # else just look up the tree level of the parent_id and add 1
            else:
                tree_level_map[message_id] = tree_level_map[parent_id] + 1

        # create a df from the tree_level_map and merge it with the original df
        df_tree_level_map = (
            pd.DataFrame.from_dict(tree_level_map, orient="index", columns=["tree_level"])
            .reset_index()
            .rename(columns={"index": "message_id"})
        )

        return df.merge(df_tree_level_map, on="message_id")

## Load Data

In [2]:
# load dataset from huggingface datasets
ds = load_dataset("OpenAssistant/oasst1")
print(ds)

Found cached dataset parquet (C:/Users/timon/.cache/huggingface/datasets/OpenAssistant___parquet/OpenAssistant--oasst1-2960c57d7e52ab15/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    validation: Dataset({
        features: ['message_id', 'parent_id', 'user_id', 'created_date', 'text', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels'],
        num_rows: 4401
    })
    train: Dataset({
        features: ['message_id', 'parent_id', 'user_id', 'created_date', 'text', 'role', 'lang', 'review_count', 'review_result', 'deleted', 'rank', 'synthetic', 'model_name', 'detoxify', 'message_tree_id', 'tree_state', 'emojis', 'labels'],
        num_rows: 84437
    })
})


## Create Pandas Dataframe

In [3]:
# lets convert the train dataset to a pandas df
df = ds["train"].to_pandas()

In [4]:
# look at the df info
df.info(verbose=True, memory_usage=True, show_counts=True)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 84437 entries, 0 to 84436
Data columns (total 18 columns):
 #   Column           Non-Null Count  Dtype  
---  ------           --------------  -----  
 0   message_id       84437 non-null  object 
 1   parent_id        74591 non-null  object 
 2   user_id          84437 non-null  object 
 3   created_date     84437 non-null  object 
 4   text             84437 non-null  object 
 5   role             84437 non-null  object 
 6   lang             84437 non-null  object 
 7   review_count     84437 non-null  int32  
 8   review_result    83732 non-null  object 
 9   deleted          84437 non-null  bool   
 10  rank             48730 non-null  float64
 11  synthetic        84437 non-null  bool   
 12  model_name       0 non-null      object 
 13  detoxify         72297 non-null  object 
 14  message_tree_id  84437 non-null  object 
 15  tree_state       84437 non-null  object 
 16  emojis           71496 non-null  object 
 17  labels      

In [5]:
# look at a sample row in a json format we can easily read
df.sample(1).transpose().to_dict()

{5204: {'message_id': 'e0635603-09b2-4247-accb-104c1a9bcb4f',
  'parent_id': None,
  'user_id': '951513be-56c0-4630-88a2-9eeb3fb9e33b',
  'created_date': '2023-02-09T10:19:58.065721+00:00',
  'text': 'How do I get more yield from chili pepper plant?',
  'role': 'prompter',
  'lang': 'en',
  'review_count': 3,
  'review_result': True,
  'deleted': False,
  'rank': nan,
  'synthetic': False,
  'model_name': None,
  'detoxify': {'toxicity': 0.0004722727171611041,
   'severe_toxicity': 2.8104701414122246e-05,
   'obscene': 0.00013236129598226398,
   'identity_attack': 0.00012626768148038536,
   'insult': 0.00017111326451413333,
   'threat': 5.790913928649388e-05,
   'sexual_explicit': 2.3751312255626544e-05},
  'message_tree_id': 'e0635603-09b2-4247-accb-104c1a9bcb4f',
  'tree_state': 'ready_for_export',
  'emojis': {'name': array(['+1', '_skip_reply', '_skip_ranking'], dtype=object),
   'count': array([10,  6,  1])},
  'labels': {'name': array(['spam', 'lang_mismatch', 'pii', 'not_appropr

## Random Message Tree

In [6]:
# lets grab a random message tree
message_tree_id = df["message_tree_id"].sample(1).values[0]
print(message_tree_id)

4aa560ff-6c26-484d-b2e5-9dd5d7fd376a


In [7]:
# look at all data for this message tree
df_message_tree = df.query(f"message_tree_id == '{message_tree_id}'").sort_values("created_date")

# add tree level to df
df_message_tree = add_tree_level(df_message_tree)

df_message_tree

Unnamed: 0,message_id,parent_id,user_id,created_date,text,role,lang,review_count,review_result,deleted,rank,synthetic,model_name,detoxify,message_tree_id,tree_state,emojis,labels,tree_level
0,4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,,3521d9bd-b5ec-45c8-94f3-2714e12255d0,2023-02-06T23:43:08.281289+00:00,Опиши алгоритм расчёта ядерной регрессии Надар...,prompter,ru,3,True,False,,False,,"{'toxicity': 0.008922765962779522, 'severe_tox...",4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,ready_for_export,"{'name': ['+1', '_skip_reply', '_skip_ranking'...","{'name': ['spam', 'lang_mismatch', 'pii', 'not...",0
1,397d5d07-447f-4d83-bd63-028444e91919,4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,97894afc-2204-4c7b-9b91-f1c71c10f032,2023-02-11T00:28:19.535048+00:00,Алгоритм ядерной регрессии Надарая-Уотсона - э...,assistant,ru,3,True,False,0.0,False,,"{'toxicity': 0.0006523863994516432, 'severe_to...",4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,ready_for_export,"{'name': ['+1', '_skip_labeling'], 'count': [1...","{'name': ['spam', 'fails_task', 'lang_mismatch...",1
2,7bd2cbb7-fdad-45c0-88d3-70bf2fe68a29,4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,d8f26101-f9a7-4bd6-bcc7-0667463c7787,2023-02-12T09:29:39.849548+00:00,Ядерная регрессия (калька с англ. “kernel regr...,assistant,ru,3,True,False,2.0,False,,"{'toxicity': 0.0963350236415863, 'severe_toxic...",4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,ready_for_export,"{'name': ['+1', '_skip_labeling'], 'count': [1...","{'name': ['spam', 'fails_task', 'lang_mismatch...",1
3,7468092d-61a4-4087-b7ee-40c868d8c9e0,4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,54156dd1-73ca-46dc-9207-75124a981566,2023-03-05T12:23:22.257298+00:00,Алгоритм расчета ядерной регрессии Надарая-Уот...,assistant,ru,3,True,False,1.0,False,,"{'toxicity': 0.000361143407644704, 'severe_tox...",4aa560ff-6c26-484d-b2e5-9dd5d7fd376a,ready_for_export,"{'name': ['_skip_labeling'], 'count': [2]}","{'name': ['spam', 'fails_task', 'lang_mismatch...",1


## Create Message Tree

In [8]:
# lets create a tree of message ids
id_tree = Tree()
# lets create a tree of message texts
text_tree = Tree()
# lets set a max char length for the text
max_char_len = 100

# iterate over rows in df_message_tree
for i, row in df_message_tree.iterrows():
    # grab the message_id, parent_id, text, and parent text
    message_id = row["message_id"]
    parent_id = row["parent_id"]
    text = row["text"]
    text_short = text[:max_char_len] if len(text) > max_char_len else text
    text_short = text_short.replace("\n", " ")
    parent_text = (
        df_message_tree.query(f"message_id == '{parent_id}'")["text"].values[0] if parent_id is not None else "ROOT"
    )
    parent_text_short = parent_text[:max_char_len] if len(parent_text) > max_char_len else parent_text
    parent_text_short = parent_text_short.replace("\n", " ")

    # create a node in the id_tree and text_tree, add row as data in case want it later
    id_tree.create_node(message_id, message_id, parent=parent_id, data=row.to_dict())

    # if parent_id is None, then it is a root message so dont add parent text as is none
    if parent_id is None:
        text_tree.create_node(text_short, text_short)
    # else use the parent text short as the parent
    else:
        text_tree.create_node(text_short, text_short, parent=parent_text_short)


print("id_tree:")
id_tree.show()

print("text_tree:")
text_tree.show()

id_tree:
4aa560ff-6c26-484d-b2e5-9dd5d7fd376a
├── 397d5d07-447f-4d83-bd63-028444e91919
├── 7468092d-61a4-4087-b7ee-40c868d8c9e0
└── 7bd2cbb7-fdad-45c0-88d3-70bf2fe68a29

text_tree:
Опиши алгоритм расчёта ядерной регрессии Надарая — Уотсона
├── Алгоритм расчета ядерной регрессии Надарая-Уотсона включает в себя следующие шаги:  Загрузка данных:
├── Алгоритм ядерной регрессии Надарая-Уотсона - это непараметрический метод оценки функции регрессии пу
└── Ядерная регрессия (калька с англ. “kernel regression”) — непараметрический статистический метод, поз



In [16]:
print(df['synthetic'].value_counts(), end='\n\n')
print(df['model_name'].value_counts(), end='\n\n')
print(df['tree_state'].value_counts(), end='\n\n')
print(df['deleted'].value_counts(), end='\n\n')
print(df['rank'].value_counts(), end='\n\n')
print(df['lang'].value_counts(), end='\n\n')

False    84437
Name: synthetic, dtype: int64

Series([], Name: model_name, dtype: int64)

ready_for_export    84437
Name: tree_state, dtype: int64

False    82952
True      1485
Name: deleted, dtype: int64

0.0     17972
1.0     17971
2.0     11463
3.0       963
4.0       234
5.0        72
6.0        27
7.0        13
8.0         6
9.0         3
10.0        1
11.0        1
12.0        1
13.0        1
14.0        1
15.0        1
Name: rank, dtype: int64

en       39283
es       22763
ru        7242
zh        3314
de        3050
fr        2474
th        1460
pt-BR     1165
ca        1158
uk-UA      587
it         554
ja         363
pl         304
eu         250
vi         191
hu          75
ar          56
da          44
tr          37
ko          24
fi          18
id          12
cs          12
sv           1
Name: lang, dtype: int64

