# STFT-ResNet

## 1. Import & Device

In [None]:
from src.data_loader import csv_to_npy, load_trace, load_supervised_set, AES_VERSIONS, IS_RAND, RAW_DATASET_PATH, DATASET_PATH
from src.utils import *

import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models

from scipy.signal import stft
from skimage.transform import resize

from sklearn.metrics import accuracy_score
from captum.attr import IntegratedGradients

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Functions

### 2.1 Model

In [None]:
def get_resnet(model_name='resnet18', num_classes=1):
    if model_name == 'resnet18':
        model = models.resnet18(weights=None)
    elif model_name == 'resnet34':
        model = models.resnet34(weights=None)

    original_conv1 = model.conv1
    # 3 channel -> 1 channel
    model.conv1 = nn.Conv2d(1, original_conv1.out_channels, kernel_size=original_conv1.kernel_size, 
                            stride=original_conv1.stride, padding=original_conv1.padding, bias=False)

    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, num_classes)
    return model

### 2.2 STFT-Dataset

In [None]:
class STFTDataset(Dataset):
    def __init__(self, X, y, fit_on_data=None, fs=2e9, nperseg=128, noverlap=96):
        self.y = torch.tensor(y.astype(np.float32)).view(-1, 1)
        self.spectrograms = []

        print("Generating spectrograms")
        db_spectrograms = []
        for i in tqdm(range(len(X))):
            waveform = X[i]
            _, _, Zxx = stft(waveform, fs=fs, nperseg=nperseg, noverlap=noverlap, boundary=None, padded=False)
            Sxx = np.abs(Zxx)
            Sxx_db = 20 * np.log10(Sxx + 1e-9)
            db_spectrograms.append(Sxx_db)

        self.target_shape = db_spectrograms[0].shape
        print(f"STFT shape: {self.target_shape}")
            
        if fit_on_data is not None: # global min-max scaling
            self.min_val = fit_on_data.min_val
            self.max_val = fit_on_data.max_val
        else:
            # global min-max scaling on Train Set
            print("Calculating global min/max normalization")
            all_sxx_db = np.stack(db_spectrograms)
            self.min_val = all_sxx_db.min()
            self.max_val = all_sxx_db.max()
        
        print(f"Normalization stats: min={self.min_val:.2f}, max={self.max_val:.2f}")

        # Normalizing & Resizing
        print("Normalizing and Resizing spectrograms")
        for sxx_db in tqdm(db_spectrograms):
            sxx_norm = (sxx_db - self.min_val) / (self.max_val - self.min_val + 1e-40)
            self.spectrograms.append(torch.from_numpy(sxx_norm.astype(np.float32)).unsqueeze(0))

    def __len__(self):
        return len(self.spectrograms)

    def __getitem__(self, idx):
        return self.spectrograms[idx], self.y[idx]

### 2.3 Train / Validate

