In [1]:
import os
import sys
import torch
import torchvision
from torchvision import transforms
from tqdm import tqdm
import pandas as pd
from PIL import Image
import numpy as np
import random
import glob
import torch.nn.functional as F
import pandas as pd
import uuid
import matplotlib.pyplot as plt

from datasets import SimpleMNISTDataset, prepare_mnist_data, minmax_normalize, prepare_cifar10_data, get_cifar10_transforms, SimpleCIFAR10Dataset
from Load_Model import load_mnist_model, get_model_details, load_model, split_model_for_mask
from mask import MaskGenerator
from unet import UNet, trigger_inversion_loss_hard, trigger_inversion_loss_hinge, trigger_inversion_loss_smooth, trigger_inversion_loss_hinge_strong
from trigger_visualisation import visualize_inverse_trigger_grid, visualize_inverse_trigger_grid_cifar

In [2]:
num_models=1

# Load model list
df = pd.read_csv('Odysseus-CIFAR10/CSV/test.csv')
triggered_models = df[df['Label'] == 1].head(num_models)

# Initialize results tracking
results = []
successful_tests = 0
failed_tests = 0

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Prepare CIFAR10 data if not already present
print("Preparing CIFAR10 dataset...")
prepare_cifar10_data()

# Get transforms
transform_train, transform_test = get_cifar10_transforms()

# Create test dataset
test_dataset = SimpleCIFAR10Dataset(
    path_to_data='./CIFAR10_Data/clean',
    csv_filename='clean.csv',
    data_transform=transform_test
)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
print(f"Test dataset size: {len(test_dataset)}")

Using device: cuda
Preparing CIFAR10 dataset...
Files already downloaded and verified


100%|███████████████████████████████████| 10000/10000 [00:01<00:00, 7783.22it/s]

Saved 10000 test images to ./CIFAR10_Data/clean
Saved CSV to ./CIFAR10_Data/clean/clean.csv
Test dataset size: 10000





