In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.cluster import KMeans
import math

# --- 1. Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

# --- 2. MEMTO Model ---
class MEMTO(nn.Module):
    def __init__(self, input_dim, latent_dim, n_heads, n_layers, num_memory_items, dropout=0.1):
        super(MEMTO, self).__init__()
        self.latent_dim = latent_dim
        self.num_memory_items = num_memory_items

        self.pos_encoder = PositionalEncoding(latent_dim, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=n_heads, dropout=dropout, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=n_layers)

        self.encoder_embedding = nn.Linear(input_dim, latent_dim)
        
        # Memory Module
        self.memory = nn.Parameter(torch.randn(num_memory_items, latent_dim), requires_grad=True)
        
        # Gated Memory Update Gate layers
        self.U_psi = nn.Linear(latent_dim, latent_dim)
        self.W_psi = nn.Linear(latent_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim * 2, latent_dim),
            nn.ReLU(),
            nn.Linear(latent_dim, input_dim)
        )

    def forward(self, src, temperature=0.1):
        # src: [batch_size, seq_len, input_dim]
        
        # --- Encoder ---
        q = self.encoder_embedding(src) # queries: [batch_size, seq_len, latent_dim]
        
        # --- Query Update Stage ---
        # Memory-conditioned query-attention (w_t,i in paper)
        attention_weights_w = F.softmax(torch.matmul(q, self.memory.T) / temperature, dim=-1) # [batch, seq_len, mem_items]
        
        # Retrieved memory item (q_tilde_t in paper)
        retrieved_memory = torch.matmul(attention_weights_w, self.memory) # [batch, seq_len, latent_dim]

        # --- Gated Memory Update Stage (only during training) ---
        if self.training:
            # Query-conditioned memory-attention (v_i,t in paper)
            # Transpose q to [batch, latent_dim, seq_len] for matmul
            attention_weights_v = F.softmax(torch.matmul(self.memory, q.transpose(1, 2)) / temperature, dim=-1) # [batch, mem_items, seq_len]
            
            # Weighted sum of queries
            weighted_queries = torch.matmul(attention_weights_v, q) # [batch, mem_items, latent_dim]

            # Update gate (psi in paper)
            psi = torch.sigmoid(self.U_psi(self.memory) + self.W_psi(weighted_queries.mean(dim=0))) # Average over batch
            
            # Update memory
            updated_memory = (1 - psi) * self.memory + psi * weighted_queries.mean(dim=0)
            self.memory.data = updated_memory

        # --- Decoder ---
        # Concatenate query and retrieved memory
        updated_queries = torch.cat([q, retrieved_memory], dim=-1) # [batch, seq_len, latent_dim * 2]
        
        reconstructed_x = self.decoder(updated_queries)
        
        return reconstructed_x, attention_weights_w

