<a href="https://colab.research.google.com/github/ShawneilRodrigues/Heuristic_coder_mumbaiHacks/blob/master/Xgboost_%26_Neural_Network_with_Pytorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
sgpjesus_bank_account_fraud_dataset_neurips_2022_path = kagglehub.dataset_download('sgpjesus/bank-account-fraud-dataset-neurips-2022')

print('Data source import complete.')


# 1. Explotary Data Analysis

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

!pip install mlflow

In [None]:
df = pd.read_csv("/kaggle/input/bank-account-fraud-dataset-neurips-2022/Base.csv")
df.shape

## 1.1. Class For EDA Plotting

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

class EdaPlotter:
    def __init__(self) -> None:
        pass

    def plot_skewness(self, df):
        # Filter numerical features in the DataFrame
        numerical_features = df.select_dtypes(include=["number"])

        # Calculate skewness of each numerical feature
        skew_values = numerical_features.skew()

        # Create a plot of skewness values
        plt.figure(figsize=(10, 5))
        skew_values.plot(kind="bar")
        plt.title("Skewness of Numerical Features")
        plt.xlabel("Features")
        plt.ylabel("Skewness Value")
        plt.axhline(y=0, color="r", linestyle="-")
        plt.xticks(rotation=45)
        plt.grid(True)
        plt.show()

        return skew_values

    def plot_numerical_features(self, dataframe):
        df_numerical = dataframe.select_dtypes(include=["int64", "float64"])
        num_cols = len(df_numerical.columns)
        num_rows = (num_cols // 4) + (num_cols % 4 > 0)  # Determine the number of rows needed

        fig, axes = plt.subplots(num_rows, 4, figsize=(20, 4 * num_rows))  # Adjust the figsize to fit 4 columns

        for i, feature in enumerate(df_numerical.columns):
            row = i // 4
            col = i % 4

            ax = axes[row, col]
            ax.hist(dataframe[feature].dropna(), bins=30, edgecolor="black")
            ax.set_title(f"Distribution of {feature}")
            ax.set_xlabel(feature)
            ax.set_ylabel("Frequency")

        # Hide empty subplots if the number of features is not a multiple of 4
        for j in range(i + 1, num_rows * 4):
            fig.delaxes(axes.flatten()[j])

        fig.tight_layout()
        plt.show()

    def plot_categorical_features(self, df):
        categorical_features = df.select_dtypes(include=["object", "category"])

        cat_cols = len(categorical_features.columns)
        cat_rows = (cat_cols // 2) + (cat_cols % 2)

        fig, axes = plt.subplots(cat_rows, 2, figsize=(15, 4 * cat_rows))

        for i, feature in enumerate(categorical_features.columns):
            row = i // 2
            col = i % 2
            ax = axes[row, col]

            value_counts = df[feature].value_counts()
            ax.bar(
                value_counts.index,
                value_counts.values,
                color="skyblue",
                edgecolor="black",
            )
            ax.set_title(f"Countplot of {feature}")
            ax.set_xlabel(feature)
            ax.set_ylabel("Count")
            ax.tick_params(
                axis="x", rotation=45
            )  # Rotate x-axis labels for better readability if necessary

        # Hide empty subplots if the number of features is odd
        if cat_cols % 2 != 0:
            axes[-1, -1].axis("off")

        plt.tight_layout()
        plt.show()

    def plot_missing_values_proportion(
        self, df: pd.DataFrame, cols_missing_neg1: list[str]
    ):
        # Replace -1 with NaN in the specified columns
        df[cols_missing_neg1] = df[cols_missing_neg1].replace(-1, np.nan)

        # Calculate the percentage of missing values by feature
        null_X = df.isna().sum() / len(df) * 100

        # Plot the missing values
        fig, ax = plt.subplots(figsize=(8, 6))
        ax = (
            null_X.loc[null_X > 0]
            .sort_values()
            .plot(kind="bar", title="Percentage of Missing Values", ax=ax)
        )

        # Annotate the bars with the percentage of missing values
        for p in ax.patches:
            ax.annotate(
                f"{p.get_height():.2f}%",
                (p.get_x() + p.get_width() / 2.0, p.get_height()),
                ha="center",
                va="bottom",
                xytext=(0, 5),
                textcoords="offset points",
                color="red",
            )

        ax.set_ylabel("Missing %")
        ax.set_xlabel("Feature")

        # Remove gridlines from the x-axis
        ax.xaxis.grid(False)

        plt.show()

## 1.2. Quick overview od the dataframe structure

In [None]:
df.info()

## 1.3. Looking For missing values in the Features

In [None]:
eda_plotter = EdaPlotter()

In [None]:
cols_missing = [
    'prev_address_months_count',
    'current_address_months_count',
    'bank_months_count',
    'session_length_in_minutes',
    'device_distinct_emails_8w',
    'intended_balcon_amount'
]

eda_plotter.plot_missing_values_proportion(df, cols_missing)

## 1.4. Proportion of the Labels

We observe that we are dealing with a very unbalanced dataset, which was to be expected given that this is a fraud detection problem. Therefore, it will be necessary to use weights.

In [None]:
# Get the value counts
fraud_counts = df["fraud_bool"].value_counts(normalize=True)

# Mapping 0 to "No Fraud" and 1 to "Fraud"
fraud_labels = {0: "No Fraud", 1: "Fraud"}
fraud_counts.index = fraud_counts.index.map(fraud_labels)

# Plotting the proportions with custom labels
plt.figure(figsize=(8, 6))
plt.bar(fraud_counts.index, fraud_counts.values, color=['green', 'red'], edgecolor='black')
plt.title("Proportion of Fraud vs Non-Fraud Cases")
plt.xlabel("Fraud Status")
plt.ylabel("Proportion")
plt.ylim(0, 1)
plt.show()

## 1.5. Plot the numerical features - Distribution

In [None]:
eda_plotter.plot_numerical_features(df)

## 1.6. Plot the categorical features - Distribution

In [None]:
eda_plotter.plot_categorical_features(df)

## 1.7. Features Correlation

High correlation between features can cause multicollinearity, which can make it difficult for machine learning models to learn effectively. By identifying and removing these features, the model can be simpler and potentially more robust.

We will remove the column "velocity_4w"

In [None]:
def identify_highly_correlated_features(dataframe : pd.DataFrame, threshold=0.80):

    # Select the numerical features
    numerical_features = dataframe.select_dtypes(include=['int64', 'float64']).columns
    X_numerical = dataframe[numerical_features]

    # Calculate the correlation matrix for numerical features
    corr_matrix = X_numerical.corr()

    # Identify pairs of highly correlated features
    high_corr_var = np.where(np.abs(corr_matrix) > threshold)
    high_corr_var = [(corr_matrix.index[x], corr_matrix.columns[y])
                     for x, y in zip(*high_corr_var) if x != y and x < y]

    # Return the list of highly correlated feature pairs
    return corr_matrix, high_corr_var

corr_matrix, high_corr_features = identify_highly_correlated_features(df, threshold=0.80)
print(high_corr_features)

In [None]:
# Visualization of the correlation matrix (for numerical features only)
plt.figure(figsize=(12, 8))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap='coolwarm', vmin=-1, vmax=1)
plt.title('Correlation Matrix of Numerical Features')
plt.show()

## 1.8. Skewness

In [None]:
eda_plotter.plot_skewness(df)

# 2. Utils

In [None]:
import pandas as pd
from sklearn.preprocessing import StandardScaler, LabelEncoder


def preprocess_with_labelencoder(df: pd.DataFrame, col_label: str):
    # Identify categorical and numerical features
    categorical_features = df.select_dtypes(include=["object", "category"]).columns
    numerical_features = df.select_dtypes(include=["number"]).columns

    categorical_features = [
        features for features in categorical_features if features != col_label
    ]
    numerical_features = [
        features for features in numerical_features if features != col_label
    ]

    # Initialize dictionaries to store the encoders and scaler
    label_encoders = {}
    scaler = StandardScaler()

    # Encode categorical features using LabelEncoder
    for col in categorical_features:
        label_encoders[col] = LabelEncoder()
        df[col] = label_encoders[col].fit_transform(df[col])

    # Scale numerical features
    df[numerical_features] = scaler.fit_transform(df[numerical_features])

    return df, label_encoders, scaler

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve, roc_curve, auc

class PerformancePlotter:
    def __init__(self):
        pass

    def plot_auc_curve(self, y_true, y_probs, ax=None):
        fpr, tpr, _ = roc_curve(y_true, y_probs)
        roc_auc = auc(fpr, tpr)
        should_display = False
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))
            should_display = True
        ax.plot(
            fpr,
            tpr,
            color="darkorange",
            lw=2,
            label=f"ROC curve (area = {roc_auc:.2f})",
        )
        ax.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
        ax.set_xlim([0.0, 1.0])
        ax.set_ylim([0.0, 1.05])
        ax.set_xlabel("False Positive Rate")
        ax.set_ylabel("True Positive Rate")
        ax.set_title("Receiver Operating Characteristic")
        ax.legend(loc="lower right")
        if should_display:
            plt.show()

    def plot_precision_recall_curve(self, y_true, y_probs, ax=None):
        precisions, recalls, thresholds = precision_recall_curve(y_true, y_probs)
        should_display = False
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))
            should_display = True
        ax.plot(recalls, precisions, label="Precision-Recall Curve")
        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.legend(loc="best")
        ax.set_title("Precision-Recall Curve")
        if should_display:
            plt.show()

    def plot_precision_recall_f1_vs_threshold(self, y_true, y_probs, ax=None):
        precisions, recalls, thresholds = precision_recall_curve(y_true, y_probs)
        f1_scores = 2 * (precisions * recalls) / (precisions + recalls)
        should_display = False
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6))
            should_display = True
        ax.plot(thresholds, f1_scores[:-1], "r-", label="F1-score")
        ax.plot(thresholds, precisions[:-1], "b--", label="Precision")
        ax.plot(thresholds, recalls[:-1], "g-", label="Recall")
        ax.set_xlabel("Threshold")
        ax.set_ylabel("Precision/Recall")
        ax.legend(loc="best")
        ax.set_title("Precision and Recall vs. Threshold")
        ax.grid(True)
        if should_display:
            plt.show()

    def plot_metrics(self, y_true, y_probs):
        fig, axs = plt.subplots(1, 3, figsize=(18, 6))
        self.plot_auc_curve(y_true, y_probs, ax=axs[0])
        self.plot_precision_recall_curve(y_true, y_probs, ax=axs[1])
        self.plot_precision_recall_f1_vs_threshold(y_true, y_probs, ax=axs[2])
        plt.tight_layout()
        plt.show()