In [3]:
def iterative_bti_dbf_training(mask_generator, generator, Sa, Sb, dataloader, device,
                               I1=200, I2=500, R1=3,
                               lr_mask=1e-2, lr_gen=1e-3, 
                               tau=0.01, constraint_weight=5000.0):
    """
    Iteration-based BTI-DBF training loop (Algorithm 1 from the paper).
    
    Args:
        mask_generator: trained MaskGenerator (decouples benign features)
        generator: U-Net generator Gθ for trigger inversion
        Sa, Sb: split model parts (conv features, classifier head)
        dataloader: benign sample loader
        I1: iterations for mask optimization
        I2: iterations for generator optimization
        R1: outer alternations of mask/gen training
    """
    mask_generator.train()
    generator.train()
    Sa.eval()
    Sb.eval()

    for r in range(R1):
        print(f"\n=== Iteration {r+1}/{R1} ===")

        # ---------- STEP 1: Train Mask m ----------
        opt_mask = torch.optim.Adam(mask_generator.parameters(), lr=lr_mask)
        for i in range(I1):
            total_loss = 0
            for x, y, *_ in dataloader:
                x, y = x.to(device), y.to(device)
                opt_mask.zero_grad()
                
                with torch.no_grad():
                    feat = Sa(x).view(x.size(0), -1)
                
                m = mask_generator.get_raw_mask().expand_as(feat)
                benign_logits = Sb(feat * m)
                backdoor_logits = Sb(feat * (1 - m))
                
                loss_benign = F.cross_entropy(benign_logits, y)
                loss_backdoor = F.cross_entropy(backdoor_logits, y)
                loss = loss_benign - loss_backdoor
                
                loss.backward()
                opt_mask.step()
                total_loss += loss.item()
            if (i+1) % max(1, I1//5) == 0:
                print(f"[Mask] Iter {i+1}/{I1} Loss={total_loss/len(dataloader):.4f}")

        # Freeze mask for generator
        m_final = mask_generator.get_raw_mask().detach()

        # ---------- STEP 2: Train Generator Gθ ----------
        opt_gen = torch.optim.Adam(generator.parameters(), lr=lr_gen)
        for j in range(I2):
            total_loss = 0
            for x, y, *_ in dataloader:
                x = x.to(device)
                opt_gen.zero_grad()

                Gx = torch.sigmoid(generator(x))
                loss_main, loss_constraint = trigger_inversion_loss(x, Gx, Sa, m_final, tau)
                loss = loss_main + constraint_weight * loss_constraint

                loss.backward()
                opt_gen.step()
                total_loss += loss.item()
            if (j+1) % max(1, I2//5) == 0:
                print(f"[Gen] Iter {j+1}/{I2} Loss={total_loss/len(dataloader):.4f}")

    print("✅ Iterative BTI-DBF training complete.")


In [4]:
# Test each model
for idx, row in triggered_models.iterrows():
    print(f'Testing model {idx+1} of {len(triggered_models)}')
    model_file = row['Model File']
    model_path = f'Odysseus-CIFAR10/Models/{model_file}'

    model_details_dict = get_model_details(model_path)

    model, mapping = load_model(model_path, device)

    trigger_type = model_details_dict['Trigger type']
    trigger_location = model_details_dict['Trigger_location']
    
    # Split the model
    Sa, Sb = split_model_for_mask(model)
    
    # Dummy forward pass to get feature size
    with torch.no_grad():
        x_sample, _, _ = next(iter(test_loader))
        x_sample = x_sample.to(device)
        feat = Sa(x_sample)
        feat_dim = feat.view(x_sample.size(0), -1).shape[1]  # Flattened feature dim
    
    mask_epochs_trials = [20]
    unet_epochs_trials = [30]
    unet_tau_trials = [0.03]
    unet_constraint_weight_trials = [5e3]
    
    results_file = "experiment_results.csv"
    results_columns = [
        "experiment_id", 
        "mask_epochs", 
        "unet_epochs", 
        "unet_tau", 
        "unet_constraint_weight", 
        "mask_loss",
        "inversion_loss", 
        "constraint_loss", 
        "total_loss",
        "model_name",
        "trigger_type",
        "trigger_location"
    ]
    
    # Create the DataFrame (load CSV if it exists)
    if os.path.exists(results_file):
        results_df = pd.read_csv(results_file)
    else:
        results_df = pd.DataFrame(columns=results_columns)
    
    for mask_epochs in mask_epochs_trials:
        # Init soft mask
        init_mask = torch.zeros(1, feat_dim).to(device)
        mask_generator = MaskGenerator(init_mask, Sb).to(device)
        
        # Train the mask
        mask_loss = mask_generator.train_decoupling_mask(
            Sa, Sb, 
            test_loader, device, 
            epochs=mask_epochs
        )
    
        for unet_epochs in unet_epochs_trials:
            for unet_tau in unet_tau_trials:
                for unet_constraint_weight in unet_constraint_weight_trials:
                    # Get Mask
                    mask = mask_generator.get_raw_mask().detach()
                    
                    # Initialize U-Net trigger generator
                    G = UNet(n_channels=3, num_classes=3, base_filter_num=32, num_blocks=4).to(device)
                    
                    # Train U-Net
                    inversion_loss = G.train_generator(
                        Sa, mask, test_loader, device,
                        epochs=unet_epochs,
                        tau=unet_tau,
                        constraint_weight=unet_constraint_weight,
                        loss_func=trigger_inversion_loss_hinge
                    )
                    
                    # Log results
                    experiment_id = str(uuid.uuid4())[:8]  # short unique ID
                    results_df = pd.concat([results_df, pd.DataFrame([{
                        "experiment_id": experiment_id,
                        "mask_epochs": mask_epochs,
                        "unet_epochs": unet_epochs,
                        "unet_tau": unet_tau,
                        "unet_constraint_weight": unet_constraint_weight,
                        "mask_loss": mask_loss,
                        "inversion_loss": inversion_loss,
                        "model_name": model_file,
                        "trigger_type": trigger_type,
                        "trigger_location": trigger_location
                    }])], ignore_index=True)
                    visualize_inverse_trigger_grid_cifar(G, test_loader, device, experiment_id, model_file)
    
    # Save CSV
    results_df.to_csv(results_file, index=False)
    
    print(f"✅ Results saved to {results_file}")

  checkpoint = torch.load(model_path, map_location="cpu")
  checkpoint = torch.load(model_path)


Testing model 1 of 1
model path  Odysseus-CIFAR10/Models/Model_1077.pth
keys are : dict_keys(['net', 'Model Category', 'Architecture_Name', 'Learning_Rate', 'Loss Function', 'optimizer', 'Momentum', 'Weight decay', 'num_workers', 'Pytorch version', 'Trigger type', 'Trigger_location', 'Trigger Size', 'Mapping', 'Normalization Type', 'Mapping Type', 'Dataset', 'Batch Size', 'trigger_fraction', 'test_clean_acc', 'test_trigerred_acc', 'epoch'])
==> Building model..
The Accuracies on clean samples:   90.8
The fooling rate:  88.95
Mapping is :  [3 2 8 1 6 7 9 4 5 0]


Epoch 1/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 100.15it/s]


Epoch 1: Loss=-31.6519, Benign Acc=0.6884


Epoch 2/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 113.38it/s]