# --- 3. Agent to handle training and prediction ---
class MEMTOAgent:
    def __init__(self, input_dim, seq_len, latent_dim=64, n_heads=4, n_layers=2, num_memory_items=10, 
                 lr=1e-4, lambda_entropy=0.01, phase1_epochs=5, temperature=0.1):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")

        self.model = MEMTO(
            input_dim=input_dim, 
            latent_dim=latent_dim, 
            n_heads=n_heads, 
            n_layers=n_layers, 
            num_memory_items=num_memory_items
        ).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.lambda_entropy = lambda_entropy
        self.phase1_epochs = phase1_epochs
        self.num_memory_items = num_memory_items
        self.temperature = temperature

    def _calculate_loss(self, x, x_rec, attn_weights):
        # Reconstruction Loss (L_rec)
        rec_loss = F.mse_loss(x, x_rec)
        
        # Entropy Loss (L_entr)
        entr_loss = -torch.sum(attn_weights * torch.log(attn_weights + 1e-12), dim=-1).mean()

        total_loss = rec_loss + self.lambda_entropy * entr_loss
        return total_loss, rec_loss, entr_loss

    def train(self, train_loader, total_epochs=50):
        print("--- Starting Phase 1: Encoder Pre-training ---")
        self.model.train()
        for epoch in range(self.phase1_epochs):
            total_loss = 0
            for i, (x_batch,) in enumerate(train_loader):
                x_batch = x_batch.to(self.device)
                self.optimizer.zero_grad()
                x_rec, attn_w = self.model(x_batch, self.temperature)
                loss, _, _ = self._calculate_loss(x_batch, x_rec, attn_w)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            print(f"Phase 1, Epoch {epoch+1}/{self.phase1_epochs}, Loss: {total_loss/len(train_loader):.6f}")

        print("\n--- Initializing Memory with K-means ---")
        self.model.eval()
        all_queries = []
        with torch.no_grad():
            for (x_batch,) in train_loader:
                x_batch = x_batch.to(self.device)
                queries = self.model.encoder_embedding(x_batch)
                all_queries.append(queries.cpu().numpy())
        
        all_queries = np.concatenate(all_queries, axis=0)
        # Reshape for K-means: [num_samples, features]
        num_samples, seq_len, latent_dim = all_queries.shape
        all_queries_reshaped = all_queries.reshape(-1, latent_dim)

        kmeans = KMeans(n_clusters=self.num_memory_items, random_state=0, n_init=10)
        kmeans.fit(all_queries_reshaped)
        
        # Set memory items to cluster centers
        centroids = torch.from_numpy(kmeans.cluster_centers_).float().to(self.device)
        self.model.memory.data = centroids
        print("Memory items initialized.")

        print("\n--- Starting Phase 2: Full Model Training ---")
        self.model.train()
        for epoch in range(total_epochs - self.phase1_epochs):
            total_loss, total_rec, total_ent = 0, 0, 0
            for i, (x_batch,) in enumerate(train_loader):
                x_batch = x_batch.to(self.device)
                self.optimizer.zero_grad()
                x_rec, attn_w = self.model(x_batch, self.temperature)
                loss, rec_loss, ent_loss = self._calculate_loss(x_batch, x_rec, attn_w)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
                total_rec += rec_loss.item()
                total_ent += ent_loss.item()
            print(f"Phase 2, Epoch {epoch+1}/{total_epochs - self.phase1_epochs}, Loss: {total_loss/len(train_loader):.6f} (Rec: {total_rec/len(train_loader):.6f}, Entr: {total_ent/len(train_loader):.6f})")

    def predict(self, test_loader):
        self.model.eval()
        anomaly_scores = []
        with torch.no_grad():
            for (x_batch,) in test_loader:
                x_batch = x_batch.to(self.device)
                
                # Get reconstruction
                x_rec, _ = self.model(x_batch, self.temperature)
                
                # Input Space Deviation (ISD)
                isd = torch.pow(x_batch - x_rec, 2).sum(dim=-1) # [batch, seq_len]
                
                # Latent Space Deviation (LSD)
                queries = self.model.encoder_embedding(x_batch) # [batch, seq_len, latent_dim]
                
                # Calculate distance from each query to each memory item
                # dist shape: [batch, seq_len, num_memory_items]
                dist = torch.cdist(queries, self.model.memory.unsqueeze(0).repeat(queries.shape[0],1,1))
                lsd = torch.min(dist, dim=-1).values # [batch, seq_len]

                # Anomaly Score (Eq. 11)
                score = F.softmax(lsd, dim=-1) * isd
                anomaly_scores.append(score.cpu().numpy())
        
        return np.concatenate(anomaly_scores, axis=0)

# Example of how to use the agent
if __name__ == '__main__':
    # --- Dummy Parameters and Data for Demonstration ---
    BATCH_SIZE = 32
    SEQ_LEN = 100
    INPUT_DIM = 38
    
    # Create dummy data
    train_data = torch.randn(BATCH_SIZE * 10, SEQ_LEN, INPUT_DIM)
    test_data = torch.randn(BATCH_SIZE * 2, SEQ_LEN, INPUT_DIM)

    train_dataset = torch.utils.data.TensorDataset(train_data)
    test_dataset = torch.utils.data.TensorDataset(test_data)

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # --- Create and run the agent ---
    agent = MEMTOAgent(
        input_dim=INPUT_DIM,
        seq_len=SEQ_LEN,
        latent_dim=64,
        n_heads=4,
        n_layers=2,
        num_memory_items=10,
        lr=1e-4,
        lambda_entropy=0.01,
        phase1_epochs=2, # Keep low for demo
        temperature=0.1
    )

    # Train the model
    agent.train(train_loader, total_epochs=5) # Keep low for demo
    
    # Get anomaly scores
    scores = agent.predict(test_loader)
    
    print(f"\nTraining and prediction finished.")
    print(f"Shape of anomaly scores: {scores.shape}")
    print(f"Expected shape: ({len(test_data)}, {SEQ_LEN})")

Using device: cuda
--- Starting Phase 1: Encoder Pre-training ---
Phase 1, Epoch 1/2, Loss: 1.024218
Phase 1, Epoch 2/2, Loss: 1.012417

--- Initializing Memory with K-means ---
Memory items initialized.

--- Starting Phase 2: Full Model Training ---
Phase 2, Epoch 1/3, Loss: 1.003300 (Rec: 1.002362, Entr: 0.093743)
Phase 2, Epoch 2/3, Loss: 0.995504 (Rec: 0.994267, Entr: 0.123698)
Phase 2, Epoch 3/3, Loss: 0.987547 (Rec: 0.985459, Entr: 0.208814)

Training and prediction finished.
Shape of anomaly scores: (64, 100)
Expected shape: (64, 100)
