In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from net import FraudNet, AttentionTransformerFraudNet, EnhancedFraudNet  # Import fraud detection model
from data import get_dataloaders_fraud  # Import dataset functions
from evaluation import evaluate_model  # Import evaluation function
from train import train_model, set_all_seeds  # Import training function from train.py
import pandas as pd
import sys
from plot import plot_metrics, plot_confusion_matrices, plot_aucpr
import pickle


# Load fraud dataset
set_all_seeds(42)

# Set dataset path
DATASET_PATH = "/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/data/creditcard.csv"
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# Training hyperparameters
batch_size = 32
num_epochs = 15
learning_rate = 0.00003

df = pd.read_csv(DATASET_PATH)
input_size = df.shape[1] - 1
print(f"Detected input size: {input_size}")
save_plot_dir = 'data_plot'
os.makedirs(save_plot_dir, exist_ok=True)
train_loader, valid_loader, test_loader, class_weights = get_dataloaders_fraud(
    DATASET_PATH, batch_size=batch_size, use_smote=True, plot=True, save_plot_dir=save_plot_dir, sampling_strategy=0.5
)

Detected input size: 30
Applying SMOTE to balance training data...
Adding Gaussian noise (std=0.1) to training data...
Overlapping samples between Train & Test: 0
Overlapping samples between Validation & Test: 0
Overlapping samples between Validation & Train: 0
Training set size after SMOTE: 271921 samples
Validation set size: 56746 samples
Test set size: 45396 samples


  self.labels = torch.tensor(labels, dtype=torch.float32)


In [None]:
for model in [AttentionTransformerFraudNet(input_size=input_size).to(DEVICE), FraudNet(input_size=input_size).to(DEVICE), EnhancedFraudNet(input_size=input_size).to(DEVICE)]:
    for stochastic_val in [True, False]: 
        # Get model name for saving metrics
        model_name = model.__class__.__name__

        class_weights = class_weights
        pos_weight = torch.tensor([class_weights[1] / class_weights[0]], device=DEVICE)

        pos_weight = None

        # Loss Function (No weight balancing since using SMOTE)
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)

        # Call `train.py` instead of writing the training loop here
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='max', patience=3, verbose=True
        )

        train_loss_list, train_metrics_list, valid_metrics_list, test_metrics = train_model(
            model, num_epochs, train_loader, valid_loader, test_loader, optimizer,
            criterion, DEVICE, scheduler=scheduler, stochastic=stochastic_val
        )

        save_plot_dir = f'plot_{model_name}_{batch_size}_{num_epochs}_{learning_rate}'
        if pos_weight:
            save_plot_dir += "_pos_weight"
        if stochastic_val:
            save_plot_dir += "_stochastic"
            
        os.makedirs(save_plot_dir, exist_ok=True)

        # Create metrics directory if it doesn't exist
        metrics_dir = f'metrics'
        os.makedirs(metrics_dir, exist_ok=True)

        # Save training metrics
        metrics_data = {
            'train_metrics': train_metrics_list,
            'valid_metrics': valid_metrics_list,
            'test_metrics': test_metrics,
            'train_loss': train_loss_list
        }
        if pos_weight:
            metrics_file = os.path.join(metrics_dir, f"{model_name}_{batch_size}_{num_epochs}_{learning_rate}_pos_weight_metrics.pickle")
        else:
            metrics_file = os.path.join(metrics_dir, f"{model_name}_{batch_size}_{num_epochs}_{learning_rate}_metrics.pickle")
            
        with open(metrics_file, 'wb') as f:
            pickle.dump(metrics_data, f)

        print(f"Metrics saved to {metrics_file}")

        # Create plots
        plot_metrics(train_metrics_list, fig_name="Training Metrics", save_path=f"{save_plot_dir}/train_metrics.png")
        plot_metrics(valid_metrics_list, fig_name="Validation Metrics", save_path=f"{save_plot_dir}/valid_metrics.png")
        plot_confusion_matrices(model, test_loader, threshold=0.85, save_path=f"{save_plot_dir}/confusion_matrix.png")
        plot_aucpr(model, test_loader, device=DEVICE, save_path=f"{save_plot_dir}/auc_pr.png")

        # Save the trained model
        best_model_path = "best_model.pth"
        model.load_state_dict(torch.load(best_model_path))
        print("Loaded best model from training phase.")

        # Save the best model explicitly at a clear location for future usage
        if pos_weight:
            final_model_path = f"./best_{model_name}_{batch_size}_{num_epochs}_{learning_rate}_pos_weight_model.pth"
        else:
            final_model_path = f"./best_{model_name}_{batch_size}_{num_epochs}_{learning_rate}_model.pth"
            
        torch.save(model.state_dict(), final_model_path)
        print(f"Final best model saved explicitly at {final_model_path}")

        # Evaluate model
        print("Evaluating Model on Test Set...")
        evaluate_model(model, test_loader, DEVICE)