# 3. XgBoost Model

In [None]:
COL_DF_LABEL_FRAUD = "fraud_bool"
COL_BANK_MONTHS_COUNT = "bank_months_count"
COL_PREV_ADDRESS_MONTHS_COUNT = "prev_address_months_count"
COL_VELOCITY_4W = "velocity_4w"

df_xgboost = df.copy()

## 3.1. Remove the features that bring bias

From the Exploratory Data Analysis notebook, we identified features that do not contribute to improving the model, so we decided to remove these columns.

In [None]:
df_xgboost = df_xgboost.drop(columns=[
    COL_BANK_MONTHS_COUNT,
    COL_PREV_ADDRESS_MONTHS_COUNT,
    COL_VELOCITY_4W
    ]
)

## 3.2. Remove the empty rows

As we observed during the EDA, there is very little missing data. Although XGBoost can handle missing values, it is simpler to remove them as a starting point.

In [None]:
cols_missing = [
    'current_address_months_count',
    'session_length_in_minutes',
    'device_distinct_emails_8w',
    'intended_balcon_amount'
]

df_xgboost[cols_missing] = df_xgboost[cols_missing].replace(-1, np.nan)

df_xgboost= df_xgboost.dropna()
df_xgboost.shape

## 3.3. Preprocessing

In [None]:
df_preprocessed, labelenocder, scaler = preprocess_with_labelencoder(df =df_xgboost, col_label=COL_DF_LABEL_FRAUD)

