# 1.Library

In [1]:
import os
import shutil
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# Set random seeds for reproducibility
np.random.seed(0)
torch.manual_seed(0)

<torch._C.Generator at 0x1c053e9cd70>

# 2.Data Loader

In [2]:
def load_mnist_binary_data(data_dir, honest_range=100, poison_range=100, img_size=28,select='Honest+Poisoned'):
    X = []
    y = []
    def load_honest(X, y):
        for i in range(honest_range):
            folder = os.path.join(data_dir, f"honest_{i}")
            if not os.path.exists(folder):
                continue
            for label in os.listdir(folder):
                label_folder = os.path.join(folder, label)
                if not os.path.isdir(label_folder):
                    continue
                for fname in os.listdir(label_folder):
                    if fname.endswith('.png') or fname.endswith('.jpg'):
                        img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                        X.append(np.array(img).flatten() / 255.0)
                        y.append(label)
        return X, y
    def load_poisoned(X, y):
        for i in range(poison_range):
            folder = os.path.join(data_dir, f"poison_{i}")
            if not os.path.exists(folder):
                continue
            for label in os.listdir(folder):
                label_folder = os.path.join(folder, label)
                if not os.path.isdir(label_folder):
                    continue
                for fname in os.listdir(label_folder):
                    if fname.endswith('.png') or fname.endswith('.jpg'):
                        img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                        X.append(np.array(img).flatten() / 255.0)
                        y.append(label)
        return X, y
    
    if select == 'Honest+Poisoned':
        X, y = load_honest(X, y)
        X, y = load_poisoned(X, y)
    elif select == 'Poisoned+Honest':
        X, y = load_poisoned(X, y)
        X, y = load_honest(X, y)
    elif select == 'Honest':
        X, y = load_honest(X, y)
    elif select == 'Poisoned':
        X, y = load_poisoned(X, y)

    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    return X, y

def load_mnist_binary_test_data(test_dir, img_size=28):
    X = []
    y = []
    if not os.path.exists(test_dir):
        return X, y
    for label in ['0', '1']:
        label_folder = os.path.join(test_dir, label)
        if not os.path.isdir(label_folder):
            continue
        for fname in os.listdir(label_folder):
            if fname.endswith('.png') or fname.endswith('.jpg'):
                img = Image.open(os.path.join(label_folder, fname)).convert('L').resize((img_size, img_size))
                X.append(np.array(img).flatten() / 255.0)
                y.append(label)
    X = np.array(X, dtype=np.float32)
    y = np.array(y, dtype=np.float32)
    return X, y

# 3.Model

In [3]:
# class MNISTNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(28*28, 32)
#         self.fc2 = nn.Linear(32, 1)
#         self.relu = nn.ReLU()
#     def forward(self, x):
#         x = self.fc1(x)
#         x = self.relu(x)
#         x = self.fc2(x)
#         x = torch.sigmoid(x)
#         return x

In [4]:
# class MNISTNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(28*28, 1)
#     def forward(self, x):
#         x = self.fc1(x)
#         x = torch.sigmoid(x)
#         return x

# 4.Loss

In [5]:
# def binary_cross_entropy(pred, target):
#     eps = 1e-7
#     pred = torch.clamp(pred, eps, 1 - eps)
#     return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).mean()

def multiclass_cross_entropy(logits, target, from_probs=False, eps=1e-7):
    """
    Cross-entropy loss for multiclass (e.g., 10 classes).
    - logits: Tensor shape (N, C) containing raw scores (preferred) or probabilities if from_probs=True
    - target: Tensor shape (N,) or (N,1) containing class indices (0..C-1)
    """
    if logits.dim() == 1:
        logits = logits.unsqueeze(0)
    target = target.squeeze()
    if target.dtype != torch.long:
        target = target.long()
    if from_probs:
        probs = torch.clamp(logits, eps, 1.0 - eps)
        log_probs = torch.log(probs)
        return -log_probs[torch.arange(target.size(0)), target].mean()
    return nn.functional.cross_entropy(logits, target)


# 5.Configuration

In [6]:
N_HONEST = 100
N_POISONED = 100
IMG_SIZE = 28
LEARNING_RATE = 0.001
EPOCHS = 100
BATCH_SIZE = 32
N_CLASSES = 10

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DATA_PATH = "../data/mnist_multiple_poison/train"
TEST_PATH = "../data/mnist_multiple_poison/test"

RESULT_PATH = "./results"


# 6.Main