In [None]:
def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss    = criterion(outputs, labels)
            
            total_loss += loss.item()
            preds       = (torch.sigmoid(outputs) > 0.5).float()
            all_preds.extend(preds.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
    accuracy = accuracy_score(all_labels, all_preds)
    return total_loss / len(dataloader), accuracy, all_labels, all_preds

### 2.4 Mean XAI by Training

In [None]:
def compute_ig_mean(model, loader, device, target=0, n_steps=50, internal_batch_size=16, max_batches=None):
    model.eval()
    ig = IntegratedGradients(model)

    sum_attr = None
    n_used = 0

    for batch_idx, (inputs, _) in enumerate(loader):
        inputs = inputs.to(device, non_blocking=True)
        inputs.requires_grad_(True)

        baselines = torch.zeros_like(inputs)

        attr = ig.attribute(
            inputs,
            baselines=baselines,
            target=target,
            n_steps=n_steps,
            internal_batch_size=internal_batch_size,
        ) # (B, C, H, W)

        batch_sum = attr.sum(dim=0) # (C, H, W)

        if sum_attr is None:
            sum_attr = batch_sum.detach()
        else:
            sum_attr += batch_sum.detach()

        n_used += inputs.size(0)

        del inputs, baselines, attr, batch_sum

    mean_attr = sum_attr / n_used
    return mean_attr.detach().cpu()

## 3. Train Models

In [None]:
XAI_SAVE_DIR = "./XAI_During_Training"
os.makedirs(XAI_SAVE_DIR, exist_ok=True)

RESNET_VERSIONS = ['resnet34']
SAVE_DIR = "./STFT-ResNet_models_with_XAI/resnet34_stft_hp"
os.makedirs(SAVE_DIR, exist_ok=True)

fs = 2e9
LEARNING_RATE = [1e-3, 5e-4, 1e-4]
NPERSEG = [64, 128, 256]
OVERLAP_RATIO = [0.0, 0.5, 0.75]

EPOCHS = 30
BATCH_SIZE = 32
PATIENCE = 10

In [None]:
for model_name in RESNET_VERSIONS:
    for version in AES_VERSIONS:
        print(f"\n{'='*80}")
        print(f"Training {model_name} on {version} with STFT data")
        print(f"{'='*80}")

        # 1. Load Data
        X_train, y_train, X_val, y_val, X_test, y_test = load_supervised_set(version)

        for nperseg in NPERSEG:
            for ratio in OVERLAP_RATIO:
                noverlap = int(nperseg * ratio)
                
                # 2. Min/Max scaling on Train Set
                print("\n--- Preparing Train Dataset ---")
                train_dataset = STFTDataset(X_train, y_train, fs=fs, nperseg=nperseg, noverlap=noverlap)
                
                # Applying Min/Max to Validation/Test Set
                print("\n--- Preparing Validation Dataset ---")
                val_dataset   = STFTDataset(X_val, y_val, fit_on_data=train_dataset, fs=fs, nperseg=nperseg, noverlap=noverlap)
                print("\n--- Preparing Test Dataset ---")
                test_dataset  = STFTDataset(X_test, y_test, fit_on_data=train_dataset, fs=fs, nperseg=nperseg, noverlap=noverlap)
                
                # DataLoader
                train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
                val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)
                test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
                
                for lr in LEARNING_RATE:
                    # 3. Model / Loss / Optimizer
                    model = get_resnet(model_name).to(device)
                    criterion = nn.BCEWithLogitsLoss()
                    optimizer = optim.Adam(model.parameters(), lr=lr)
                    
                    # 4. Train Loop
                    best_val_acc = 0.0
                    best_val_loss = float("inf")
                    patience_counter = 0
                    best_model_path = os.path.join(SAVE_DIR, f"{version}_{model_name}_stft_n{nperseg}_o{noverlap}_lr{lr}_best.pth")
                    
                    ig_epochs = []
                    ig_means  = []
                    for epoch in tqdm(range(EPOCHS), desc=f"{version}-{model_name}-n{nperseg}-o{noverlap}-lr{lr}"):
                        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
                        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
                        
                        print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.8f} | Val Loss: {val_loss:.8f} | Val Acc: {val_acc:.4f}")
        
                        update_model = False
                        if val_acc > best_val_acc:
                            update_model = True
                        elif val_acc == best_val_acc and val_loss < best_val_loss:
                            update_model = True
        
                        if update_model:
                            best_val_acc  = val_acc
                            best_val_loss = val_loss
                            torch.save(model.state_dict(), best_model_path)
                            print(f"[BEST] Epoch {epoch+1} | Val Acc={val_acc:.4f} | Val Loss={val_loss:.8f}")
                            patience_counter = 0
                        else:
                            patience_counter += 1
                        
                        mean_attr = compute_ig_mean(model, train_loader, device=device, target=0, n_steps=50, internal_batch_size=BATCH_SIZE, max_batches=None)
                        ig_epochs.append(epoch + 1)
                        ig_means.append(mean_attr)
                        model.train()

                        if patience_counter >= PATIENCE:
                            print(f"Early stopping after epoch {epoch+1}.")
                            break
        
                    ig_save_path = os.path.join(
                        XAI_SAVE_DIR,
                        f"{version}_{model_name}_stft_n{nperseg}_o{noverlap}_lr{lr}_IG_HISTORY.pt"
                    )
                    torch.save(
                        {
                            "meta": {
                                "version": version,
                                "model_name": model_name,
                                "fs": fs,
                                "nperseg": nperseg,
                                "noverlap": noverlap,
                                "lr": lr,
                                "n_steps": 50,
                                "internal_batch_size": BATCH_SIZE,
                            },
                            "epochs": ig_epochs,
                            "mean": torch.stack(ig_means, dim=0),  # (E, C, H, W)
                        },
                        ig_save_path
                    )
                    
                    # 5. Best Model Evaluation
                    print("\n--- Best Model Evaluation on Test Set ---")
                    model.load_state_dict(torch.load(best_model_path))
                    
                    test_loss, test_acc, true_labels, pred_labels = validate(model, test_loader, criterion, device)
                    
                    metrics, cm = evaluate_model(true_labels, pred_labels)
                    print_eval(f"{version} - {model_name}-n{nperseg}-o{noverlap}-lr{lr} (STFT)", metrics, cm)

