In [1]:
import pandas as pd

df = pd.read_csv('/scratch/x3100a06/tRP/Data/train_source.csv')
df['Substance'] = df['Compound']
del df['Compound']
del df['Source'], df['Type']
del df['criterion'], df['t_RP'], df['p_Label']
del df['cnA1'], df['cnA2'], df['cnB1'], df['cnB2'], df['cnX']

: 

In [2]:
df.loc[df['Substance'] == 'Rb2MnI2Cl2', 'Substance'] = 'Rb2MnCl2I2'
df.loc[df['Substance'] == 'Rb2MnCl2I2', 'X1'] = 'Cl'
df.loc[df['Substance'] == 'Rb2MnCl2I2', 'X2'] = 'I'

In [3]:
fold_1 = pd.read_csv('/scratch/x3100a06/tRP/SISSO/run/fold_1/split_1.csv')
fold_2 = pd.read_csv('/scratch/x3100a06/tRP/SISSO/run/fold_2/split_2.csv')
fold_3 = pd.read_csv('/scratch/x3100a06/tRP/SISSO/run/fold_3/split_3.csv')
fold_4 = pd.read_csv('/scratch/x3100a06/tRP/SISSO/run/fold_4/split_4.csv')
fold_5 = pd.read_csv('/scratch/x3100a06/tRP/SISSO/run/fold_5/split_5.csv')

In [4]:
len(df), len(fold_1), len(fold_2), len(fold_3), len(fold_4), len(fold_5)

(476, 476, 476, 476, 476, 476)

In [5]:
import pandas as pd

def merge_df(fold_df, index):
    try:
        del fold_df['Type']
    except KeyError:
        pass

    # 병합 기준 컬럼 설정
    common_keys = ['Substance', 'Label', 'n', 'A1', 'A2', 'B1', 'B2', 'X1', 'X2']

    # 데이터 타입 변환 (문자열과 숫자 구분)
    for col in common_keys:
        if col in ['Substance', 'A1', 'A2', 'B1', 'B2', 'X1', 'X2']:
            df[col] = df[col].astype(str).str.strip()
            fold_df[col] = fold_df[col].astype(str).str.strip()
        else:
            df[col] = round(df[col].astype(float), 5)
            fold_df[col] = round(fold_df[col].astype(float), 5)

    # 병합 수행 (outer join)
    merge_fold = df.merge(fold_df, on=common_keys, how='outer', suffixes=('', '_fold'))

    # fold_df에 있는 값들을 우선적으로 사용
    for col in fold_df.columns:
        if col not in common_keys:  # 기준 컬럼 제외
            if col in merge_fold.columns and f"{col}_fold" in merge_fold.columns:
                merge_fold[col] = merge_fold[f"{col}_fold"].combine_first(merge_fold[col])
                merge_fold.drop(columns=[f"{col}_fold"], inplace=True)

    # fold_X 값이 존재하는 행만 유지
    merge_fold = merge_fold[merge_fold[f'fold_{index}'].notna()].reset_index(drop=True)

    # 검증: `Substance` 기준으로 데이터 일관성 확인
    fold_df_comp = sorted(set(fold_df['Substance'].to_list()))
    merge_df_comp = sorted(set(merge_fold['Substance'].to_list()))

    if fold_df_comp != merge_df_comp:
        print('Error: Mismatch in Substance values')
        print(f'fold_df_comp: {fold_df_comp}')
        print(f'merge_df_comp: {merge_df_comp}')

    # `train` 컬럼 제거 (존재하면)
    if 'train' in merge_fold.columns:
        del merge_fold['train']

    # fold_{index} 값이 'train'인 행만 유지
    merge_fold = merge_fold[merge_fold[f'fold_{index}'] == 'train'].reset_index(drop=True)
    fold_df = fold_df[fold_df[f'fold_{index}']=='train'].reset_index(drop=True)

    # fold_df의 Substance 순서를 유지하도록 정렬
    merge_fold = merge_fold.set_index('Substance').loc[fold_df['Substance']].reset_index()

    # 최종 저장할 컬럼 리스트 설정 (df와 fold_df에서 자동으로 가져옴)
    final_columns = ['Substance', 'rA', 'rB', 'rX', 'nA', 'nB', 'nX', 'xX']

    # NaN 값 검출 및 오류 발생
    if merge_fold[final_columns].isna().any().any():
        nan_rows = merge_fold[merge_fold[final_columns].isna().any(axis=1)]
        print("Error: NaN values detected before saving:")
        print(nan_rows)
        raise ValueError("NaN values found in data. Check the above rows before proceeding.")

    # 데이터 저장
    output_path = f'/scratch/x3100a06/tRP/SISSO/run/fold_{index}/train_charge.dat'
    merge_fold[final_columns].to_csv(output_path, sep=" ", index=False)

    return merge_fold

In [6]:
fold1 = merge_df(fold_1, 1)
fold2 = merge_df(fold_2, 2)
fold3 = merge_df(fold_3, 3)
fold4 = merge_df(fold_4, 4)
fold5 = merge_df(fold_5, 5)

In [7]:
fold1

Unnamed: 0,Substance,Label,n,A1,A2,B1,B2,X1,X2,rA,...,xA,xB,xX,iA,iB,iX,vA,vB,vX,fold_1
0,Cs2AgF4,1.0,1.0,Cs,Cs,Ag,Ag,F,F,1.780,...,0.154213,0.147217,0.344443,3.893906,7.576234,17.422820,1.0,11.0,7.0,train
1,Sr2RuO4,1.0,1.0,Sr,Sr,Ru,Ru,O,O,1.310,...,0.118508,0.137649,0.304575,5.694867,7.360500,13.618054,2.0,8.0,6.0,train
2,SrSmCrO4,1.0,1.0,Sr,Sm,Cr,Cr,O,O,1.221,...,0.151988,0.131305,0.304575,5.669289,6.766510,13.618054,5.0,6.0,6.0,train
3,Rb2ZnF4,1.0,1.0,Rb,Rb,Zn,Zn,F,F,1.630,...,0.104686,0.155152,0.344443,4.177128,9.394199,17.422820,1.0,12.0,7.0,train
4,Rb2NiF4,1.0,1.0,Rb,Rb,Ni,Ni,F,F,1.630,...,0.104686,0.147207,0.344443,4.177128,7.639877,17.422820,1.0,10.0,7.0,train
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
375,Zn2ZrO4,0.0,1.0,Zn,Zn,Zr,Zr,O,O,0.900,...,0.155152,0.124889,0.304575,9.394199,6.633900,13.618054,12.0,4.0,6.0,train
376,Zn2TiO4,0.0,1.0,Zn,Zn,Ti,Ti,O,O,0.900,...,0.155152,0.123364,0.304575,9.394199,6.828120,13.618054,12.0,4.0,6.0,train
377,LiMgAsO4,0.0,1.0,Li,Mg,As,As,O,O,0.905,...,0.113368,0.206821,0.304575,6.518975,9.789000,13.618054,1.5,5.0,6.0,train
378,Fe2GeS4,0.0,1.0,Fe,Fe,Ge,Ge,S,S,0.920,...,0.139253,0.189589,0.235960,7.902468,7.899435,10.360010,8.0,4.0,6.0,train
