In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import os
import shutil
import numpy as np
import pandas as pd


from data_processing_sherlock import DataProcessingSherlock

In [3]:
TABLE_FOLDER_PATH = "../tables"
GROUND_TRUTH_PATH   = "../gold/cta_gt.csv"
LANGUAGE_METADATA_PATH = "../gold/language_metadata.csv"
OUTPUT_PATH = "../sherlock_data_processing"

### Ground truth label processing
Modifying ground truth labels to one standard. Removing labels that counts is less than 7 since they are creating noise and could not be correctly disrtibuted to the train/test/val

In [4]:
LABEL_MAP = {
    # Date
    "Vaccination_date": "Date",
    "Date_report":"Date",
    "Date_onset":  "Date",
    "Date_confirmation": "Date",
    "Date_of_first_consultation":"Date",
    "Date_hospitalisation":  "Date",
    "Date_discharge_hospital": "Date",
    "Date_admission_ICU":   "Date",
    "Date_discharge_ICU":  "Date",
    "Date_isolation":  "Date",
    "Date_death":  "Date",
    "Date_recovered":  "Date",
    "Travel_history_entry": "Date",
    "Travel_history_start":  "Date",
    "Date_entry":  "Date",
    "Date_last_modified": "Date",

    # ID
    "Contact_ID": "ID",
    "ID": "ID",

    #Gender
    "Gender": "Gender",
    "Sex_at_birth": "Gender",
    "Gender_other": "Gender",
    "Sex_at_birth_other": "Gender",

    #Location
    "Travel_history_location": "Location",
    "Location_information": "Location",

    # Contact setting
    "Contact_setting": "Contact_setting",
    "Contact_setting_other": "Contact_setting",

    # demographic
    "Race": "Demographic",
    "Ehtnicity": "Demographic",

    # Medical Boolean
    "Healthcare_worker": "Medical_boolean",
    "Previous_infection": "Medical_boolean",
    "Pregnancy_Status": "Medical_boolean",
    "Vaccination":  "Medical_boolean",
    "Hospitalised":  "Medical_boolean",
    "Intensive_care":  "Medical_boolean",
    "Home_monitoring":  "Medical_boolean",
    "Isolated": "Medical_boolean",
    "Contact_with_case": "Medical_boolean",
    "Travel_history": "Medical_boolean",

    # Sourec
    "Source": "Source",
    "Source_II": "Source",
    "Source_III": "Source",
    "Source_IV": "Source",
}

In [5]:
LABEL_MAP_LC = {k.lower(): v.lower() for k, v in LABEL_MAP.items()}

def remap_labels(arr, mapping=LABEL_MAP_LC):
    """
    • forces each element to lower-case
    • replaces it if the key exists in `mapping`
    • otherwise leaves it as lower-case original
    """
    return np.array([mapping.get(x.lower(), x.lower()) for x in arr])

In [7]:
data = pd.read_csv("../gold/cta_gt.csv")

data['label_mapped'] = remap_labels(data['label'].values)

label_counts = data['label_mapped'].value_counts()

data = data.drop('label', axis=1)

# 4. find all labels with count < 7
small_labels = label_counts[label_counts < 7].index

# 5. set those labels to '__none__'
data.loc[data['label_mapped'].isin(small_labels), 'label_mapped'] = '__none__'

data = data.rename(columns={'label_mapped': 'label'})

data.to_csv("../gold/cta_gt.csv", index=False)

print(data["label"].value_counts())

label
__none__                  1185
date                       126
symptoms                   125
medical_boolean             66
location                    55
id                          53
case_status                 27
age                         24
gender                      23
pre_existing_condition      22
outcome                     20
occupation                  12
contact_setting             11
Name: count, dtype: int64


In [8]:
def data_cleaning(data):
    """
    Clean the data in dataframes, removes hashes tha was broken during processing

    :param data: dataframe to clean
    :return: cleaned dataframe
    """

    for col in data.select_dtypes(include="object"):

        # make a bool mask of rows where the cell contains 'x000D'
        mask = data[col].str.contains("x000D", na=False)
        data.loc[mask, col] = np.nan

    return data