## 4. Visualize XAI mean by epochs

In [None]:
def visualize_epoch_xai_progression_history(
    xai_dir,
    version,
    model_name,
    nperseg,
    noverlap,
    lr,
    epoch_range=None,
    cmap="hot",
    use_std=False,
    max_cols=10
):
    fname = f"{version}_{model_name}_stft_n{nperseg}_o{noverlap}_lr{lr}_IG_HISTORY.pt"
    path = os.path.join(xai_dir, fname)

    if not os.path.exists(path):
        raise FileNotFoundError(f"Not found: {path}")

    obj = torch.load(path, map_location="cpu")

    epochs = obj["epochs"]                 # list[int]
    key = "std" if use_std else "mean"
    arr = obj[key]                         # tensor: (E, C, H, W) or (E, C, L)
    # (E, ...) -> numpy
    arr = arr.numpy()

    # epoch_range filter
    if epoch_range is not None:
        start, end = epoch_range
        keep_idx = [i for i, e in enumerate(epochs) if start <= e <= end]
        if len(keep_idx) == 0:
            raise RuntimeError("No epochs in the given epoch_range.")
        epochs = [epochs[i] for i in keep_idx]
        arr = arr[keep_idx]

    # squeeze channel dim handle
    if arr.ndim >= 3 and arr.shape[1] == 1:
        arr = arr[:, 0]

    n = len(epochs)

    # (L,) -> (1, L)
    if arr.ndim == 2:  # (E, L)
        arr = arr[:, None, :]

    cols = min(n, max_cols)
    rows = int(np.ceil(n / cols))

    plt.figure(figsize=(4 * cols, 4 * rows))

    for i, e in enumerate(epochs):
        r = i // cols
        c = i % cols
        plt.subplot(rows, cols, i + 1)
        plt.imshow(arr[i], cmap=cmap, origin="lower")
        plt.title(f"Epoch {e}")
        plt.axis("off")

    title_metric = "STD" if use_std else "MEAN"
    plt.suptitle(
        f"{version} | {model_name} | n={nperseg}, o={noverlap} | lr={lr} | {title_metric}",
        fontsize=14
    )
    plt.tight_layout()
    plt.show()

In [None]:
for version in AES_VERSIONS:
    for nperseg in NPERSEG:
        for ratio in OVERLAP_RATIO:
            noverlap = int(nperseg * ratio)
            for lr in LEARNING_RATE:
                visualize_epoch_xai_progression_history(
                    xai_dir="./XAI_During_Training",
                    version=version,
                    model_name="resnet34",
                    nperseg=nperseg,
                    noverlap=noverlap,
                    lr=lr,
                    epoch_range=(1, 30),
                    cmap="hot",
                    use_std=False
                )