## 3.4. Training and Testing

In [None]:
X = df_preprocessed.drop(columns=COL_DF_LABEL_FRAUD, axis=1)
y = df_preprocessed[COL_DF_LABEL_FRAUD]

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, roc_curve, accuracy_score
from xgboost import XGBClassifier

categorical_features = X.select_dtypes(include=['object', 'category']).columns
X[categorical_features] = X[categorical_features].astype('category')

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y, shuffle=True)

# Calculate the scale_pos_weight parameter
negative_class_count = len(y_train[y_train == 0])
positive_class_count = len(y_train[y_train == 1])
scale_pos_weight = negative_class_count / positive_class_count


# Train an XGBoost model
model = XGBClassifier(
    # use_label_encoder=True,
    # enable_categorical=True
    eval_metric='logloss',
    scale_pos_weight= scale_pos_weight,
)

model.fit(X_train, y_train)

# Predict on the test set
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)[:, 1]

# Metrics
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_auc = roc_auc_score(y_test, y_proba)

# Print the results
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1-Score: {f1:.4f}")
print(f"AUC-ROC: {roc_auc:.4f}")

## 3.5. Plots

In [None]:
plotter = PerformancePlotter()
plotter.plot_metrics(y_test, y_proba)

## 3.6. SHAP

In fraud detection, you may want to prioritize recall (capturing as many fraud cases as possible), even if it comes at the expense of lower precision. The business cost of missing fraud might be higher than the cost of false positives.

