In [None]:
import numpy as np
import warnings
warnings.filterwarnings("ignore")
from sklearn.preprocessing import MinMaxScaler
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
dataSize = 305 # Only LiDAR features
npy_features = np.load('dataset.npy')

# Temporal split for the dataset
train_raw = npy_features[:50335,:dataSize]
val_raw = npy_features[50335:60341,:dataSize]
test_raw = npy_features[60341:,:dataSize]
print(train_raw.shape, val_raw.shape, test_raw.shape)

In [None]:
# Reshape all datasets to shape (-1, 5, 61)
TIMESTEPS = 5
FEATURES = 61
X_train = train_raw.reshape(-1, TIMESTEPS, FEATURES)
X_val = val_raw.reshape(-1, TIMESTEPS, FEATURES)
X_test = test_raw.reshape(-1, TIMESTEPS, FEATURES)

In [None]:
# scaling the features using MinMaxScaler
scaler = MinMaxScaler()
X_train = scaler.fit_transform(train_raw.reshape(-1, dataSize)).reshape(train_raw.shape)
X_test = scaler.transform(test_raw.reshape(-1, dataSize)).reshape(test_raw.shape)
X_val = scaler.transform(val_raw.reshape(-1, dataSize)).reshape(val_raw.shape)

In [None]:
# Reshape all datasets to shape (-1, TIMESTEPS, FEATURES)
X_train = X_train.reshape(-1, TIMESTEPS, FEATURES)
X_val = X_val.reshape(-1, TIMESTEPS, FEATURES)
X_test = X_test.reshape(-1, TIMESTEPS, FEATURES)

print(X_train.shape, X_val.shape, X_test.shape)

## TFMe

#### Config

In [None]:
# Configuration
N_LAYERS = 4
D_MODEL = 128
N_HEAD = 8
FFN = 512
MAX_LENGTH = 10
BATCH_SIZE = 256
EPOCHS = 500
LEARNING_RATE = 1e-4
MEMORY_SLOTS = 250
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Reconstruction loss
criterion = lambda recon, target: nn.functional.mse_loss(recon, target)