In [9]:
data_processing_sherlock = DataProcessingSherlock()

if os.path.exists(OUTPUT_PATH):
    shutil.rmtree(OUTPUT_PATH)

for filename in os.listdir(TABLE_FOLDER_PATH):
    if not filename.lower().endswith(".csv"):
        continue

    table_csv_path = os.path.join(TABLE_FOLDER_PATH, filename)
    table_name = filename  # matches the 'table_name' column in GT

    try:
        # Example usage:
        data, labels = data_processing_sherlock.load_table_with_labels(
            table_csv_path=TABLE_FOLDER_PATH + "/" + filename,
            gt_csv_path   =GROUND_TRUTH_PATH,
            table_name    =table_name
        )

        #cleaned_data = data_cleaning(data)

        data_processing_sherlock.flatten_and_save(data, labels, OUTPUT_PATH, table_name, LANGUAGE_METADATA_PATH,
                                                  train_ratio=0.6, val_ratio=0.2, test_ratio=0.2)

        #print(f"Processed {filename} -> {data_path}")
    except Exception as e:
        print(f"Skipping {filename}: {e}")

Combined data length: 20
Combined labels length: 20
Combined lang length: 20
Combined data length: 26
Combined labels length: 26
Combined lang length: 26
Combined data length: 56
Combined labels length: 56
Combined lang length: 56
Combined data length: 73
Combined labels length: 73
Combined lang length: 73
Combined data length: 89
Combined labels length: 89
Combined lang length: 89
Combined data length: 118
Combined labels length: 118
Combined lang length: 118
Combined data length: 150
Combined labels length: 150
Combined lang length: 150
Combined data length: 165
Combined labels length: 165
Combined lang length: 165
Combined data length: 169
Combined labels length: 169
Combined lang length: 169
Combined data length: 219
Combined labels length: 219
Combined lang length: 219
Combined data length: 221
Combined labels length: 221
Combined lang length: 221
Combined data length: 223
Combined labels length: 223
Combined lang length: 223
Combined data length: 231
Combined labels length: 231
C

In [10]:
import pyarrow.parquet as pq
tbl = pq.read_table("../sherlock_data_processing/data.parquet")
print(tbl.schema)

__index_level_0__: int64
values: string


In [11]:
for row in tbl.column("values"):
    print(row.as_py(), type(row.as_py()))

e_4800uq0,046_q0u0e,ue_009q06,u_018q0e0,qe4_0060u,1_0q002eu,0u040_7eq,0ueq0_202,80u_00qe6,_q97eu000,0u_e000q8,004qu_0e7,u1_060e0q,25q0eu0_0,_900qe40u,q408e0_u0,6u000q0_e,u190q0e1_,1e0_u00q1,4e00_0u1q,0_u00q0e5,_00eq200u,5p0000_br,010_u7eq0,e1_0uq206,05u0e0_q6,010u2e_5q,ue80_70q0,80q_e110u,0uq0e07_0,e0800uq_1,_ue080q01,11qu_e010,_q102eu40,_301q0eu0,e0003q_u3,6_0u0q30e,u_003eq90,00u_e0q49,90_0e9qu0,000q_32ue,00uq5e_70,u_q0310e1,00q660e_u,0_9ue0q02,_qe83000u,e0u14_1q0,0u003e_q9,02_050qeu,805q0u0_e,070eq20u_,e90u8q0_0,uqe0_0060,q80_008ue,q2000eu0_,0u70eq06_,5_u0040qe,1003q0_ue,q00e400u_,e010_q90u,09eu_000q,0e700_qu0,2eq10u_00,_0u0qe404,e680u0_0q,0q57u0e_0,eq0050_u3,1q_e10u02,069e_0uq0,1_1uq000e,q090u_0e0,uq300e0_1,00u0e_q30,q3e200u0_,00u0q_16e,_u0qe0127,_e0qu1500,026_q0e0u,q20eu101_,u00_eq001 <class 'str'>
0C101DE5RTQ2,E15R0T0CQ52D,RC18E0T03Q2D,03QRTED050C0,02RE0QT1C4D1,0D3ERT002Q0C,T0CQDE00100R,RC0E2T7Q01D2,6E80DC09QR0T,0QR3T0D0E1C0,RD3TE012QC01,Q012260DRECT,101RE27QTDC0,QE32CDR0T106,72Q5

