In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import scipy.io.arff as arff
import matplotlib.pyplot as plt

# ---------------------------
# Dataset Definition (ECG5000)
# ---------------------------
class ECG5000(Dataset):
    def __init__(self, mode, split='train'):
        """
        mode: 'normal', 'anomaly', or 'all'. 
              'all' means do not filter any samples (both normal and anomaly).
        split: 'train' to load training data; 'test' to load test data.
        """
        assert mode in ['normal', 'anomaly', 'all']
        assert split in ['train', 'test']
        
        # Select the file based on the split.
        if split == 'train':
            file_path = '/kaggle/input/ecg50000/ECG5000_TRAIN.arff'
        else:
            file_path = '/kaggle/input/ecg50000/ECG5000_TEST.arff'
        
        data, meta = arff.loadarff(file_path)
        df = pd.DataFrame(data, columns=meta.names())
        print(df.head())
        # Rename the label column.
        new_columns = list(df.columns)
        new_columns[-1] = 'target'
        df.columns = new_columns
        
        # Filter samples based on mode.
        if mode == 'normal':
            df = df[df.target == b'1'].drop(labels='target', axis=1)
        elif mode == 'anomaly':
            df = df[df.target != b'1'].drop(labels='target', axis=1)
        else:  # mode == 'all'
            df = df.drop(labels='target', axis=1)
        
        # Convert DataFrame to a numpy array of type float32.
        self.X = df.astype(np.float32).to_numpy()
        
    def __getitem__(self, index):
        # Each sample is reshaped as (sequence_length, 1)
        sample = torch.from_numpy(self.X[index]).unsqueeze(-1)
        return sample
    
    def __len__(self):
        return self.X.shape[0]
    
    def get_torch_tensor(self):
        return torch.from_numpy(self.X)

class MemoryModule(nn.Module):
    def __init__(self, memory_size, hidden_size, sparsity_threshold=0.05):
        """
        memory_size: Number of memory items.
        hidden_size: Dimensionality of each memory item.
        sparsity_threshold: Threshold for rectifying the addressing vector.
        """
        super(MemoryModule, self).__init__()
        self.memory_size = memory_size
        self.hidden_size = hidden_size
        self.sparsity_threshold = sparsity_threshold
        # Initialize learnable memory items.
        self.memory = nn.Parameter(torch.randn(memory_size, hidden_size))
    
    def forward(self, z):
        """
        z: latent representation from encoder with shape (batch, hidden_size)
        Returns:
          z_hat: recombined latent representation from memory.
          q: sparse addressing vector with shape (batch, memory_size)
        """

        # Compute similarity scores between latent vector and memory items.
        sim = torch.matmul(z, self.memory.t())  # shape: (batch, memory_size)
        # Softmax to obtain addressing weights.
        q = nn.functional.softmax(sim, dim=1)
        # Rectify: subtract threshold and zero out negatives.
        q = torch.max(q - self.sparsity_threshold, torch.zeros_like(q))
        # Normalize so that each row sums to 1.
        q = q / (q.sum(dim=1, keepdim=True) + 1e-8)
        # Recombine memory items.
        z_hat = torch.matmul(q, self.memory)
        return z_hat, q

