In [1]:
import os
import torch
import wandb
import numpy as np
import pandas as pd
import warnings
import random
import re

from accelerate import infer_auto_device_map
from peft import prepare_model_for_kbit_training

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
    TrainerCallback,
    EarlyStoppingCallback,
    TrainerState,
    TrainerControl,
)
from datasets import load_dataset, Dataset
from trl import SFTTrainer, setup_chat_format
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import bitsandbytes as bnb
from sklearn.metrics import (accuracy_score, 
                             classification_report, 
                             confusion_matrix)

In [2]:
# huggingface-cli login --token key   
# wandb login --relogin key

In [3]:
warnings.filterwarnings("ignore")
#https://www.kaggle.com/datasets/rikdifos/credit-card-approval-prediction
# Load the application_record dataset
data = pd.read_csv("/opt/notebooks/Chatbot-Credit-Card/backend/dataset/credit-card-approval/application_record.csv")
# Load the credit_record dataset
record = pd.read_csv("/opt/notebooks/Chatbot-Credit-Card/backend/dataset/credit-card-approval/credit_record.csv")
# Find the first account open month for each user
begin_month = record.loc[record.groupby("ID")["MONTHS_BALANCE"].idxmin()]
begin_month = begin_month.rename(columns={"MONTHS_BALANCE": "begin_month"})

# Merge the datasets
df = pd.merge(data, begin_month, how="left", on="ID")
print("Datasets loaded and merged successfully.")
# Define approval logic based on multiple criteria
def determine_approval(row):
    # Define custom approval logic
    if row["STATUS"] in ["0", "1", "C", "X"]:  # Good credit status
            return 1  # Approved
    return 0  # Default to denial if STATUS is bad or missing

