In [1]:
import torch
from torch.utils.data import Dataset
import numpy as np
from torch_geometric.utils import to_undirected
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from ssm_utlis import set_seed
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm 
#set_seed(21)

class CTDGraphDataset(Dataset):
    def __init__(self, N, M, seed=42):
        """
        Args:
            N (int): number of edges per sequence
            M (int): number of sequences
        """
        self.N = N
        self.M = M
        np.random.seed(seed)

        self.data = []
        self.labels = []

        for _ in range(M):
            # node 0 signal = label
            y = np.random.choice([0, 1])
            self.labels.append(y)

            # node signals
            signals = {0: float(y)*2-1}
            for i in range(1, N + 1):
                signals[i] = np.random.uniform(-1, 1)

            # edges
            src = np.arange(0, N)
            dst = np.arange(1, N + 1)
            times = np.sort(np.random.uniform(0, 10, N))  # increasing times

            # endpoint signals
            x_src = np.array([signals[s] for s in src])
            x_dst = np.array([signals[d] for d in dst])

            self.data.append((src, dst, times, x_src, x_dst))

    def __len__(self):
        return self.M

    def __getitem__(self, idx):
        src, dst, times, x_src, x_dst = self.data[idx]
        y = self.labels[idx]

        return {
            "src": torch.tensor(src, dtype=torch.long),
            "dst": torch.tensor(dst, dtype=torch.long),
            "t": torch.tensor(times, dtype=torch.float32),
            "x_src": torch.tensor(x_src, dtype=torch.float32),
            "x_dst": torch.tensor(x_dst, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32)
        }





In [2]:
from torch.utils.data import DataLoader, random_split


def get_loaders(N, M, batch_size=4, train_split=0.7, val_split=0.15):
    """
    Args:
        N (int): number of nodes (dataset-specific param for CTDGraphDataset).
        M (int): number of samples.
        batch_size (int): batch size for loaders.
        train_split (float): fraction of data to use for training.
        val_split (float): fraction of data to use for validation.
        seed (int): random seed.
        
    Returns:
        train_loader, val_loader, test_loader
    """
    dataset = CTDGraphDataset(N, M)

    # sizes
    train_size = int(train_split * M)
    val_size = int(val_split * M)
    test_size = M - (train_size + val_size)

    # make sure the total matches exactly
    assert train_size + val_size + test_size == M, "Split sizes must sum to dataset size"

    # split
    train_dataset, val_dataset, test_dataset = random_split(
        dataset, [train_size, val_size, test_size], generator=torch.Generator()
    )

    # loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader, test_loader

In [3]:
class EarlyStopping:
    def __init__(self, patience=10, delta=0.0, path="best_model_sq.pt"):
        self.patience = patience
        self.delta = delta
        self.best_acc = -float("inf")
        self.counter = 0
        self.early_stop = False
        self.path = path  # save checkpoint

    def __call__(self, acc, model):
        if acc > self.best_acc + self.delta:
            self.best_acc = acc
            self.counter = 0
            # save model checkpoint
            torch.save(model.state_dict(), self.path)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

## Create Dataset

In [13]:
N = 9
M=1000
batch_size = 400
train_loader,val_laoder,test_loader = get_loaders(N=N,M=M,batch_size=batch_size)

## Create Model

In [14]:
a = torch.arange(10)
torch.hstack([a,a]).view(2*10,-1).view(2,10,-1)[0]

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]])

In [17]:
from ssm_memory import MemoryModel
state_dim =32
time_dim = 4
device = 'cuda'
lr = 1e-3
epochs = 1600
model = MemoryModel(num_nodes=N+1,input_dim=1,hidden_dim=state_dim,time_dim=time_dim,reg=1e-4,device=device,update_type='mamba').to(device)
path = 'sq_mod.pt'
early_stopper = EarlyStopping(patience=500, delta=1e-3,path=path)


Static Embeddings: Nil


## Train Model

In [20]:
criterion = nn.BCEWithLogitsLoss()  # for binary classification
optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=1e-3)

