In [2]:
# train_wesad_alignment.py
import os
import random
import numpy as np
import pandas as pd
from typing import List, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# ===============================
# 1) CONFIGURATION
# ===============================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

EMBEDDING_DIM = 64            # shared embedding size
MARGIN = 0.5                  # similarity threshold used for "match" decision
EPOCHS = 15                   # WESAD rows can be many; start modest
BATCH_SIZE = 16               # larger batch is fine with GPU; adjust if OOM
LEARNING_RATE = 1e-4
TEXT_MODEL_NAME = "distilbert-base-uncased"

# === Attached dataset path ===
DATA_FILE = "C:/Users/siu856558563/OneDrive - Southern Illinois University/Documents/Attack/SensorEncoderAttack/Data/WESAD.csv"

# Columns (the code will auto-detect if these are absent)
PREFERRED_TEXT_COLS = ["Semantic_Interpretation", "text", "Text", "semantic", "caption"]
EXCLUDE_COLS = {"Label", "label"}  # exclude from sensor features if present

# Model save paths
SENSOR_MODEL_PATH = "sensor_encoder_wesad.pth"
TEXT_MODEL_PATH = "text_encoder_wesad.pth"

# Optional tokenizer max length
MAX_LEN = 128

# ===============================
# 2) DATASET
# ===============================
class SensorTextDataset(Dataset):
    def __init__(self, sensor: np.ndarray, texts: List[str], labels: np.ndarray):
        self.sensor = torch.tensor(sensor, dtype=torch.float32)
        self.texts = list(texts)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # return texts raw; tokenization is done inside the text encoder
        return self.sensor[idx], self.texts[idx], self.labels[idx]

def collate_fn(batch):
    """
    Keep text as a list for tokenizer; stack sensors/labels into tensors.
    """
    sensors = torch.stack([b[0] for b in batch], dim=0)
    texts = [b[1] for b in batch]
    labels = torch.stack([b[2] for b in batch], dim=0)
    return sensors, texts, labels

# ===============================
# 3) MODELS
# ===============================
class SensorEncoder(nn.Module):
    def __init__(self, input_dim: int, output_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim),
        )

    def forward(self, x):
        return self.net(x)

class TextEncoder(nn.Module):
    def __init__(self, model_name: str, output_dim: int):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.bert = AutoModel.from_pretrained(model_name)
        self.proj = nn.Linear(self.bert.config.hidden_size, output_dim)

    def forward(self, texts: List[str]):
        # Tokenize on the fly; move to the same device as the BERT model
        enc = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
        )
        enc = {k: v.to(self.bert.device) for k, v in enc.items()}
        out = self.bert(**enc).last_hidden_state  # [B, T, H]
        pooled = out.mean(dim=1)                  # mean pool -> [B, H]
        return self.proj(pooled)

# ===============================
# 4) LOSS
# ===============================
class ContrastiveSimilarityLoss(nn.Module):
    """
    L = y * (1 - S) + (1 - y) * max(0, S - margin), where S = cosine_similarity
    """
    def __init__(self, margin=0.5):
        super().__init__()
        self.margin = margin

    def forward(self, u, v, y):
        # u, v: [B, D]; y: [B] or [B, 1]
        sim = F.cosine_similarity(u, v)  # [B]
        if y.dim() > 1:
            y = y.squeeze(1)
        zero = torch.zeros_like(sim)
        loss_pos = y * (1.0 - sim)                 # push positives to high similarity
        loss_neg = (1.0 - y) * torch.clamp(sim - self.margin, min=0.0)
        return (loss_pos + loss_neg).mean()

# ===============================
# 5) UTILS
# ===============================
def find_text_column(df: pd.DataFrame) -> str:
    # Prefer known names
    for c in PREFERRED_TEXT_COLS:
        if c in df.columns and df[c].dtype == object:
            return c
    # Otherwise pick the first object dtype column with non-empty strings
    obj_cols = [c for c in df.columns if df[c].dtype == object]
    if len(obj_cols) == 0:
        # If nothing object-typed, try to coerce any column to string that looks textual
        raise ValueError(
            "No textual (object) column found for semantic descriptions. "
            f"Tried: {PREFERRED_TEXT_COLS}. Please add one to {DATA_FILE}."
        )
    return obj_cols[0]