Epoch 2: Loss=-238.8949, Benign Acc=0.8712


Epoch 3/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 110.54it/s]


Epoch 3: Loss=-506.1025, Benign Acc=0.9022


Epoch 4/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 110.73it/s]


Epoch 4: Loss=-788.9374, Benign Acc=0.9076


Epoch 5/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 112.13it/s]


Epoch 5: Loss=-1088.9162, Benign Acc=0.9057


Epoch 6/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 113.43it/s]


Epoch 6: Loss=-1397.9059, Benign Acc=0.9050


Epoch 7/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 113.66it/s]


Epoch 7: Loss=-1707.3443, Benign Acc=0.9080


Epoch 8/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 112.31it/s]


Epoch 8: Loss=-2016.0710, Benign Acc=0.9080


Epoch 9/20: 100%|██████████████████████████████| 79/79 [00:00<00:00, 111.82it/s]


Epoch 9: Loss=-2320.6327, Benign Acc=0.9085


Epoch 10/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 113.20it/s]


Epoch 10: Loss=-2617.9125, Benign Acc=0.9074


Epoch 11/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 112.62it/s]


Epoch 11: Loss=-2907.9573, Benign Acc=0.9092


Epoch 12/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 110.31it/s]


Epoch 12: Loss=-3198.5906, Benign Acc=0.9063


Epoch 13/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 112.36it/s]


Epoch 13: Loss=-3488.7895, Benign Acc=0.9075


Epoch 14/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 112.81it/s]


Epoch 14: Loss=-3773.2550, Benign Acc=0.9075


Epoch 15/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 111.91it/s]


Epoch 15: Loss=-4048.9822, Benign Acc=0.9068


Epoch 16/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 111.30it/s]


Epoch 16: Loss=-4321.6507, Benign Acc=0.9066


Epoch 17/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 111.23it/s]


Epoch 17: Loss=-4599.2860, Benign Acc=0.9063


Epoch 18/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 111.67it/s]


Epoch 18: Loss=-4874.5507, Benign Acc=0.9060


Epoch 19/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 112.46it/s]


Epoch 19: Loss=-5150.6268, Benign Acc=0.9058


Epoch 20/20: 100%|█████████████████████████████| 79/79 [00:00<00:00, 112.33it/s]


Epoch 20: Loss=-5427.0820, Benign Acc=0.9068
✅ Mask training complete.
Epoch 1: Train Loss=5449.7867
Epoch 2: Train Loss=4750.7699
Epoch 3: Train Loss=4572.4077
Epoch 4: Train Loss=4509.2790
Epoch 5: Train Loss=4475.5173
Epoch 6: Train Loss=4449.7664
Epoch 7: Train Loss=4427.5038
Epoch 8: Train Loss=4409.5936
Epoch 9: Train Loss=4407.1509
Epoch 10: Train Loss=4394.8852
Epoch 11: Train Loss=4382.2734
Epoch 12: Train Loss=4377.3157
Epoch 13: Train Loss=4375.1075
Epoch 14: Train Loss=4379.0898
Epoch 15: Train Loss=4377.5123
Epoch 16: Train Loss=4374.4496
Epoch 17: Train Loss=4367.2087
Epoch 18: Train Loss=4361.4376
Epoch 19: Train Loss=4359.3106
Epoch 20: Train Loss=4357.1253
Epoch 21: Train Loss=4356.1288
Epoch 22: Train Loss=4360.8540
Epoch 23: Train Loss=4363.6959
Epoch 24: Train Loss=4360.4733
Epoch 25: Train Loss=4361.7671
Epoch 26: Train Loss=4363.4889
Epoch 27: Train Loss=4358.2974
Epoch 28: Train Loss=4351.2626
Epoch 29: Train Loss=4349.3697
Epoch 30: Train Loss=4348.5485
✅ Saved 