In [7]:
poison_percent = [0]#[0, 10, 20, 30, 40, 50]
poison_n_list = []
for p in range(0,N_POISONED+1):
    percent = int((p / (p+N_HONEST)) * 100)
    if percent in poison_percent:
        poison_percent.remove(percent)
        print(f"Adding poison n: {p} for percent: {percent}")
        poison_n_list.append(p)
print("Poison n list:", poison_n_list)

Adding poison n: 0 for percent: 0
Poison n list: [0]


In [8]:

import re
def make_safe_filename(name: str) -> str:
    # remove control chars and characters invalid on Windows
    safe = re.sub(r'[<>:"/\\|?*\x00-\x1F]', '_', name)
    # strip trailing dots/spaces which are problematic on Windows
    safe = safe.rstrip('. ')
    return safe


In [9]:
NODE = 1
for LAYER in [3]:#[3,4,5,6]: # LAYER - (first,last)
    
    class MNISTNet(nn.Module):
        def __init__(self):
            super().__init__()
            
            # 1. Define the stack of hidden layers (Linear + ReLU)
            # Use a list comprehension to create the (Linear, ReLU) pairs
            hidden_layers = []
            for _ in range(LAYER-2):
                hidden_layers.append(nn.Linear(NODE, NODE))
                hidden_layers.append(nn.ReLU())
                
            # 2. Use nn.Sequential to wrap the entire dynamic stack
            self.dynamic_hidden = nn.Sequential(*hidden_layers)
            
            # Other layers
            self.fc1 = nn.Linear(28*28, NODE)
            self.fc2 = nn.Linear(NODE, N_CLASSES)

        def forward(self, x):
            # Flatten the input
            x = x.view(x.size(0), -1) 
            
            # Input Layer
            x = self.fc1(x)
            x = nn.ReLU()(x) # We need to manually apply ReLU here
            
            # Dynamic Hidden Layers (Executed sequentially without a loop)
            x = self.dynamic_hidden(x)
                
            # Output Layer
            x = self.fc2(x)
            # x = torch.sigmoid(x)
            
            return x

    BATCH_SIZE_list = [BATCH_SIZE]
    # temp_bs = BATCH_SIZE
    # while temp_bs >= 1:
    #     BATCH_SIZE_list.append(temp_bs)
    #     temp_bs = temp_bs // 2

    for rev in range(3):

        model = MNISTNet().to(device)
        for experiment_bs in BATCH_SIZE_list:

            name_save_path = f"CL_ModelN{NODE}L{LAYER}_Batchsize{experiment_bs}_rev{rev}"
            
            for i_poisoned in poison_n_list:
                percent_poisoned = int((i_poisoned / (i_poisoned + N_HONEST)) * 100)

                # create directory
                NAME_SAVE_update_PATH = f"poisoned_{percent_poisoned}percent"
                save_path = os.path.join(RESULT_PATH,name_save_path,NAME_SAVE_update_PATH)
                # Remove existing directory if it exists
                if os.path.exists(save_path):
                    shutil.rmtree(save_path)
                if not os.path.exists(save_path):
                    os.makedirs(save_path)
                save_all_path = os.path.join(RESULT_PATH,name_save_path,'all')
                os.makedirs(save_all_path, exist_ok=True)
                
                # Load data
                X_train, y_train = load_mnist_binary_data( DATA_PATH, honest_range=N_HONEST, 
                                                        poison_range=i_poisoned, img_size=IMG_SIZE, 
                                                        select='Honest+Poisoned')
                X_test, y_test = load_mnist_binary_test_data( TEST_PATH, img_size=IMG_SIZE)
                # print(f"Loaded train: X={X_train.shape}, y={y_train.shape}")
                # print(f"Loaded test: X={X_test.shape}, y={y_test.shape}")

                if X_test.shape[0] == 0:
                    raise ValueError("No valid test images found in the test directory!")

                X_train_tensor = torch.tensor(X_train, dtype=torch.float32).to(device)
                y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1).to(device)
                X_test_tensor = torch.tensor(X_test, dtype=torch.float32).to(device)
                y_test_tensor = torch.tensor(y_test, dtype=torch.float32).unsqueeze(1).to(device)

                torch.save(model, os.path.join(save_path,'model_init.pt'))
                optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

                records = []

                # Training loop
                for epoch in range(EPOCHS):

                    # Train
                    model.train()
                    
                    #randomize the training data
                    perm = torch.randperm(X_train_tensor.size(0))
                    X_train_tensor = X_train_tensor[perm]
                    y_train_tensor = y_train_tensor[perm]

                    for start in range(0, X_train_tensor.size(0), experiment_bs):
                        end = start + experiment_bs
                        xb = X_train_tensor[start:end]
                        yb = y_train_tensor[start:end]
                        outputs = model(xb)
                        # preds = (outputs >= 0.5).float()
                        preds = torch.argmax(outputs, dim=1, keepdim=True)
                        acc_train = (preds == yb).float().mean().item()
                        loss_train = multiclass_cross_entropy(outputs, yb)
                        optimizer.zero_grad()
                        loss_train.backward()
                        optimizer.step()

                    # Evaluate
                    model.eval()
                    with torch.no_grad():
                        outputs = model(X_test_tensor)
                        # preds = (outputs >= 0.5).float()
                        preds = torch.argmax(outputs, dim=1, keepdim=True)
                        loss_test = multiclass_cross_entropy(outputs, y_test_tensor)
                        acc_test = (preds == y_test_tensor).float().mean().item()
                    records.append({'poison_percent': percent_poisoned, 'epoch': epoch, 
                                    'loss_train': loss_train.item(), 'acc_train': acc_train, 
                                    'loss_test': loss_test.item(), 'acc_test': acc_test})
                    report_txt = f"poison_percent {percent_poisoned} "
                    report_txt += f"epoch {epoch}: "
                    report_txt += f"loss_train={loss_train.item():.4f}, "
                    report_txt += f"acc_train={acc_train:.4f}, "
                    report_txt += f"loss_test={loss_test.item():.4f}, "
                    report_txt += f"acc_test={acc_test:.4f}"
                    print(report_txt)

                torch.save(model, os.path.join(save_path,'model_last.pt'))

                # Save training log
                df = pd.DataFrame(records)

                # save_name_path = os.path.join(save_path, f'{NAME_SAVE_update_PATH}.csv')
                safe_name = make_safe_filename(NAME_SAVE_update_PATH)
                save_name_path = os.path.abspath(os.path.join(save_path, f'{safe_name}.csv'))
                os.makedirs(os.path.dirname(save_name_path), exist_ok=True)

                df.to_csv(save_name_path, index=False)
                print(f"Training log saved to {save_name_path}")

                # Plot loss and accuracy
                plt.figure(figsize=(10,4))

                plt.subplot(1,2,1)
                plt.plot(df['epoch'], df['loss_test'], marker='o')
                plt.title('Test Loss')
                plt.ylabel('Loss')
                plt.xlabel('Epoch')
                plt.grid(True)
                plt.ylim(0, 1.1)
                # plt.legend()

                plt.subplot(1,2,2)
                plt.plot(df['epoch'], df['acc_test'], marker='o')
                plt.title('Test Accuracy')
                plt.ylabel('Accuracy')
                plt.xlabel('Epoch')
                plt.grid(True)
                plt.ylim(0, 1.1)
                # plt.legend()

                plt.tight_layout()
                save_name_path = os.path.join(save_path, f'loss_accuracy.jpg')
                plt.savefig(save_name_path)

                save_name_path = os.path.join(save_all_path, f'{NAME_SAVE_update_PATH}_latest.jpg')
                plt.savefig(save_name_path)
                plt.close()


