In [31]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split


df =   pd.read_csv(
    "/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/mmlu_masj_education_levels.tsv",
    sep="\t",
    header=0, 
    escapechar="\\",
)


df = df[df['masj_rating'] >= 9]
df

Unnamed: 0,src,answer,options,category,question,cot_content,question_id,answer_index,total_tokens,meta_cluster,base_cluster,masj_complexity,masj_rating
0,ori_mmlu-jurisprudence,C,['There is no distinction between the two form...,law,Which of the following criticisms of Llewellyn...,,1286,2,81,Legal Interpretation,Legal Theory Interpretations,graduate,9.0
1,ori_mmlu-international_law,E,"['Article 19', 'Article 11', 'Article 12', 'Ar...",law,Which of the following articles are not qualif...,,1293,4,38,Legal Interpretation,Constitutional Law,undergraduate,9.0
2,ori_mmlu-management,D,"['Work delegation', 'Workload balancing', 'Wor...",business,As what is ensuring that one individual does n...,,83,3,49,Economics & Finance MCQs,Business & Marketing Queries,high_school_and_easier,9.0
3,stemez-Business,J,"['$308.25', '$142.75', '$199.99', '$225.85', '...",business,Margaret Denault recently rented a truck to dr...,,94,9,118,Economics & Finance MCQs,Business Finance Questions,high_school_and_easier,10.0
4,stemez-Business,I,"['$60,000', '$43,200', '$1,794', '$25,000', '$...",business,The tax rate in the town of Centerville is 11(...,,104,8,102,Economics & Finance MCQs,Business Finance Questions,high_school_and_easier,9.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
12027,ori_mmlu-high_school_macroeconomics,F,['Higher interest rates that result from borro...,economics,"The ""crowding-out"" effect refers to which of t...",,7681,5,150,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9.0
12028,ori_mmlu-high_school_macroeconomics,A,['Lower reserve requirements; lower the discou...,economics,Which of the following lists contains only Fed...,,7683,0,124,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9.0
12029,ori_mmlu-high_school_macroeconomics,I,['The productivity of labor in country X is 75...,economics,Output in country X is 30000 units and there a...,,7684,8,206,Economics & Finance MCQs,Economic Concepts & Policies,high_school_and_easier,9.0
12030,ori_mmlu-high_school_macroeconomics,B,"['an increase in net exports', 'a decrease in ...",economics,A use of easy money (expansionary) policy by t...,,7685,1,58,Economics & Finance MCQs,Economic Concepts & Policies,undergraduate,9.0


In [32]:
df.loc[df['masj_complexity'] == 'graduate', 'masj_complexity'] = 'graduate_and_postgraduate'
df.loc[df['masj_complexity'] == 'postgraduate', 'masj_complexity'] = 'graduate_and_postgraduate'
df['masj_complexity'].value_counts()

masj_complexity
undergraduate                6736
graduate_and_postgraduate    2453
high_school_and_easier       2378
Name: count, dtype: int64

In [33]:

train_valid_df, test_df = train_test_split(df, test_size=0.1, random_state=42)
test_df.to_csv("/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/test_combined_masj.tsv", sep="\t", index=False)
print(f"Тестовый датасет: {len(test_df)} примеров сохранён в 'test.tsv'.")

train_valid_df = train_valid_df.sort_values(by="masj_complexity", ascending=False)
N = len(train_valid_df)
print(f"Всего обучающих+валидационных примеров: {N}")

def split_and_save_data(
    df,
    complexity_col='masj_complexity',
    thresholds=[('undergraduate'), ('graduate_and_postgraduate'), ('high_school_and_easier')],
    test_size=0.1,
    random_state=42,
    output_prefix=""
    ):
    """
    Разделяет данные по уровням сложности.
    """
    for suffix in thresholds:
        if suffix == 'undergraduate':
            filtered = df[df[complexity_col] == suffix].sample(1000,random_state=42)
        elif suffix == 'high_school_and_easier':
            filtered = df[df[complexity_col] == suffix].sample(1000,random_state=42)
        elif suffix == 'graduate_and_postgraduate':
            filtered = df[(df[complexity_col] == suffix)].sample(1000,random_state=42)
        
        filtered = filtered.reset_index(drop=True)
        
        train, valid = train_test_split(filtered, test_size=test_size, random_state=random_state)
        
        train.to_csv(f"{output_prefix}train_df_{suffix}.tsv", sep='\t', index=False)
        valid.to_csv(f"{output_prefix}valid_df_{suffix}.tsv", sep='\t', index=False)

split_and_save_data(train_valid_df, complexity_col='masj_complexity',output_prefix='/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/')

Тестовый датасет: 1157 примеров сохранён в 'test.tsv'.
Всего обучающих+валидационных примеров: 10410


In [40]:
test = pd.read_csv("/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/test_combined_entr.tsv", sep="\t")
test['masj_complexity'].value_counts()

masj_complexity
undergraduate                686
graduate_and_postgraduate    238
high_school_and_easier       233
Name: count, dtype: int64

In [41]:
test = pd.read_csv("/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/test_combined_masj.tsv", sep="\t").dropna(subset='masj_complexity')


class_count = test['masj_complexity'].value_counts()
min_count =class_count.min()
balanced_dfs = []
for classe in test['masj_complexity'].dropna().unique():
    class_df = test[test['masj_complexity'] == classe]
    sampled_class_df = class_df.sample(n=min_count,random_state=42)
    balanced_dfs.append(sampled_class_df)
    
balanced_df = pd.concat(balanced_dfs)
balanced_df.to_csv("/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/entropy_full_split/test_balanced_combined_masj.tsv", sep="\t", index=False)
    

In [43]:
balanced_df['masj_complexity'].value_counts()

masj_complexity
graduate_and_postgraduate    233
undergraduate                233
high_school_and_easier       233
Name: count, dtype: int64

In [44]:
random_train_valid_df = train_valid_df.sample(1000,random_state=42)
train, valid = train_test_split(random_train_valid_df, test_size=0.1, random_state=42)

In [45]:
train.to_csv(f"/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/train_df_random.tsv", sep='\t', index=False)
valid.to_csv(f"/mnt/data202/PERSONAL/VYAZHEV/PROJECT_1/full_masj_split/valid_df_random.tsv", sep='\t', index=False)