# ---------------------------
# TSMAE Model Definition
# ---------------------------
class TSMAE(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, sparsity_threshold=0.05, sparsity_factor=0.001):
        """
        input_size: Dimension of each time step (e.g., 1)
        hidden_size: Dimension of the latent representation
        memory_size: Number of memory items.
        sparsity_threshold: Threshold used in the memory module.
        sparsity_factor: Weight for the sparsity penalty in the loss.
        """
        super(TSMAE, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.memory_size = memory_size
        self.sparsity_factor = sparsity_factor
        
        # LSTM Encoder: encodes input sequence into a latent vector.
        self.encoder = nn.LSTM(input_size, hidden_size, batch_first=True)
        # Memory Module: extracts typical normal patterns.
        self.memory_module = MemoryModule(memory_size, hidden_size, sparsity_threshold)
        # LSTM Decoder: decodes the latent representation back to sequence.
        self.decoder = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        # Final layer to project the LSTM decoder output to the input space.
        self.output_layer = nn.Linear(hidden_size, input_size)
        
    def forward(self, x):
        """
        x: Input tensor of shape (batch, seq_len, input_size)
        Returns:
          x_recon: Reconstructed sequence of shape (batch, seq_len, input_size)
          q: Sparse addressing vector from the memory module (batch, memory_size)
          z: Latent representation from the encoder (batch, hidden_size)
          z_hat: Recombined latent representation from the memory module (batch, hidden_size)
        """
        batch_size, seq_len, _ = x.size()
        # Encode input sequence.
        enc_out, (h_n, c_n) = self.encoder(x)
        z = h_n[-1]  # Use the final hidden state; shape: (batch, hidden_size)
        
        # Pass through memory module.
        z_hat, q = self.memory_module(z)
        
        # For decoding, repeat z_hat across the sequence length.
        z_hat_seq = z_hat.unsqueeze(1).repeat(1, seq_len, 1)
        dec_out, _ = self.decoder(z_hat_seq)
        # Project decoder output back to input dimension.
        x_recon = self.output_layer(dec_out)
        return x_recon, q, z, z_hat

    def loss_function(self, x, x_recon, q):
        # Mean Squared Error reconstruction loss.
        rec_loss = torch.mean((x - x_recon)**2)
        # Sparsity loss to encourage a sparse addressing vector.
        sparsity_loss = torch.mean(torch.log(1 + q**2))
        loss = rec_loss + self.sparsity_factor * sparsity_loss
        return loss, rec_loss, sparsity_loss

# ---------------------------
# Training Setup
# ---------------------------
def train_model(model, dataloader, optimizer, device, num_epochs=50):
    model.to(device)
    model.train()
    train_losses = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch = batch.to(device)
            optimizer.zero_grad()
            x_recon, q, z, z_hat = model(batch)
            loss, rec_loss, sparsity_loss = model.loss_function(batch, x_recon, q)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * batch.size(0)
        avg_loss = epoch_loss / len(dataloader.dataset)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} Loss: {avg_loss:.6f}")
    return train_losses

# ---------------------------
# Main Execution
# ---------------------------
if __name__ == '__main__':
    # Hyperparameters.
    input_size = 1           # Each time step has 1 feature.
    hidden_size = 10         # Latent representation dimension.
    memory_size = 1         # Number of memory items.
    sparsity_threshold = 0.05
    sparsity_factor = 0.001
    batch_size = 1
    num_epochs = 100
    learning_rate = 1e-3

    # ---------------------------
    # Training: Use the TRAIN file with both normal and anomaly samples.
    # ---------------------------
    train_dataset = ECG5000(mode='all', split='train')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # Initialize the TSMAE model.
    model = TSMAE(input_size, hidden_size, memory_size, sparsity_threshold, sparsity_factor)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Check for available device.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Training on device:", device)

    # Train the model.
    train_losses = train_model(model, train_loader, optimizer, device, num_epochs=num_epochs)

    # Plot training loss.
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Loss vs Epochs")
    plt.show()

    # ---------------------------
    # Save the Model Weights
    # ---------------------------
    torch.save(model.state_dict(), 'tsmae_weights.pth')
    print("Model weights saved to 'tsmae_weights.pth'")

    # ---------------------------
    # Evaluation: Use the TEST file with both normal and anomaly samples.
    # ---------------------------
    test_dataset = ECG5000(mode='all', split='test')
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    model.eval()
    total_test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            x_recon, q, z, z_hat = model(batch)
            loss, rec_loss, sparsity_loss = model.loss_function(batch, x_recon, q)
            total_test_loss += loss.item() * batch.size(0)
    avg_test_loss = total_test_loss / len(test_loader.dataset)
    print(f"Average test loss on TEST file: {avg_test_loss:.6f}")

    # ---------------------------
    # Plotting Reconstruction Comparisons (TEST file)
    # ---------------------------
    # Evaluate one normal sample from the test file.
    normal_dataset_eval = ECG5000(mode='normal', split='test')
    normal_sample = normal_dataset_eval[0].unsqueeze(0).to(device)
    with torch.no_grad():
        normal_recon, normal_q, _, _ = model(normal_sample)
    normal_series_np = normal_sample.cpu().numpy().flatten()
    normal_recon_np = normal_recon.cpu().numpy().flatten()

    # Evaluate one anomaly sample from the test file.
    anomaly_dataset_eval = ECG5000(mode='anomaly', split='test')
    anomaly_sample = anomaly_dataset_eval[0].unsqueeze(0).to(device)
    with torch.no_grad():
        anomaly_recon, anomaly_q, _, _ = model(anomaly_sample)
    anomaly_series_np = anomaly_sample.cpu().numpy().flatten()
    anomaly_recon_np = anomaly_recon.cpu().numpy().flatten()

    # Plot normal sample reconstruction.
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(normal_series_np, label="Original Normal")
    plt.plot(normal_recon_np, label="Reconstructed Normal", linestyle="--")
    plt.title("Normal Sample Reconstruction (TEST file)")
    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.legend()

    # Plot anomaly sample reconstruction.
    plt.subplot(1, 2, 2)
    plt.plot(anomaly_series_np, label="Original Anomaly")
    plt.plot(anomaly_recon_np, label="Reconstructed Anomaly", linestyle="--")
    plt.title("Anomaly Sample Reconstruction (TEST file)")
    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.legend()
    plt.show()


       att1      att2      att3      att4      att5      att6      att7  \
