In [2]:
import pandas as pd
import numpy as np
from datetime import datetime
from pyarrow.feather import write_feather, read_feather

def convert_to_extension_dtypes(data):
    """
    Convert the data types of a given DataFrame or Series to Pandas extension data types.
    
    :param data: DataFrame or Series
    :return: DataFrame or Series with updated data types
    """
    # Check if the input is a DataFrame or Series
    if not isinstance(data, (pd.DataFrame, pd.Series)):
        raise ValueError("Input must be a pandas DataFrame or Series")

    # Convert data types for DataFrame
    if isinstance(data, pd.DataFrame):
        for col in data.columns:
            col_data = data[col]
            # Convert to appropriate extension dtype
            # if pd.api.types.is_integer_dtype(col_data):
            #     data[col] = col_data.astype("Int64")
            # elif pd.api.types.is_float_dtype(col_data):
            #     data[col] = col_data.astype("Float64")
            
            if pd.api.types.is_string_dtype(col_data):
                data[col] = col_data.astype("string")
            elif pd.api.types.is_bool_dtype(col_data):
                data[col] = col_data.astype("boolean")

    # Convert data type for Series
    else:
        if pd.api.types.is_integer_dtype(data):
            data = data.astype("Int64")
        elif pd.api.types.is_float_dtype(data):
            data = data.astype("Float64")
        elif pd.api.types.is_string_dtype(data):
            data = data.astype("string")
        elif pd.api.types.is_bool_dtype(data):
            data = data.astype("boolean")

    return data


In [None]:
def get_obj_mni():
    '''get the object-level MNI
    Returns:
        obj_mni: a DataFrame that contains the mni of each object

    '''
    obj_mni_k10_kick = read_feather(
        f"data/obj_mni_kick_k10_p{int(score_threshold*100)}.feather"
    )
    obj_mni_k25_kick = read_feather(
        f"data/obj_mni_kick_k25_p{int(score_threshold*100)}.feather"
    )
    obj_mni_k50_kick = read_feather(
        f"data/obj_mni_kick_k50_p{int(score_threshold*100)}.feather"
    )
    obj_mni_k100_kick = read_feather(
        f"data/obj_mni_kick_k100_p{int(score_threshold*100)}.feather"
    )

    # load MNI based on V3D (for V3D we only use one threshold, 0.5)
    obj_mni_k10_v3d = read_feather("data/obj_mni_v3d_k10_p50.feather")
    obj_mni_k25_v3d = read_feather("data/obj_mni_v3d_k25_p50.feather")
    obj_mni_k50_v3d = read_feather("data/obj_mni_v3d_k50_p50.feather")
    obj_mni_k100_v3d = read_feather("data/obj_mni_v3d_k100_p50.feather")

    # rename columns in obj_mni_kick
    obj_mni_k10_kick = obj_mni_k10_kick.rename(columns={"mni": "mni_k10_kick"})
    obj_mni_k25_kick = obj_mni_k25_kick.rename(columns={"mni": "mni_k25_kick"})
    obj_mni_k50_kick = obj_mni_k50_kick.rename(columns={"mni": "mni_k50_kick"})
    obj_mni_k100_kick = obj_mni_k100_kick.rename(columns={"mni": "mni_k100_kick"})

    # rename columns in obj_mni_v3d
    obj_mni_k10_v3d = obj_mni_k10_v3d.rename(columns={"mni": "mni_k10_v3d"})
    obj_mni_k25_v3d = obj_mni_k25_v3d.rename(columns={"mni": "mni_k25_v3d"})
    obj_mni_k50_v3d = obj_mni_k50_v3d.rename(columns={"mni": "mni_k50_v3d"})
    obj_mni_k100_v3d = obj_mni_k100_v3d.rename(columns={"mni": "mni_k100_v3d"})

    # merge all MNI datasets
    obj_mni = (
        obj_mni_k10_kick.merge(obj_mni_k25_kick, on="obj", how="inner")
        .merge(obj_mni_k50_kick, on="obj", how="inner")
        .merge(obj_mni_k100_kick, on="obj", how="inner")
        .merge(obj_mni_k10_v3d, on="obj", how="inner")
        .merge(obj_mni_k25_v3d, on="obj", how="inner")
        .merge(obj_mni_k50_v3d, on="obj", how="inner")
        .merge(obj_mni_k100_v3d, on="obj", how="inner")
    )

    return obj_mni

In [None]:
x = pd.DataFrame({'a': [1.234, 2.3234234234, 3.123341312], 'b': [4, 5, 6]})
x

x.round(2)