# AIO

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from net import FraudNet, EnhancedFraudNet  # Import fraud detection model
from data import get_dataloaders_fraud, get_dataloaders_fraud_2  # Import dataset functions
from evaluation import evaluate_model  # Import evaluation function
from train import train_model, set_all_seeds  # Import training function from train.py
import pandas as pd
import sys
from plot import plot_metrics, plot_confusion_matrices, plot_aucpr
import pickle


# Load fraud dataset
set_all_seeds(42)

# Set dataset path
DATASET_PATH = "/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/data/creditcard.csv"
TEST_DATASET_PATH = "/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/data/creditcard.csv"
TRAIN_VALID_DATASET_PATH = "/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/data/train_valid.csv"

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

# Training hyperparameters
batch_size = 32

# for batch_size in [96, 128, 256]:
for batch_size in [32, 64, 96, 128, 256]:
    num_epochs = 15
    learning_rate = 0.00003

    df = pd.read_csv(DATASET_PATH)
    input_size = df.shape[1] - 1
    print(f"Detected input size: {input_size}")

    # train_loader, valid_loader, test_loader, class_weights = get_dataloaders_fraud(
    #     DATASET_PATH, batch_size=batch_size, use_smote=True, plot=True, save_plot_dir='data_plot'
    # )

    train_loader, valid_loader, test_loader, class_weights, _ = get_dataloaders_fraud_2(
        TRAIN_VALID_DATASET_PATH, test_csv=TEST_DATASET_PATH, batch_size=batch_size, use_smote=True, plot=True, save_plot_dir='data_plot'
    )

    for model in [FraudNet(input_size=input_size).to(DEVICE), EnhancedFraudNet(input_size=input_size).to(DEVICE)]:
        for pos_weight in [torch.tensor([class_weights[1] / class_weights[0]], device=DEVICE), None]:
            for stochastic_val in [True, False]: 
                # Get model name for saving metrics
                model_name = model.__class__.__name__

                # Loss Function (No weight balancing since using SMOTE)
                criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

                # Optimizer
                optimizer = optim.Adam(model.parameters(), lr=learning_rate)

                # Call `train.py` instead of writing the training loop here
                scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer, mode='max', patience=3, verbose=True
                )

                train_loss_list, train_metrics_list, valid_metrics_list, test_metrics = train_model(
                    model, num_epochs, train_loader, valid_loader, test_loader, optimizer,
                    criterion, DEVICE, scheduler=scheduler, stochastic=stochastic_val
                )

                save_plot_dir = f'plot_{model_name}_{batch_size}_{num_epochs}_{learning_rate}'
                if pos_weight:
                    save_plot_dir += "_pos_weight"
                if stochastic_val:
                    save_plot_dir += "_stochastic"
                    
                os.makedirs(save_plot_dir, exist_ok=True)

                # Create metrics directory if it doesn't exist
                metrics_dir = f'metrics'
                os.makedirs(metrics_dir, exist_ok=True)

                # Save training metrics
                metrics_data = {
                    'train_metrics': train_metrics_list,
                    'valid_metrics': valid_metrics_list,
                    'test_metrics': test_metrics,
                    'train_loss': train_loss_list
                }
                metrics_file_name = f"{model_name}_{batch_size}_{num_epochs}_{learning_rate}"
                if pos_weight:
                    metrics_file_name += "_pos_weight"
                if stochastic_val:
                    metrics_file_name += "_stochastic"
                metrics_file = os.path.join(metrics_dir, f"{metrics_file_name}_metrics.pickle")
                    
                with open(metrics_file, 'wb') as f:
                    pickle.dump(metrics_data, f)

                print(f"Metrics saved to {metrics_file}")

                # Create plots
                plot_metrics(train_metrics_list, fig_name="Training Metrics", save_path=f"{save_plot_dir}/train_metrics.png")
                plot_metrics(valid_metrics_list, fig_name="Validation Metrics", save_path=f"{save_plot_dir}/valid_metrics.png")
                plot_confusion_matrices(model, test_loader, threshold=0.85, save_path=f"{save_plot_dir}/confusion_matrix.png")
                plot_aucpr(model, test_loader, device=DEVICE, save_path=f"{save_plot_dir}/auc_pr.png")

                # Save the trained model
                best_model_path = "best_model.pth"
                model.load_state_dict(torch.load(best_model_path))
                print("Loaded best model from training phase.")

                model_dir = 'models'
                os.makedirs(model_dir, exist_ok=True)
                
                model_file_name = f"{model_name}_{batch_size}_{num_epochs}_{learning_rate}"
                # Save the best model explicitly at a clear location for future usage
                if pos_weight:
                    model_file_name += "_pos_weight"
                if stochastic_val:
                    model_file_name += "_stochastic"
                final_model_path = os.path.join(model_dir, f"{model_file_name}_model.pth")
                    
                torch.save(model.state_dict(), final_model_path)
                print(f"Final best model saved explicitly at {final_model_path}")

                # Evaluate model
                print("Evaluating Model on Test Set...")
                evaluate_model(model, test_loader, DEVICE)