0 -0.112522 -2.827204 -3.773897 -4.349751 -4.376041 -3.474986 -2.181408   
1 -1.100878 -3.996840 -4.285843 -4.506579 -4.022377 -3.234368 -1.566126   
2 -0.567088 -2.593450 -3.874230 -4.584095 -4.187449 -3.151462 -1.742940   
3  0.490473 -1.914407 -3.616364 -4.318823 -4.268016 -3.881110 -2.993280   
4  0.800232 -0.874252 -2.384761 -3.973292 -4.338224 -3.802422 -2.534510   

       att8      att9     att10  ...    att132    att133    att134    att135  \
0 -1.818286 -1.250522 -0.477492  ...  0.792168  0.933541  0.796958  0.578621   
1 -0.992258 -0.754680  0.042321  ...  0.538356  0.656881  0.787490  0.724046   
2 -1.490659 -1.183580 -0.394229  ...  0.886073  0.531452  0.311377 -0.021919   
3 -1.671131 -1.333884 -0.965629  ...  0.350816  0.499111  0.600345  0.842069   
4 -1.783423 -1.594450 -0.753199  ...  1.148884  0.958434  1.059025  1.371682   

     att136    att137    att138    att139    att140  target  
0  0.2

Epoch 1/100: 100%|██████████| 500/500 [00:03<00:00, 146.60it/s]


Epoch 1 Loss: 0.745244


Epoch 2/100: 100%|██████████| 500/500 [00:03<00:00, 163.97it/s]


Epoch 2 Loss: 0.620413


Epoch 3/100: 100%|██████████| 500/500 [00:03<00:00, 150.99it/s]


Epoch 3 Loss: 0.600558


Epoch 4/100: 100%|██████████| 500/500 [00:03<00:00, 133.49it/s]


Epoch 4 Loss: 0.597288


Epoch 5/100: 100%|██████████| 500/500 [00:03<00:00, 131.42it/s]


Epoch 5 Loss: 0.593601


Epoch 6/100: 100%|██████████| 500/500 [00:03<00:00, 159.32it/s]


Epoch 6 Loss: 0.595682


Epoch 7/100: 100%|██████████| 500/500 [00:03<00:00, 153.94it/s]


Epoch 7 Loss: 0.594660


Epoch 8/100: 100%|██████████| 500/500 [00:03<00:00, 152.75it/s]


Epoch 8 Loss: 0.592216


Epoch 9/100: 100%|██████████| 500/500 [00:03<00:00, 157.69it/s]


Epoch 9 Loss: 0.588220


Epoch 10/100: 100%|██████████| 500/500 [00:03<00:00, 153.96it/s]


Epoch 10 Loss: 0.589200


Epoch 11/100: 100%|██████████| 500/500 [00:03<00:00, 143.53it/s]


Epoch 11 Loss: 0.587827


Epoch 12/100: 100%|██████████| 500/500 [00:03<00:00, 136.68it/s]


Epoch 12 Loss: 0.585673


Epoch 13/100: 100%|██████████| 500/500 [00:03<00:00, 148.48it/s]


Epoch 13 Loss: 0.535913


