# **Imports &#8595;**

In [2]:
import numpy as np
import pandas as pd 
import re

from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer

from sentence_transformers import InputExample
from torch.utils.data import DataLoader
from sentence_transformers import models, SentenceTransformer
from sentence_transformers.losses import DenoisingAutoEncoderLoss
from sentence_transformers.datasets import SentencesDataset
from imblearn.over_sampling import RandomOverSampler

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

# **Parameters &#8595;**

In [None]:
# _______________________External Dataset______________________________________
EXTERNAL_DATA = True
# _______________________TSDAE Fine-Tuning_____________________________________
FINE_TUNED_MODEL = False
ALREADY_FINE_TUNED = True
output_path = "output/tsdae-model-math-similarity"
# _______________________Sentence Embeddings/TF-IDF____________________________
EMBEDDINGS = True
# model_name = 'sentence-transformers/all-MiniLM-L6-v2'
# model_name = 'sentence-transformers/all-MiniLM-L12-v2'
model_name = 'hkunlp/instructor-large'
# model_name = "Qwen/Qwen2.5-0.5B-Instruct"  
# model_name = 'math-similarity/Bert-MLM_arXiv-MP-class_zbMath'
# ________________________Oversampling_________________________________________
OVERSAMPLING = True
# ________________________Filtering_________________________________________
FILTERING = True
# ________________________Stratification_________________________________________
STRATIFICATION = False

# **Load Dataset &#8595;**

In [3]:
DATA_PATH = "datasets/eedi-mining-misconceptions-in-mathematics"
EXTERNAL_DATA_PATH = "datasets/eedi-external-dataset"

if not EXTERNAL_DATA:
    train_df = pd.read_csv(f'{DATA_PATH}/train.csv', index_col='QuestionId')
else:
    train_df = pd.read_csv(f'{EXTERNAL_DATA_PATH}/all_train.csv', index_col='QuestionId') #this contains the original dataset + an external dataset generated by a LLM

misconceptions_df = pd.read_csv(f'{DATA_PATH}/misconception_mapping.csv')

pd.options.display.max_colwidth = 300
display(train_df.head(5))
pd.options.display.max_colwidth = 50