# Early stopping implementation
class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_model = None
        self.best_loss = None
        self.counter = 0
        self.status = ""

    def __call__(self, model, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = copy.deepcopy(model.state_dict())
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            self.status = f"No improvement in the last {self.counter} epochs."
            if self.counter >= self.patience:
                self.status = f"Early stopping triggered after {self.counter} epochs."
                if self.restore_best_weights:
                    model.load_state_dict(self.best_model)
                return True
        else:
            self.best_loss = val_loss
            self.best_model = copy.deepcopy(model.state_dict())
            self.counter = 0
            self.status = f"Improvement found, counter reset."
        return False
    
# Positional Encoding for the Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

#### Memory Module

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadMemory(nn.Module):
    def __init__(self, mem_dim, fea_dim, n_heads=4, shrink_thres=0.0025):
        super().__init__()
        self.n_heads = n_heads
        self.mem_dim = mem_dim
        self.fea_dim = fea_dim
        self.head_dim = fea_dim // n_heads
        
        # Threshold lambda is typically [1/N, 3/N]
        self.shrink_thres = shrink_thres 

        assert fea_dim % n_heads == 0, "Feature dimension must be divisible by n_heads"

        self.memory = nn.Parameter(torch.Tensor(n_heads, mem_dim, self.head_dim))
        nn.init.kaiming_uniform_(self.memory)

    def forward(self, x):
        """
        x: [batch_size, timesteps, fea_dim]
        """
        b, t, d = x.shape
        
        # 1. Reshape Input
        x = x.view(b, t, self.n_heads, self.head_dim).transpose(1, 2)
        x = x.reshape(b * t, self.n_heads, self.head_dim)
        
        # 2. Normalize (Cosine Similarity)
        x_norm = F.normalize(x, dim=2) 
        mem_norm = F.normalize(self.memory, dim=2) 
        
        # 3. Calculate Similarity
        att_weight = torch.einsum('bhd,hmd->bhm', x_norm, mem_norm)
        
        # 4. Softmax
        att_weight = F.softmax(att_weight, dim=2)
        
        # 5. Hard Shrinkage
        if self.shrink_thres > 0:
            w_minus_lambda = att_weight - self.shrink_thres
            mask = F.relu(w_minus_lambda)
            att_weight = (mask * att_weight) / (torch.abs(w_minus_lambda) + 1e-12)

            # Re-normalize so weights sum to 1 
            att_weight = att_weight / (att_weight.sum(dim=2, keepdim=True) + 1e-12)

        # 6. Reconstruct
        out = torch.einsum('bhm,hmd->bhd', att_weight, self.memory)
        
        # 7. Reshape back
        out = out.view(b, t, self.n_heads, self.head_dim)
        out = out.transpose(1, 2).contiguous()
        out = out.view(b, t, d)
        
        return out, att_weight

#### Time Domain

In [None]:
class TimeBranch(nn.Module):
    def __init__(self, 
                 input_features=FEATURES, 
                 timesteps=TIMESTEPS,
                 d_model=D_MODEL,
                 nhead=N_HEAD,
                 dim_feedforward=FFN,
                 num_layers=N_LAYERS,
                 dropout=0.1,
                 mem_dim=MEMORY_SLOTS,
                 max_length=MAX_LENGTH):
        super().__init__()
        self.timesteps = timesteps
        self.input_features = input_features
        self.d_model = d_model
        self.input_embedding = nn.Linear(input_features, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_len=max_length)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # --- ADDED MEMORY MODULE ---
        self.memory_layer = MultiHeadMemory(mem_dim=mem_dim, fea_dim=d_model, n_heads=nhead, shrink_thres=1/mem_dim)
                
        self.output_layer = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(d_model * 2, input_features)
        )
        
    def forward(self, x):
        
        # 1. Input Embedding
        x = self.input_embedding(x) 
        x = x * np.sqrt(self.d_model)
        x = x.permute(1, 0, 2)
        
        # 2. Add Positional Encoding
        x = self.positional_encoding(x)
        
        # 3. Transformer Encoder
        enc_out = self.transformer_encoder(x)
        enc_out = enc_out.permute(1, 0, 2)
        
        # 4. Memory Module
        # We pass the encoder output into the memory module.
        # 'z_hat' is the reconstructed latent vector (clean).
        # 'att' contains the weights (needed for loss function).
        z_hat, att = self.memory_layer(enc_out)
        
        # 5. Final Output Reconstruction
        recon = self.output_layer(z_hat)
        
        return recon, att

#### Frequency Domain

In [None]:
class SpectralBranch(nn.Module):
    def __init__(self,
                 input_features=FEATURES,
                 d_model=D_MODEL,
                 timesteps=TIMESTEPS,
                 nhead=N_HEAD,
                 num_layers=N_LAYERS,
                 mem_dim=MEMORY_SLOTS,
                 ffn=FFN):
        super().__init__()
        self.freq_bins = (timesteps // 2) + 1
        self.freq_input_dim = input_features * 2 # Real + Imag
        
        self.embedding = nn.Linear(self.freq_input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=self.freq_bins + 10)

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model, 
            nhead=nhead, 
            dim_feedforward=ffn, 
            dropout=0.1,
            batch_first=True
        )

        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
        
        # --- ADDED MEMORY MODULE ---
        self.freq_memory = MultiHeadMemory(mem_dim=mem_dim, fea_dim=d_model, n_heads=nhead, shrink_thres=1/mem_dim)

        self.output_layer = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.LeakyReLU(0.2),
            nn.Linear(d_model * 2, self.freq_input_dim)
        )
        
    def forward(self, x):
        
        # 1. FFT
        x_fft = torch.fft.rfft(x, dim=1) 
        x_spectral = torch.cat([x_fft.real, x_fft.imag], dim=-1)
        
        # 2. Embed & Position
        x_emb = self.embedding(x_spectral)
        x_emb = x_emb.permute(1, 0, 2)
        x_emb = self.pos_encoder(x_emb)
        x_emb = x_emb.permute(1, 0, 2)
        
        # 3. Transformer Encoder
        enc_out = self.transformer_encoder(x_emb)
        
        # 4. Memory Module
        mem_out, att = self.freq_memory(enc_out)
        
        # 5. Reconstruct
        recon_spectral = self.output_layer(mem_out)
        
        # 6. Inverse FFT
        real, imag = torch.chunk(recon_spectral, 2, dim=-1)
        recon_complex = torch.complex(real, imag)
        recon_time = torch.fft.irfft(recon_complex, n=x.size(1), dim=1)
        
        return recon_time, att

