In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler, TensorDataset
import math
import random

import numpy as np
import pickle
import os
import torch

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.model_selection import StratifiedKFold

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.ops import DeformConv2d

class DeformableTokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model, cycle_length):
        """
        c_in: number of input channels
        d_model: desired output embedding dimension
        cycle_length: kernel size (and stride) in the temporal dimension,
                      which corresponds to one cycle.
        """
        super(DeformableTokenEmbedding, self).__init__()
        # We'll treat the 1D sequence as a 2D input with height=1 and width=L.
        # Define the offset convolution that outputs 2 * cycle_length channels (for x and y offsets).
        self.offset_conv = nn.Conv2d(
            in_channels=c_in,
            out_channels=2 * cycle_length,
            kernel_size=(1, cycle_length),
            stride=(1, cycle_length)
        )
        # Define the deformable convolution:
        # Input: [B, c_in, 1, L]
        # Kernel size: (1, cycle_length), stride: (1, cycle_length)
        self.deform_conv = DeformConv2d(
            in_channels=c_in,
            out_channels=d_model,
            kernel_size=(1, cycle_length),
            stride=(1, cycle_length),
            padding=(0, 0)
        )
        # Initialize weights for both offset_conv and deform_conv.
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        # Input x shape: [B, L, c_in]
        # Permute to [B, c_in, L] and add a dummy height dimension -> [B, c_in, 1, L]
        x = x.permute(0, 2, 1).unsqueeze(2)
        # Compute the offsets using the offset_conv; output shape: [B, 2 * cycle_length, 1, L_out]
        offsets = self.offset_conv(x)
        # Apply deformable convolution. It uses the computed offsets.
        out = self.deform_conv(x, offsets)
        # out shape: [B, d_model, 1, L_out]
        # Remove the height dimension and transpose to [B, L_out, d_model]
        out = out.squeeze(2).transpose(1, 2)
        return out


In [91]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
import math


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=200):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]


class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model, cycle_length):
        super(TokenEmbedding, self).__init__()
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=cycle_length, stride = cycle_length)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x


class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, cycle_length, dropout=0.0):
        super(DataEmbedding, self).__init__()

        self.value_embedding = DeformableTokenEmbedding(c_in=c_in, d_model=d_model, cycle_length=cycle_length)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = self.value_embedding(x)
        x = torch.cat([self.cls_token.expand(x.size(0), -1, -1), x], dim=1)
        x = x + self.position_embedding(x)
        return self.dropout(x)

In [4]:

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None,
                 d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.norm = nn.LayerNorm(d_model)
        self.inner_attention = attention

        self.query_projection = nn.Linear(d_model,
                                          d_keys * n_heads)
        self.key_projection = nn.Linear(d_model,
                                        d_keys * n_heads)
        self.value_projection = nn.Linear(d_model,
                                          d_values * n_heads)
        self.sigma_projection = nn.Linear(d_model,
                                          n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)

        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads
        x = queries
        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        sigma = self.sigma_projection(x).view(B, L, H)

        out, series, prior, sigma = self.inner_attention(
            queries,
            keys,
            values,
            sigma,
            attn_mask
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), series, prior, sigma

In [5]:
class TriangularCausalMask():
    def __init__(self, B, L, device="cpu"):
        mask_shape = [B, 1, L, L]
        with torch.no_grad():
            self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)

    @property
    def mask(self):
        return self._mask