# Apply logic to determine approval (filling missing STATUS values first)
record["STATUS"] = record["STATUS"].fillna("X")  # Handle missing values
record["Approved"] = record.apply(determine_approval, axis=1)
# Aggregate approval status for each ID (disapproval if any ID has disqualifying criteria)
approval_status = record.groupby("ID")["Approved"].min().reset_index()
# Merge approval status back into the main dataset, avoiding "_x" and "_y" columns
df = pd.merge(data, approval_status, how="left", on="ID")
df["Approved"] = df["Approved"].fillna(0).astype(int)  # Fill missing approvals as denial
print("Approval status merged successfully.")
# Preprocess the 'DAYS_BIRTH' column to convert days to years
df['DAYS_BIRTH'] = (-df['DAYS_BIRTH'] // 365).fillna(0).astype(int)
df.drop(columns=['ID'], inplace=True)
# Preprocess the 'DAYS_EMPLOYED' column to get absolute values and handle unemployment
df['DAYS_EMPLOYED'] = df['DAYS_EMPLOYED'].apply(lambda x: abs(x) if x < 0 else 0)
# Handle missing or infinite values in numerical columns
numerical_cols = df.select_dtypes(include=['int64', 'float64']).columns
df[numerical_cols] = df[numerical_cols].replace([np.inf, -np.inf], np.nan)  # Replace infinities with NaN
df[numerical_cols] = df[numerical_cols].fillna(df[numerical_cols].median())  # Fill NaN with median values
print("Preprocessing completed successfully.")

# Define the feature mapping dictionary
feature_mapping = {
    'CODE_GENDER': 'Gender',
    'FLAG_OWN_CAR': 'Car Ownership',
    'FLAG_OWN_REALTY': 'Property Ownership',
    'CNT_CHILDREN': 'Number of Children',
    'AMT_INCOME_TOTAL': 'Annual Income',
    'NAME_INCOME_TYPE': 'Income Category',
    'NAME_EDUCATION_TYPE': 'Education Level',
    'NAME_FAMILY_STATUS': 'Marital Status',
    'NAME_HOUSING_TYPE': 'Housing Type',
    'DAYS_BIRTH': 'Age (Days)',
    'DAYS_EMPLOYED': 'Employment Duration (Days)',
    'FLAG_MOBIL': 'Mobile Phone',
    'FLAG_WORK_PHONE': 'Work Phone',
    'FLAG_PHONE': 'Phone',
    'FLAG_EMAIL': 'Email',
    'OCCUPATION_TYPE': 'Occupation',
    'CNT_FAM_MEMBERS': 'Family Size',
    'STATUS': 'Credit Status'
}

# Rename the columns in the DataFrame using the mapping
df.rename(columns=feature_mapping, inplace=True)

# Display the first few rows to confirm the changes
df.head()
df.rename(columns={'TARGET': 'Approved'}, inplace=True)
# Display the first few rows to confirm the change
df.head()

Datasets loaded and merged successfully.


Approval status merged successfully.
Preprocessing completed successfully.


Unnamed: 0,Gender,Car Ownership,Property Ownership,Number of Children,Annual Income,Income Category,Education Level,Marital Status,Housing Type,Age (Days),Employment Duration (Days),Mobile Phone,Work Phone,Phone,Email,Occupation,Family Size,Approved
0,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32,4542.0,1.0,1.0,0.0,0.0,,2.0,1
1,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32,4542.0,1.0,1.0,0.0,0.0,,2.0,1
2,M,Y,Y,0,112500.0,Working,Secondary / secondary special,Married,House / apartment,58,1134.0,1.0,0.0,0.0,0.0,Security staff,2.0,1
3,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1
4,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1


In [4]:
explanation_df = pd.read_csv('/opt/notebooks/Chatbot-Credit-Card/backend/dataset/explanations_df.csv')
# explanation_df.drop(['Approved','Explanation'], axis=1, inplace=True)
explanation_df.head()

Unnamed: 0,Prediction,Actual,Explanation
0,Approved,Approved,This application was approved due to Employmen...
1,Approved,Approved,This application was approved due to Employmen...
2,Approved,Approved,"This application was approved due to -0.77, Ma..."
3,Approved,Approved,"This application was approved due to Email, -0..."
4,Denied,Approved,"This application was denied due to -1.62, Emai..."


In [5]:
print(explanation_df['Explanation'].iloc[0])

This application was approved due to Employment Duration (Days), Housing Type, -0.77, Marital Status, Work Phone, -0.15, Number of Children, -0.81, Family Size, Annual Income.


In [6]:
# Function to remove numbers and clean the Explanation column
def clean_explanation(text):
    return re.sub(r'[-+]?\d*\.\d+|\d+', '', text).replace(', ,', ',').replace(' ,', ',').strip(", ")

# Apply the cleaning function to the Explanation column
explanation_df["Explanation"] = explanation_df["Explanation"].apply(clean_explanation)

In [7]:
# Display the DataFrame with the cleaned Explanation column
print(explanation_df['Explanation'].iloc[0])
explanation_df.head()

This application was approved due to Employment Duration (Days), Housing Type, Marital Status, Work Phone, Number of Children, Family Size, Annual Income.


Unnamed: 0,Prediction,Actual,Explanation
0,Approved,Approved,This application was approved due to Employmen...
1,Approved,Approved,This application was approved due to Employmen...
2,Approved,Approved,"This application was approved due to, Marital ..."
3,Approved,Approved,"This application was approved due to Email, Nu..."
4,Denied,Approved,"This application was denied due to, Email, Hou..."


In [8]:
# df.rename(columns={'Age (Days)': 'Age (Years)'}, inplace=True)
df.rename(columns={'Employment Duration' : 'Employment Duration (Days)'}, inplace=True)
df['Age (Days)'] = df['Age (Days)'].astype(str) + " years old"

In [9]:
# Combine explanation_df and df based on index
df['Reason'] = explanation_df['Explanation'].reset_index(drop=True)

# Display the first value in the 'Reason' column
print(df['Reason'].iloc[0])
df.head()

This application was approved due to Employment Duration (Days), Housing Type, Marital Status, Work Phone, Number of Children, Family Size, Annual Income.


Unnamed: 0,Gender,Car Ownership,Property Ownership,Number of Children,Annual Income,Income Category,Education Level,Marital Status,Housing Type,Age (Days),Employment Duration (Days),Mobile Phone,Work Phone,Phone,Email,Occupation,Family Size,Approved,Reason
0,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,4542.0,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
1,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,4542.0,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
2,M,Y,Y,0,112500.0,Working,Secondary / secondary special,Married,House / apartment,58 years old,1134.0,1.0,0.0,0.0,0.0,Security staff,2.0,1,"This application was approved due to, Marital ..."
3,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was approved due to Email, Nu..."
4,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was denied due to, Email, Hou..."


In [10]:
# Function to dynamically replace labels with row values
def insert_numbers_dynamically(row):
    reason = row['Reason']
    
    # Extract all parts of the reason where dynamic replacement is needed
    labels = re.findall(r'([A-Za-z\s()]+)', reason)
    
    for label in labels:
        # Match the label to the corresponding column name
        column_name = None
        for col in df.columns:
            # Normalize column names and labels for matching
            normalized_col = re.sub(r'[\s()]+', '', col).lower()
            normalized_label = re.sub(r'[\s()]+', '', label).lower()
            
            if normalized_col == normalized_label:
                column_name = col
                break
        
        if column_name and column_name in row:  # Ensure column exists in the DataFrame
            # Replace the label with the corresponding value from the row
            value = row[column_name]
            # Ensure proper replacement in the text
            reason = reason.replace(label, f"{label.strip()} {value}")
    
    return reason

In [11]:
# Apply the function to each row in the DataFrame
df['Reason'] = df.apply(insert_numbers_dynamically, axis=1)

# Display the updated DataFrame
df.head()

Unnamed: 0,Gender,Car Ownership,Property Ownership,Number of Children,Annual Income,Income Category,Education Level,Marital Status,Housing Type,Age (Days),Employment Duration (Days),Mobile Phone,Work Phone,Phone,Email,Occupation,Family Size,Approved,Reason
0,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,4542.0,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
1,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,4542.0,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
2,M,Y,Y,0,112500.0,Working,Secondary / secondary special,Married,House / apartment,58 years old,1134.0,1.0,0.0,0.0,0.0,Security staff,2.0,1,"This application was approved due to,Marital S..."
3,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was approved due to Email,Num..."
4,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,3051.0,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was denied due to,Email 1.0,H..."


In [12]:
# Convert "Employment Duration (Days)" to years in-place
df["Employment Duration (Days)"] = df["Employment Duration (Days)"].apply(lambda x: round(x / 365.0, 2) if not pd.isna(x) else np.nan)
df.rename(columns={"Age (Days)": "Age (Years)", "Employment Duration (Days)": "Employment Duration (Years)"}, inplace=True)
df.head()

Unnamed: 0,Gender,Car Ownership,Property Ownership,Number of Children,Annual Income,Income Category,Education Level,Marital Status,Housing Type,Age (Years),Employment Duration (Years),Mobile Phone,Work Phone,Phone,Email,Occupation,Family Size,Approved,Reason
0,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,12.44,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
1,M,Y,Y,0,427500.0,Working,Higher education,Civil marriage,Rented apartment,32 years old,12.44,1.0,1.0,0.0,0.0,,2.0,1,This application was approved due to Employmen...
2,M,Y,Y,0,112500.0,Working,Secondary / secondary special,Married,House / apartment,58 years old,3.11,1.0,0.0,0.0,0.0,Security staff,2.0,1,"This application was approved due to,Marital S..."
3,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,8.36,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was approved due to Email,Num..."
4,F,N,Y,0,270000.0,Commercial associate,Secondary / secondary special,Single / not married,House / apartment,52 years old,8.36,1.0,0.0,1.0,1.0,Sales staff,1.0,1,"This application was denied due to,Email 1.0,H..."


In [13]:
# Step 2: Update the "Reason" column to match the new labels
def update_reason(reason_text):
    reason_text = reason_text.replace("Age (Days)", "Age (Years)")
    reason_text = reason_text.replace("Employment Duration (Days)", "Employment Duration (Years)")
    return reason_text

df["Reason"] = df["Reason"].apply(update_reason)

In [14]:
# Display the first value in the 'Reason' column
print(df['Reason'].iloc[0])

This application was approved due to Employment Duration (Years),Housing Type Rented apartment,Marital Status Civil marriage,Work Phone 1.0,Number of Children 0,Family Size 2.0,Annual Income 427500.0.


In [15]:
def generate_plain_input_and_label(row):
    # Generate the reasoning part of the label
    reasoning_parts = [
        f"Gender: {'Male' if row['Gender'] == 'M' else 'Female'}" if pd.notna(row.get('Gender')) else None,
        f"Age: {row['Age (Years)']} years old" if pd.notna(row.get('Age (Years)')) else None,
        f"Car Ownership: {'Yes' if row['Car Ownership'] == 'Y' else 'No'}" if pd.notna(row.get('Car Ownership')) else None,
        f"Property Ownership: {'Yes' if row['Property Ownership'] == 'Y' else 'No'}" if pd.notna(row.get('Property Ownership')) else None,
        f"Number of Children: {int(row['Number of Children'])}" if pd.notna(row.get('Number of Children')) else None,
        f"Annual Income: {row['Annual Income']}" if pd.notna(row.get('Annual Income')) else None,
        f"Income Category: {row['Income Category']}" if pd.notna(row.get('Income Category')) else None,
        f"Education Level: {row['Education Level']}" if pd.notna(row.get('Education Level')) else None,
        f"Marital Status: {row['Marital Status']}" if pd.notna(row.get('Marital Status')) else None,
        f"Housing Type: {row['Housing Type']}" if pd.notna(row.get('Housing Type')) else None,
        f"Employment Duration: {row['Employment Duration (Years)']} years" if pd.notna(row.get('Employment Duration (Years)')) else None,
        f"Mobile Phone: {'Yes' if row['Mobile Phone'] == 1 else 'No'}" if pd.notna(row.get('Mobile Phone')) else None,
        f"Work Phone: {'Yes' if row['Work Phone'] == 1 else 'No'}" if pd.notna(row.get('Work Phone')) else None,
        f"Email: {'Yes' if row['Email'] == 1 else 'No'}" if pd.notna(row.get('Email')) else None,
        f"Family Size: {row['Family Size']}" if pd.notna(row.get('Family Size')) else None,
    ]
    reasoning_text = ". ".join([part for part in reasoning_parts if part is not None])
    label = f"Approval Status: Approved. Reasoning: {reasoning_text}."

    # Generate input in the same way
    input_text = reasoning_text  # Reuse the reasoning parts for consistency

    return input_text, label

In [16]:
# Apply the function to each row
df["text"], df["label"] = zip(*df.apply(generate_plain_input_and_label, axis=1))

# Keep only the required columns
final_df = df[["text", "label"]]

In [17]:
# Display the first value in the 'Reason' column
print("Label\n")
print(final_df['label'].iloc[0])
print("Input\n")
print(final_df['text'].iloc[0])

Label

Approval Status: Approved. Reasoning: Gender: Male. Age: 32 years old years old. Car Ownership: Yes. Property Ownership: Yes. Number of Children: 0. Annual Income: 427500.0. Income Category: Working. Education Level: Higher education. Marital Status: Civil marriage. Housing Type: Rented apartment. Employment Duration: 12.44 years. Mobile Phone: Yes. Work Phone: Yes. Email: No. Family Size: 2.0.
Input

Gender: Male. Age: 32 years old years old. Car Ownership: Yes. Property Ownership: Yes. Number of Children: 0. Annual Income: 427500.0. Income Category: Working. Education Level: Higher education. Marital Status: Civil marriage. Housing Type: Rented apartment. Employment Duration: 12.44 years. Mobile Phone: Yes. Work Phone: Yes. Email: No. Family Size: 2.0


In [18]:
# Clean the 'text' and 'label' columns
final_df['text'] = final_df['text'].str.replace('years old years old', 'years old', regex=False)

# Sample the dataset to 3000 rows for training efficiency
df = df.sample(frac=1, random_state=42).reset_index(drop=True).head(3000)

In [19]:
# Split ratios
train_size = 0.8
eval_size = 0.1

# Calculate split indices
train_end = int(train_size * len(df))
eval_end = train_end + int(eval_size * len(df))

# Split data
X_train = df[:train_end]
X_eval = df[train_end:eval_end]
X_test = df[eval_end:]

In [20]:
# Define the prompt generation functions
def generate_prompt(data_point):
    return f"""
Predict the credit card application status and provide reasoning.
text: {data_point["text"]}
label: {data_point["label"]}
""".strip()

def generate_test_prompt(data_point):
    return f"""
Predict the credit card application status and provide reasoning.
text: {data_point["text"]}
label: 
""".strip()

# Apply to training and evaluation datasets
X_train['text'] = X_train.apply(generate_prompt, axis=1)
X_eval['text'] = X_eval.apply(generate_prompt, axis=1)

# Apply to test dataset
X_test_prompts = X_test.apply(generate_test_prompt, axis=1)
y_true = X_test['label']
X_test = pd.DataFrame(X_test_prompts, columns=["text"])

In [21]:
# Convert to datasets
train_data = Dataset.from_pandas(X_train[["text"]])
eval_data = Dataset.from_pandas(X_eval[["text"]])

In [22]:
# Display lengths of each entry in train_data
train_lengths = train_data.map(lambda x: {"length": len(x["text"])})
print("Train data lengths:")
print(train_lengths)

# Display lengths of each entry in eval_data
eval_lengths = eval_data.map(lambda x: {"length": len(x["text"])})
print("Eval data lengths:")
print(eval_lengths)

Map:   0%|          | 0/2400 [00:00<?, ? examples/s]

Train data lengths:
Dataset({
    features: ['text', 'length'],
    num_rows: 2400
})


Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Eval data lengths:
Dataset({
    features: ['text', 'length'],
    num_rows: 300
})


In [23]:
train_data['text'][3]

'Predict the credit card application status and provide reasoning.\ntext: Gender: Female. Age: 32 years old years old. Car Ownership: Yes. Property Ownership: No. Number of Children: 0. Annual Income: 126000.0. Income Category: State servant. Education Level: Incomplete higher. Marital Status: Single / not married. Housing Type: House / apartment. Employment Duration: 14.26 years. Mobile Phone: Yes. Work Phone: Yes. Email: No. Family Size: 1.0\nlabel: Approval Status: Approved. Reasoning: Gender: Female. Age: 32 years old years old. Car Ownership: Yes. Property Ownership: No. Number of Children: 0. Annual Income: 126000.0. Income Category: State servant. Education Level: Incomplete higher. Marital Status: Single / not married. Housing Type: House / apartment. Employment Duration: 14.26 years. Mobile Phone: Yes. Work Phone: Yes. Email: No. Family Size: 1.0.'

In [24]:
base_model = "/opt/notebooks/Chatbot-Credit-Card/backend/models/llama-3.1-8b-Instruct/"

tokenizer = AutoTokenizer.from_pretrained(base_model)

In [25]:
base_model_name = base_model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype="float16",
)