In [None]:
import shap

# SHAP Values
# Create a SHAP explainer for the trained XGBoost model
explainer = shap.TreeExplainer(model)

# Calculate SHAP values for the test set
shap_values = explainer.shap_values(X_test)

# Ensure base_values are included in the Explanation object
shap_explanation = shap.Explanation(values=shap_values,
    base_values=explainer.expected_value,
    data=X_test,
    feature_names=X_test.columns
)

In [None]:
# Visualize global feature importance with SHAP summary plot
shap.summary_plot(shap_values, X_test, plot_type="bar")

In [None]:
# Identify the true positives (where predicted class is 1 and actual class is 1)
true_positives = np.where((y_pred == 1) & (y_test == 1))[0]

# First true positive case
if len(true_positives) > 0:
    i_pos = true_positives[0]
    shap.plots.waterfall(shap_explanation[i_pos], max_display=20)
    plt.show()

In [None]:
# Identify the true negatives (where predicted class is 0 and actual class is 0)
true_negatives = np.where((y_pred == 0) & (y_test == 0))[0]

# First true negative case
if len(true_negatives) > 0:
    i_neg = true_negatives[0]
    shap.plots.waterfall(shap_explanation[i_neg], max_display=20)
    plt.show()

In [None]:
shap.plots.beeswarm(shap_explanation, max_display=20)

# 4. Neural Network with Pytorch Lightning

## 4.1. Classifier

In [None]:
from torch import nn

class FraudDetectionModel(nn.Module):
    def __init__(self, input_dim : int):
        super(FraudDetectionModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.bn1 = nn.BatchNorm1d(128)  # Batch normalization after first layer
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)  # Batch normalization after second layer
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)  # Dropout with 50% probability

    def forward(self, x):
        x = self.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)  # Apply dropout after activation
        x = self.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)  # Apply dropout after activation
        x = self.fc3(x)  # No activation function here (output is logits)
        return x


## 4.2. DataModule

In [None]:
import os
import pandas as pd
import torch
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningDataModule

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

class DataFrameDataset(Dataset):
    def __init__(self, dataframe: pd.DataFrame, label_column: str):
        """
        Args:
            dataframe (pd.DataFrame): Input data in pandas DataFrame format.
            label_column (str): Name of the column to be used as the labels.
        """

        self.features = dataframe.drop(label_column, axis=1).values
        self.labels = dataframe[label_column].values

        # Convert features and labels to torch tensors
        self.features = torch.tensor(self.features, dtype=torch.float32)
        self.labels = torch.tensor(self.labels, dtype=torch.float32)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