Epoch 14/100: 100%|██████████| 500/500 [00:03<00:00, 147.88it/s]


Epoch 14 Loss: 0.544370


Epoch 15/100: 100%|██████████| 500/500 [00:03<00:00, 159.28it/s]


Epoch 15 Loss: 0.506451


Epoch 16/100: 100%|██████████| 500/500 [00:03<00:00, 165.19it/s]


Epoch 16 Loss: 0.499451


Epoch 17/100: 100%|██████████| 500/500 [00:03<00:00, 166.21it/s]


Epoch 17 Loss: 0.498721


Epoch 18/100: 100%|██████████| 500/500 [00:02<00:00, 168.00it/s]


Epoch 18 Loss: 0.501007


Epoch 19/100: 100%|██████████| 500/500 [00:03<00:00, 165.42it/s]


Epoch 19 Loss: 0.582376


Epoch 20/100: 100%|██████████| 500/500 [00:03<00:00, 165.41it/s]


Epoch 20 Loss: 0.623200


Epoch 21/100: 100%|██████████| 500/500 [00:02<00:00, 167.52it/s]


Epoch 21 Loss: 0.614328


Epoch 22/100: 100%|██████████| 500/500 [00:02<00:00, 168.10it/s]


Epoch 22 Loss: 0.599363


Epoch 23/100: 100%|██████████| 500/500 [00:03<00:00, 164.12it/s]


Epoch 23 Loss: 0.597636


Epoch 24/100: 100%|██████████| 500/500 [00:03<00:00, 162.76it/s]


Epoch 24 Loss: 0.593195


Epoch 25/100: 100%|██████████| 500/500 [00:03<00:00, 145.57it/s]


Epoch 25 Loss: 0.591028


Epoch 26/100: 100%|██████████| 500/500 [00:03<00:00, 157.15it/s]


Epoch 26 Loss: 0.585882


Epoch 27/100: 100%|██████████| 500/500 [00:03<00:00, 165.12it/s]


Epoch 27 Loss: 0.585745


Epoch 28/100: 100%|██████████| 500/500 [00:03<00:00, 165.87it/s]


Epoch 28 Loss: 0.584776


Epoch 29/100: 100%|██████████| 500/500 [00:03<00:00, 164.86it/s]


Epoch 29 Loss: 0.583731


Epoch 30/100: 100%|██████████| 500/500 [00:03<00:00, 162.45it/s]


Epoch 30 Loss: 0.584033


Epoch 31/100: 100%|██████████| 500/500 [00:03<00:00, 166.51it/s]


Epoch 31 Loss: 0.584210


Epoch 32/100: 100%|██████████| 500/500 [00:03<00:00, 165.87it/s]


Epoch 32 Loss: 0.583340


Epoch 33/100: 100%|██████████| 500/500 [00:03<00:00, 162.48it/s]


Epoch 33 Loss: 0.583769


Epoch 34/100: 100%|██████████| 500/500 [00:03<00:00, 163.68it/s]


Epoch 34 Loss: 0.583466


Epoch 35/100: 100%|██████████| 500/500 [00:03<00:00, 153.64it/s]


Epoch 35 Loss: 0.583746


Epoch 36/100: 100%|██████████| 500/500 [00:03<00:00, 157.62it/s]


Epoch 36 Loss: 0.583160


Epoch 37/100: 100%|██████████| 500/500 [00:03<00:00, 163.16it/s]


Epoch 37 Loss: 0.583059


Epoch 38/100: 100%|██████████| 500/500 [00:03<00:00, 165.29it/s]


Epoch 38 Loss: 0.583251


Epoch 39/100: 100%|██████████| 500/500 [00:03<00:00, 161.89it/s]


Epoch 39 Loss: 0.582610


Epoch 40/100: 100%|██████████| 500/500 [00:03<00:00, 162.76it/s]


Epoch 40 Loss: 0.582543


Epoch 41/100: 100%|██████████| 500/500 [00:03<00:00, 163.19it/s]


Epoch 41 Loss: 0.583788


Epoch 42/100: 100%|██████████| 500/500 [00:03<00:00, 164.72it/s]


Epoch 42 Loss: 0.582602


Epoch 43/100: 100%|██████████| 500/500 [00:03<00:00, 162.47it/s]