model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    device_map="auto",
    torch_dtype="float16",
    quantization_config=bnb_config, 
)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.float16,
    device_map="auto",
)
model.config.use_cache = False
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

tokenizer.pad_token_id = tokenizer.eos_token_id

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [26]:
from tqdm import tqdm
from transformers import pipeline

def predict(test, model, tokenizer):
    y_pred = []
    statuses = ["Approved", "Denied"]  # The statuses for prediction
    
    for i in tqdm(range(len(test))):
        prompt = test.iloc[i]["text"]
        
        # Initialize the text-generation pipeline
        pipe = pipeline(
            task="text-generation", 
            model=model, 
            tokenizer=tokenizer, 
            max_new_tokens=50,  # Increase to ensure reasoning is generated
            temperature=0.1
        )
        
        # Generate prediction
        result = pipe(prompt)
        generated_text = result[0]['generated_text']
        
        # Parse the output to extract status and reason
        try:
            # Extract 'label:' portion and split to get the prediction
            answer = generated_text.split("label:")[-1].strip()
            
            # Extract status (Approved/Denied)
            status = None
            for s in statuses:
                if s.lower() in answer.lower():
                    status = s
                    break
            
            # Append the status and reasoning
            if status:
                y_pred.append({
                    "status": status,
                    "reason": answer  # Include full reasoning text
                })
            else:
                y_pred.append({
                    "status": "Unknown",
                    "reason": answer
                })
        except Exception as e:
            # Handle unexpected output format
            y_pred.append({
                "status": "Error",
                "reason": f"Parsing failed: {e}"
            })
    
    return y_pred