class DataModule(LightningDataModule):
    batch_size: int
    random_sate: int
    persistent_workers: int
    num_workers: int

    def __init__(
        self,
        train_df: pd.DataFrame,
        val_df: pd.DataFrame,
        test_df: pd.DataFrame,
        batch_size: int = 32,
        random_sate: int = 42,
        num_workers: int = os.cpu_count(),
        persistent_workers: bool = True,
        prefetch_factor: int = 2,
    ):
        super().__init__()

        self.train_df = train_df
        self.val_df = val_df
        self.test_df = test_df
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.random_state = random_sate
        self.persistent_workers = persistent_workers
        self.prefetch_factor = prefetch_factor

        self._class_weights = None
        self.is_data_splitted = False

    def _create_tensor_dataset_from_dataframe(self, dataframe: pd.DataFrame):
        return DataFrameDataset(dataframe, COL_DF_LABEL_FRAUD)

    def setup(self, stage: str):

        if stage == "fit" or stage is None:
            self.train_dataset = self._create_tensor_dataset_from_dataframe(
                self.train_df
            )
            self.val_dataset = self._create_tensor_dataset_from_dataframe(self.val_df)

        if stage == "test" or stage is None:
            self.test_dataset = self._create_tensor_dataset_from_dataframe(self.test_df)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            batch_size=self.batch_size,
            prefetch_factor=self.prefetch_factor,
            dataset=self.train_dataset,
            pin_memory=True
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            persistent_workers=self.persistent_workers,
            num_workers=self.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            persistent_workers=self.persistent_workers,
            num_workers=self.num_workers,
        )

## 4.3. Metrics

In [None]:
from torchmetrics import (
    Accuracy,
    Recall,
    Precision,
    F1Score,
    AUROC,
    MetricCollection,
)


def get_metrics(num_classes: int):
    accuracy_metrics = get_accuracy_metrics(num_classes=num_classes)
    precision_metrics = get_precision_metrics(num_classes=num_classes)
    f1_score_metrics = get_f1_score_metrics(num_classes=num_classes)
    recall_metrics = get_recall_metrics(num_classes=num_classes)
    auc_metrics = get_auc_metrics(num_classes=num_classes)

    return MetricCollection(
        {
            **accuracy_metrics,
            **recall_metrics,
            **precision_metrics,
            **f1_score_metrics,
            **auc_metrics,
        }
    )


def get_recall_metrics(num_classes: int):
    return {

        "recall_weighted": Recall(
            num_classes=num_classes, average="weighted", task="binary"
        ),
    }


def get_precision_metrics(num_classes: int):
    return {

        "precision_weighted": Precision(
            num_classes=num_classes, average="weighted", task="binary"
        ),
    }


def get_f1_score_metrics(num_classes: int):
    return {

        "f1_score_weighted": F1Score(
            num_classes=num_classes, average="weighted", task="binary"
        ),
    }


def get_accuracy_metrics(num_classes: int):
    return {
        "accuracy_weighted": Accuracy(
            num_classes=num_classes, average="weighted", task="binary")
    }


def get_auc_metrics(num_classes: int):
    return {"auroc": AUROC(num_classes=num_classes, task="binary")}


## 4.4. LightningModule

In [None]:
import os
import pandas as pd
from typing import Optional, Any

import torch
from torch import nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer, Adam
from torchmetrics import (
    MetricCollection,
    Accuracy,
    Recall,
    Precision,
)
from pytorch_lightning import LightningModule