Epoch 43 Loss: 0.582701


Epoch 44/100: 100%|██████████| 500/500 [00:03<00:00, 162.68it/s]


Epoch 44 Loss: 0.582856


Epoch 45/100: 100%|██████████| 500/500 [00:03<00:00, 165.94it/s]


Epoch 45 Loss: 0.582414


Epoch 46/100: 100%|██████████| 500/500 [00:03<00:00, 157.76it/s]


Epoch 46 Loss: 0.582014


Epoch 47/100: 100%|██████████| 500/500 [00:02<00:00, 167.39it/s]


Epoch 47 Loss: 0.583042


Epoch 48/100: 100%|██████████| 500/500 [00:02<00:00, 168.06it/s]


Epoch 48 Loss: 0.582165


Epoch 49/100: 100%|██████████| 500/500 [00:03<00:00, 158.57it/s]


Epoch 49 Loss: 0.581662


Epoch 50/100: 100%|██████████| 500/500 [00:03<00:00, 162.12it/s]


Epoch 50 Loss: 0.582174


Epoch 51/100: 100%|██████████| 500/500 [00:02<00:00, 166.70it/s]


Epoch 51 Loss: 0.582214


Epoch 52/100: 100%|██████████| 500/500 [00:03<00:00, 165.78it/s]


Epoch 52 Loss: 0.581935


Epoch 53/100: 100%|██████████| 500/500 [00:03<00:00, 164.40it/s]


Epoch 53 Loss: 0.582122


Epoch 54/100: 100%|██████████| 500/500 [00:02<00:00, 167.62it/s]


Epoch 54 Loss: 0.582516


Epoch 55/100: 100%|██████████| 500/500 [00:03<00:00, 163.19it/s]


Epoch 55 Loss: 0.581897


Epoch 56/100: 100%|██████████| 500/500 [00:03<00:00, 149.45it/s]


Epoch 56 Loss: 0.582287


Epoch 57/100: 100%|██████████| 500/500 [00:03<00:00, 162.46it/s]


Epoch 57 Loss: 0.582475


Epoch 58/100: 100%|██████████| 500/500 [00:02<00:00, 166.70it/s]


Epoch 58 Loss: 0.582179


Epoch 59/100: 100%|██████████| 500/500 [00:03<00:00, 166.62it/s]


Epoch 59 Loss: 0.582294


Epoch 60/100: 100%|██████████| 500/500 [00:02<00:00, 171.72it/s]


Epoch 60 Loss: 0.582476


Epoch 61/100: 100%|██████████| 500/500 [00:03<00:00, 162.39it/s]


Epoch 61 Loss: 0.582242


Epoch 62/100: 100%|██████████| 500/500 [00:03<00:00, 165.96it/s]


Epoch 62 Loss: 0.582867


Epoch 63/100: 100%|██████████| 500/500 [00:02<00:00, 166.90it/s]


Epoch 63 Loss: 0.582159


Epoch 64/100: 100%|██████████| 500/500 [00:02<00:00, 169.35it/s]


Epoch 64 Loss: 0.581965


Epoch 65/100: 100%|██████████| 500/500 [00:03<00:00, 164.01it/s]


Epoch 65 Loss: 0.582371


Epoch 66/100: 100%|██████████| 500/500 [00:03<00:00, 164.48it/s]


Epoch 66 Loss: 0.582469


Epoch 67/100: 100%|██████████| 500/500 [00:03<00:00, 153.86it/s]


Epoch 67 Loss: 0.582025


Epoch 68/100: 100%|██████████| 500/500 [00:03<00:00, 166.60it/s]


Epoch 68 Loss: 0.568710


Epoch 69/100: 100%|██████████| 500/500 [00:03<00:00, 163.94it/s]


Epoch 69 Loss: 0.545989


Epoch 70/100: 100%|██████████| 500/500 [00:02<00:00, 168.82it/s]


Epoch 70 Loss: 0.518016


Epoch 71/100: 100%|██████████| 500/500 [00:02<00:00, 169.94it/s]


Epoch 71 Loss: 0.508571


Epoch 72/100: 100%|██████████| 500/500 [00:02<00:00, 170.34it/s]


Epoch 72 Loss: 0.509704