## 5. Evaluate Models

In [None]:
for model_name in RESNET_VERSIONS:
    for version in AES_VERSIONS:
        print(f"\n{'='*80}")
        print(f"Evaluating {model_name} on {version}")
        print(f"{'='*80}")
        X_train, y_train, X_val, y_val, X_test, y_test = load_supervised_set(version)
        
        for nperseg in NPERSEG:
            for ratio in OVERLAP_RATIO:
                noverlap = int(nperseg * ratio)
                train_dataset = STFTDataset(X_train, y_train, fs=fs, nperseg=nperseg, noverlap=noverlap)
                test_dataset  = STFTDataset(X_test, y_test, fit_on_data=train_dataset, fs=fs, nperseg=nperseg, noverlap=noverlap)
                test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
                
                for lr in LEARNING_RATE:

                    model_path = f"{SAVE_DIR}/{version}_{model_name}_stft_n{nperseg}_o{noverlap}_lr{lr}_best.pth"
                    model = get_resnet(model_name).to(device)
                    model.load_state_dict(torch.load(model_path, map_location=device))
        
                    criterion = nn.BCEWithLogitsLoss()
                    test_loss, test_acc, true_labels, pred_labels = validate(model, test_loader, criterion, device)
                    metrics, cm = evaluate_model(true_labels, pred_labels)
                
                    print_eval(f"{version}-{model_name}-stft_n{nperseg}-o{noverlap}-lr{lr} (STFT)", metrics, cm)

## 6. XAI Mean & Diff

In [None]:
def attribute_image_features(model, input_tensor, target_class=0):
    model.eval()
    ig = IntegratedGradients(model)
    baselines = torch.zeros_like(input_tensor)
    attributions, delta = ig.attribute(input_tensor, baselines, target=target_class, return_convergence_delta=True)
    return attributions

def compute_mean_xai(model, dataset, indices, target_class):
    """
    특정 클래스 샘플들의 XAI 평균 계산
    """
    model.eval()
    attrs = []

    for idx in tqdm(indices):
        img, _ = dataset[idx]
        img = img.unsqueeze(0).to(device)
        attr = attribute_image_features(model, img, target_class=target_class)
        attrs.append(attr.detach().cpu().numpy())

    mean_attr = np.mean(np.stack(attrs), axis=0)
    return mean_attr.squeeze()  # (H, W)

def visualize_mean_xai(
    mean_xai,
    title="Mean XAI (Trojan)",
    cmap="coolwarm",
    save_path=None
):
    """
    평균 XAI heatmap 시각화
    mean_xai: (H, W) numpy array
    """

    plt.figure(figsize=(6, 5))

    # 대칭 컬러바를 위해 max 절댓값 기준 설정
    vmax = np.max(np.abs(mean_xai))

    im = plt.imshow(
        mean_xai,
        cmap=cmap,
        vmin=-vmax,
        vmax=vmax,
        aspect="auto",
        origin="lower"
    )

    plt.title(title, fontsize=16)
    plt.xlabel("Time Bins", fontsize=14)
    plt.ylabel("Frequency Bins", fontsize=14)
    plt.colorbar(im, fraction=0.046, pad=0.04, label="Attribution")
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        plt.close()
    else:
        plt.show()

In [None]:
MODEL_NAME = "resnet34"
XAI_MEAN_SAVE_DIR = f"{SAVE_DIR}/Mean_XAI"
os.makedirs(XAI_MEAN_SAVE_DIR, exist_ok=True)