class AnomalyAttention(nn.Module):
    def __init__(self, win_size, mask_flag=True, scale=None, attention_dropout=0.0, output_attention=False):
        super(AnomalyAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
        window_size = win_size
        self.distances = torch.zeros((window_size, window_size)).cuda()
        for i in range(window_size):
            for j in range(window_size):
                self.distances[i][j] = abs(i - j)

    def forward(self, queries, keys, values, sigma, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / math.sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        attn = scale * scores

        sigma = sigma.transpose(1, 2)  # B L H ->  B H L
        window_size = attn.shape[-1]
        sigma = torch.sigmoid(sigma * 5) + 1e-5
        sigma = torch.pow(3, sigma) - 1
        sigma = sigma.unsqueeze(-1).repeat(1, 1, 1, window_size)  # B H L L

        prior = self.distances.unsqueeze(0).unsqueeze(0).repeat(sigma.shape[0], sigma.shape[1], 1, 1).cuda()
        prior = 1.0 / (math.sqrt(2 * math.pi) * sigma) * torch.exp(-prior ** 2 / 2 / (sigma ** 2))

        series = self.dropout(torch.softmax(attn, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", series, values)

        if self.output_attention:
            return (V.contiguous(), series, prior, sigma)
        else:
            return (V.contiguous(), None)

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F



class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None):
        new_x, attn, mask, sigma = self.attention(
            x, x, x,
            attn_mask=attn_mask
        )
        x = x + self.dropout(new_x)
        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn, mask, sigma


class Encoder(nn.Module):
    def __init__(self, attn_layers, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.norm = norm_layer

    def forward(self, x, attn_mask=None):
        # x [B, L, D]
        series_list = []
        prior_list = []
        sigma_list = []
        for attn_layer in self.attn_layers:
            x, series, prior, sigma = attn_layer(x, attn_mask=attn_mask)
            series_list.append(series)
            prior_list.append(prior)
            sigma_list.append(sigma)

        if self.norm is not None:
            x = self.norm(x)

        return x, series_list, prior_list, sigma_list


class AnomalyTransformer(nn.Module):
    def __init__(self, win_size, enc_in, c_out, d_model=512, n_heads=8, e_layers=2, d_ff=512,
                 dropout=0.0, activation='gelu', output_attention=True):
        super(AnomalyTransformer, self).__init__()
        self.output_attention = output_attention

        # Encoding
        self.embedding = DataEmbedding(enc_in, d_model, cycle_length=80, dropout = dropout)

        # Encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        AnomalyAttention(win_size, False, attention_dropout=dropout, output_attention=output_attention),
                        d_model, n_heads),
                    d_model,
                    d_ff,
                    dropout=dropout,
                    activation=activation
                ) for l in range(e_layers)
            ],
            norm_layer=torch.nn.LayerNorm(d_model)
        )

        self.cls_head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model//2),
            nn.ReLU(),
            nn.Linear(d_model//2, 2)
        )
    def forward(self, x):
        enc_out = self.embedding(x)
        enc_out, series, prior, sigmas = self.encoder(enc_out)
        #feed the first cls token to the classifier
        cls_out = self.cls_head(enc_out[:, 0, :])

        if self.output_attention:
            return cls_out, series, prior, sigmas
        else:
            return cls_out  

In [None]:
#load data
train_data = pickle.load(open('packedData/train_transformer.pkl', 'rb'))
test_data = pickle.load(open('packedData/test_transformer.pkl', 'rb'))

print(np.array(train_data['eye']).shape)
train_eye = torch.tensor(np.array(train_data['eye']), dtype=torch.float32)
train_label = torch.tensor(np.array(train_data['label']), dtype=torch.long)

test_eye = torch.tensor(np.array(test_data['eye']), dtype=torch.float32)
test_label = torch.tensor(np.array(test_data['label']), dtype=torch.long)

eye = torch.cat((train_eye, test_eye), 0)
label = torch.cat((train_label, test_label), 0)


#train
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


(101, 1520, 3)


In [8]:
circle_length = 80
sequence_length = 1520
embedding_dim = 128
num_heads = 4
num_layers = 3
num_classes = 2
epochs = 1
window_length = 20 #19 cycles + 1 cls token


def my_kl_loss(p, q):

    res = p * (torch.log(p + 1e-4) - torch.log(q + 1e-4))
    return torch.mean(torch.sum(res, dim=-1), dim=1)


def bootstrap_balance_minority(data, labels):

    class_0_indices = (labels == 0).nonzero(as_tuple=True)[0]
    class_1_indices = (labels == 1).nonzero(as_tuple=True)[0]

    if len(class_0_indices) < len(class_1_indices):
        minority_indices = class_0_indices
        majority_indices = class_1_indices
    else:
        minority_indices = class_1_indices
        majority_indices = class_0_indices
    
    num_minority_samples_needed = len(majority_indices) - len(minority_indices)
    additional_samples = np.random.choice(minority_indices, size=num_minority_samples_needed, replace=True)
    
    balanced_minority_indices = torch.cat((minority_indices, torch.tensor(additional_samples)))
    
    balanced_indices = torch.cat((balanced_minority_indices, majority_indices))
    
    balanced_indices = balanced_indices[torch.randperm(len(balanced_indices))]

    balanced_data = data[balanced_indices]
    balanced_labels = labels[balanced_indices]

    return balanced_data, balanced_labels