poison_percent 0 epoch 0: loss_train=1.9543, acc_train=0.3684, loss_test=2.2302, acc_test=0.4871
poison_percent 0 epoch 1: loss_train=1.9492, acc_train=0.2632, loss_test=1.8538, acc_test=0.0711
poison_percent 0 epoch 2: loss_train=1.9080, acc_train=0.1579, loss_test=1.7080, acc_test=0.1048
poison_percent 0 epoch 3: loss_train=1.6779, acc_train=0.3684, loss_test=1.6900, acc_test=0.0000
poison_percent 0 epoch 4: loss_train=1.8340, acc_train=0.2105, loss_test=1.6527, acc_test=0.5025
poison_percent 0 epoch 5: loss_train=1.6869, acc_train=0.2632, loss_test=1.6209, acc_test=0.4908
poison_percent 0 epoch 6: loss_train=1.6627, acc_train=0.2105, loss_test=1.5792, acc_test=0.9406
poison_percent 0 epoch 7: loss_train=1.6143, acc_train=0.3684, loss_test=1.5832, acc_test=0.4939
poison_percent 0 epoch 8: loss_train=1.6620, acc_train=0.3158, loss_test=1.5761, acc_test=0.5043
poison_percent 0 epoch 9: loss_train=1.6312, acc_train=0.1053, loss_test=1.5647, acc_test=0.4467
poison_percent 0 epoch 10: los