# Example Usage
# Assuming `X_test` is your test DataFrame and `model`, `tokenizer` are already loaded
y_pred = predict(X_test, model, tokenizer)

# Print a few predictions to verify
for prediction in y_pred[:5]:
    print(prediction)


  0%|                                                                                                                                                          | 0/300 [00:00<?, ?it/s]

  0%|▍                                                                                                                                                 | 1/300 [00:05<28:56,  5.81s/it]

  1%|▉                                                                                                                                                 | 2/300 [00:11<27:16,  5.49s/it]

  1%|█▍                                                                                                                                                | 3/300 [00:16<27:18,  5.52s/it]

  1%|█▉                                                                                                                                                | 4/300 [00:22<27:47,  5.63s/it]

  2%|██▍                                                                                                                                               | 5/300 [00:28<27:43,  5.64s/it]

  2%|██▉                                                                                                                                               | 6/300 [00:33<27:38,  5.64s/it]

  2%|███▍                                                                                                                                              | 7/300 [00:39<27:20,  5.60s/it]

  3%|███▉                                                                                                                                              | 8/300 [00:45<27:32,  5.66s/it]

  3%|████▍                                                                                                                                             | 9/300 [00:50<27:22,  5.64s/it]

  3%|████▊                                                                                                                                            | 10/300 [00:56<27:07,  5.61s/it]

  4%|█████▎                                                                                                                                           | 11/300 [01:01<26:49,  5.57s/it]

  4%|█████▊                                                                                                                                           | 12/300 [01:07<26:40,  5.56s/it]

  4%|██████▎                                                                                                                                          | 13/300 [01:12<26:22,  5.51s/it]

  5%|██████▊                                                                                                                                          | 14/300 [01:18<26:39,  5.59s/it]

  5%|███████▎                                                                                                                                         | 15/300 [01:24<26:49,  5.65s/it]

  5%|███████▋                                                                                                                                         | 16/300 [01:29<26:39,  5.63s/it]

  6%|████████▏                                                                                                                                        | 17/300 [01:35<26:33,  5.63s/it]

  6%|████████▋                                                                                                                                        | 18/300 [01:40<26:25,  5.62s/it]

  6%|█████████▏                                                                                                                                       | 19/300 [01:46<26:32,  5.67s/it]

  7%|█████████▋                                                                                                                                       | 20/300 [01:52<26:21,  5.65s/it]

  7%|██████████▏                                                                                                                                      | 21/300 [01:57<26:12,  5.63s/it]

  7%|██████████▋                                                                                                                                      | 22/300 [02:03<25:59,  5.61s/it]

  8%|███████████                                                                                                                                      | 23/300 [02:08<25:34,  5.54s/it]

  8%|███████████▌                                                                                                                                     | 24/300 [02:14<25:41,  5.58s/it]

  8%|████████████                                                                                                                                     | 25/300 [02:21<27:21,  5.97s/it]

  9%|████████████▌                                                                                                                                    | 26/300 [02:28<28:13,  6.18s/it]

  9%|█████████████                                                                                                                                    | 27/300 [02:33<27:32,  6.05s/it]

  9%|█████████████▌                                                                                                                                   | 28/300 [02:39<26:30,  5.85s/it]

 10%|██████████████                                                                                                                                   | 29/300 [02:44<25:47,  5.71s/it]

 10%|██████████████▌                                                                                                                                  | 30/300 [02:50<25:19,  5.63s/it]

 10%|██████████████▉                                                                                                                                  | 31/300 [02:55<25:33,  5.70s/it]

 11%|███████████████▍                                                                                                                                 | 32/300 [03:01<25:18,  5.67s/it]

 11%|███████████████▉                                                                                                                                 | 33/300 [03:07<25:21,  5.70s/it]

 11%|████████████████▍                                                                                                                                | 34/300 [03:12<25:05,  5.66s/it]

 12%|████████████████▉                                                                                                                                | 35/300 [03:18<24:53,  5.64s/it]

 12%|█████████████████▍                                                                                                                               | 36/300 [03:24<25:00,  5.68s/it]

 12%|█████████████████▉                                                                                                                               | 37/300 [03:29<24:53,  5.68s/it]

 13%|██████████████████▎                                                                                                                              | 38/300 [03:35<24:55,  5.71s/it]

 13%|██████████████████▊                                                                                                                              | 39/300 [03:41<24:35,  5.65s/it]

 13%|███████████████████▎                                                                                                                             | 40/300 [03:46<24:26,  5.64s/it]

 14%|███████████████████▊                                                                                                                             | 41/300 [03:52<24:23,  5.65s/it]

 14%|████████████████████▎                                                                                                                            | 42/300 [03:58<24:18,  5.65s/it]

 14%|████████████████████▊                                                                                                                            | 43/300 [04:04<24:56,  5.82s/it]

 15%|█████████████████████▎                                                                                                                           | 44/300 [04:10<25:41,  6.02s/it]

 15%|█████████████████████▊                                                                                                                           | 45/300 [04:17<26:12,  6.17s/it]

 15%|██████████████████████▏                                                                                                                          | 46/300 [04:23<26:07,  6.17s/it]

 16%|██████████████████████▋                                                                                                                          | 47/300 [04:29<25:22,  6.02s/it]

 16%|███████████████████████▏                                                                                                                         | 48/300 [04:34<24:49,  5.91s/it]

 16%|███████████████████████▋                                                                                                                         | 49/300 [04:40<24:22,  5.83s/it]

 17%|████████████████████████▏                                                                                                                        | 50/300 [04:46<24:04,  5.78s/it]

 17%|████████████████████████▋                                                                                                                        | 51/300 [04:51<23:44,  5.72s/it]

 17%|█████████████████████████▏                                                                                                                       | 52/300 [04:57<23:27,  5.67s/it]

 18%|█████████████████████████▌                                                                                                                       | 53/300 [05:02<23:21,  5.68s/it]

 18%|██████████████████████████                                                                                                                       | 54/300 [05:08<23:21,  5.70s/it]

 18%|██████████████████████████▌                                                                                                                      | 55/300 [05:14<23:00,  5.64s/it]

 19%|███████████████████████████                                                                                                                      | 56/300 [05:19<22:59,  5.65s/it]

 19%|███████████████████████████▌                                                                                                                     | 57/300 [05:25<22:57,  5.67s/it]

 19%|████████████████████████████                                                                                                                     | 58/300 [05:31<22:37,  5.61s/it]

 20%|████████████████████████████▌                                                                                                                    | 59/300 [05:36<22:34,  5.62s/it]

 20%|█████████████████████████████                                                                                                                    | 60/300 [05:42<22:28,  5.62s/it]

 20%|█████████████████████████████▍                                                                                                                   | 61/300 [05:48<22:58,  5.77s/it]

 21%|█████████████████████████████▉                                                                                                                   | 62/300 [05:53<22:27,  5.66s/it]

 21%|██████████████████████████████▍                                                                                                                  | 63/300 [05:59<22:19,  5.65s/it]

 21%|██████████████████████████████▉                                                                                                                  | 64/300 [06:05<22:24,  5.70s/it]

 22%|███████████████████████████████▍                                                                                                                 | 65/300 [06:10<21:52,  5.58s/it]

 22%|███████████████████████████████▉                                                                                                                 | 66/300 [06:16<21:33,  5.53s/it]

 22%|████████████████████████████████▍                                                                                                                | 67/300 [06:21<21:19,  5.49s/it]

 23%|████████████████████████████████▊                                                                                                                | 68/300 [06:26<21:02,  5.44s/it]

 23%|█████████████████████████████████▎                                                                                                               | 69/300 [06:31<20:42,  5.38s/it]

 23%|█████████████████████████████████▊                                                                                                               | 70/300 [06:37<21:01,  5.48s/it]

 24%|██████████████████████████████████▎                                                                                                              | 71/300 [06:43<20:58,  5.49s/it]

 24%|██████████████████████████████████▊                                                                                                              | 72/300 [06:48<20:52,  5.49s/it]

 24%|███████████████████████████████████▎                                                                                                             | 73/300 [06:54<20:51,  5.51s/it]

 25%|███████████████████████████████████▊                                                                                                             | 74/300 [06:59<20:46,  5.52s/it]

 25%|████████████████████████████████████▎                                                                                                            | 75/300 [07:05<20:50,  5.56s/it]

 25%|████████████████████████████████████▋                                                                                                            | 76/300 [07:10<20:31,  5.50s/it]

 26%|█████████████████████████████████████▏                                                                                                           | 77/300 [07:16<20:25,  5.49s/it]

 26%|█████████████████████████████████████▋                                                                                                           | 78/300 [07:22<20:32,  5.55s/it]

 26%|██████████████████████████████████████▏                                                                                                          | 79/300 [07:27<20:21,  5.53s/it]

 27%|██████████████████████████████████████▋                                                                                                          | 80/300 [07:32<19:59,  5.45s/it]

 27%|███████████████████████████████████████▏                                                                                                         | 81/300 [07:38<19:57,  5.47s/it]

 27%|███████████████████████████████████████▋                                                                                                         | 82/300 [07:43<19:28,  5.36s/it]

 28%|████████████████████████████████████████                                                                                                         | 83/300 [07:48<19:12,  5.31s/it]

 28%|████████████████████████████████████████▌                                                                                                        | 84/300 [07:53<19:05,  5.30s/it]

 28%|█████████████████████████████████████████                                                                                                        | 85/300 [07:59<18:56,  5.28s/it]

 29%|█████████████████████████████████████████▌                                                                                                       | 86/300 [08:04<18:54,  5.30s/it]

 29%|██████████████████████████████████████████                                                                                                       | 87/300 [08:09<19:03,  5.37s/it]

 29%|██████████████████████████████████████████▌                                                                                                      | 88/300 [08:15<18:49,  5.33s/it]

 30%|███████████████████████████████████████████                                                                                                      | 89/300 [08:20<18:48,  5.35s/it]

 30%|███████████████████████████████████████████▌                                                                                                     | 90/300 [08:26<18:58,  5.42s/it]

 30%|███████████████████████████████████████████▉                                                                                                     | 91/300 [08:31<18:47,  5.40s/it]

 31%|████████████████████████████████████████████▍                                                                                                    | 92/300 [08:36<18:44,  5.40s/it]

 31%|████████████████████████████████████████████▉                                                                                                    | 93/300 [08:42<18:42,  5.42s/it]

 31%|█████████████████████████████████████████████▍                                                                                                   | 94/300 [08:48<18:49,  5.48s/it]

 32%|█████████████████████████████████████████████▉                                                                                                   | 95/300 [08:53<19:03,  5.58s/it]

 32%|██████████████████████████████████████████████▍                                                                                                  | 96/300 [09:00<19:46,  5.82s/it]

 32%|██████████████████████████████████████████████▉                                                                                                  | 97/300 [09:06<20:02,  5.93s/it]

 33%|███████████████████████████████████████████████▎                                                                                                 | 98/300 [09:12<20:20,  6.04s/it]

 33%|███████████████████████████████████████████████▊                                                                                                 | 99/300 [09:18<20:29,  6.12s/it]

 33%|████████████████████████████████████████████████                                                                                                | 100/300 [09:25<20:39,  6.20s/it]

 34%|████████████████████████████████████████████████▍                                                                                               | 101/300 [09:31<20:32,  6.19s/it]

 34%|████████████████████████████████████████████████▉                                                                                               | 102/300 [09:37<20:39,  6.26s/it]

 34%|█████████████████████████████████████████████████▍                                                                                              | 103/300 [09:44<20:54,  6.37s/it]

 35%|█████████████████████████████████████████████████▉                                                                                              | 104/300 [09:51<21:03,  6.45s/it]

 35%|██████████████████████████████████████████████████▍                                                                                             | 105/300 [09:57<21:12,  6.53s/it]

 35%|██████████████████████████████████████████████████▉                                                                                             | 106/300 [10:04<21:19,  6.60s/it]

 36%|███████████████████████████████████████████████████▎                                                                                            | 107/300 [10:11<21:21,  6.64s/it]

 36%|███████████████████████████████████████████████████▊                                                                                            | 108/300 [10:17<21:06,  6.59s/it]

 36%|████████████████████████████████████████████████████▎                                                                                           | 109/300 [10:25<21:34,  6.78s/it]

 37%|████████████████████████████████████████████████████▊                                                                                           | 110/300 [10:31<21:25,  6.77s/it]

 37%|█████████████████████████████████████████████████████▎                                                                                          | 111/300 [10:38<21:25,  6.80s/it]

 37%|█████████████████████████████████████████████████████▊                                                                                          | 112/300 [10:45<21:12,  6.77s/it]

 38%|██████████████████████████████████████████████████████▏                                                                                         | 113/300 [10:52<20:57,  6.72s/it]

 38%|██████████████████████████████████████████████████████▋                                                                                         | 114/300 [10:58<20:39,  6.67s/it]

 38%|███████████████████████████████████████████████████████▏                                                                                        | 115/300 [11:04<20:10,  6.54s/it]

 39%|███████████████████████████████████████████████████████▋                                                                                        | 116/300 [11:11<20:00,  6.52s/it]

 39%|████████████████████████████████████████████████████████▏                                                                                       | 117/300 [11:18<20:09,  6.61s/it]

 39%|████████████████████████████████████████████████████████▋                                                                                       | 118/300 [11:24<20:12,  6.66s/it]

 40%|█████████████████████████████████████████████████████████                                                                                       | 119/300 [11:31<20:14,  6.71s/it]

 40%|█████████████████████████████████████████████████████████▌                                                                                      | 120/300 [11:38<20:25,  6.81s/it]

 40%|██████████████████████████████████████████████████████████                                                                                      | 121/300 [11:45<20:21,  6.82s/it]

 41%|██████████████████████████████████████████████████████████▌                                                                                     | 122/300 [11:52<20:15,  6.83s/it]

 41%|███████████████████████████████████████████████████████████                                                                                     | 123/300 [11:59<20:09,  6.84s/it]

 41%|███████████████████████████████████████████████████████████▌                                                                                    | 124/300 [12:06<20:17,  6.92s/it]

 42%|████████████████████████████████████████████████████████████                                                                                    | 125/300 [12:13<20:23,  6.99s/it]

 42%|████████████████████████████████████████████████████████████▍                                                                                   | 126/300 [12:20<20:35,  7.10s/it]

 42%|████████████████████████████████████████████████████████████▉                                                                                   | 127/300 [12:28<20:27,  7.10s/it]

 43%|█████████████████████████████████████████████████████████████▍                                                                                  | 128/300 [12:34<20:11,  7.04s/it]

 43%|█████████████████████████████████████████████████████████████▉                                                                                  | 129/300 [12:41<19:59,  7.02s/it]

 43%|██████████████████████████████████████████████████████████████▍                                                                                 | 130/300 [12:48<19:45,  6.97s/it]

 44%|██████████████████████████████████████████████████████████████▉                                                                                 | 131/300 [12:55<19:22,  6.88s/it]

 44%|███████████████████████████████████████████████████████████████▎                                                                                | 132/300 [13:00<18:08,  6.48s/it]

 44%|███████████████████████████████████████████████████████████████▊                                                                                | 133/300 [13:06<17:08,  6.16s/it]

 45%|████████████████████████████████████████████████████████████████▎                                                                               | 134/300 [13:11<16:28,  5.95s/it]

 45%|████████████████████████████████████████████████████████████████▊                                                                               | 135/300 [13:17<15:50,  5.76s/it]

 45%|█████████████████████████████████████████████████████████████████▎                                                                              | 136/300 [13:22<15:30,  5.67s/it]

 46%|█████████████████████████████████████████████████████████████████▊                                                                              | 137/300 [13:28<15:13,  5.60s/it]

 46%|██████████████████████████████████████████████████████████████████▏                                                                             | 138/300 [13:33<15:05,  5.59s/it]

 46%|██████████████████████████████████████████████████████████████████▋                                                                             | 139/300 [13:39<15:00,  5.59s/it]

 47%|███████████████████████████████████████████████████████████████████▏                                                                            | 140/300 [13:44<14:58,  5.62s/it]

 47%|███████████████████████████████████████████████████████████████████▋                                                                            | 141/300 [13:51<15:30,  5.85s/it]

 47%|████████████████████████████████████████████████████████████████████▏                                                                           | 142/300 [13:57<15:28,  5.88s/it]

 48%|████████████████████████████████████████████████████████████████████▋                                                                           | 143/300 [14:03<15:18,  5.85s/it]

 48%|█████████████████████████████████████████████████████████████████████                                                                           | 144/300 [14:08<14:59,  5.77s/it]

 48%|█████████████████████████████████████████████████████████████████████▌                                                                          | 145/300 [14:14<14:54,  5.77s/it]

 49%|██████████████████████████████████████████████████████████████████████                                                                          | 146/300 [14:20<14:48,  5.77s/it]

 49%|██████████████████████████████████████████████████████████████████████▌                                                                         | 147/300 [14:26<15:11,  5.96s/it]

 49%|███████████████████████████████████████████████████████████████████████                                                                         | 148/300 [14:33<15:30,  6.12s/it]

 50%|███████████████████████████████████████████████████████████████████████▌                                                                        | 149/300 [14:39<15:29,  6.16s/it]

 50%|████████████████████████████████████████████████████████████████████████                                                                        | 150/300 [14:44<14:51,  5.95s/it]

 50%|████████████████████████████████████████████████████████████████████████▍                                                                       | 151/300 [14:50<14:28,  5.83s/it]

 51%|████████████████████████████████████████████████████████████████████████▉                                                                       | 152/300 [14:55<14:13,  5.76s/it]

 51%|█████████████████████████████████████████████████████████████████████████▍                                                                      | 153/300 [15:01<14:18,  5.84s/it]

 51%|█████████████████████████████████████████████████████████████████████████▉                                                                      | 154/300 [15:07<14:17,  5.88s/it]

 52%|██████████████████████████████████████████████████████████████████████████▍                                                                     | 155/300 [15:13<14:13,  5.88s/it]

 52%|██████████████████████████████████████████████████████████████████████████▉                                                                     | 156/300 [15:19<13:48,  5.75s/it]

 52%|███████████████████████████████████████████████████████████████████████████▎                                                                    | 157/300 [15:25<13:53,  5.83s/it]

 53%|███████████████████████████████████████████████████████████████████████████▊                                                                    | 158/300 [15:31<13:47,  5.83s/it]

 53%|████████████████████████████████████████████████████████████████████████████▎                                                                   | 159/300 [15:36<13:26,  5.72s/it]

 53%|████████████████████████████████████████████████████████████████████████████▊                                                                   | 160/300 [15:42<13:30,  5.79s/it]

 54%|█████████████████████████████████████████████████████████████████████████████▎                                                                  | 161/300 [15:48<13:13,  5.71s/it]

 54%|█████████████████████████████████████████████████████████████████████████████▊                                                                  | 162/300 [15:53<13:11,  5.74s/it]

 54%|██████████████████████████████████████████████████████████████████████████████▏                                                                 | 163/300 [15:59<12:57,  5.68s/it]

 55%|██████████████████████████████████████████████████████████████████████████████▋                                                                 | 164/300 [16:05<13:05,  5.78s/it]

 55%|███████████████████████████████████████████████████████████████████████████████▏                                                                | 165/300 [16:11<13:18,  5.92s/it]

 55%|███████████████████████████████████████████████████████████████████████████████▋                                                                | 166/300 [16:17<13:14,  5.93s/it]

 56%|████████████████████████████████████████████████████████████████████████████████▏                                                               | 167/300 [16:23<13:01,  5.87s/it]

 56%|████████████████████████████████████████████████████████████████████████████████▋                                                               | 168/300 [16:28<12:37,  5.74s/it]

 56%|█████████████████████████████████████████████████████████████████████████████████                                                               | 169/300 [16:34<12:40,  5.81s/it]

 57%|█████████████████████████████████████████████████████████████████████████████████▌                                                              | 170/300 [16:40<12:26,  5.74s/it]

 57%|██████████████████████████████████████████████████████████████████████████████████                                                              | 171/300 [16:45<12:15,  5.70s/it]

 57%|██████████████████████████████████████████████████████████████████████████████████▌                                                             | 172/300 [16:51<11:58,  5.61s/it]

 58%|███████████████████████████████████████████████████████████████████████████████████                                                             | 173/300 [16:56<11:39,  5.51s/it]

 58%|███████████████████████████████████████████████████████████████████████████████████▌                                                            | 174/300 [17:02<11:33,  5.50s/it]

 58%|████████████████████████████████████████████████████████████████████████████████████                                                            | 175/300 [17:07<11:21,  5.45s/it]

 59%|████████████████████████████████████████████████████████████████████████████████████▍                                                           | 176/300 [17:12<11:21,  5.50s/it]

 59%|████████████████████████████████████████████████████████████████████████████████████▉                                                           | 177/300 [17:18<11:07,  5.42s/it]

 59%|█████████████████████████████████████████████████████████████████████████████████████▍                                                          | 178/300 [17:23<11:11,  5.51s/it]

 60%|█████████████████████████████████████████████████████████████████████████████████████▉                                                          | 179/300 [17:29<11:12,  5.56s/it]

 60%|██████████████████████████████████████████████████████████████████████████████████████▍                                                         | 180/300 [17:35<11:10,  5.58s/it]

 60%|██████████████████████████████████████████████████████████████████████████████████████▉                                                         | 181/300 [17:41<11:09,  5.63s/it]

 61%|███████████████████████████████████████████████████████████████████████████████████████▎                                                        | 182/300 [17:46<11:02,  5.62s/it]

 61%|███████████████████████████████████████████████████████████████████████████████████████▊                                                        | 183/300 [17:52<11:18,  5.80s/it]

 61%|████████████████████████████████████████████████████████████████████████████████████████▎                                                       | 184/300 [17:59<11:34,  5.98s/it]

 62%|████████████████████████████████████████████████████████████████████████████████████████▊                                                       | 185/300 [18:05<11:27,  5.98s/it]

 62%|█████████████████████████████████████████████████████████████████████████████████████████▎                                                      | 186/300 [18:10<11:15,  5.92s/it]

 62%|█████████████████████████████████████████████████████████████████████████████████████████▊                                                      | 187/300 [18:16<10:57,  5.82s/it]

 63%|██████████████████████████████████████████████████████████████████████████████████████████▏                                                     | 188/300 [18:22<11:06,  5.95s/it]

 63%|██████████████████████████████████████████████████████████████████████████████████████████▋                                                     | 189/300 [18:28<10:49,  5.85s/it]

 63%|███████████████████████████████████████████████████████████████████████████████████████████▏                                                    | 190/300 [18:33<10:31,  5.74s/it]

 64%|███████████████████████████████████████████████████████████████████████████████████████████▋                                                    | 191/300 [18:39<10:16,  5.65s/it]

 64%|████████████████████████████████████████████████████████████████████████████████████████████▏                                                   | 192/300 [18:44<10:06,  5.61s/it]

 64%|████████████████████████████████████████████████████████████████████████████████████████████▋                                                   | 193/300 [18:50<09:55,  5.56s/it]

 65%|█████████████████████████████████████████████████████████████████████████████████████████████                                                   | 194/300 [18:55<09:52,  5.59s/it]

 65%|█████████████████████████████████████████████████████████████████████████████████████████████▌                                                  | 195/300 [19:01<09:48,  5.60s/it]

 65%|██████████████████████████████████████████████████████████████████████████████████████████████                                                  | 196/300 [19:07<09:45,  5.63s/it]

 66%|██████████████████████████████████████████████████████████████████████████████████████████████▌                                                 | 197/300 [19:12<09:37,  5.61s/it]

 66%|███████████████████████████████████████████████████████████████████████████████████████████████                                                 | 198/300 [19:18<09:28,  5.57s/it]

 66%|███████████████████████████████████████████████████████████████████████████████████████████████▌                                                | 199/300 [19:23<09:20,  5.55s/it]

 67%|████████████████████████████████████████████████████████████████████████████████████████████████                                                | 200/300 [19:29<09:11,  5.52s/it]

 67%|████████████████████████████████████████████████████████████████████████████████████████████████▍                                               | 201/300 [19:34<09:06,  5.52s/it]

 67%|████████████████████████████████████████████████████████████████████████████████████████████████▉                                               | 202/300 [19:40<08:59,  5.50s/it]

 68%|█████████████████████████████████████████████████████████████████████████████████████████████████▍                                              | 203/300 [19:45<08:54,  5.51s/it]

 68%|█████████████████████████████████████████████████████████████████████████████████████████████████▉                                              | 204/300 [19:51<08:41,  5.43s/it]

 68%|██████████████████████████████████████████████████████████████████████████████████████████████████▍                                             | 205/300 [19:56<08:36,  5.44s/it]

 69%|██████████████████████████████████████████████████████████████████████████████████████████████████▉                                             | 206/300 [20:02<08:34,  5.47s/it]

 69%|███████████████████████████████████████████████████████████████████████████████████████████████████▎                                            | 207/300 [20:07<08:28,  5.46s/it]

 69%|███████████████████████████████████████████████████████████████████████████████████████████████████▊                                            | 208/300 [20:13<08:28,  5.53s/it]

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████▎                                           | 209/300 [20:18<08:26,  5.57s/it]

 70%|████████████████████████████████████████████████████████████████████████████████████████████████████▊                                           | 210/300 [20:24<08:21,  5.58s/it]

 70%|█████████████████████████████████████████████████████████████████████████████████████████████████████▎                                          | 211/300 [20:30<08:20,  5.62s/it]

 71%|█████████████████████████████████████████████████████████████████████████████████████████████████████▊                                          | 212/300 [20:35<08:17,  5.65s/it]

 71%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏                                         | 213/300 [20:41<08:09,  5.63s/it]