class LightningFraudClassifier(LightningModule):
    model: FraudDetectionModel
    metrics: MetricCollection
    critirion: _Loss

    def __init__(
        self,
        model: FraudDetectionModel,
        num_classes: int,
        *args,
        metrics: MetricCollection | None = None,
        criterion: _Loss | None = None,
        **kwargs: Any
    ):
        super().__init__(*args, **kwargs)

        metrics = metrics or self.initialize_metrics(num_classes=num_classes)

        self.metrics = metrics
        self.model = model
        self.criterion = criterion

        self.train_metrics = metrics.clone(prefix="train_")
        self.val_metrics = metrics.clone(prefix="val_")
        self.test_metrics = metrics.clone(prefix="test_")

    def initialize_metrics(self, num_classes: int) -> MetricCollection:

        metrics = MetricCollection(
            {
                "accuracy_weighted": Accuracy(
                    average="weighted", task="binary", num_classes=num_classes
                ),
                "recall_weighted": Recall(
                    average="weighted", task="binary", num_classes=num_classes
                ),
                "precision_weighted": Precision(
                    average="weighted",
                    task="binary",
                    num_classes=num_classes,
                ),
            }
        )

        return metrics

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        return self.model(inputs)

    def step(self, batch: tuple[torch.Tensor, torch.Tensor]):
        inputs, targets = batch
        logits = self.forward(inputs).squeeze()
        targets = targets.long()

        loss = self.criterion(logits, targets.float())
        probs = torch.sigmoid(logits)

        return loss, probs, targets

    def training_step(self, batch: tuple[torch.Tensor, torch.Tensor]):
        loss, probs, targets = self.step(batch=batch)
        self.log(name="train_loss", value=loss, on_step=False, on_epoch=True)
        self.train_metrics.update(probs, targets)

        return loss

    def on_train_epoch_end(self):
        # Compute and log metrics
        metrics = self.train_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        self.train_metrics.reset()

    def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor]):
        loss, probs, targets = self.step(batch=batch)
        self.log(name="val_loss", value=loss, on_step=False, on_epoch=True)
        self.val_metrics.update(probs, targets)
        return loss

    def on_validation_epoch_end(self):
        # Compute and log metrics
        metrics = self.val_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        self.val_metrics.reset()

    def test_step(self, batch: tuple[torch.Tensor, torch.Tensor]):
        loss, probs, targets = self.step(batch=batch)
        self.log(name="test_loss", value=loss, on_step=False, on_epoch=True)

        self.test_metrics.update(probs, targets)

        return {
            "loss": loss,
            "probs": probs.detach(),
            "target": targets,
        }

    def on_test_epoch_end(self):
        # Compute and log metrics
        metrics = self.test_metrics.compute()
        self.log_dict(metrics, on_epoch=True)
        self.test_metrics.reset()


    def configure_optimizers(self) -> Optimizer:
        return Adam(self.parameters())

## 4.5. Trainer

In [None]:
from pathlib import Path
from datetime import datetime

import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import MLFlowLogger, TensorBoardLogger
from typing import List, Tuple

class TrainerManager:
    def __init__(
        self,
        pl_model: LightningFraudClassifier,
        pl_datamodule: DataModule,
        run_datadir: str = f"./model_trainer",
    ):

        self.pl_model = pl_model
        self.pl_datamodule = pl_datamodule
        self.run_datadir = Path(run_datadir)

        self._logger = MLFlowLogger(
            experiment_name="FraudDetection",
            tracking_uri=str(self.run_datadir / "mlflow"),
            run_name=datetime.now().strftime("%Y%m%d_%H%M"),
        )

        self._callback_list = [
            EarlyStopping(monitor="val_loss", patience=3, mode="min", verbose=True),
            ModelCheckpoint(
                dirpath=Path(self.run_datadir, "checkpoints"),
                filename="{epoch}-{val_loss:.2f}",
                monitor="val_loss",
                mode="min",
                save_top_k=1,
            ),
        ]
        self.trainer = None

    @property
    def logger(self) -> TensorBoardLogger:
        return self._logger

    @property
    def callback(self) -> List[Callback]:
        return self._callback_list

    @classmethod
    def set_seed(cls, seed: int = 42):
        seed_everything(seed=seed)

    def train(
        self, epochs: int = 10, use_gpu: bool = False
    ) -> Tuple[FraudDetectionModel, dict]:

        self.set_seed()

        self.trainer = Trainer(
            logger=self.logger,
            callbacks=self.callback,
            max_epochs=epochs,
            num_sanity_val_steps=0,
            check_val_every_n_epoch=1,
            devices="auto",
            accelerator="gpu" if use_gpu else "cpu",
            accumulate_grad_batches=1,
        )

        self.trainer.fit(model=self.pl_model, datamodule=self.pl_datamodule)

        return self.pl_model, self.trainer.logged_metrics

    def load_best_model(self) -> FraudDetectionModel:
        """
        Load the best model from the checkpoint.
        """
        checkpoint_path = self._callback_list[1].best_model_path  # The ModelCheckpoint is the second callback in the list
        if checkpoint_path == "":
            raise ValueError("No checkpoint found. Ensure that training has been completed and a checkpoint has been saved.")

        best_model = LightningFraudClassifier.load_from_checkpoint(checkpoint_path)
        return best_model

    def test(self):

        if self.trainer is None:
            raise ValueError("The model has not been trained, Please call train first")

        results = self.trainer.test(dataloaders=self.pl_datamodule, ckpt_path="best")

        return results