Unnamed: 0_level_0,ConstructId,ConstructName,SubjectId,SubjectName,CorrectAnswer,QuestionText,AnswerAText,AnswerBText,AnswerCText,AnswerDText,MisconceptionAId,MisconceptionBId,MisconceptionCId,MisconceptionDId,source,MisconceptionAName,MisconceptionBName,MisconceptionCName,MisconceptionDName,OriginalQuestionId
QuestionId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
0,856.0,Use the order of operations to carry out calculations involving powers,33.0,BIDMAS,A,\[\r\n3 \times 2+4-5\r\n\]\r\nWhere do the brackets need to go to make the answer equal \( 13 \) ?,\( 3 \times(2+4)-5 \),\( 3 \times 2+(4-5) \),\( 3 \times(2+4-5) \),Does not need brackets,,,,1672.0,original,,,,"Confuses the order of operations, believes addition comes before multiplication",
1,1612.0,Simplify an algebraic fraction by factorising the numerator,1077.0,Simplifying Algebraic Fractions,D,"Simplify the following, if possible: \( \frac{m^{2}+2 m-3}{m-3} \)",\( m+1 \),\( m+2 \),\( m-1 \),Does not simplify,2142.0,143.0,2142.0,,original,"Does not know that to factorise a quadratic expression, to find two numbers that add to give the coefficient of the x term, and multiply to give the non variable term\r\n","Thinks that when you cancel identical terms from the numerator and denominator, they just disappear","Does not know that to factorise a quadratic expression, to find two numbers that add to give the coefficient of the x term, and multiply to give the non variable term\r\n",,
2,2774.0,Calculate the range from a list of data,339.0,Range and Interquartile Range from a List of Data,B,"Tom and Katie are discussing the \( 5 \) plants with these heights:\r\n\( 24 \mathrm{~cm}, 17 \mathrm{~cm}, 42 \mathrm{~cm}, 26 \mathrm{~cm}, 13 \mathrm{~cm} \)\r\nTom says if all the plants were cut in half, the range wouldn't change.\r\nKatie says if all the plants grew by \( 3 \mathrm{~cm} \)...",Only\r\nTom,Only\r\nKatie,Both Tom and Katie,Neither is correct,1287.0,,1287.0,1073.0,original,Believes if you changed all values by the same proportion the range would not change,,Believes if you changed all values by the same proportion the range would not change,Believes if you add the same value to all numbers in the dataset the range will change,
3,2377.0,Recall and use the intersecting diagonals properties of a rectangle,88.0,Properties of Quadrilaterals,C,The angles highlighted on this rectangle with different length sides can never be... ![A rectangle with the diagonals drawn in. The angle on the right hand side at the centre is highlighted in red and the angle at the bottom at the centre is highlighted in yellow.](),acute,obtuse,\( 90^{\circ} \),Not enough information,1180.0,1180.0,,1180.0,original,Does not know the properties of a rectangle,Does not know the properties of a rectangle,,Does not know the properties of a rectangle,
4,3387.0,Substitute positive integer values into formulae involving powers or roots,67.0,Substitution into Formula,A,The equation \( f=3 r^{2}+3 \) is used to find values in the table below. What is the value covered by the star? \begin{tabular}{|c|c|c|c|c|}\r\n\hline\( r \) & \( 1 \) & \( 2 \) & \( 3 \) & \( 4 \) \\\r\n\hline\( f \) & \( 6 \) & \( 15 \) & \( \color{gold}\bigstar \) & \\\r\n\hline\r\n\end{tabu...,\( 30 \),\( 27 \),\( 51 \),\( 24 \),,,,1818.0,original,,,,Thinks you can find missing values in a given table by treating the row as linear and adding on the difference between the first two values given.,


# **Data Preprocessing &#8595;**

In [4]:
def clean(example, columns):
    """
    Cleans the example from the Dataset
    Args:
        example: an example from the Dataset
        columns: columns that will be cleaned

    Returns: update example containing 'clean' columns

    """
    for col in columns:
        text = example[f'{col}']

        # Empty text
        if type(text) not in (str, np.str_) or text=='':
            example[f'clean_{col}'] = ''
            return example

        # 'text' from the example can be of type numpy.str_, let's convert it to a python str
        text = str(text).lower()

        # Clean the text
        text = re.sub("\"", " ", text) # removes the " from certain texts
        text = re.sub("\n", " ", text) # removes the multiple "\n" 
        text = re.sub(r"(\\\w+)(\W)", r" \1 \2", text) # matches with the LaTeX commands like "\hline{}",... and transforms them to " \hline {}"
        text = re.sub(r"([\(|\{|\[|\|])", r" \1", text) # matches every opening parenthesis types and puts spaces on their left
        text = re.sub(r"([\)|\}|\]])", r"\1 ", text) # matches every closing parenthesis types and puts spaces on their right
        text = re.sub(r"\\(?![a-zA-Z])", " ", text) # removes every backslash that is not the start of a LaTeX command
        text = re.sub(r"\( | \)", "", text) # removes the parentheses that appear sometimes from nowhere 
        text = re.sub(r"\[ | \]", "", text) # removes the parentheses that appear sometimes from nowhere
        
        text = re.sub(r" +", " ", text) # cleans the double spaces made by above substitutions
        # Update the example with the cleaned text
        example[f'clean_{col}'] = text.strip()
    return example

columns_to_clean = ['QuestionText', 'AnswerAText', 'AnswerBText', 'AnswerCText', 'AnswerDText']
train_df = train_df.apply(clean, axis = 1, columns = columns_to_clean)

# Adjust column order
new_order = ['ConstructId', 'ConstructName', 'SubjectId', 'SubjectName', 'CorrectAnswer']
for col in columns_to_clean:
    new_order.append(col)
    new_order.append(f'clean_{col}')
new_order.extend(['MisconceptionAId', 'MisconceptionBId', 'MisconceptionCId', 'MisconceptionDId'])
train_df = train_df[new_order]


display_train_df = train_df[['QuestionText', 'clean_QuestionText','AnswerAText', 'clean_AnswerAText', 'AnswerBText', 'clean_AnswerBText', 'AnswerCText', 'clean_AnswerCText', 'AnswerDText', 'clean_AnswerDText']]
pd.options.display.max_colwidth = 300
display(display_train_df.head(1))
pd.options.display.max_colwidth = 50

display(misconceptions_df.head(1))

Unnamed: 0_level_0,QuestionText,clean_QuestionText,AnswerAText,clean_AnswerAText,AnswerBText,clean_AnswerBText,AnswerCText,clean_AnswerCText,AnswerDText,clean_AnswerDText
QuestionId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
0,\[\r\n3 \times 2+4-5\r\n\]\r\nWhere do the brackets need to go to make the answer equal \( 13 \) ?,[\r 3 \times 2+4-5\r \r where do the brackets need to go to make the answer equal 13 ?,\( 3 \times(2+4)-5 \),3 \times (2+4) -5,\( 3 \times 2+(4-5) \),3 \times 2+ (4-5),\( 3 \times(2+4-5) \),3 \times (2+4-5),Does not need brackets,does not need brackets


Unnamed: 0,MisconceptionId,MisconceptionName
0,0,Does not know that angles in a triangle sum to...


# **Build Subject-Misconception Mapping**

In [5]:
def build_subject_to_misconception_mapping(train_df):
    """
    Create a mapping from each SubjectName to the set of MisconceptionIds linked in the training data.

    Args:
        train_df: DataFrame containing columns 'SubjectName' and 'Misconception[A/B/C/D]Id'.

    Returns:
        A dictionary where keys are subjects and values are sets of MisconceptionIds.
    """
    subject_to_misconceptions = {}
    for _, row in train_df.iterrows():
        subject = row['SubjectName']
        misconceptions = {
            row['MisconceptionAId'], 
            row['MisconceptionBId'], 
            row['MisconceptionCId'], 
            row['MisconceptionDId']
        }
        misconceptions = {m for m in misconceptions if not pd.isna(m)}  # Remove NaN values
        if subject not in subject_to_misconceptions:
            subject_to_misconceptions[subject] = set()
        subject_to_misconceptions[subject].update(misconceptions)
    return subject_to_misconceptions

subject_to_misconceptions = build_subject_to_misconception_mapping(train_df)

print("Subjects and their linked misconceptions (sample):")
for subject, misconceptions in list(subject_to_misconceptions.items())[:5]:
    print(f"Subject: {subject}, Misconceptions: {misconceptions}")

Subjects and their linked misconceptions (sample):
Subject: BIDMAS, Misconceptions: {2306.0, 706.0, 2181.0, 1862.0, 1672.0, 328.0, 1416.0, 907.0, 2316.0, 77.0, 1805.0, 15.0, 524.0, 657.0, 2326.0, 1880.0, 217.0, 27.0, 2140.0, 2270.0, 1054.0, 1507.0, 2532.0, 1316.0, 2488.0, 1510.0, 1828.0, 234.0, 1963.0, 1516.0, 1642.0, 1207.0, 1400.0, 1597.0, 2175.0}
Subject: Simplifying Algebraic Fractions, Misconceptions: {2307.0, 1540.0, 1610.0, 143.0, 792.0, 1755.0, 2398.0, 2142.0, 2078.0, 353.0, 1825.0, 167.0, 1256.0, 363.0, 113.0, 891.0, 1535.0}
Subject: Range and Interquartile Range from a List of Data, Misconceptions: {1349.0, 1287.0, 2119.0, 2346.0, 1677.0, 397.0, 1073.0, 691.0, 2551.0, 2456.0, 1177.0}
Subject: Properties of Quadrilaterals, Misconceptions: {1348.0, 1940.0, 1877.0, 85.0, 1752.0, 1180.0, 226.0, 423.0, 551.0, 106.0, 1007.0, 1009.0, 1393.0, 2355.0, 2357.0, 2102.0, 1917.0, 1790.0, 2493.0}
Subject: Substitution into Formula, Misconceptions: {1792.0, 641.0, 643.0, 389.0, 1417.0, 533.0

# **Reshape Dataset For Training &#8595;**

In [6]:
# train_df columns: QuestionID, ConstructID, ConstructName, CorrectAnswer, SubjectId, SubjectName, QuestionText, Answer[A/B/C/D]Text, Misconception[A/B/C/D]Id

reshaped_data = []
for _, row in train_df.iterrows():
    for answer, misconception_id in zip(
        ['clean_AnswerAText', 'clean_AnswerBText', 'clean_AnswerCText', 'clean_AnswerDText'],
        ['MisconceptionAId', 'MisconceptionBId', 'MisconceptionCId', 'MisconceptionDId']
    ): # turn the data into a format where each datapoint (row) represents an answer choice (i.e there are now 4 datapoints for each question)
        misc_id = int(row[misconception_id]) if not pd.isna(row[misconception_id]) else row[misconception_id]
        reshaped_data.append({
            'QuestionText': row['clean_QuestionText'],
            'AnswerText': row[answer],
            'MisconceptionId': misc_id,
            'MisconceptionText': misconceptions_df.loc[misconceptions_df['MisconceptionId'] == misc_id, 'MisconceptionName'].values[0] if not pd.isna(misc_id) else misc_id,
            'SubjectName': row['SubjectName'],
            'ConstructName': row['ConstructName']
        })

reshaped_df = pd.DataFrame(reshaped_data)
display(reshaped_df.head())

# removed columns: QuestionId, ConstructId, CorrectAnswer, SubjectId
# other changes: Answer[A/B/C/D]Text are now in separate datapoints along with their associated Misconception[A/B/C/D]Texts 

Unnamed: 0,QuestionText,AnswerText,MisconceptionId,MisconceptionText,SubjectName,ConstructName
0,[\r 3 \times 2+4-5\r \r where do the brackets ...,3 \times (2+4) -5,,,BIDMAS,Use the order of operations to carry out calcu...
1,[\r 3 \times 2+4-5\r \r where do the brackets ...,3 \times 2+ (4-5),,,BIDMAS,Use the order of operations to carry out calcu...
2,[\r 3 \times 2+4-5\r \r where do the brackets ...,3 \times (2+4-5),,,BIDMAS,Use the order of operations to carry out calcu...
3,[\r 3 \times 2+4-5\r \r where do the brackets ...,does not need brackets,1672.0,"Confuses the order of operations, believes add...",BIDMAS,Use the order of operations to carry out calcu...
4,"simplify the following, if possible: \frac {m^...",m+1,2142.0,Does not know that to factorise a quadratic ex...,Simplifying Algebraic Fractions,Simplify an algebraic fraction by factorising ...


# **Sentence Embeddings / TF-IDF & OneHot Encoding&#8595;**

In [7]:
# remove NaN values (dropping all datapoints that do not have misconceptions assigned to them)
# P.S. that means we are also deleting all the rows (answer choices) that are correct
# P.P.S. unless somehow there are correct answers that have misconceptions associated with them
print(f"NaN values: {reshaped_df['MisconceptionId'].isnull().sum()}")  # 10582 NaN values yikes :/
reshaped_df = reshaped_df.dropna(subset=['MisconceptionId'])
print(f"NaN values: {reshaped_df['MisconceptionId'].isnull().sum()}")  # 0 now yippie
print(f"Dataset Shape: {reshaped_df.shape}")

10582
0


In [9]:
# Combine QuestionText and AnswerText into a single text

reshaped_df['CombinedText'] = reshaped_df['QuestionText'] + " " + reshaped_df['AnswerText']

In [None]:
# TF-IDF

if not EMBEDDINGS:
    vectorizer = TfidfVectorizer(max_features=5000) 
    X_processed = vectorizer.fit_transform(reshaped_df['CombinedText'])

In [None]:
def embed_text(text, model):
    """
    Create sentence embeddings for the given text using the given model.

    Args:
        text: String to be embedded.
        model: SentenceTransformer model.

    Returns:
        A numpy array of shape (embedding_size,) containing the sentence embedding.
    """
    return model.encode(text, convert_to_tensor=True)

In [10]:
# TSDAE training

if model_name != 'hkunlp/instructor-large' and not ALREADY_FINE_TUNED and EMBEDDINGS: # TSDAE doesn't work with the instructor-large model
    word_embedding_model = models.Transformer(model_name)
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    tsdae_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    train_loss = DenoisingAutoEncoderLoss(tsdae_model, tie_encoder_decoder=True)

    train_examples = [
        InputExample(texts=[row['CombinedText'], row['MisconceptionText']])
        for _, row in reshaped_df.iterrows()
    ]

    train_dataset = SentencesDataset(train_examples, model=None)  
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)

    print("Starting training")
    tsdae_model.fit(
        train_objectives=[(train_dataloader, train_loss)],
        epochs=1, 
        warmup_steps=reshaped_df.shape[1] * 0.15,  
        optimizer_params={'lr': 1e-5},  
        output_path=output_path
    )

In [None]:
if FINE_TUNED_MODEL and EMBEDDINGS:
    model = SentenceTransformer(output_path)
    X_processed = np.array([embed_text(text, model) for text in reshaped_df['CombinedText']])
elif EMBEDDINGS:
    model = SentenceTransformer(model_name)
    X_processed = np.array([embed_text(text, model) for text in reshaped_df['CombinedText']])

In [11]:
# use One hot encoding for categorical data (create a "column" for each unique subject and construct and represent each row with 0 and 1)

encoder = OneHotEncoder(sparse_output=False)
categorical_features = encoder.fit_transform(reshaped_df[['SubjectName', 'ConstructName']])

# **Stratification & Oversampling&#8595;**

In [22]:
X = np.hstack([X_processed, categorical_features])
y = reshaped_df['MisconceptionText']

if OVERSAMPLING:
    oversampler = RandomOverSampler(random_state=42)
    X_final, y_final = oversampler.fit_resample(X, y)
    reshaped_df = reshaped_df.loc[oversampler.sample_indices_].reset_index(drop=True)
else:
    reshaped_df = reshaped_df.reset_index(drop=True)
    X_final, y_final = X, y

print(f"X_final shape: {X_final.shape}")
print(f"y_final shape: {y_final.shape}")

if STRATIFICATION:
    X_train, X_val, y_train, y_val, train_meta, test_meta = train_test_split(X_final, y_final, reshaped_df, stratify=y, test_size=0.1, random_state=42)
else:
    X_train, X_val, y_train, y_val, train_meta, test_meta = train_test_split(X_final, y_final, reshaped_df, test_size=0.1, random_state=42)

(86616, 1688)
(86616,)


# **Random Forest Training&#8595;**

In [23]:
rf_classifier = RandomForestClassifier(n_estimators=100, random_state=42)
rf_classifier.fit(X_train, y_train)

# **Filtering&#8595;**

In [14]:
def filter_predictions_by_subject(y_pred_probs, test_subjects, misconceptions_by_subject):
    """
    Filter predictions by subject to prioritize misconceptions linked to each subject.
    If a subject is not found, all misconceptions are eligible.

    Args:
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        test_subjects: List of SubjectName values for the test set.
        misconceptions_by_subject: Dictionary mapping subjects to MisconceptionIds.

    Returns:
        Filtered and normalized probabilities.
    """
    filtered_probs = []
    for probs, subject in zip(y_pred_probs, test_subjects):
        # Use all misconceptions if the subject is not in the mapping
        subject_misconceptions = misconceptions_by_subject.get(subject, set(rf_classifier.classes_))
        # Zero out probabilities for misconceptions not linked to the subject
        filtered_prob = [probs[j] if j in subject_misconceptions else 0 for j in range(len(probs))]
        # Normalize probabilities
        filtered_prob = np.array(filtered_prob) / np.sum(filtered_prob) if np.sum(filtered_prob) > 0 else probs
        filtered_probs.append(filtered_prob)
    return np.array(filtered_probs)

# **Evaluation Metrics&#8595;**

In [15]:
def map_at_25(y_true, y_pred_probs, top_k=25):
    """
    Gives the Mean Average Precision at k for the given predictions.

    Args:
        y_true: True labels for each class (misconceptions).
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        top_k: Number of top predictions to consider.

    Returns:
        Mean Average Precision at k.
    """
    
    map_25 = 0.0
    for true_label, pred_prob in zip(y_true, y_pred_probs):
        # Get top_k predictions
        top_preds = np.argsort(pred_prob)[::-1][:top_k]
        
        if not true_label:
            continue
        
        score = 0.0
        hits = 0
        for i, pred in enumerate(top_preds, start=1):
            if pred == true_label:
                hits += 1
                score += hits / i  # Precision at i
        
        # Average Precision at 25
        map_25 += score / min(1, top_k)
    
    return map_25 / len(y_true)

def ndcg_at_25(y_true, y_pred_probs, k=25):
    """
    Gives the nDCG at k for the given predictions.

    Args:
        y_true: True labels for each class (misconceptions).
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        top_k: Number of top predictions to consider.

    Returns:
        nDCG at k.
    """
    ndcg = 0.0
    for true_label, pred_prob in zip(y_true, y_pred_probs):
        top_preds = np.argsort(pred_prob)[::-1][:k]
        if not true_label:
            continue

        dcg = 0.0
        for i, pred in enumerate(top_preds, start=1):
            if pred == true_label:
                dcg += 1 / np.log2(i + 1)  # Discounted gain

        ideal_dcg = 1 / np.log2(1 + 1)  # Ideal DCG when correct at rank 1
        ndcg += dcg / ideal_dcg

    return ndcg / len(y_true)

def precision_at_25(y_true, y_pred_probs, k=25):
    """
    Gives the Precision at k for the given predictions.

    Args:
        y_true: True labels for each class (misconceptions).
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        top_k: Number of top predictions to consider.

    Returns:
        Precision at k.
    """
    precision = 0.0
    for true_label, pred_prob in zip(y_true, y_pred_probs):
        top_preds = np.argsort(pred_prob)[::-1][:k]
        if not true_label:
            continue

        correct = 1 if true_label in top_preds else 0
        precision += correct / k

    return precision / len(y_true)

def recall_at_25(y_true, y_pred_probs, k=25):
    """
    Gives the Recall at k for the given predictions.

    Args:
        y_true: True labels for each class (misconceptions).
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        top_k: Number of top predictions to consider.

    Returns:
        Recall at k.
    """
    recall = 0.0
    for true_label, pred_prob in zip(y_true, y_pred_probs):
        top_preds = np.argsort(pred_prob)[::-1][:k]
        if not true_label:
            continue

        correct = 1 if true_label in top_preds else 0
        recall += correct

    return recall / len(y_true)

def f1_at_25(y_true, y_pred_probs, k=25):
    """
    Gives the F1 Score at k for the given predictions.

    Args:
        y_true: True labels for each class (misconceptions).
        y_pred_probs: Predicted probabilities for each class (misconceptions).
        top_k: Number of top predictions to consider.

    Returns:
        F1 Score at k.
    """
    precision = precision_at_25(y_true, y_pred_probs, k)
    recall = recall_at_25(y_true, y_pred_probs, k)
    if precision + recall == 0:
        return 0.0
    return 2 * (precision * recall) / (precision + recall)

# **Testing&#8595;**

In [24]:
y_val_pred_probs = rf_classifier.predict_proba(X_val)  
y_val_true = list(y_val)
y_val_true_id = []

# Map misconception names to their corresponding ids for y_val_true
for val in y_val_true:
    matching_row = misconceptions_df[misconceptions_df['MisconceptionName'] == val]
    if not matching_row.empty:  # Check if a matching row exists
        misconception_id = matching_row.iloc[0]['MisconceptionId']  # Get the MisconceptionId
        y_val_true_id.append(misconception_id)
    else:
        y_val_true_id.append(None)

map25_score = map_at_25(y_val_true_id, y_val_pred_probs)
ndcg_score = ndcg_at_25(y_val_true_id, y_val_pred_probs, k=25)
precision_score = precision_at_25(y_val_true_id, y_val_pred_probs, k=25)
recall_score = recall_at_25(y_val_true_id, y_val_pred_probs, k=25)
f1_score = f1_at_25(y_val_true_id, y_val_pred_probs, k=25)

# Print Unfiltered Scores
if FILTERING:
    print("Unfiltered Scores:")
else:
    print("Scores:")
print(f"MAP@25 Score: {map25_score}")
print(f"NDCG@25: {ndcg_score}")
print(f"Precision@25: {precision_score}")
print(f"Recall@25: {recall_score}")
print(f"F1@25: {f1_score}")

if FILTERING:
    test_subjects = test_meta['SubjectName'].values
    # Apply subject-based filtering
    filtered_y_pred_probs = filter_predictions_by_subject(
        y_pred_probs=y_val_pred_probs,
        test_subjects=test_subjects,
        misconceptions_by_subject=subject_to_misconceptions
    )

    print(f"size of unfiltered predictions: {y_val_pred_probs.shape}")
    print(f"size of filtered predictions: {filtered_y_pred_probs.shape}")
    
    map25_score_filtered = map_at_25(y_val_true_id, filtered_y_pred_probs)
    ndcg_score_filtered = ndcg_at_25(y_val_true_id, filtered_y_pred_probs, k=25)
    precision_score_filtered = precision_at_25(y_val_true_id, filtered_y_pred_probs, k=25)
    recall_score_filtered = recall_at_25(y_val_true_id, filtered_y_pred_probs, k=25)
    f1_score_filtered = f1_at_25(y_val_true_id, filtered_y_pred_probs, k=25)
    
    print("\nFiltered Scores:")
    print(f"MAP@25 Score (Filtered): {map25_score_filtered}")
    print(f"NDCG@25 (Filtered): {ndcg_score_filtered}")
    print(f"Precision@25 (Filtered): {precision_score_filtered}")
    print(f"Recall@25 (Filtered): {recall_score_filtered}")
    print(f"F1@25 (Filtered): {f1_score_filtered}")

Unfiltered Scores:
MAP@25 Score: 0.0010010992054441375
NDCG@25: 0.002652953923816565
Precision@25: 0.0003648118217501734
Recall@25: 0.009120295543754329
F1@25: 0.0007015611956734104

Filtered Scores:
MAP@25 Score (Filtered): 0.0010588225949615699
NDCG@25 (Filtered): 0.002695561895004344
Precision@25 (Filtered): 0.0003648118217501734
Recall@25 (Filtered): 0.009120295543754329
F1@25 (Filtered): 0.0007015611956734104