In = torch.eye(4).to(device)
A = In.clone()
e = [(0,1),(0,2),(2,3)]

for u,v in e:
    A[u,v] = 1
    A[v,u] = 1
D = torch.sum(A,dim=-1)
D[D==0]=1
D =  torch.diag(D**-(1/2))
L_t = In - D@A@D

ep = [(0,2),(2,3)]
Ap = In.clone()
for u,v in ep:
    Ap[u,v] = 1
    Ap[v,u] = 1

D = torch.sum(Ap,dim=-1)
D[D==0]=1
D =  torch.diag(D**-(1/2))

L_p = In - D@Ap@D
Id   = torch.eye(state_dim,device='cuda')
delta = 1/N+1
for epoch in tqdm(range(epochs),desc='Epochs: '):
    model.train()
    total_loss, total_score, total_samples = 0, [], []
    pred = []
    hidden_state_list = []
    # yt = torch.zeros((len(train_loader),1),device=model.device)
    neighbour = 0
    thresholds = 0
    for bid,batch in enumerate(train_loader):
        src, dst = batch["src"].to(device), batch["dst"].to(device)
        batch_size = src.shape[0]
        
        batch_ids = torch.arange(src.shape[0],device=device)
        x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
        src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
        m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
        ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
        # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
        for t in range(src.shape[-1]):
            src_t = src_b[:,t]
            dst_t = dst_b[:,t]
            
            # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
            K = 2
            last_id = torch.clamp_min(src[0,t]-K,min=0).to(torch.int).item()
              
            active = torch.arange(last_id,dst[0,t]+1).to(device)
            ax = len(active)
            if ax<4:
                zero_ids = 4-len(active)
                active = torch.hstack([active,torch.zeros(zero_ids,).to(device).to(torch.int)])
            x_t = torch.zeros((batch_size,len(active),model.input_dim),device=model.device)
            # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
            # L_t = torch.eye(dst[0,t]+1).to(device)
            # L_t[src[0,:t+1],dst[0,:t+1]]=1
            # L_t[dst[0,:t+1],src[0,:t+1]]=1
            # D = torch.diag(torch.sum(L_t,dim=1)**(-0.5))
            # L_t = torch.eye(len(L_t)).to(device) - D@L_t@D

            # L_p = torch.eye(len(L_t)).to(device)
            # if t>0:
                # L_p[src[0,:t],dst[0,:t]]=1
                # L_p[dst[0,:t],src[0,:t]]=1
                # D = torch.diag(torch.sum(L_p,dim=1)**(-0.5))
                # L_p = torch.eye(len(L_t)).to(device) - D@L_p@D
                
            # Fill Data to Current Node
            if ax<4:
                for k in range(2,4-ax+1):
                    x_t[:,k] = x_src[:,0].unsqueeze(-1)
                
            x_t[:,0]= x_src[:,t].unsqueeze(-1)
            x_t[:,1]= x_dst[:,t].unsqueeze(-1)

            
            prev_c_1= m1_vec[:,active].view((batch_size*(len(active)),-1))
            x_batch = x_t.view(batch_size*len(active),-1)
            zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
            
            A = model.A_
            At_ = torch.matrix_exp(-A.T*delta)
            reg = 1
            In = torch.eye(len(active)).to(device)
            PL_ = In + reg*L_t + L_t@L_t
            PL_p = In + reg*L_p + + L_p@L_p                             

            PL_inv = torch.linalg.inv(PL_)
            A_s= In - PL_inv@PL_p
            PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
            A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
            updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ zt
            final_state = updated_c_1

            m1_vec[:,active] = updated_c_1.view(batch_size,len(active),-1)
            ne_vec[:,active] = final_state.view(batch_size,len(active),-1)

            # edge_index = current_edge_index.clone()

        hidden_state_list.append(ne_vec)
        logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
        loss = criterion(target = y,input=logits) #+ torch.norm(model.A_hippo-model.A_,p=2)*1e-1
        total_loss += loss #/ batch_size
        # preds = ( logits> 0).float()
        pred += (logits>0).to(int).tolist()
        total_score += logits.tolist()
        total_samples += y.tolist()

        optimizer.zero_grad()
        loss.backward()
   

        optimizer.step()


    v_score, v_samples =  [], []
    v_pred = []
    hidden_state_list = []
    # yt = torch.zeros((len(train_loader),1),device=model.device)
    neighbour = 0
    thresholds = 0
    with torch.no_grad():
        for bid,batch in enumerate(val_laoder):
            src, dst = batch["src"].to(device), batch["dst"].to(device)
            batch_size = src.shape[0]
            
            batch_ids = torch.arange(src.shape[0],device=device)
            x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
            src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
            m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
            ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
            # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
            for t in range(src.shape[-1]):
                src_t = src_b[:,t]
                dst_t = dst_b[:,t]
                
                # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
                K = 2
                last_id = torch.clamp_min(src[0,t]-K,min=0).to(torch.int).item()
                
                active = torch.arange(last_id,dst[0,t]+1).to(device)
                ax = len(active)
                if ax<4:
                    zero_ids = 4-len(active)
                    active = torch.hstack([active,torch.zeros(zero_ids,).to(device).to(torch.int)])
                x_t = torch.zeros((batch_size,len(active),model.input_dim),device=model.device)
                # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
                # L_t = torch.eye(dst[0,t]+1).to(device)
                # L_t[src[0,:t+1],dst[0,:t+1]]=1
                # L_t[dst[0,:t+1],src[0,:t+1]]=1
                # D = torch.diag(torch.sum(L_t,dim=1)**(-0.5))
                # L_t = torch.eye(len(L_t)).to(device) - D@L_t@D

                # L_p = torch.eye(len(L_t)).to(device)
                # if t>0:
                    # L_p[src[0,:t],dst[0,:t]]=1
                    # L_p[dst[0,:t],src[0,:t]]=1
                    # D = torch.diag(torch.sum(L_p,dim=1)**(-0.5))
                    # L_p = torch.eye(len(L_t)).to(device) - D@L_p@D
                    
                # Fill Data to Current Node
                # if ax<4:
                #     for k in range(2,4-ax+1):
                #         x_t[:,k] = x_src[:,0].unsqueeze(-1)
                    
                x_t[:,0]= x_src[:,t].unsqueeze(-1)
                x_t[:,1]= x_dst[:,t].unsqueeze(-1)

                
                prev_c_1= m1_vec[:,active].view((batch_size*(len(active)),-1))
                x_batch = x_t.view(batch_size*len(active),-1)
                zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
                
                A = model.A_
                At_ = torch.matrix_exp(-A.T*delta)
                reg = 1
                In = torch.eye(len(active)).to(device)
                PL_ = In + reg*L_t + L_t@L_t
                PL_p = In + reg*L_p + + L_p@L_p                             

                PL_inv = torch.linalg.inv(PL_)
                A_s= In - PL_inv@PL_p
                PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
                A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
                updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ zt
                final_state = updated_c_1

                m1_vec[:,active] = updated_c_1.view(batch_size,len(active),-1)
                ne_vec[:,active] = final_state.view(batch_size,len(active),-1)



                # edge_index = current_edge_index.clone()

            hidden_state_list.append(ne_vec)
            logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
            # preds = ( logits> 0).float()
            v_pred += (logits>0).to(int).tolist()
            v_score += logits.tolist()
            v_samples += y.tolist()

    val_acc = accuracy_score(v_pred,v_samples)
    auc = roc_auc_score(total_samples,total_score)
    acc = accuracy_score(pred,total_samples)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} |Train Acc: {acc:.4f}| Val Acc: {val_acc:.4f}")

    early_stopper(val_acc, model)
    if early_stopper.early_stop:
        print(f"Early stopping at epoch {epoch+1}")
        break