In [12]:
df = tbl.to_pandas()
print(df)

     __index_level_0__                                             values
0                    0  e_4800uq0,046_q0u0e,ue_009q06,u_018q0e0,qe4_00...
1                    1  0C101DE5RTQ2,E15R0T0CQ52D,RC18E0T03Q2D,03QRTED...
2                    2  bik094,mbk228,bik074,bik001,mbk461,mbk055,mbk2...
3                    3  89T1R7D9QK010BCME06--,410030ER-0QKBDCT2-6R,4CD...
4                    4  45,4,60,65,22,8,30,19,2,46,18,3,35,54,33,85,31...
..                 ...                                                ...
526                526  0R46-B4K000310,R020-BK4004610,970B09000-R0K0,-...
527                527  44,64,70,31,2S,8,63,1992,16,9M,39,20,15,57,77,...
528                528                                          m,f,A,F,M
529                529                                           Confirmé
530                530                    NON,Non Vacciné,VACCINE,Vacciné

[531 rows x 2 columns]


In [13]:
py_lists = [s.split(",") for s in tbl.column("values").to_pylist()]
for L in py_lists:
    print(L)

['e_4800uq0', '046_q0u0e', 'ue_009q06', 'u_018q0e0', 'qe4_0060u', '1_0q002eu', '0u040_7eq', '0ueq0_202', '80u_00qe6', '_q97eu000', '0u_e000q8', '004qu_0e7', 'u1_060e0q', '25q0eu0_0', '_900qe40u', 'q408e0_u0', '6u000q0_e', 'u190q0e1_', '1e0_u00q1', '4e00_0u1q', '0_u00q0e5', '_00eq200u', '5p0000_br', '010_u7eq0', 'e1_0uq206', '05u0e0_q6', '010u2e_5q', 'ue80_70q0', '80q_e110u', '0uq0e07_0', 'e0800uq_1', '_ue080q01', '11qu_e010', '_q102eu40', '_301q0eu0', 'e0003q_u3', '6_0u0q30e', 'u_003eq90', '00u_e0q49', '90_0e9qu0', '000q_32ue', '00uq5e_70', 'u_q0310e1', '00q660e_u', '0_9ue0q02', '_qe83000u', 'e0u14_1q0', '0u003e_q9', '02_050qeu', '805q0u0_e', '070eq20u_', 'e90u8q0_0', 'uqe0_0060', 'q80_008ue', 'q2000eu0_', '0u70eq06_', '5_u0040qe', '1003q0_ue', 'q00e400u_', 'e010_q90u', '09eu_000q', '0e700_qu0', '2eq10u_00', '_0u0qe404', 'e680u0_0q', '0q57u0e_0', 'eq0050_u3', '1q_e10u02', '069e_0uq0', '1_1uq000e', 'q090u_0e0', 'uq300e0_1', '00u0e_q30', 'q3e200u0_', '00u0q_16e', '_u0qe0127', '_e0qu1500'

## Labels

In [14]:
import pandas as pd
labels = pd.read_parquet("../sherlock_data_processing/labels.parquet")
print(labels.columns)
print(labels.shape)
print(labels[:5])
print(labels.index.name)

Index(['type'], dtype='object')
(531, 1)
  type
0   id
1   id
2   id
3   id
4  age
None


In [15]:
y_train = np.array([x.lower() for x in labels])
print(y_train)

['type']


### Unique labels

In [16]:
from itertools import count, groupby
import pandas as pd
data = pd.read_csv("../gold/cta_gt.csv")

unique_labels = data["label"].unique()
unique_count = len(unique_labels)

label_count = (data.groupby("label").size().reset_index(name="count").sort_values("count", ascending=False) )

print(label_count)


                     label  count
0                 __none__   1185
4                     date    126
12                symptoms    125
8          medical_boolean     66
7                 location     55
6                       id     53
2              case_status     27
1                      age     24
5                   gender     23
11  pre_existing_condition     22
10                 outcome     20
9               occupation     12
3          contact_setting     11