In [None]:
from collections import Counter

# Extract statuses from y_pred
statuses = [pred.get("status", "Unknown") for pred in y_pred]

# Count occurrences of each status
status_counts = Counter(statuses)

# Display the counts
print(f"Total counts by status:")
print(f"Approved: {status_counts['Approved']}")
print(f"Denied: {status_counts['Denied']}")
print(f"Unknown: {status_counts['Unknown']}")

In [None]:
def extract_label_from_text(data_point):
    try:
        text = data_point["text"]
        if "Approval Status:" in text:
            label_section = text.split("Approval Status:")[1].strip()
            return label_section.split(".")[0].strip()
        else:
            print(f"Missing 'Approval Status:' in text: {text}")
            return "Unknown"
    except Exception as e:
        print(f"Error parsing text: {text}, Error: {e}")
        return "Unknown"

# Add the 'status' column to train and evaluation datasets
train_data = train_data.map(lambda x: {"status": extract_label_from_text(x)})
eval_data = eval_data.map(lambda x: {"status": extract_label_from_text(x)})


In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np

def evaluate(y_true, y_pred):
    # Extract statuses from predictions
    y_pred_statuses = []
    for pred in y_pred:
        if "status" in pred and pred["status"] in ["Approved", "Denied", "Unknown"]:
            y_pred_statuses.append(pred["status"])
        else:
            y_pred_statuses.append("Unknown")  # Default for malformed predictions

    # Define label mapping (including "Unknown")
    labels = ["Approved", "Denied", "Unknown"]
    mapping = {label: idx for idx, label in enumerate(labels)}

    def map_func(x):
        return mapping.get(x, -1)  # Map to -1 if not found

    # Ensure lengths of y_true and y_pred match
    if len(y_true) != len(y_pred_statuses):
        raise ValueError("Length mismatch between y_true and y_pred.")

    # Map true and predicted labels
    y_true_mapped = np.vectorize(map_func)(y_true)
    y_pred_mapped = np.vectorize(map_func)(y_pred_statuses)

    # Check for unmapped values
    if -1 in y_true_mapped:
        raise ValueError("Unmapped labels found in y_true. Ensure all labels are 'Approved', 'Denied', or 'Unknown'.")
    if -1 in y_pred_mapped:
        raise ValueError("Unmapped labels found in y_pred. Ensure all predictions are 'Approved', 'Denied', or 'Unknown'.")

    # Calculate accuracy
    accuracy = accuracy_score(y_true=y_true_mapped, y_pred=y_pred_mapped)
    print(f'Accuracy: {accuracy:.3f}')
    
    # Generate classification report
    class_report = classification_report(
        y_true=y_true_mapped, 
        y_pred=y_pred_mapped, 
        target_names=labels, 
        labels=list(range(len(labels)))
    )
    print('\nClassification Report:')
    print(class_report)
    
    # Generate confusion matrix
    conf_matrix = confusion_matrix(
        y_true=y_true_mapped, 
        y_pred=y_pred_mapped, 
        labels=list(range(len(labels)))
    )
    print('\nConfusion Matrix:')
    print(conf_matrix)