Epochs:   0%|          | 0/1600 [00:00<?, ?it/s]

Epoch 1/1600 | Loss: 1.3864 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 2/1600 | Loss: 1.3865 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 3/1600 | Loss: 1.3860 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 4/1600 | Loss: 1.3862 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 5/1600 | Loss: 1.3860 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 6/1600 | Loss: 1.3859 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 7/1600 | Loss: 1.3859 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 8/1600 | Loss: 1.3859 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 9/1600 | Loss: 1.3860 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 10/1600 | Loss: 1.3861 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 11/1600 | Loss: 1.3859 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 12/1600 | Loss: 1.3860 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 13/1600 | Loss: 1.3863 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 14/1600 | Loss: 1.3861 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 15/1600 | Loss: 1.3861 |Train Acc: 0.5071| Val Acc: 0.4733
Epoch 16/1600 | Loss: 1.3860 |Trai

best_t,best_acc = find_best_threshold(total_samples,total_score)

In [43]:
model.load_state_dict(torch.load(path,weights_only=True))
model.eval()
total_loss, total_score, total_samples = 0,[], []
hidden_state_list = []
pred = []
yt = []

for bid,batch in enumerate(test_loader):
    src, dst = batch["src"].to(device), batch["dst"].to(device)
    batch_size = src.shape[0]
    batch_ids = torch.arange(src.shape[0],device=device)
    x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
    src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
    m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
    ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=mode l.device)
    # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
    for t in range(src.shape[-1]):
        src_t = src_b[:,t]
        dst_t = dst_b[:,t]
        # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
        active = torch.hstack([src[0,t],dst[0,t]])
        x_t = torch.zeros((batch_size,3,model.input_dim),device=model.device)
        # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
        
        # Fill Data to Current Node
        x_t[:,0]= x_src[:,t].unsqueeze(-1)
        x_t[:,1]= x_dst[:,t].unsqueeze(-1)

        if t>=1:
            x_t[:,2]= x_src[:,t-1].unsqueeze(-1)
            active = torch.hstack([active,src[0,t-1]])
        else:
            x_t[:,2]= x_src[:,0].unsqueeze(-1)
            active = torch.hstack([active,src[0,0]])

        prev_c_1= m1_vec[:,active].view((batch_size*3,-1))
        x_batch = x_t.view(batch_size*3,-1)
        zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
        # gate = model.B_gate_a(zt)
        # ztn = F.sigmoid(gate)*zt
        Bzt = zt #model.B1(zt)
        # delta_1 = torch.ones((zt.shape[0],model.hidden_dim),device=model.device)
        # A_cont_1 = -torch.exp(model.A_log_1) 
        # At_bar_1 = torch.exp(delta_1 * A_cont_1) 
        At_ = Id - model.A_.T*delta

        reg = 1
        
        # C_1 = ((Bzt * delta))[None]*t_weights[:,None,None] 
        # RHS_ = C_1@torch.matrix_exp((model.A_.T*delta)[None] * t_nodes[:,None,None])
        
        
        PL_inv = torch.linalg.inv(In+reg*L_t)
        PL_prev = In*(1+reg)
        s  =  (In + PL_inv@PL_prev)
        A_s= PL_inv@(reg*del_L)

        PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
        # A_st = torch.matrix_exp(s*t_nodes[:,None,None])
        A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
        # A_stb = torch.stack([torch.block_diag(*s) for s in A_st[:,None,:,:].repeat(1,batch_size,1,1)])
        # integral_1 = torch.sum(A_stb@PL_b[None,:,:]@RHS_,dim=0)

        updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ Bzt

        u_2 =  zt+F.gelu(updated_c_1,approximate='tanh')# + zt approximate for faster results, maybe less accurate
        final_state = updated_c_1
        m1_vec[:,active] = updated_c_1.view(batch_size,3,-1)
        ne_vec[:,active] = final_state.view(batch_size,3,-1)



    hidden_state_list.append(ne_vec)
    logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
    total_score += logits.tolist()
    total_samples += y.tolist()
    pred += (logits>0).to(int).tolist()