Detected input size: 30
Applying SMOTE to balance training data...
Adding Gaussian noise (std=0.1) to training data...
Training set size after SMOTE: 265200 samples
Validation set size: 51085 samples


  self.labels = torch.tensor(labels, dtype=torch.float32)


Starting Training...
[Epoch 1, Batch 1] Loss: 1.3599
[Epoch 1, Batch 829] Loss: 0.4718
[Epoch 1, Batch 1657] Loss: 0.2514
[Epoch 1, Batch 2485] Loss: 0.2331
[Epoch 1, Batch 3313] Loss: 0.4884
[Epoch 1, Batch 4141] Loss: 0.1810
[Epoch 1, Batch 4969] Loss: 0.1062
[Epoch 1, Batch 5797] Loss: 0.1382
[Epoch 1, Batch 6625] Loss: 0.3277
[Epoch 1, Batch 7453] Loss: 0.3750
[Epoch 1, Batch 8281] Loss: 0.2270
Epoch 1/15: Train Loss: 0.2959 | Train Acc: 97.21% | Valid Loss: 0.1152 | Valid Acc: 99.66% | Valid Precision: 30.70% | Valid Recall: 82.35% | Valid F1-score: 44.73% | Valid AUC-PR: 65.39%
Time Elapsed: 0.45 minutes
Model improved. Saving best model.
[Epoch 2, Batch 1] Loss: 0.0871
[Epoch 2, Batch 829] Loss: 0.2656
[Epoch 2, Batch 1657] Loss: 0.0788
[Epoch 2, Batch 2485] Loss: 0.0719
[Epoch 2, Batch 3313] Loss: 0.2442
[Epoch 2, Batch 4141] Loss: 0.4090
[Epoch 2, Batch 4969] Loss: 0.0284
[Epoch 2, Batch 5797] Loss: 0.0928
[Epoch 2, Batch 6625] Loss: 0.0985
[Epoch 2, Batch 7453] Loss: 0.0880
[

  self.labels = torch.tensor(labels, dtype=torch.float32)


Training set size after SMOTE: 265200 samples
Validation set size: 51085 samples
Starting Training...
[Epoch 1, Batch 1] Loss: 1.2845
[Epoch 1, Batch 415] Loss: 0.7658
[Epoch 1, Batch 829] Loss: 0.5096
[Epoch 1, Batch 1243] Loss: 0.3635
[Epoch 1, Batch 1657] Loss: 0.4367
[Epoch 1, Batch 2071] Loss: 0.3486
[Epoch 1, Batch 2485] Loss: 0.3434
[Epoch 1, Batch 2899] Loss: 0.3028
[Epoch 1, Batch 3313] Loss: 0.1963
[Epoch 1, Batch 3727] Loss: 0.2942
[Epoch 1, Batch 4141] Loss: 0.1206
Epoch 1/15: Train Loss: 0.4087 | Train Acc: 96.37% | Valid Loss: 0.1533 | Valid Acc: 99.92% | Valid Precision: 73.12% | Valid Recall: 80.00% | Valid F1-score: 76.40% | Valid AUC-PR: 81.07%
Time Elapsed: 0.27 minutes
Model improved. Saving best model.
[Epoch 2, Batch 1] Loss: 0.3092
[Epoch 2, Batch 415] Loss: 0.2998
[Epoch 2, Batch 829] Loss: 0.2457
[Epoch 2, Batch 1243] Loss: 0.2019
[Epoch 2, Batch 1657] Loss: 0.2523
[Epoch 2, Batch 2071] Loss: 0.1099
[Epoch 2, Batch 2485] Loss: 0.1931
[Epoch 2, Batch 2899] Loss:

  self.labels = torch.tensor(labels, dtype=torch.float32)


Starting Training...
[Epoch 1, Batch 1] Loss: 1.1787
[Epoch 1, Batch 277] Loss: 0.7642
[Epoch 1, Batch 553] Loss: 0.6153
[Epoch 1, Batch 829] Loss: 0.4670
[Epoch 1, Batch 1105] Loss: 0.4463
[Epoch 1, Batch 1381] Loss: 0.3517
[Epoch 1, Batch 1657] Loss: 0.3731
[Epoch 1, Batch 1933] Loss: 0.3884
[Epoch 1, Batch 2209] Loss: 0.3535
[Epoch 1, Batch 2485] Loss: 0.2719
[Epoch 1, Batch 2761] Loss: 0.1992
Epoch 1/15: Train Loss: 0.4583 | Train Acc: 96.38% | Valid Loss: 0.2422 | Valid Acc: 99.85% | Valid Precision: 53.44% | Valid Recall: 82.35% | Valid F1-score: 64.81% | Valid AUC-PR: 72.06%
Time Elapsed: 0.21 minutes
Model improved. Saving best model.
[Epoch 2, Batch 1] Loss: 0.2665
[Epoch 2, Batch 277] Loss: 0.3328
[Epoch 2, Batch 553] Loss: 0.2592
[Epoch 2, Batch 829] Loss: 0.1646
[Epoch 2, Batch 1105] Loss: 0.2649
[Epoch 2, Batch 1381] Loss: 0.1968
[Epoch 2, Batch 1657] Loss: 0.2492
[Epoch 2, Batch 1933] Loss: 0.1598
[Epoch 2, Batch 2209] Loss: 0.2150
[Epoch 2, Batch 2485] Loss: 0.1369
[Epoc

  self.labels = torch.tensor(labels, dtype=torch.float32)


Training set size after SMOTE: 265200 samples
Validation set size: 51085 samples
Starting Training...
[Epoch 1, Batch 1] Loss: 1.1294
[Epoch 1, Batch 208] Loss: 0.6828
[Epoch 1, Batch 415] Loss: 0.7214
[Epoch 1, Batch 622] Loss: 0.5013
[Epoch 1, Batch 829] Loss: 0.5353
[Epoch 1, Batch 1036] Loss: 0.4630
[Epoch 1, Batch 1243] Loss: 0.3889
[Epoch 1, Batch 1450] Loss: 0.3469
[Epoch 1, Batch 1657] Loss: 0.3906
[Epoch 1, Batch 1864] Loss: 0.2742
[Epoch 1, Batch 2071] Loss: 0.3501
Epoch 1/15: Train Loss: 0.4956 | Train Acc: 94.29% | Valid Loss: 0.2227 | Valid Acc: 99.95% | Valid Precision: 84.88% | Valid Recall: 85.88% | Valid F1-score: 85.38% | Valid AUC-PR: 85.69%
Time Elapsed: 0.17 minutes
Model improved. Saving best model.
[Epoch 2, Batch 1] Loss: 0.3355
[Epoch 2, Batch 208] Loss: 0.2590
[Epoch 2, Batch 415] Loss: 0.2984
[Epoch 2, Batch 622] Loss: 0.3021
[Epoch 2, Batch 829] Loss: 0.2353
[Epoch 2, Batch 1036] Loss: 0.3156
[Epoch 2, Batch 1243] Loss: 0.2639
[Epoch 2, Batch 1450] Loss: 0.2

  self.labels = torch.tensor(labels, dtype=torch.float32)


Starting Training...
[Epoch 1, Batch 1] Loss: 1.0655
[Epoch 1, Batch 104] Loss: 0.7617
[Epoch 1, Batch 207] Loss: 0.6917
[Epoch 1, Batch 310] Loss: 0.6280
[Epoch 1, Batch 413] Loss: 0.5647
[Epoch 1, Batch 516] Loss: 0.4934
[Epoch 1, Batch 619] Loss: 0.4608
[Epoch 1, Batch 722] Loss: 0.5062
[Epoch 1, Batch 825] Loss: 0.4560
[Epoch 1, Batch 928] Loss: 0.4609
[Epoch 1, Batch 1031] Loss: 0.4108
Epoch 1/15: Train Loss: 0.5706 | Train Acc: 93.58% | Valid Loss: 0.3645 | Valid Acc: 99.90% | Valid Precision: 68.37% | Valid Recall: 78.82% | Valid F1-score: 73.22% | Valid AUC-PR: 59.53%
Time Elapsed: 0.14 minutes
Model improved. Saving best model.
[Epoch 2, Batch 1] Loss: 0.4712
[Epoch 2, Batch 104] Loss: 0.4897
[Epoch 2, Batch 207] Loss: 0.4455
[Epoch 2, Batch 310] Loss: 0.4934
[Epoch 2, Batch 413] Loss: 0.4376
[Epoch 2, Batch 516] Loss: 0.3572
[Epoch 2, Batch 619] Loss: 0.3773
[Epoch 2, Batch 722] Loss: 0.3933
[Epoch 2, Batch 825] Loss: 0.3904
[Epoch 2, Batch 928] Loss: 0.3660
[Epoch 2, Batch 1

In [None]:
import pickle

DIR = '/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/metrics'
metrics = os.listdir(DIR)
metrics.sort()

for metric in metrics:
    if metric.startswith('Attention') and 'pos_weight' not in metric:
        with open(os.path.join(DIR, metric), 'rb') as f:
            data = pickle.load(f)
        
        print(data.keys())
        print(data['train_metrics'][1].keys())


# Note
## train_loss for comparison between model (same batch size)

# Compare Model with itself in diff batch

In [None]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
import re

def sort_by_model_and_size(filename):
    # Extract model name (everything before the first underscore)
    model_match = re.match(r'([^_]+)_', filename)
    model_name = model_match.group(1) if model_match else ""
    
    # Extract size (first number after the model name)
    size_match = re.search(r'_(\d+)_', filename)
    size = int(size_match.group(1)) if size_match else 0
    
    # Return tuple for sorting (first by model, then by size)
    return (model_name, size)

# Directory where metrics are stored
DIR = '/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/metrics'
METRIC_NAMES = ['train_metrics', 'valid_metrics']
METRIC_PERSIONS =  ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'auc_pr', 'loss']

models = [FraudNet(), EnhancedFraudNet()]
for model in models:
    for metric_name in METRIC_NAMES:
        for metric_percision in METRIC_PERSIONS:
            model_name = model.__class__.__name__
            
            metrics = os.listdir(DIR)
            metrics.sort()
            
            # Set up the plot
            plt.figure(figsize=(12, 6))
            plt.title(f'{metric_name}_{metric_percision} in {model_name}', fontsize=14)
            plt.xlabel('Step', fontsize=12)
            plt.ylabel('Metric Value', fontsize=12)

            # Colors for different models with better contrast
            colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
            markers = ['o', 's', '^', 'D', 'x', '*']

            # Load and plot each metric file
            attention_models = []
            for i, metric in enumerate(sorted(metrics, key=sort_by_model_and_size)):
                if metric.startswith(model_name):
                    attention_models.append(metric)
                    with open(os.path.join(DIR, metric), 'rb') as f:
                        data = pickle.load(f)
                            
                    # Extract x and y values for plotting
                    x_values = []
                    y_values = []
                    
                    for j, point in enumerate(data[metric_name]):
                        x_values.append(j)
                        y_values.append(point[metric_percision])
                    
                    # Plot each point individually
                    for j in range(len(x_values)):
                        plt.plot(x_values[j], y_values[j], 
                                marker=markers[i % len(markers)], 
                                color=colors[i % len(colors)],
                                markersize=5)
                    
                    # Connect points with a line
                    plt.plot(x_values, y_values, 
                            color=colors[i % len(colors)], 
                            linewidth=1.5, 
                            alpha=0.7,
                            label=f'{metric}')

            # Add legend with better placement
            plt.legend(loc='best', fontsize=10)

            # Add grid for better readability but make it subtle
            plt.grid(True, linestyle='--', alpha=0.3)

            # Improve appearance
            plt.tight_layout()

            # Save the figure
            plt.savefig(f'{model_name}_{metric_name}_{metric_percision}_pos_weight_points.png', dpi=300, bbox_inches='tight')

            # Show the plot
            plt.show()

# Same Batch Size Diff Model

In [None]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import re
from matplotlib.lines import Line2D

# Directory where metrics are stored
DIR = '/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/metrics'
METRIC_NAMES = ['train_metrics', 'valid_metrics']
METRIC_PERSIONS = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'auc_pr', 'loss']

# Define batch sizes and models to compare
BATCH_SIZE = "256"  # Set this to the batch size you want to compare
MODEL_NAMES = ["FraudNet", "EnhancedFraudNet"]

# Function to check if a file matches our criteria
def matches_criteria(filename, model_name, batch_size, use_pos_weight):
    # Check if file starts with the model name
    if not filename.startswith(model_name):
        return False
    
    # Check if file contains the specified batch size
    batch_match = re.search(r'_(\d+)_', filename)
    if not (batch_match and batch_match.group(1) == batch_size):
        return False
    
    # Check if file has "pos_weight" according to the preference
    has_pos_weight = "pos_weight" in filename
    if has_pos_weight != use_pos_weight:
        return False
    
    return True

# Create plots for each metric
for metric_name in METRIC_NAMES:
    for metric_precision in METRIC_PERSIONS:
        # Set up the plot
        plt.figure(figsize=(14, 8))
        plt.title(f'Comparison of Models: {metric_name}_{metric_precision} (Batch Size {BATCH_SIZE})', 
                 fontsize=14)
        plt.xlabel('Step', fontsize=12)
        plt.ylabel(f'{metric_precision.replace("_", " ").title()}', fontsize=12)

        # Colors for different models with better contrast
        colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
        
        # Different line styles for pos_weight vs regular
        line_styles = ['-', '--']  # Solid for regular, dashed for pos_weight
        markers = ['o', 's', '^', 'D', 'x', '*']

        # Get all metrics files
        metrics_files = os.listdir(DIR)
        
        # To store data for custom legend
        legend_elements = []
        
        # For each model, find and plot both the pos_weight and regular files
        for i, model_name in enumerate(MODEL_NAMES):
            color = colors[i % len(colors)]
            marker = markers[i % len(markers)]
            
            for weight_type in [False, True]:  # False = regular, True = pos_weight
                model_found = False
                for metric_file in metrics_files:
                    if matches_criteria(metric_file, model_name, BATCH_SIZE, weight_type):
                        with open(os.path.join(DIR, metric_file), 'rb') as f:
                            try:
                                data = pickle.load(f)                                
                                # Extract x and y values for plotting
                                x_values = []
                                y_values = []
                                
                                for j, point in enumerate(data[metric_name]):
                                    # Check if the metric precision exists in the point data
                                    if isinstance(point, dict) and metric_precision in point:
                                        x_values.append(j)
                                        y_values.append(point[metric_precision])
                                
                                if len(x_values) > 0:
                                    # Choose line style by weight type
                                    line_style = line_styles[1 if weight_type else 0]
                                    
                                    # Create appropriate label
                                    label = f"{model_name} {'(pos_weight)' if weight_type else '(regular)'}"
                                    
                                    # Plot the line
                                    plt.plot(x_values, y_values,
                                            color=color,
                                            linestyle=line_style,
                                            linewidth=2,
                                            alpha=0.7)
                                    
                                    # Plot the points
                                    plt.scatter(x_values, y_values,
                                              marker=marker,
                                              color=color,
                                              s=30,
                                              alpha=0.8 if weight_type else 0.6)
                                    
                                    # Create a legend element for this line style
                                    legend_elements.append(
                                        Line2D([0], [0], color=color, marker=marker, linestyle=line_style,
                                              markersize=8, label=label)
                                    )
                                    
                                    model_found = True
                                else:
                                    print(f"No data points found for {metric_precision} in {metric_file}")
                                    
                            except Exception as e:
                                print(f"Error processing {metric_file}: {e}")
                        
                        if model_found:
                            # We only need one file per model/weight combo, so break after finding the first match
                            break
                
                if not model_found:
                    weight_label = "pos_weight" if weight_type else "regular"
        
        # Add a custom legend that correctly shows line styles and markers
        if legend_elements:
            plt.legend(handles=legend_elements, loc='best', fontsize=10, ncol=2)
        else:
            plt.close()
            continue
        
        # Add grid for better readability but make it subtle
        plt.grid(True, linestyle='--', alpha=0.3)
        
        # Set background color to light gray for better contrast
        plt.gca().set_facecolor('#f8f8f8')
        
        # Improve appearance
        plt.tight_layout()
        
        # Save the figure
        plt.savefig(f'comparison_{metric_name}_{metric_precision}_batch{BATCH_SIZE}_combined.png', 
                   dpi=300, bbox_inches='tight')
                
        # Close the figure to free memory
        plt.close()

print("All plots created successfully!")

# Compare Same Model Same Batch Size

In [6]:
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import savgol_filter
import re

def extract_model_info(filename):
    # Remove the "_metrics.pickle" suffix if it exists
    base_name = filename.replace('_metrics.pickle', '')
    
    # Split the filename into components
    parts = base_name.split('_')
    
    # Extract model name (first part)
    model_name = parts[0]
    
    # Extract batch size, epochs, and learning rate
    batch_size = int(parts[1]) if len(parts) > 1 and parts[1].isdigit() else 0
    
    # Extract variant information (pos_weight and/or stochastic)
    variant = []
    if "pos_weight" in base_name:
        variant.append("pos_weight")
    if "stochastic" in base_name:
        variant.append("stochastic")
    
    variant_str = "_".join(variant) if variant else "base"
    
    return model_name, batch_size, variant_str

# Directory where metrics are stored
DIR = '/home/khoa/Khoa/outsource/na_thesis/examples/hello-world/ml-to-fl/pt/src/metrics'
METRIC_NAMES = ['train_metrics', 'valid_metrics']
METRIC_PERSIONS = ['accuracy', 'precision', 'recall', 'f1_score', 'auc_roc', 'auc_pr', 'loss']
PLOT_DIR = '__plot_dir'
os.makedirs(PLOT_DIR, exist_ok=True)

# Get all metric files
metrics = os.listdir(DIR)

# Group files by model type and batch size
model_batch_groups = {}

for metric in metrics:
    model_name, batch_size, variant = extract_model_info(metric)
    key = f"{model_name}_{batch_size}"
    
    if key not in model_batch_groups:
        model_batch_groups[key] = []
    
    model_batch_groups[key].append(metric)

# Process each group separately
for group_key, group_metrics in model_batch_groups.items():
    model_name, batch_size = group_key.split('_')
    
    for metric_name in METRIC_NAMES:
        for metric_percision in METRIC_PERSIONS:
            # Set up the plot
            plt.figure(figsize=(12, 6))
            plt.title(f'{model_name} (Batch Size {batch_size}) - {metric_name}_{metric_percision}', fontsize=14)
            plt.xlabel('Step', fontsize=12)
            plt.ylabel('Metric Value', fontsize=12)
            
            # Colors for different variants with better contrast
            colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
            markers = ['o', 's', '^', 'D', 'x', '*']
            
            # Load and plot each metric file in the group
            for i, metric in enumerate(sorted(group_metrics)):
                try:
                    with open(os.path.join(DIR, metric), 'rb') as f:
                        data = pickle.load(f)
                    
                    # Extract x and y values for plotting
                    x_values = []
                    y_values = []
                    
                    for j, point in enumerate(data[metric_name]):
                        if metric_percision in point:
                            x_values.append(j)
                            y_values.append(point[metric_percision])
                    
                    if len(x_values) > 0:
                        # Extract variant name for labeling
                        _, _, variant = extract_model_info(metric)
                        label = variant
                        
                        # Plot each point individually
                        for j in range(len(x_values)):
                            plt.plot(x_values[j], y_values[j],
                                    marker=markers[i % len(markers)],
                                    color=colors[i % len(colors)],
                                    markersize=5)
                        
                        # Connect points with a line
                        plt.plot(x_values, y_values,
                                color=colors[i % len(colors)],
                                linewidth=1.5,
                                alpha=0.7,
                                label=f'{label}')
                except Exception as e:
                    print(f"Error processing {metric}: {e}")
            
            # Add legend with better placement
            plt.legend(loc='best', fontsize=10)
            
            # Add grid for better readability but make it subtle
            plt.grid(True, linestyle='--', alpha=0.3)
            
            # Improve appearance
            plt.tight_layout()
            
            # Save the figure
            plt.savefig(os.path.join(PLOT_DIR, f'{model_name}_batch{batch_size}_{metric_name}_{metric_percision}.png'), 
                       dpi=300, bbox_inches='tight')
            
            # Close the figure to free memory
            plt.close()