# Assertions to ensure valid inputs
assert all(label in ["Approved", "Denied", "Unknown"] for label in y_true), "y_true contains unexpected labels!"
assert all("status" in pred for pred in y_pred), "Some predictions are missing the 'status' key!"
assert len(y_true) == len(y_pred), "Mismatch in lengths of y_true and y_pred!"

# Example usage with y_true and y_pred
y_true = X_test["status"].tolist()
evaluate(y_true, y_pred)

In [None]:
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
modules = find_all_linear_names(model)
modules

In [None]:
output_dir="/opt/notebooks/Chatbot-Credit-Card/backend/models/llama-3.1-fine-tuned-model"

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules,
)

training_arguments = TrainingArguments(
    output_dir=output_dir,                    # directory to save and repository id
    num_train_epochs=1,                       # number of training epochs
    per_device_train_batch_size=1,            # batch size per device during training
    gradient_accumulation_steps=8,            # number of steps before performing a backward/update pass
    gradient_checkpointing=True,              # use gradient checkpointing to save memory
    optim="paged_adamw_32bit",
    logging_steps=1,                         
    learning_rate=2e-4,                       # learning rate, based on QLoRA paper
    weight_decay=0.001,
    fp16=True,
    bf16=False,
    max_grad_norm=0.3,                        # max gradient norm based on QLoRA paper
    max_steps=-1,
    warmup_ratio=0.03,                        # warmup ratio based on QLoRA paper
    group_by_length=False,
    lr_scheduler_type="cosine",               # use cosine learning rate scheduler
    report_to="wandb",                  # report metrics to w&b
    eval_strategy="steps",              # save checkpoint every epoch
    eval_steps = 0.2
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,
    dataset_text_field="text",
    tokenizer=tokenizer,
    max_seq_length=512,
    packing=False,
    dataset_kwargs={
    "add_special_tokens": False,
    "append_concat_token": False,
    }
)

In [None]:
trainer.train()

In [None]:
wandb.finish()
model.config.use_cache = True

In [None]:
# Save trained model and tokenizer
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)

In [None]:
y_pred = predict(X_test, model, tokenizer)
evaluate(y_true, y_pred)