auc = roc_auc_score(total_samples,total_score)
best_acc = accuracy_score(total_samples,pred)
print(f" Test Acc: {best_acc:.4f} |Test AUC: {auc:.4f}")


SyntaxError: invalid syntax. Perhaps you forgot a comma? (45451634.py, line 15)

In [None]:
raise

In [None]:
# Full TestBench
delta = 1/20
from ssm_memory import MemoryModel
test_vals = []
for _ in range(10):
    torch.cuda.empty_cache()
    state_dim = 64
    time_dim = 4
    device = 'cuda'
    lr = 1e-3
    epochs = 2000
    print('New Model Created')
    model = MemoryModel(num_nodes=N+1,input_dim=1,hidden_dim=state_dim,time_dim=time_dim,reg=1e-4,device=device,update_type='mamba').to(device)
    path = 'sq_mod.pt'
    early_stopper = EarlyStopping(patience=500, delta=1e-3,path=path)
    criterion = nn.BCEWithLogitsLoss()  # for binary classification
    optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=1e-3)

    In = torch.eye(3).to(device)
    A = In.clone()
    e = [(0,1),(0,2)]

    for u,v in e:
        A[u,v] = 1
        A[v,u] = 1
    D = torch.diag(torch.sum(A,dim=-1)**-(1/2))
    L_t = In - D@A@D

    ep = [(0,2)]
    Ap = In.clone()
    for u,v in ep:
        Ap[u,v] = 1
        Ap[v,u] = 1
    D = torch.diag(torch.sum(Ap,dim=-1)**-(1/2))
    L_p = In - D@Ap@D
    del_L = L_t-L_p
    Id   = torch.eye(state_dim,device='cuda')

    for epoch in tqdm(range(epochs),desc='Epochs: '):
        model.train()
        total_loss, total_score, total_samples = 0, [], []
        pred = []
        hidden_state_list = []
        # yt = torch.zeros((len(train_loader),1),device=model.device)
        neighbour = 0
        thresholds = 0
        for bid,batch in enumerate(train_loader):
            src, dst = batch["src"].to(device), batch["dst"].to(device)
            batch_size = src.shape[0]
            batch_ids = torch.arange(src.shape[0],device=device)
            x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
            src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
            m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
            ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
            # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
            for t in range(src.shape[-1]):
                src_t = src_b[:,t]
                dst_t = dst_b[:,t]
                # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
                active = torch.hstack([src[0,t],dst[0,t]])
                x_t = torch.zeros((batch_size,3,model.input_dim),device=model.device)
                # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
                
                # Fill Data to Current Node
                x_t[:,0]= x_src[:,t].unsqueeze(-1)
                x_t[:,1]= x_dst[:,t].unsqueeze(-1)

                if t>=1:
                    x_t[:,2]= x_src[:,t-1].unsqueeze(-1)
                    active = torch.hstack([active,src[0,t-1]])
                else:
                    x_t[:,2]= x_src[:,0].unsqueeze(-1)
                    active = torch.hstack([active,src[0,0]])

                prev_c_1= m1_vec[:,active].view((batch_size*3,-1))
                x_batch = x_t.view(batch_size*3,-1)
                zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
                A = model.A_
                At_ = Id - A.T*delta
                reg = 1
                PL_inv = torch.linalg.inv(In+reg*L_t)
                A_s= PL_inv@(reg*del_L)
                PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
                A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
                updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ zt
                final_state = updated_c_1

                m1_vec[:,active] = updated_c_1.view(batch_size,3,-1)
                ne_vec[:,active] = final_state.view(batch_size,3,-1)

                # edge_index = current_edge_index.clone()

            hidden_state_list.append(ne_vec)
            logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
            loss = criterion(target = y,input=logits) #+ torch.norm(model.A_hippo-model.A_,p=2)*1e-1
            total_loss += loss #/ batch_size
            # preds = ( logits> 0).float()
            pred += (logits>0).to(int).tolist()
            total_score += logits.tolist()
            total_samples += y.tolist()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()


        v_score, v_samples =  [], []
        v_pred = []
        hidden_state_list = []
        # yt = torch.zeros((len(train_loader),1),device=model.device)
        neighbour = 0
        thresholds = 0
        with torch.no_grad():
            for bid,batch in enumerate(val_laoder):
                src, dst = batch["src"].to(device), batch["dst"].to(device)
                batch_size = src.shape[0]
                batch_ids = torch.arange(src.shape[0],device=device)
                x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
                src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
                m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
                ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
                # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
                for t in range(src.shape[-1]):
                    src_t = src_b[:,t]
                    dst_t = dst_b[:,t]
                    # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
                    active = torch.hstack([src[0,t],dst[0,t]])
                    x_t = torch.zeros((batch_size,3,model.input_dim),device=model.device)
                    # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
                    
                    # Fill Data to Current Node
                    x_t[:,0]= x_src[:,t].unsqueeze(-1)
                    x_t[:,1]= x_dst[:,t].unsqueeze(-1)

                    if t>=1:
                        x_t[:,2]= x_src[:,t-1].unsqueeze(-1)
                        active = torch.hstack([active,src[0,t-1]])
                    else:
                        x_t[:,2]= x_src[:,0].unsqueeze(-1)
                        active = torch.hstack([active,src[0,0]])

                    prev_c_1= m1_vec[:,active].view((batch_size*3,-1))
                    x_batch = x_t.view(batch_size*3,-1)
                    zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
                    A = model.A_
                    At_ = Id - A.T*delta
                    reg = 1
                    PL_inv = torch.linalg.inv(In+reg*L_t)
                    A_s= PL_inv@(reg*del_L)
                    PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
                    A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
                    updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ zt
                    final_state = updated_c_1

                    m1_vec[:,active] = updated_c_1.view(batch_size,3,-1)
                    ne_vec[:,active] = final_state.view(batch_size,3,-1)

                    # edge_index = current_edge_index.clone()

                hidden_state_list.append(ne_vec)
                logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
                # preds = ( logits> 0).float()
                v_pred += (logits>0).to(int).tolist()
                v_score += logits.tolist()
                v_samples += y.tolist()

        val_acc = accuracy_score(v_pred,v_samples)
        auc = roc_auc_score(total_samples,total_score)
        acc = accuracy_score(pred,total_samples)
        
        # print(f"Epoch {epoch+1}/{epochs} | Loss: {total_loss:.4f} |Train Acc: {acc:.4f}| Val Acc: {val_acc:.4f}")

        early_stopper(val_acc, model)
        if early_stopper.early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

    model.load_state_dict(torch.load(path,weights_only=True))
    model.eval()
    total_loss, total_score, total_samples = 0,[], []
    hidden_state_list = []
    pred = []
    yt = []
    for bid,batch in enumerate(test_loader):
        src, dst = batch["src"].to(device), batch["dst"].to(device)
        batch_size = src.shape[0]
        batch_ids = torch.arange(src.shape[0],device=device)
        x_src, x_dst, y = batch["x_src"].to(device), batch["x_dst"].to(device), batch["y"].to(device)
        src_b,dst_b = src+batch_ids[:,None],dst+batch_ids[:,None]
        m1_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
        ne_vec = torch.zeros((batch_size,model.num_nodes,model.hidden_dim),device=model.device)
        # edge_index = torch.stack([torch.arange(batch_size*model.num_nodes,device=model.device),torch.arange(batch_size*model.num_nodes,device=model.device)])
        for t in range(src.shape[-1]):
            src_t = src_b[:,t]
            dst_t = dst_b[:,t]
            # current_edge_index = to_undirected(torch.cat([edge_index,new_edges],dim=1))
            active = torch.hstack([src[0,t],dst[0,t]])
            x_t = torch.zeros((batch_size,3,model.input_dim),device=model.device)
            # x_n =  torch.zeros((batch_size,2,model.hidden_dim),device=model.device)
            
            # Fill Data to Current Node
            x_t[:,0]= x_src[:,t].unsqueeze(-1)
            x_t[:,1]= x_dst[:,t].unsqueeze(-1)

            if t>=1:
                x_t[:,2]= x_src[:,t-1].unsqueeze(-1)
                active = torch.hstack([active,src[0,t-1]])
            else:
                x_t[:,2]= x_src[:,0].unsqueeze(-1)
                active = torch.hstack([active,src[0,0]])

            prev_c_1= m1_vec[:,active].view((batch_size*3,-1))
            x_batch = x_t.view(batch_size*3,-1)
            zt = model.TuneInputSC(x_batch) # input Tuning / Selective Scan
            # gate = model.B_gate_a(zt)
            # ztn = F.sigmoid(gate)*zt
            Bzt = zt #model.B1(zt)
            # delta_1 = torch.ones((zt.shape[0],model.hidden_dim),device=model.device)
            # A_cont_1 = -torch.exp(model.A_log_1) 
            # At_bar_1 = torch.exp(delta_1 * A_cont_1) 
            
            At_ = Id - model.A_.T*delta

            reg = 1
            
            # C_1 = ((Bzt * delta))[None]*t_weights[:,None,None] 
            # RHS_ = C_1@torch.matrix_exp((model.A_.T*delta)[None] * t_nodes[:,None,None])
            
            
            PL_inv = torch.linalg.inv(In+reg*L_t)
            PL_prev = In*(1+reg)
            s  =  (In + PL_inv@PL_prev)
            A_s= PL_inv@(reg*del_L)

            PL_b = torch.block_diag(*PL_inv[None,:,:].repeat(batch_size,1,1))
            # A_st = torch.matrix_exp(s*t_nodes[:,None,None])
            A_sb = torch.block_diag(*A_s.unsqueeze(0).repeat(batch_size,1,1))
            # A_stb = torch.stack([torch.block_diag(*s) for s in A_st[:,None,:,:].repeat(1,batch_size,1,1)])
            # integral_1 = torch.sum(A_stb@PL_b[None,:,:]@RHS_,dim=0)

            updated_c_1=  -A_sb@ prev_c_1 +  prev_c_1 @ At_  +  PL_b @ Bzt

            u_2 =  zt+F.gelu(updated_c_1,approximate='tanh')# + zt approximate for faster results, maybe less accurate
            final_state = updated_c_1
            m1_vec[:,active] = updated_c_1.view(batch_size,3,-1)
            ne_vec[:,active] = final_state.view(batch_size,3,-1)



        hidden_state_list.append(ne_vec)
        logits = model.SeqClass(hidden_state_list[-1][:,-1]).view(-1)
        total_score += logits.tolist()
        total_samples += y.tolist()
        pred += (logits>0).to(int).tolist()

    auc = roc_auc_score(total_samples,total_score)
    best_acc = accuracy_score(total_samples,pred)
    test_vals.append(best_acc)
    print(f" Test Acc: {best_acc:.4f} |Test AUC: {auc:.4f}")

print(f'{100*np.mean(test_vals)} $pm$ {100*np.std(test_vals)}')


New Model Created
Static Embeddings: Nil


Epochs:   0%|          | 0/2000 [00:00<?, ?it/s]

Early stopping at epoch 653
 Test Acc: 0.4533 |Test AUC: 0.4985
New Model Created
Static Embeddings: Nil


Epochs:   0%|          | 0/2000 [00:00<?, ?it/s]

Early stopping at epoch 565
 Test Acc: 0.4867 |Test AUC: 0.4915
New Model Created
Static Embeddings: Nil


Epochs:   0%|          | 0/2000 [00:00<?, ?it/s]

Early stopping at epoch 588
 Test Acc: 0.4800 |Test AUC: 0.4960
New Model Created
Static Embeddings: Nil


Epochs:   0%|          | 0/2000 [00:00<?, ?it/s]

KeyboardInterrupt: 