Epoch 73/100: 100%|██████████| 500/500 [00:03<00:00, 162.66it/s]


Epoch 73 Loss: 0.506042


Epoch 74/100: 100%|██████████| 500/500 [00:02<00:00, 168.59it/s]


Epoch 74 Loss: 0.502779


Epoch 75/100: 100%|██████████| 500/500 [00:03<00:00, 164.41it/s]


Epoch 75 Loss: 0.500743


Epoch 76/100: 100%|██████████| 500/500 [00:03<00:00, 164.32it/s]


Epoch 76 Loss: 0.500287


Epoch 77/100: 100%|██████████| 500/500 [00:03<00:00, 160.67it/s]


Epoch 77 Loss: 0.505400


Epoch 78/100: 100%|██████████| 500/500 [00:03<00:00, 154.75it/s]


Epoch 78 Loss: 0.500494


Epoch 79/100: 100%|██████████| 500/500 [00:03<00:00, 165.56it/s]


Epoch 79 Loss: 0.498066


Epoch 80/100: 100%|██████████| 500/500 [00:02<00:00, 168.80it/s]


Epoch 80 Loss: 0.497615


Epoch 81/100: 100%|██████████| 500/500 [00:03<00:00, 166.57it/s]


Epoch 81 Loss: 0.495677


Epoch 82/100: 100%|██████████| 500/500 [00:02<00:00, 166.92it/s]


Epoch 82 Loss: 0.495895


Epoch 83/100: 100%|██████████| 500/500 [00:02<00:00, 166.98it/s]


Epoch 83 Loss: 0.495126


Epoch 84/100: 100%|██████████| 500/500 [00:02<00:00, 167.80it/s]


Epoch 84 Loss: 0.496322


Epoch 85/100:  31%|███       | 153/500 [00:00<00:02, 163.42it/s]

In [None]:
import struct

def export_memory_coe(memory_tensor: torch.Tensor, coe_path: str):
    """
    Export a (M × H) float32 memory tensor to a .coe file where each float
    is formatted as an 8-digit hex of its IEEE-754 bits.
    
    memory_tensor:  (memory_size, hidden_size)
    coe_path:       path to write, e.g. 'memory_weights.coe'
    """
    # Flatten and move to CPU numpy
    vals = memory_tensor.detach().cpu().numpy().flatten()
    # Pack floats into uint32 little-endian and format hex
    hex_vals = [
        format(struct.unpack('<I', struct.pack('<f', float(v)))[0], '08X')
        for v in vals
    ]
    with open(coe_path, 'w') as f:
        f.write('memory_initialization_radix=16;\n')
        f.write('memory_initialization_vector=\n')
        for i, h in enumerate(hex_vals):
            sep = ',' if i < len(hex_vals) - 1 else ';'
            f.write(f'{h}{sep}\n')
    print(f"Exported memory weights → {coe_path}")

# … after saving your .pth …
torch.save(model.state_dict(), 'tsmae_weights.pth')
print("Model weights saved to 'tsmae_weights.pth'")

# Export just the MemoryModule weights:
export_memory_coe(model.memory_module.memory, 'memory_weights.coe')


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import scipy.io.arff as arff
import matplotlib.pyplot as plt

# ---------------------------
# Dataset Definition (ECG5000)
# ---------------------------
class ECG5000(Dataset):
    def __init__(self, mode, split='train'):
        """
        mode: 'normal', 'anomaly', or 'all'. 
              'all' means do not filter any samples (both normal and anomaly).
        split: 'train' to load training data; 'test' to load test data.
        """
        assert mode in ['normal', 'anomaly', 'all']
        assert split in ['train', 'test']
        
        # Select the file based on the split.
        if split == 'train':
            file_path = '/kaggle/input/ecg50000/ECG5000_TRAIN.arff'
        else:
            file_path = '/kaggle/input/ecg50000/ECG5000_TEST.arff'
        
        data, meta = arff.loadarff(file_path)
        df = pd.DataFrame(data, columns=meta.names())
        
        # Rename the label column.
        new_columns = list(df.columns)
        new_columns[-1] = 'target'
        df.columns = new_columns
        
        # Filter samples based on mode.
        if mode == 'normal':
            df = df[df.target == b'1'].drop(labels='target', axis=1)
        elif mode == 'anomaly':
            df = df[df.target != b'1'].drop(labels='target', axis=1)
        else:  # mode == 'all'
            df = df.drop(labels='target', axis=1)
        
        # Convert DataFrame to a numpy array of type float32.
        self.X = df.astype(np.float32).to_numpy()
        
    def __getitem__(self, index):
        # Each sample is reshaped as (sequence_length, 1)
        sample = torch.from_numpy(self.X[index]).unsqueeze(-1)
        return sample
    
    def __len__(self):
        return self.X.shape[0]
    
    def get_torch_tensor(self):
        return torch.from_numpy(self.X)