## 4.6. Splitting

In [None]:
df_nn = df.copy()

df_nn = df_nn.drop(columns=[
    COL_BANK_MONTHS_COUNT,
    COL_PREV_ADDRESS_MONTHS_COUNT,
    COL_VELOCITY_4W
    ]
)

cols_missing = [
    'current_address_months_count',
    'session_length_in_minutes',
    'device_distinct_emails_8w',
    'intended_balcon_amount'
]

df_nn[cols_missing] = df_nn[cols_missing].replace(-1, np.nan)
df_nn= df_nn.dropna()

df_preprocessed_nn, label_encoder, sclarer = preprocess_with_labelencoder(
    df=df_nn,
    col_label=COL_DF_LABEL_FRAUD
)

In [None]:
from sklearn.model_selection import train_test_split

test_size = 0.30
val_size = 0.5

train_df, test_df = train_test_split(
    df_preprocessed_nn,
    test_size=test_size,
    random_state=42,
    shuffle=True,
    stratify=df_preprocessed_nn[COL_DF_LABEL_FRAUD],
)

# Split to create a train and validation dataframe
test_df, val_df = train_test_split(
    test_df,
    test_size=val_size,
    shuffle=True,
    random_state=42,
    stratify=test_df[COL_DF_LABEL_FRAUD],
)

## 4.7. Compute Weights

In [None]:
# Compute the class weights
class_weights = compute_class_weight(
        class_weight="balanced",
        classes=train_df[COL_DF_LABEL_FRAUD].unique(),
        y=train_df[COL_DF_LABEL_FRAUD],
    )
tensor_class_weights = torch.tensor(data=class_weights, dtype=torch.float32)

## 4.8 Train

In [None]:
pl_datamodule = DataModule(
    train_df=train_df,
    val_df=val_df,
    test_df=test_df,
    batch_size=128,
    prefetch_factor=2,
    persistent_workers=True
)

num_classes = tensor_class_weights.shape[0]
print(tensor_class_weights)

In [None]:
from torch.nn import BCEWithLogitsLoss

model = FraudDetectionModel(df_nn.shape[1]-1)
metrics= get_metrics(num_classes=num_classes)
criterion = BCEWithLogitsLoss(pos_weight=tensor_class_weights[1])

In [None]:
pl_model = LightningFraudClassifier(
    num_classes=num_classes,
    model=model,
    metrics=metrics,
    criterion=criterion,
)

In [None]:
run_datadir = "./kaggle/model_trained"

trainer = TrainerManager(
    pl_datamodule=pl_datamodule,
    pl_model=pl_model,
    run_datadir=run_datadir
)


In [None]:
model_trained, _ = trainer.train(epochs=22, use_gpu=True)

## 4.9 Test

In [None]:
test_metrics = trainer.test()

In [None]:
# Load the test dataloader from the DataModule
test_dataloader = pl_datamodule.test_dataloader()

y_true_nn = []
y_probs_nn = []

for batch in test_dataloader:
    inputs, targets = batch

    with torch.no_grad():  # Disable gradient computation
        outputs = model_trained(inputs)
        probabilities = torch.sigmoid(outputs).cpu().numpy()
        positive_probs = probabilities.squeeze()  # Get probabilities for the positive class

    y_true_nn.extend(targets.cpu().numpy())
    y_probs_nn.extend(positive_probs)

In [None]:
plotter = PerformancePlotter()
plotter.plot_metrics(y_true_nn, y_probs_nn)

##