for version in AES_VERSIONS:
    print(f"\n{'='*80}")
    print(f"Computing Mean XAI for {version}")
    print(f"{'='*80}")

    # -------------------------------------------------
    # 1. Load dataset
    # -------------------------------------------------
    X_train, y_train, _, _, X_test, y_test = load_supervised_set(version)

    for nperseg in NPERSEG:
        for ratio in OVERLAP_RATIO:
            noverlap = int(nperseg * ratio)
            train_dataset = STFTDataset(X_train, y_train, fs=fs, nperseg=nperseg, noverlap=noverlap)
            test_dataset  = STFTDataset(X_test, y_test, fit_on_data=train_dataset, fs=fs, nperseg=nperseg, noverlap=noverlap)

            for lr in LEARNING_RATE:
                # -------------------------------------------------
                # 2. Load model
                # -------------------------------------------------
                model_path = f"{SAVE_DIR}/{version}_{MODEL_NAME}_stft_n{nperseg}_o{noverlap}_lr{lr}_best.pth"
                model = get_resnet(MODEL_NAME).to(device)
                model.load_state_dict(torch.load(model_path, map_location=device))
                model.eval()
            
                # -------------------------------------------------
                # 3. Trojan indices (고정 가정)
                # -------------------------------------------------
                normal_indices = list(range(0, 1000))
                trojan_indices = list(range(1000, 2000))
            
                # -------------------------------------------------
                # 4. Compute Mean XAI (Trojan only)
                # -------------------------------------------------
                mean_normal_xai = compute_mean_xai(
                    model=model,
                    dataset=test_dataset,
                    indices=normal_indices,
                    target_class=0
                )
            
                mean_trojan_xai = compute_mean_xai(
                    model=model,
                    dataset=test_dataset,
                    indices=trojan_indices,
                    target_class=0
                )
                
                mean_diff_xai = np.abs(mean_trojan_xai - mean_normal_xai)
            
                # -------------------------------------------------
                # 5. Visualize & Save
                # -------------------------------------------------
                visualize_mean_xai(
                    mean_normal_xai,
                    title=f"{version} | Mean XAI (Normal) \n (nperseg={nperseg}, noverlap={noverlap}, lr={lr})",
                    save_path=f"{XAI_MEAN_SAVE_DIR}/{version}_{MODEL_NAME}_stft_n{nperseg}_o{noverlap}_lr{lr}_MeanXAI_Normal.png"
                )
            
                visualize_mean_xai(
                    mean_trojan_xai,
                    title=f"{version} | Mean XAI (Trojan) \n (nperseg={nperseg}, noverlap={noverlap}, lr={lr})",
                    save_path=f"{XAI_MEAN_SAVE_DIR}/{version}_{MODEL_NAME}_stft_n{nperseg}_o{noverlap}_lr{lr}_MeanXAI_Trojan.png"
                )
            
                visualize_mean_xai(
                    mean_diff_xai,
                    title=f"{version} | Mean XAI Difference (Trojan − Normal) \n (nperseg={nperseg}, noverlap={noverlap}, lr={lr})",
                    save_path=f"{XAI_MEAN_SAVE_DIR}/{version}_{MODEL_NAME}_stft_n{nperseg}_o{noverlap}_lr{lr}_MeanXAI_Diff.png"
                )
            
                print(f"[DONE] Saved mean XAI to {XAI_MEAN_SAVE_DIR}")

## 7. Visualize STFT Mean & Diff 

In [None]:
def generate_and_normalize_spectrograms(X_data, fs=2e9, nperseg=128, noverlap=96):
    db_spectrograms = []
    print("Generating spectrograms...")
    for waveform in tqdm(X_data):
        _, _, Zxx = stft(waveform,fs=fs, nperseg=nperseg, noverlap=noverlap, boundary=None, padded=False)
        Sxx = np.abs(Zxx)
        Sxx_db = 20 * np.log10(Sxx + 1e-9)
        db_spectrograms.append(Sxx_db)
    
    print("Calculating global min/max for normalization...")
    all_sxx_db = np.stack(db_spectrograms)
    min_val = all_sxx_db.min()
    max_val = all_sxx_db.max()
    print(f"Normalization stats: min={min_val:.2f}, max={max_val:.2f}")

    final_spectrograms = []
    print("Normalizing and resizing spectrograms...")
    for sxx_db in tqdm(db_spectrograms):
        sxx_norm = (sxx_db - min_val) / (max_val - min_val + 1e-40)
        final_spectrograms.append(sxx_norm.astype(np.float32))
        
    return np.stack(final_spectrograms)