# ---------------------------
# Memory Module
# ---------------------------
class MemoryModule(nn.Module):
    def __init__(self, memory_size, hidden_size, sparsity_threshold=0.05):
        """
        memory_size: Number of memory items.
        hidden_size: Dimensionality of each memory item.
        sparsity_threshold: Threshold for rectifying the addressing vector.
        """
        super(MemoryModule, self).__init__()
        self.memory_size = memory_size
        self.hidden_size = hidden_size
        self.sparsity_threshold = sparsity_threshold
        # Initialize learnable memory items.
        self.memory = nn.Parameter(torch.randn(memory_size, hidden_size))
    
    def forward(self, z):
        """
        z: latent representation from encoder, shape (batch, hidden_size)
        Returns:
          z_hat: recombined latent representation from memory.
          q: sparse addressing vector, shape (batch, memory_size)
        """
        # Compute similarity scores between latent vector and memory items.
        sim = torch.matmul(z, self.memory.t())
        # Softmax to obtain addressing weights.
        q = nn.functional.softmax(sim, dim=1)
        # Rectify: subtract threshold and zero out negatives.
        q = torch.max(q - self.sparsity_threshold, torch.zeros_like(q))
        # Normalize so that each row sums to 1.
        q = q / (q.sum(dim=1, keepdim=True) + 1e-8)
        # Recombine memory items.
        z_hat = torch.matmul(q, self.memory)
        return z_hat, q

# ---------------------------
# TSMAE Model
# ---------------------------
class TSMAE(nn.Module):
    def __init__(self, input_size, hidden_size, memory_size, sparsity_threshold=0.05, sparsity_factor=0.001):
        """
        input_size: Dimension of each time step (e.g., 1)
        hidden_size: Dimension of the latent representation.
        memory_size: Number of memory items.
        sparsity_threshold: Threshold used in the memory module.
        sparsity_factor: Weight for the sparsity penalty in the loss.
        """
        super(TSMAE, self).__init__()
        self.sparsity_factor = sparsity_factor
        
        # Encoder: a single-layer LSTM.
        self.encoder = nn.LSTM(input_size, hidden_size, batch_first=True)
        # Memory module.
        self.memory_module = MemoryModule(memory_size, hidden_size, sparsity_threshold)
        # Decoder: a single-layer LSTM.
        self.decoder = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        # Project decoder output back to the input space.
        self.output_layer = nn.Linear(hidden_size, input_size)
        
    def forward(self, x):
        """
        x: Input tensor of shape (batch, seq_len, input_size)
        Returns:
          x_recon: Reconstructed sequence, shape (batch, seq_len, input_size)
          q: Sparse addressing vector from memory, shape (batch, memory_size)
          z: Latent representation from encoder, shape (batch, hidden_size)
          z_hat: Recombined latent representation from memory, shape (batch, hidden_size)
        """
        batch_size, seq_len, _ = x.size()
        # Encode: use the final hidden state as the latent vector.
        _, (h_n, _) = self.encoder(x)
        z = h_n[-1]
        # Apply memory module.
        z_hat, q = self.memory_module(z)
        # For decoding, repeat z_hat over the sequence length.
        z_hat_seq = z_hat.unsqueeze(1).repeat(1, seq_len, 1)
        dec_out, _ = self.decoder(z_hat_seq)
        x_recon = self.output_layer(dec_out)
        return x_recon, q, z, z_hat

    def loss_function(self, x, x_recon, q):
        rec_loss = torch.mean((x - x_recon)**2)
        sparsity_loss = torch.mean(torch.log(1 + q**2))
        loss = rec_loss + self.sparsity_factor * sparsity_loss
        return loss, rec_loss, sparsity_loss

