# Input prompt feature extraction

In [1]:
import pandas as pd
import numpy as np
from collections import Counter

import spacy 
import textstat
from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer


from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import mutual_info_classif


import matplotlib.pyplot as plt

## I) Preprocessing

### a) Data Loading task 1

In [2]:
generation_file_task1 = 'dataset/task1_merged_generation.jsonl'
eval_result_task1 = 'dataset/task1_merged_hallucination_scores.jsonl'


generation_task1_df = pd.read_json(generation_file_task1, lines=True)
evaluation_task1_df = pd.read_json(eval_result_task1, lines=True)

In [3]:
generation_task1_df

Unnamed: 0,title,h_score_cat,pageid,revid,description,categories,reference,prompt,answer,generation,run_id
0,Jan Willem Spruyt,0,16859978,1154166135,"Schoolteacher, South African civil servant, la...","[1826 births, 1908 deaths, 19th-century Dutch ...",Jan Willem Spruyt (4 July 1826 – 8 September 1...,What position did Jan Willem Spruyt hold in th...,Government secretary.,Jan Willem Spruyt served as the State Attorney...,task_1_run_1
1,HMS Prize,0,58453794,1102135374,German schooner captured and refitted as a Bri...,"[Q-ships of the Royal Navy, World War I naval ...",## Service as a Q ship First Prize underwent a...,What was the new name given to First Prize by ...,HMS Prize,"The vessel originally known as ""First Prize"" w...",task_1_run_1
2,Mississippi Highway 902,0,58127285,1061978358,Highway in Mississippi,"[State highways in Mississippi, Transportation...","Mississippi Highway 902 (MS 902, also known as...",What is the northern terminus of Mississippi H...,MS 35 in Lorena.,Mississippi Highway 902 (MS 902) does not have...,task_1_run_1
3,St. Elmo (1910 Thanhouser film),1,45082337,1168621248,,"[1910 drama films, 1910 films, 1910s American ...",## Plot\n\nThe plot was adapted from Augusta J...,Who is betrothed to St. Elmo in the film?,Agnes,"In the 1985 film ""St. Elmo's Fire,"" there is n...",task_1_run_1
4,Oddfellows (album),1,37222693,1154749417,,"[2013 albums, Ipecac Recordings albums, Tomaha...",## Release\n\nOddfellows was released on Janua...,What date was the album Oddfellows released?,"January 29, 2013","The album ""Oddfellows"" by Tomahawk was release...",task_1_run_1
...,...,...,...,...,...,...,...,...,...,...,...
1761,Sara Ramirez,7,1999305,1173275338,American actor and singer,"[1975 births, 20th-century American LGBT peopl...","## Personal life\n\nOn June 27, 2011, Ramírez ...",What organization did Sara Ramirez donate thei...,Locks of Love,"In September 2016, Sara Ramirez donated their ...",task_1_run_9
1762,Khalili Collection of Japanese Art,7,63416118,1153676727,Private collection of Meiji-era art,"[Japanese art collectors, Khalili Collections]",The Khalili Collection of Japanese Art is a pr...,What is the era of Japanese art featured in th...,Meiji-era (1868–1912),The Khalili Collection of Japanese Art feature...,task_1_run_9
1763,Jimmy Greaves,8,231068,1173499110,English footballer (1940–2021),"[1940 births, 1962 FIFA World Cup players, 196...","Greaves was fit to play in the final, but mana...",Who presented Jimmy Greaves with his World Cup...,Gordon Brown,Jimmy Greaves was presented with his World Cup...,task_1_run_9
1764,4X,8,705620,1168877758,Genre of strategy-based video and board games,"[4X games, Real-time strategy video games, Str...",### Complexity\n\n4X games are known for their...,What aspect of gameplay is prioritized over gr...,Gameplay,"In 4X games, gameplay tends to prioritize stra...",task_1_run_9


In [4]:
evaluation_task1_df

Unnamed: 0,run_id,halu_test_res
0,task_1_run_1,True
1,task_1_run_1,False
2,task_1_run_1,True
3,task_1_run_1,True
4,task_1_run_1,False
...,...,...
1761,task_1_run_9,False
1762,task_1_run_9,False
1763,task_1_run_9,False
1764,task_1_run_9,False


In [5]:
data_task1 = pd.concat([generation_task1_df, evaluation_task1_df], axis=1)
data_task1

Unnamed: 0,title,h_score_cat,pageid,revid,description,categories,reference,prompt,answer,generation,run_id,run_id.1,halu_test_res
0,Jan Willem Spruyt,0,16859978,1154166135,"Schoolteacher, South African civil servant, la...","[1826 births, 1908 deaths, 19th-century Dutch ...",Jan Willem Spruyt (4 July 1826 – 8 September 1...,What position did Jan Willem Spruyt hold in th...,Government secretary.,Jan Willem Spruyt served as the State Attorney...,task_1_run_1,task_1_run_1,True
1,HMS Prize,0,58453794,1102135374,German schooner captured and refitted as a Bri...,"[Q-ships of the Royal Navy, World War I naval ...",## Service as a Q ship First Prize underwent a...,What was the new name given to First Prize by ...,HMS Prize,"The vessel originally known as ""First Prize"" w...",task_1_run_1,task_1_run_1,False
2,Mississippi Highway 902,0,58127285,1061978358,Highway in Mississippi,"[State highways in Mississippi, Transportation...","Mississippi Highway 902 (MS 902, also known as...",What is the northern terminus of Mississippi H...,MS 35 in Lorena.,Mississippi Highway 902 (MS 902) does not have...,task_1_run_1,task_1_run_1,True
3,St. Elmo (1910 Thanhouser film),1,45082337,1168621248,,"[1910 drama films, 1910 films, 1910s American ...",## Plot\n\nThe plot was adapted from Augusta J...,Who is betrothed to St. Elmo in the film?,Agnes,"In the 1985 film ""St. Elmo's Fire,"" there is n...",task_1_run_1,task_1_run_1,True
4,Oddfellows (album),1,37222693,1154749417,,"[2013 albums, Ipecac Recordings albums, Tomaha...",## Release\n\nOddfellows was released on Janua...,What date was the album Oddfellows released?,"January 29, 2013","The album ""Oddfellows"" by Tomahawk was release...",task_1_run_1,task_1_run_1,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1761,Sara Ramirez,7,1999305,1173275338,American actor and singer,"[1975 births, 20th-century American LGBT peopl...","## Personal life\n\nOn June 27, 2011, Ramírez ...",What organization did Sara Ramirez donate thei...,Locks of Love,"In September 2016, Sara Ramirez donated their ...",task_1_run_9,task_1_run_9,False
1762,Khalili Collection of Japanese Art,7,63416118,1153676727,Private collection of Meiji-era art,"[Japanese art collectors, Khalili Collections]",The Khalili Collection of Japanese Art is a pr...,What is the era of Japanese art featured in th...,Meiji-era (1868–1912),The Khalili Collection of Japanese Art feature...,task_1_run_9,task_1_run_9,False
1763,Jimmy Greaves,8,231068,1173499110,English footballer (1940–2021),"[1940 births, 1962 FIFA World Cup players, 196...","Greaves was fit to play in the final, but mana...",Who presented Jimmy Greaves with his World Cup...,Gordon Brown,Jimmy Greaves was presented with his World Cup...,task_1_run_9,task_1_run_9,False
1764,4X,8,705620,1168877758,Genre of strategy-based video and board games,"[4X games, Real-time strategy video games, Str...",### Complexity\n\n4X games are known for their...,What aspect of gameplay is prioritized over gr...,Gameplay,"In 4X games, gameplay tends to prioritize stra...",task_1_run_9,task_1_run_9,False


In [6]:
data_task1 = data_task1[['prompt', 'generation', 'halu_test_res']]
data_task1

Unnamed: 0,prompt,generation,halu_test_res
0,What position did Jan Willem Spruyt hold in th...,Jan Willem Spruyt served as the State Attorney...,True
1,What was the new name given to First Prize by ...,"The vessel originally known as ""First Prize"" w...",False
2,What is the northern terminus of Mississippi H...,Mississippi Highway 902 (MS 902) does not have...,True
3,Who is betrothed to St. Elmo in the film?,"In the 1985 film ""St. Elmo's Fire,"" there is n...",True
4,What date was the album Oddfellows released?,"The album ""Oddfellows"" by Tomahawk was release...",False
...,...,...,...
1761,What organization did Sara Ramirez donate thei...,"In September 2016, Sara Ramirez donated their ...",False
1762,What is the era of Japanese art featured in th...,The Khalili Collection of Japanese Art feature...,False
1763,Who presented Jimmy Greaves with his World Cup...,Jimmy Greaves was presented with his World Cup...,False
1764,What aspect of gameplay is prioritized over gr...,"In 4X games, gameplay tends to prioritize stra...",False


### b) Data Loading task 3

In [7]:
generation_file_task3 = 'dataset/task3_merged_generation.jsonl'
eval_result_task3 = 'dataset/task3_merged_hallucination_scores.jsonl'


generation_task3_df = pd.read_json(generation_file_task3, lines=True)
evaluation_task3_df = pd.read_json(eval_result_task3, lines=True)

In [8]:
generation_task3_df

Unnamed: 0,prompt,name,type,generation,run_id,item_idx
0,I'm curious about the medicine bactocin gargle...,bactocin gargle,medicine,"Bactocin, often referred to in its generic for...",task_3_run_1,0
1,I'd like to learn more about the medicine gota...,gotam divicap,medicine,It seems there may be a typographical or trans...,task_3_run_1,1
2,Could you give me details about the medicine b...,bgmont-lc -kit,medicine,Bgmont-LC Kit is a medication commonly used fo...,task_3_run_1,2
3,What can you tell me about the medicine domici...,domicit cfc-free paste,medicine,Domicit CFC-Free Paste is a medication primari...,task_3_run_1,3
4,I'd like to learn more about the medicine drob...,drobest-m caplet,medicine,Drobest-M is a combination medication that typ...,task_3_run_1,4
...,...,...,...,...,...,...
2315,Can you share some details about the bacteria ...,Neptuniibacter balearica,bacteria,Neptuniibacter balearica is a species of bacte...,task_3_run_9,195
2316,I want to know more about the bacteria Limnoly...,Limnolyngbya fragilis,bacteria,**Limnolyngbya fragilis** is a species of cyan...,task_3_run_9,196
2317,What do you know about the bacteria Sacchariba...,Saccharibacillus naphthalenovorans,bacteria,**Saccharibacillus naphthalenovorans** is a sp...,task_3_run_9,197
2318,What do you know about the bacteria Grimontia ...,Grimontia fortis,bacteria,Grimontia fortis is a species of bacteria that...,task_3_run_9,198


In [9]:
evaluation_task3_df 

Unnamed: 0,model,false_acceptance_rate,refusal_rate,refusal_eval_raw,run_id,item_idx
0,gpt-4o-mini,0.825,0.175,True,task_3_run_1,0
1,gpt-4o-mini,0.825,0.175,True,task_3_run_1,1
2,gpt-4o-mini,0.825,0.175,False,task_3_run_1,2
3,gpt-4o-mini,0.825,0.175,False,task_3_run_1,3
4,gpt-4o-mini,0.825,0.175,False,task_3_run_1,4
...,...,...,...,...,...,...
2315,gpt-4o-mini,0.875,0.125,False,task_3_run_9,195
2316,gpt-4o-mini,0.875,0.125,False,task_3_run_9,196
2317,gpt-4o-mini,0.875,0.125,False,task_3_run_9,197
2318,gpt-4o-mini,0.875,0.125,False,task_3_run_9,198


In [10]:
data_task3 = pd.concat([generation_task3_df, evaluation_task3_df ], axis=1)
data_task3

Unnamed: 0,prompt,name,type,generation,run_id,item_idx,model,false_acceptance_rate,refusal_rate,refusal_eval_raw,run_id.1,item_idx.1
0,I'm curious about the medicine bactocin gargle...,bactocin gargle,medicine,"Bactocin, often referred to in its generic for...",task_3_run_1,0,gpt-4o-mini,0.825,0.175,True,task_3_run_1,0
1,I'd like to learn more about the medicine gota...,gotam divicap,medicine,It seems there may be a typographical or trans...,task_3_run_1,1,gpt-4o-mini,0.825,0.175,True,task_3_run_1,1
2,Could you give me details about the medicine b...,bgmont-lc -kit,medicine,Bgmont-LC Kit is a medication commonly used fo...,task_3_run_1,2,gpt-4o-mini,0.825,0.175,False,task_3_run_1,2
3,What can you tell me about the medicine domici...,domicit cfc-free paste,medicine,Domicit CFC-Free Paste is a medication primari...,task_3_run_1,3,gpt-4o-mini,0.825,0.175,False,task_3_run_1,3
4,I'd like to learn more about the medicine drob...,drobest-m caplet,medicine,Drobest-M is a combination medication that typ...,task_3_run_1,4,gpt-4o-mini,0.825,0.175,False,task_3_run_1,4
...,...,...,...,...,...,...,...,...,...,...,...,...
2315,Can you share some details about the bacteria ...,Neptuniibacter balearica,bacteria,Neptuniibacter balearica is a species of bacte...,task_3_run_9,195,gpt-4o-mini,0.875,0.125,False,task_3_run_9,195
2316,I want to know more about the bacteria Limnoly...,Limnolyngbya fragilis,bacteria,**Limnolyngbya fragilis** is a species of cyan...,task_3_run_9,196,gpt-4o-mini,0.875,0.125,False,task_3_run_9,196
2317,What do you know about the bacteria Sacchariba...,Saccharibacillus naphthalenovorans,bacteria,**Saccharibacillus naphthalenovorans** is a sp...,task_3_run_9,197,gpt-4o-mini,0.875,0.125,False,task_3_run_9,197
2318,What do you know about the bacteria Grimontia ...,Grimontia fortis,bacteria,Grimontia fortis is a species of bacteria that...,task_3_run_9,198,gpt-4o-mini,0.875,0.125,False,task_3_run_9,198


In [11]:
data_task3 = data_task3[['prompt', 'generation', 'refusal_eval_raw']]
data_task3 = data_task3.rename(columns={'refusal_eval_raw': 'halu_test_res'})
data_task3

Unnamed: 0,prompt,generation,halu_test_res
0,I'm curious about the medicine bactocin gargle...,"Bactocin, often referred to in its generic for...",True
1,I'd like to learn more about the medicine gota...,It seems there may be a typographical or trans...,True
2,Could you give me details about the medicine b...,Bgmont-LC Kit is a medication commonly used fo...,False
3,What can you tell me about the medicine domici...,Domicit CFC-Free Paste is a medication primari...,False
4,I'd like to learn more about the medicine drob...,Drobest-M is a combination medication that typ...,False
...,...,...,...
2315,Can you share some details about the bacteria ...,Neptuniibacter balearica is a species of bacte...,False
2316,I want to know more about the bacteria Limnoly...,**Limnolyngbya fragilis** is a species of cyan...,False
2317,What do you know about the bacteria Sacchariba...,**Saccharibacillus naphthalenovorans** is a sp...,False
2318,What do you know about the bacteria Grimontia ...,Grimontia fortis is a species of bacteria that...,False


In [12]:
dataset = pd.concat([data_task1, data_task3], axis=0, ignore_index=True)
dataset

Unnamed: 0,prompt,generation,halu_test_res
0,What position did Jan Willem Spruyt hold in th...,Jan Willem Spruyt served as the State Attorney...,True
1,What was the new name given to First Prize by ...,"The vessel originally known as ""First Prize"" w...",False
2,What is the northern terminus of Mississippi H...,Mississippi Highway 902 (MS 902) does not have...,True
3,Who is betrothed to St. Elmo in the film?,"In the 1985 film ""St. Elmo's Fire,"" there is n...",True
4,What date was the album Oddfellows released?,"The album ""Oddfellows"" by Tomahawk was release...",False
...,...,...,...
4081,Can you share some details about the bacteria ...,Neptuniibacter balearica is a species of bacte...,False
4082,I want to know more about the bacteria Limnoly...,**Limnolyngbya fragilis** is a species of cyan...,False
4083,What do you know about the bacteria Sacchariba...,**Saccharibacillus naphthalenovorans** is a sp...,False
4084,What do you know about the bacteria Grimontia ...,Grimontia fortis is a species of bacteria that...,False


In [None]:
dataset.to_csv("dataset/full_dataset.csv", index=False)

## II) Text features extraction

### a) Spacy Analysis

In [None]:
prompt_df = dataset[['prompt']]
nlp = spacy.load("en_core_web_sm")

prompt_df

In [None]:
def spacy_feature(doc):
    
    # ------------- token filter ------------- 
    non_punct_token = [t for t in doc if not t.is_punct and not t.is_space]                 #Token that are not punctuation or space
    words = [t for t in non_punct_token if t.is_alpha]                                      #Token that are alphabetic words
    lemma = [t.lemma_.lower() for t in words]                                               #Lemma words: normalized form of a word
    
    # ------------- bacic count ------------- 
    num_sents = sum(1 for _ in doc.sents)                                                   #Number of sentence
    num_token = len(doc)                                                                    #Number of tokens
    
    # ------------- Richess vocabulary ------------- 
    lemma_counts = Counter(lemma)
    lemma_ratio = len(lemma_counts) / num_token                                             #Lemma ratio
    
    # ------------- POS ratios ------------- 
    pos = Counter(t.pos_ for t in doc)
    noun_ratio = pos["NOUN"] / len(doc)
    verb_ratio = pos["VERB"] / len(doc)
    adj_ratio = pos["ADJ"] / len(doc)
    adv_ratio = pos["ADV"] / len(doc)
    aux_ratio = pos['AUX'] / len(doc)

    # ------------- Punctuations -------------    
    stop_ratio = np.mean([t.is_stop for t in doc]) if len(doc) else 0.0                     #Ratio stop words in the sentence  
    punct_ratio = np.mean([t.is_punct for t in doc]) if len(doc) else 0.0                   #Ratio punctuation in the sentence  
    
    # ------------- Recognized entity ------------- 
    entity_count = len(doc.ents)
    entity_ratio = entity_count / len(doc)                                                  #Ratio recognized entity in the sentence
    
    features = {
        'number sentence' : num_sents,
        'lemma ratio' : lemma_ratio, 
        'noun ratio' : noun_ratio, 
        'verb ratio' : verb_ratio, 
        'adjective ratio' : adj_ratio, 
        'adverb ratio' : adv_ratio, 
        'auxiliairis ratio' : aux_ratio, 
        'stop ratio' : stop_ratio, 
        'punctuation ratio' : punct_ratio, 
        'entity ratio' : entity_ratio}
    
    return features

In [None]:
docs = nlp.pipe(prompt_df["prompt"].tolist())

spacy_features_df = [spacy_feature(doc) for doc in docs]
features_df = pd.DataFrame(spacy_features_df)
df_with_spacy_features = pd.concat([prompt_df.reset_index(drop=True), features_df], axis=1)


df_with_spacy_features

### b) TextStat Analysis

In [None]:
def textstat_feature(doc):
    # ------------- Readibility ------------- 
    flesch_score = textstat.flesch_reading_ease(doc)                        #flesh score
    us_school_grade = textstat.flesch_kincaid_grade(doc)                    #US School grade level  
    complexity = textstat.gunning_fog(doc)                                  #Complexity of the sentence
    
    # ------------- Words difficulty ------------- 
    nb_rare_words = textstat.difficult_words(doc)                           #Number difficulties words
    dale_chall_score = textstat.dale_chall_readability_score(doc)           #Dale CHall Score
    
    features = {
        'flesch score' : flesch_score,
        'grade (US School)' : us_school_grade,
        'complexity' : complexity,
        'Number rare words' : nb_rare_words,
        'Dale-Chall Score' : dale_chall_score
    }
    
    return features

In [None]:
docs = prompt_df["prompt"].tolist()

textstat_features_df = [textstat_feature(doc) for doc in docs]


textstat_features_df = pd.DataFrame(textstat_features_df)
df_with_textstat_features = pd.concat([df_with_spacy_features.reset_index(drop=True), textstat_features_df], axis=1)


df_with_textstat_features 

### c) Sentiment Analysis

In [None]:
def sentiment_analysis_feature(doc):
    sentiment_analyzer = SentimentIntensityAnalyzer()
    features = sentiment_analyzer.polarity_scores(doc)
    
    rename_feature = {
        'negative sentiment' : features['neg'],
        'neutral sentiment' : features['neu'],
        'positive sentiment' : features['pos'],
        'overall sentiment' : features['compound']
    }
    return rename_feature

In [None]:
docs = prompt_df["prompt"].tolist()

sentiments_features_df = [sentiment_analysis_feature(doc) for doc in docs]


sentiments_features_df = pd.DataFrame(sentiments_features_df)
df_with_sentiment_features = pd.concat([df_with_textstat_features .reset_index(drop=True), sentiments_features_df], axis=1)


df_with_sentiment_features

### d) Semantic Embeddings

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('all-mpnet-base-v2')


In [None]:
prompts = prompt_df['prompt'].tolist()

emb_space = model.encode(
    prompts, 
    normalize_embeddings=True,
    batch_size=32,
    show_progress_bar=True
)

In [None]:
emb_df = pd.DataFrame(emb_space, columns=[f"emb_{i}" for i in range(emb_space.shape[1])])
emb_df

### e) Hallucination output

In [None]:
hallucination_output = dataset['halu_test_res'].astype(int)
hallucination_output

In [None]:
input_df = pd.concat([df_with_sentiment_features.reset_index(drop=True), hallucination_output], axis=1)

input_df = input_df.rename(columns={'halu_test_res' : 'hallucination score'})
input_df

In [None]:
input_df_embedding = pd.concat([emb_df.reset_index(drop=True), hallucination_output], axis=1)
input_df_embedding

In [None]:
input_df = input_df.drop(columns={'prompt'})
input_df

## III) Training, validation, testing dataset

In [None]:
from sklearn.model_selection import train_test_split

X = input_df.drop(columns={'hallucination score'})
y = input_df['hallucination score']

#Training and testing
X_train, X_temp, y_train, y_temp = train_test_split(
    X, y, test_size=0.30, random_state=524, stratify=y
)

#Validation
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=524, stratify=y_temp
)

## IV) Logistic regression 

In [None]:
model_logistic_regression = LogisticRegression(max_iter = 10000)
model_logistic_regression.fit(X, y)

In [None]:
importance_logistic_regression = pd.Series(model_logistic_regression.coef_[0], index=X.columns).sort_values(ascending=False)
importance_logistic_regression 

In [None]:
plt.figure()
plt.barh(importance_logistic_regression.index, importance_logistic_regression.values)
plt.title('Feature importance Logistic regression')
plt.show()

### a) Correlation matrix

In [None]:
corr_matrix = input_df.corr()  
corr_matrix = corr_matrix['hallucination score']
corr_matrix = corr_matrix.drop(['hallucination score', 'number sentence'])
corr_matrix = corr_matrix.sort_values(ascending=False)
corr_matrix

In [None]:
plt.figure()
plt.barh(corr_matrix.index, corr_matrix.values)
plt.title('Feature correlation with hallucination')
plt.show()

### b) Mutual Information

In [None]:
mi = mutual_info_classif(X, y)
mi_df = pd.Series(mi, index=X.columns).sort_values(ascending=False)
mi_df

In [None]:
plt.figure()
plt.barh(mi_df.index, mi_df.values)
plt.title('Mutual Information between features and hallucination')
plt.show()

## IV) Random Forest

In [None]:
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(
    n_estimators=400,
    max_depth=8,            
    min_samples_split=40,    
    min_samples_leaf=20,     
    max_features=0.3,        
    class_weight='balanced',
    random_state=524,
)
rf.fit(X_train, y_train)


### a) Check Overfitting and cross validation

In [None]:
from sklearn.metrics import log_loss

train_loss = log_loss(y_train, rf.predict_proba(X_train))
val_loss   = log_loss(y_val,   rf.predict_proba(X_val))

print("Train log-loss:", train_loss)
print("Val log-loss  :", val_loss)

In [None]:
from sklearn.model_selection import cross_val_score

cv_scores = cross_val_score(
    rf, X, y, cv=5, scoring='accuracy'
)

print(cv_scores)
print("Mean:", cv_scores.mean())
print("Std:", cv_scores.std())


In [None]:
from sklearn.model_selection import learning_curve

train_sizes, train_scores, val_scores = learning_curve(
    rf,
    X,
    y,
    cv=5,
    scoring="accuracy",
    n_jobs=-1
)

train_mean = train_scores.mean(axis=1)
val_mean   = val_scores.mean(axis=1)

plt.plot(train_sizes, train_mean, label="Train score")
plt.plot(train_sizes, val_mean, label="Val score")
plt.xlabel("Training samples")
plt.ylabel("Accuracy")
plt.title("Learning Curve")
plt.legend()
plt.show()

 - Unstable model
 - No overfitting
 - cross validation is bad: the dataset is volatile

## V) Xgboost

In [None]:
from xgboost import XGBClassifier

xgb = XGBClassifier(
    n_estimators=2000,
    max_depth=5,
    learning_rate=0.01,
    subsample=0.8,
    colsample_bytree=0.8,
    eval_metric="logloss",
    objective="binary:logistic",
    random_state=524,
)

xgb.fit(
    X_train, y_train,
    eval_set=[(X_val, y_val)],
    verbose=True
)



### a) Overfitting check and cross validation

In [None]:
train_pred_proba = xgb.predict_proba(X_train)
val_pred_proba   = xgb.predict_proba(X_val)

print("Train log-loss:", log_loss(y_train, train_pred_proba))
print("Val log-loss  :", log_loss(y_val,   val_pred_proba))


In [None]:
from sklearn.model_selection import  StratifiedKFold

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

cv_scores = cross_val_score(
    xgb,
    X, y,
    cv=cv,
    scoring="accuracy",
    n_jobs=-1
)

print("CV Scores:", cv_scores)
print("Mean:", cv_scores.mean())
print("Std:", cv_scores.std())

In [None]:
train_sizes, train_scores, val_scores = learning_curve(
    xgb,
    X,
    y,
    cv=5,
    scoring="accuracy",
)

train_mean = train_scores.mean(axis=1)
val_mean   = val_scores.mean(axis=1)

plt.plot(train_sizes, train_mean, label="Train score")
plt.plot(train_sizes, val_mean, label="Val score")
plt.xlabel("Training samples")
plt.ylabel("Accuracy")
plt.title("Learning Curve")
plt.legend()
plt.show()

Issue from the quality of the features and not the model