def train(model, train_loader, val_loader, criterion, optimizer, device, k, alpha=0.8):
    """
    Trains the model for one epoch using classification loss combined with association discrepancy.
    k: weight factor for the discrepancy loss.
    Returns:
        train_accuracy, val_accuracy, val_loss
    """
    model.train()
    loss_list = []
    total_correct = 0
    total_samples = 0
    import time
    time_now = time.time()
    iter_count = 0

    for i, (input_data, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        iter_count += 1

        # Move data to device
        inputs = input_data.float().to(device)
        labels = labels.to(device)

        # Forward pass: output is classification logits, and the model also returns attention info:
        # series: list of series-association tensors from each encoder layer
        # prior: list of prior-association tensors (computed via Gaussian kernel) from each encoder layer
        # sigma: corresponding sigma values
        output, series, prior, sigma = model(inputs)

        # Classification loss on the CLS token output
        cls_loss = criterion(output, labels)

        # Compute association discrepancy losses over all layers
        series_loss = 0.0
        prior_loss = 0.0
        # Assuming series and prior are lists (one per encoder layer)
        for u in range(len(prior)):
            prior_ex_cls = prior[u][:, :, 1:, 1:]
            series_ex_cls = series[u][:, :, 1:, 1:]
            
            # Normalize the prior association along the last dimension
            norm_prior = prior_ex_cls / torch.unsqueeze(torch.sum(prior_ex_cls, dim=-1), dim=-1)
            
            # Compute the KL divergence-based losses
            series_loss += (torch.mean(my_kl_loss(series_ex_cls, norm_prior.detach())) +
                            torch.mean(my_kl_loss(norm_prior.detach(), series_ex_cls)))
            prior_loss += (torch.mean(my_kl_loss(norm_prior, series_ex_cls.detach())) +
                        torch.mean(my_kl_loss(series_ex_cls.detach(), norm_prior)))
        series_loss = series_loss / len(prior)
        prior_loss = prior_loss / len(prior)


        # Define two losses for minimax strategy:
        # loss1 encourages larger discrepancy (by subtracting the series loss)
        # loss2 encourages the opposite for the prior branch
        loss1 = - k * series_loss 
        loss2 = cls_loss + k * prior_loss 

        # Backpropagate in two phases (minimax strategy)
        loss1.backward(retain_graph=True)
        loss2.backward()
        #cls_loss.backward()
        optimizer.step()

        loss_list.append(cls_loss.item())

        # Calculate training accuracy for this batch
        _, predicted = torch.max(output, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        if (i + 1) % 100 == 0:
            speed = (time.time() - time_now) / iter_count
            print(f'\tIteration {i+1}: speed: {speed:.4f}s/iter')
            iter_count = 0
            time_now = time.time()

    avg_train_loss = np.mean(loss_list)
    train_accuracy = total_correct / total_samples * 100

    # Evaluate on validation set
    val_loss, val_accuracy = validate(model, val_loader, criterion, device)
    return train_accuracy, val_accuracy, val_loss


def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for input_data, labels in val_loader:
            inputs = input_data.float().to(device)
            labels = labels.to(device)
            output, series, prior, sigma = model(inputs)
            loss = criterion(output, labels)
            total_loss += loss.item() * labels.size(0)
            _, predicted = torch.max(output, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100
    return avg_loss, accuracy




In [9]:
epochs = 40
k_fold = 5
kf = StratifiedKFold(n_splits=k_fold, shuffle=True, random_state=42)
fold_accuracy = []
fold_train_accuracy = []

early_stop_patience = 4   # number of epochs with no improvement before stopping
min_delta = 1e-4          # minimum change to qualify as an improvement

for fold, (train_index, test_index) in enumerate(kf.split(eye, label)):
    # Define the classification loss (cross-entropy)
    criterion = nn.CrossEntropyLoss().to(device)
    
    # Instantiate the revised anomaly transformer model.
    model = AnomalyTransformer(
        win_size=window_length,  # total tokens: 20 (19 cycles + 1 CLS token)
        enc_in=3,                # input channels from eye data
        c_out=num_classes,
        d_model=embedding_dim,
        n_heads=num_heads,
        e_layers=num_layers,
        d_ff=embedding_dim,       # adjust if needed
        dropout=0.3,
        activation='gelu',
        output_attention=True
    ).to(device)

    #print number of parameters
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=2e-2)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.5, patience=5, cooldown=0.5, min_lr=5e-6, verbose=True
    )
    
    print(f"Fold {fold}")
    # Split the dataset using the provided indices
    train_eye_fold, test_eye_fold = eye[train_index], eye[test_index]
    train_label_fold, test_label_fold = label[train_index], label[test_index]

    # Bootstrap balance the minority class for training and validation sets
    train_eye_fold, train_label_fold = bootstrap_balance_minority(train_eye_fold, train_label_fold)
    test_eye_fold, test_label_fold = bootstrap_balance_minority(test_eye_fold, test_label_fold)

    train_dataset = torch.utils.data.TensorDataset(train_eye_fold, train_label_fold)
    val_dataset = torch.utils.data.TensorDataset(test_eye_fold, test_label_fold)

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    best_val_acc = 0
    previous_val_acc = 0
    patience_counter = 0
    for epoch in range(epochs):
        # Use the modified training function that combines classification loss with discrepancy loss.
        train_accuracy, val_accuracy, val_loss = train(
            model, train_loader, val_loader, criterion, optimizer, device, k=0, alpha=0
        )
        scheduler.step(val_loss)
        print(f"Epoch {epoch}: Train Acc = {train_accuracy:.2f}, Val Acc = {val_accuracy:.2f}, Val Loss = {val_loss:.4f}")
        
        # Check if validation loss improved by at least min_delta
        if  val_accuracy - previous_val_acc >=0:
            patience_counter = 0
        else:
            patience_counter += 1

        if val_accuracy > best_val_acc:
            best_val_acc = val_accuracy
        
        if patience_counter >= early_stop_patience or val_accuracy >= 88.0:
            print("Early stopping triggered")
            #break
        previous_val_acc = val_accuracy

    fold_accuracy.append(best_val_acc)
    fold_train_accuracy.append(train_accuracy)

print(f"Average Validation Accuracy: {np.mean(fold_accuracy)}")
print("Fold Validation Accuracies:", fold_accuracy)
print(f"Average Training Accuracy: {np.mean(fold_train_accuracy)}")
print("Fold Training Accuracies:", fold_train_accuracy)




Number of parameters: 379502
Fold 0
Epoch 0: Train Acc = 46.91, Val Acc = 50.00, Val Loss = 0.6959
Epoch 1: Train Acc = 49.38, Val Acc = 50.00, Val Loss = 0.6974
Epoch 2: Train Acc = 53.70, Val Acc = 50.00, Val Loss = 0.6887
Epoch 3: Train Acc = 55.56, Val Acc = 50.00, Val Loss = 0.6865
Epoch 4: Train Acc = 50.62, Val Acc = 50.00, Val Loss = 0.6856
Epoch 5: Train Acc = 56.79, Val Acc = 50.00, Val Loss = 0.6828
Epoch 6: Train Acc = 51.85, Val Acc = 50.00, Val Loss = 0.6791
Epoch 7: Train Acc = 62.35, Val Acc = 50.00, Val Loss = 0.6754
Epoch 8: Train Acc = 58.64, Val Acc = 80.95, Val Loss = 0.6681
Epoch 9: Train Acc = 53.70, Val Acc = 61.90, Val Loss = 0.6600
Epoch 10: Train Acc = 58.02, Val Acc = 88.10, Val Loss = 0.6475
Early stopping triggered
Epoch 11: Train Acc = 56.79, Val Acc = 90.48, Val Loss = 0.6337
Early stopping triggered
Epoch 12: Train Acc = 61.11, Val Acc = 85.71, Val Loss = 0.6083
Epoch 13: Train Acc = 69.75, Val Acc = 83.33, Val Loss = 0.5743
Epoch 14: Train Acc = 65.43,