# ---------------------------
# Training Setup
# ---------------------------
def train_model(model, dataloader, optimizer, device, num_epochs=50):
    model.to(device)
    model.train()
    train_losses = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            batch = batch.to(device)
            optimizer.zero_grad()
            x_recon, q, z, z_hat = model(batch)
            loss, rec_loss, sparsity_loss = model.loss_function(batch, x_recon, q)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * batch.size(0)
        avg_loss = epoch_loss / len(dataloader.dataset)
        train_losses.append(avg_loss)
        print(f"Epoch {epoch+1} Loss: {avg_loss:.6f}")
    return train_losses

# ---------------------------
# Main Execution
# ---------------------------
if __name__ == '__main__':
    # Hyperparameters.
    input_size = 1           # Each time step has 1 feature.
    hidden_size = 10         # Latent representation dimension.
    memory_size = 20         # Number of memory items.
    sparsity_threshold = 0.05
    sparsity_factor = 0.001
    batch_size = 32
    num_epochs = 50
    learning_rate = 1e-3

    # ---------------------------
    # Training: Use the TRAIN file with both normal and anomaly samples.
    # ---------------------------
    train_dataset = ECG5000(mode='all', split='train')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    # Initialize the TSMAE model.
    model = TSMAE(input_size, hidden_size, memory_size, sparsity_threshold, sparsity_factor)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Check for available device.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Training on device:", device)

    # Train the model.
    train_losses = train_model(model, train_loader, optimizer, device, num_epochs=num_epochs)

    # Plot training loss.
    plt.figure(figsize=(8, 4))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training Loss vs Epochs")
    plt.show()

    # ---------------------------
    # Save the Model Weights
    # ---------------------------
    torch.save(model.state_dict(), 'tsmae_weights.pth')
    print("Model weights saved to 'tsmae_weights.pth'")

    # ---------------------------
    # Evaluation: Use the TEST file with both normal and anomaly samples.
    # ---------------------------
    test_dataset = ECG5000(mode='all', split='test')
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    model.eval()
    total_test_loss = 0.0
    with torch.no_grad():
        for batch in test_loader:
            batch = batch.to(device)
            x_recon, q, z, z_hat = model(batch)
            loss, rec_loss, sparsity_loss = model.loss_function(batch, x_recon, q)
            total_test_loss += loss.item() * batch.size(0)
    avg_test_loss = total_test_loss / len(test_loader.dataset)
    print(f"Average test loss on TEST file: {avg_test_loss:.6f}")

    # ---------------------------
    # Plotting Reconstruction Comparisons (TEST file)
    # ---------------------------
    # Evaluate one normal sample from the test file.
    normal_dataset_eval = ECG5000(mode='normal', split='test')
    normal_sample = normal_dataset_eval[0].unsqueeze(0).to(device)
    with torch.no_grad():
        normal_recon, normal_q, _, _ = model(normal_sample)
    normal_series_np = normal_sample.cpu().numpy().flatten()
    normal_recon_np = normal_recon.cpu().numpy().flatten()

    # Evaluate one anomaly sample from the test file.
    anomaly_dataset_eval = ECG5000(mode='anomaly', split='test')
    anomaly_sample = anomaly_dataset_eval[0].unsqueeze(0).to(device)
    with torch.no_grad():
        anomaly_recon, anomaly_q, _, _ = model(anomaly_sample)
    anomaly_series_np = anomaly_sample.cpu().numpy().flatten()
    anomaly_recon_np = anomaly_recon.cpu().numpy().flatten()

    # Plot normal sample reconstruction.
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(normal_series_np, label="Original Normal")
    plt.plot(normal_recon_np, label="Reconstructed Normal", linestyle="--")
    plt.title("Normal Sample Reconstruction (TEST file)")
    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.legend()

    # Plot anomaly sample reconstruction.
    plt.subplot(1, 2, 2)
    plt.plot(anomaly_series_np, label="Original Anomaly")
    plt.plot(anomaly_recon_np, label="Reconstructed Anomaly", linestyle="--")
    plt.title("Anomaly Sample Reconstruction (TEST file)")
    plt.xlabel("Time Step")
    plt.ylabel("Value")
    plt.legend()
    plt.show()
