In [None]:
import os, sys, logging, torch, shap

import pandas as pd
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "../")))
from configuration import Configuration
from os.path import join as path
from torch.utils.data import DataLoader, Dataset
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)
simplefilter(action="ignore", category=RuntimeWarning)

from model import Model

# Configuration
c = Configuration()

# logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)

In [None]:
def load_model(in_features, out_features, scenario_name):
    model = Model(in_features, out_features)
    model.load_state_dict(torch.load(path(c.path_results, "scenarios", "CL", scenario_name, "best_model.torch")))
    return model


features = c._flowstats + c._pstats + c._pflowstats + c._pstats_subdirs
columns = features + c.appl

in_features = len(features)
out_features = len(c.classes)

model = load_model(in_features, out_features, "FC-all")

In [None]:
class FQDataset(Dataset):
    def __init__(self, x, y):
        self.x = torch.tensor(x.values, dtype=torch.float32)
        self.y = torch.tensor(y.values, dtype=torch.long)

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
pth = path(c.path_dataset, "4-dataset", "CL")
    
train_cnt = 0

for filename in os.listdir(pth):
    if filename.startswith("train-"):
        train_cnt += 1

train_paths = [ path(c.path_dataset, "4-dataset", "CL", f"train-{i}.parquet") for i in range(1, train_cnt+1)]
test_path = path(c.path_dataset, "4-dataset", "CL", "test.parquet")
validation_path = path(c.path_dataset, "4-dataset", "CL", "validation.parquet")

def load_data(path):
    df= pd.read_parquet(path, columns=columns)

    X, y = df.drop(columns=c.appl), df[c.app]

    ds = FQDataset(X, y)

    return DataLoader(ds, batch_size=1024, shuffle=True)

In [None]:
model.eval()

test_loader = load_data(test_path)

sample_data, labels = next(iter(test_loader))
    
explainer = shap.DeepExplainer(model, sample_data[:50])

# Compute SHAP values
shap_values = explainer.shap_values(sample_data[50:60], check_additivity=False)


In [None]:
df= pd.read_parquet(test_path, columns=columns)

In [None]:
shap.summary_plot(shap_values, sample_data[50:60], plot_type="bar", show=False, feature_names=df.columns)

class_names = ["discord", "facebook-graph", "google-www", "instagram", "snapchat", "spotify", "youtube"]  # class 0–6

# Get current legend handles and labels
handles, labels = plt.gca().get_legend_handles_labels()

# Replace with your custom labels
plt.legend(handles, class_names, title="Classes")

# Show updated plot
plt.show()