def select_sensor_columns(df: pd.DataFrame, text_col: str) -> List[str]:
    # numeric columns except excluded and the text column
    num_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
    candidates = [c for c in num_cols if c not in EXCLUDE_COLS and c != text_col]
    if len(candidates) == 0:
        raise ValueError("No numeric sensor feature columns detected.")
    return candidates

def standardize_train_test(train: np.ndarray, test: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    # simple standardization (z-score) per feature using train stats
    mean = train.mean(axis=0, keepdims=True)
    std = train.std(axis=0, keepdims=True) + 1e-8
    return (train - mean) / std, (test - mean) / std

# ===============================
# 6) TRAIN
# ===============================
def train():
    print("=== Loading dataset ===")
    if not os.path.exists(DATA_FILE):
        raise FileNotFoundError(f"Data file not found: {DATA_FILE}")

    df = pd.read_csv(DATA_FILE)

    text_col = find_text_column(df)
    sensor_cols = select_sensor_columns(df, text_col)

    print(f"Using text column: {text_col}")
    print(f"Using {len(sensor_cols)} sensor columns: {sensor_cols[:6]}{' ...' if len(sensor_cols)>6 else ''}")

    # Pull raw arrays
    sensors = df[sensor_cols].to_numpy(dtype=np.float32)
    texts = df[text_col].astype(str).fillna("").tolist()

    # Build synthetic positives (aligned) and negatives (misaligned via one-step roll)
    pos_sensor, pos_text, pos_labels = sensors, texts, np.ones(len(df), dtype=np.float32)
    neg_sensor, neg_text, neg_labels = sensors, np.roll(texts, 1).tolist(), np.zeros(len(df), dtype=np.float32)

    all_sensor = np.concatenate([pos_sensor, neg_sensor], axis=0)
    all_text = pos_text + neg_text
    all_labels = np.concatenate([pos_labels, neg_labels], axis=0)

    # Train/test split
    Xs_tr, Xs_te, Xt_tr, Xt_te, y_tr, y_te = train_test_split(
        all_sensor, all_text, all_labels, test_size=0.2, random_state=SEED, shuffle=True
    )

    # Standardize numeric features
    Xs_tr, Xs_te = standardize_train_test(Xs_tr, Xs_te)

    # Build models
    sensor_encoder = SensorEncoder(input_dim=all_sensor.shape[1], output_dim=EMBEDDING_DIM).to(DEVICE)
    text_encoder = TextEncoder(TEXT_MODEL_NAME, EMBEDDING_DIM).to(DEVICE)
    criterion = ContrastiveSimilarityLoss(margin=MARGIN)
    optimizer = optim.Adam(list(sensor_encoder.parameters()) + list(text_encoder.parameters()),
                           lr=LEARNING_RATE)

    train_ds = SensorTextDataset(Xs_tr, Xt_tr, y_tr)
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

    print(f"Training samples: {len(train_ds)}  |  Sensor dim: {all_sensor.shape[1]}  |  Device: {DEVICE}")

    sensor_encoder.train()
    text_encoder.train()
    for epoch in range(1, EPOCHS + 1):
        running = 0.0
        for sensor_batch, text_batch, label_batch in train_loader:
            sensor_batch = sensor_batch.to(DEVICE)
            label_batch = label_batch.to(DEVICE)

            optimizer.zero_grad()
            z_s = sensor_encoder(sensor_batch)
            z_t = text_encoder(text_batch)  # text encoder handles its own device

            loss = criterion(z_s, z_t, label_batch)
            loss.backward()
            optimizer.step()
            running += loss.item()

        avg = running / max(1, len(train_loader))
        if epoch == 1 or epoch % 2 == 0 or epoch == EPOCHS:
            print(f"Epoch {epoch:02d}/{EPOCHS} - loss: {avg:.4f}")

    # Save models
    torch.save(sensor_encoder.state_dict(), SENSOR_MODEL_PATH)
    torch.save(text_encoder.state_dict(), TEXT_MODEL_PATH)
    print(f"Saved: {SENSOR_MODEL_PATH}, {TEXT_MODEL_PATH}")

    return Xs_te, Xt_te, y_te, sensor_cols, text_col

# ===============================
# 7) EVALUATE
# ===============================
def evaluate(Xs_te, Xt_te, y_te, sensor_cols):
    print("\n=== Evaluating on held-out test set ===")
    sensor_encoder = SensorEncoder(input_dim=len(sensor_cols), output_dim=EMBEDDING_DIM).to(DEVICE)
    text_encoder = TextEncoder(TEXT_MODEL_NAME, EMBEDDING_DIM).to(DEVICE)

    sensor_encoder.load_state_dict(torch.load(SENSOR_MODEL_PATH, map_location=DEVICE))
    text_encoder.load_state_dict(torch.load(TEXT_MODEL_PATH, map_location=DEVICE))
    sensor_encoder.eval(); text_encoder.eval()

    test_ds = SensorTextDataset(Xs_te, Xt_te, y_te)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

    sims, labels = [], []
    with torch.no_grad():
        for sensor_batch, text_batch, label_batch in test_loader:
            sensor_batch = sensor_batch.to(DEVICE)
            z_s = sensor_encoder(sensor_batch)
            z_t = text_encoder(text_batch)
            sim = F.cosine_similarity(z_s, z_t).detach().cpu().numpy()
            sims.append(sim)
            labels.append(label_batch.numpy())

    sims = np.concatenate(sims)
    labels = np.concatenate(labels)

    pos = sims[labels == 1.0]
    neg = sims[labels == 0.0]
    print(f"Test size: {len(labels)}")
    print(f"Mean similarity (pos): {pos.mean():.4f} | (neg): {neg.mean():.4f}")

    preds = (sims > MARGIN).astype(np.float32)
    acc = accuracy_score(labels, preds)
    print(f"Matching accuracy (threshold={MARGIN}): {acc*100:.2f}%")

# ===============================
# 8) MAIN
# ===============================
if __name__ == "__main__":
    try:
        Xs_te, Xt_te, y_te, sensor_cols, text_col = train()
        evaluate(Xs_te, Xt_te, y_te, sensor_cols)
    except Exception as e:
        print("\n----------------------------------------------------------------------")
        print("ðŸ›‘ EXECUTION ERROR")
        print(f"{type(e).__name__}: {e}")
        print("Tip: Ensure /mnt/data/WESAD.csv has one textual column for semantics "
              f"(e.g., one of {PREFERRED_TEXT_COLS}) and the rest numeric sensor features.")
        print("----------------------------------------------------------------------")


  from .autonotebook import tqdm as notebook_tqdm


=== Loading dataset ===
Using text column: time
Using 14 sensor columns: ['chest_acc_x', 'chest_acc_y', 'chest_acc_z', 'chest_ecg_ch1', 'chest_emg_ch1', 'chest_eda_ch1'] ...
Training samples: 12800  |  Sensor dim: 14  |  Device: cpu
Epoch 01/15 - loss: 0.2519
Epoch 02/15 - loss: 0.2497
Epoch 04/15 - loss: 0.2488
Epoch 06/15 - loss: 0.2480
Epoch 08/15 - loss: 0.2474
Epoch 10/15 - loss: 0.2477
Epoch 12/15 - loss: 0.2471
Epoch 14/15 - loss: 0.2468
Epoch 15/15 - loss: 0.2468
Saved: sensor_encoder_wesad.pth, text_encoder_wesad.pth

=== Evaluating on held-out test set ===
Test size: 3200
Mean similarity (pos): 0.7748 | (neg): 0.8129
Matching accuracy (threshold=0.5): 48.72%