#### Dual Domain Transformer

In [None]:
class DualDomainTransformer(nn.Module):
    def __init__(self, 
                 input_features=FEATURES, 
                 timesteps=TIMESTEPS,
                 d_model=D_MODEL,
                 nhead=N_HEAD,
                 dim_feedforward=FFN,
                 num_layers=N_LAYERS,
                 mem_dim=MEMORY_SLOTS):
        super().__init__()
        
        # Branch 1: Time Domain 
        self.time_branch = TimeBranch(
        input_features=input_features,
        timesteps=timesteps,
        d_model=d_model,
        nhead=nhead,
        dim_feedforward=dim_feedforward,
        num_layers=num_layers,
        mem_dim=mem_dim
    )
        
        # Branch 2: Frequency Domain 
        self.spectral_branch = SpectralBranch(
            input_features=input_features,
            d_model=d_model,
            timesteps=timesteps,
            nhead=nhead,
            num_layers=num_layers,
            ffn=dim_feedforward,
            mem_dim=mem_dim
        )
    def forward(self, x):

        # 1. Get Time Domain Reconstruction and Attention
        time_recon, time_att = self.time_branch(x)
        
        # 2. Get Frequency Domain Reconstruction (Projected back to Time)
        freq_recon, freq_att = self.spectral_branch(x)
        
        return time_recon, freq_recon, time_att, freq_att

### Loading the model

In [None]:
model = DualDomainTransformer()
model.load_state_dict(torch.load("best.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()

### Detect anomalies based on the 3-sigma threshold

In [None]:
def detect_anomalies(model, data_loader):
    anomalies_found = 0
    is_time_anomaly_count = 0
    is_freq_anomaly_count = 0   
    with torch.no_grad():
        for batch in data_loader:
            inputs = batch[0].to(DEVICE)
            time_recon, freq_recon, _, _ = model(inputs)
            
            loss_time = torch.mean(torch.square(inputs - time_recon), dim=(1, 2))
            loss_freq = torch.mean(torch.square(inputs - freq_recon), dim=(1, 2))
            
            is_time_anomaly = loss_time > 0.007686922559514642
            is_freq_anomaly = loss_freq > 0.005322358105331659

            is_time_anomaly_count += is_time_anomaly.sum().item()
            is_freq_anomaly_count += is_freq_anomaly.sum().item()
            
            is_anomaly = torch.logical_or(is_time_anomaly, is_freq_anomaly)
            
            batch_anomalies = is_anomaly.sum().item()
            anomalies_found += batch_anomalies
            
    return anomalies_found

In [None]:
test_tensor = torch.tensor(X_test, dtype=torch.float32)
test_loader = DataLoader(TensorDataset(test_tensor), batch_size=BATCH_SIZE, shuffle=False)

false_positive = detect_anomalies(model, test_loader)
true_negative = len(X_test) - false_positive
print(f"False Positive: {false_positive}, True Negative: {true_negative}")
print("All Samples:", len(X_test))