In [None]:
for version in ["AES-T600", "AES-T1600"]:
    print(f"\n{'='*80}")
    print(f"Analyzing Mean Spectrogram Difference for: {version}")
    print(f"{'='*80}")

    X_train, y_train, _, _, _, _ = load_supervised_set(version)
    for nperseg in NPERSEG:
        for ratio in OVERLAP_RATIO:
            noverlap = int(nperseg * ratio)
            
            processed_spectrograms = generate_and_normalize_spectrograms(X_train, fs, nperseg, noverlap)
            
            # 3. Normal/Triggered Spectrogram
            normal_spectrograms = processed_spectrograms[y_train == 0]
            triggered_spectrograms = processed_spectrograms[y_train == 1]
            
            print(f"\nNormal shape={normal_spectrograms.shape}, Triggered shape={triggered_spectrograms.shape}")
        
            # 4. Mean Spectrogram by class
            mean_normal_spec = np.mean(normal_spectrograms, axis=0)
            mean_triggered_spec = np.mean(triggered_spectrograms, axis=0)
        
            # 5. Diff Mean
            difference_map = np.abs(mean_triggered_spec - mean_normal_spec)
            
            # 6. Visualization
            fig, axes = plt.subplots(1, 3, figsize=(24, 7))
            fig.suptitle(f"Mean Spectrogram Analysis for {version}", fontsize=20)
            
            TITLE_FONTSIZE = 20
            LABEL_FONTSIZE = 18
            TICK_FONTSIZE = 16
            
            # Plot 1: Mean Normal Spectrogram
            im1 = axes[0].imshow(mean_normal_spec, cmap='viridis', aspect='auto', origin='lower')
            axes[0].set_title('Mean Spectrogram (Disabled)', fontsize=TITLE_FONTSIZE)
            axes[0].set_xlabel('Time Bins', fontsize=LABEL_FONTSIZE)
            axes[0].set_ylabel('Frequency Bins', fontsize=LABEL_FONTSIZE)
            axes[0].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
            cb1 = fig.colorbar(im1, ax=axes[0])
            cb1.ax.tick_params(labelsize=TICK_FONTSIZE)
            
            # Plot 2: Mean Triggered Spectrogram
            im2 = axes[1].imshow(mean_triggered_spec, cmap='viridis', aspect='auto', origin='lower')
            axes[1].set_title('Mean Spectrogram (Triggered)', fontsize=TITLE_FONTSIZE)
            axes[1].set_xlabel('Time Bins', fontsize=LABEL_FONTSIZE)
            axes[1].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
            cb2 = fig.colorbar(im2, ax=axes[1])
            cb2.ax.tick_params(labelsize=TICK_FONTSIZE)
            
            # Plot 3: Diff Map
            max_abs_diff = np.abs(difference_map).max()
            im3 = axes[2].imshow(difference_map, cmap='coolwarm', vmin=-max_abs_diff, vmax=max_abs_diff, aspect='auto', origin='lower')
            axes[2].set_title('Difference |Triggered - Disabled|', fontsize=TITLE_FONTSIZE)
            axes[2].set_xlabel('Time Bins', fontsize=LABEL_FONTSIZE)
            axes[2].tick_params(axis='both', which='major', labelsize=TICK_FONTSIZE)
            cb3 = fig.colorbar(im3, ax=axes[2])
            cb3.set_label('Magnitude Difference', size=LABEL_FONTSIZE)
            cb3.ax.tick_params(labelsize=TICK_FONTSIZE)
            
            plt.tight_layout(rect=[0, 0.03, 1, 0.95])
            plt.show()