In [10]:
from scipy.io import loadmat

mat_data = loadmat('mvmd_20_trials_N500_K08.mat')

# Extract variables (MATLAB adds '_all' suffix)
u_time = mat_data['u']          # Shape: (3, 8, 4000, 306)
u_hat_freq = mat_data['u_hat']  # Shape: (3, 8, 4000, 306)
omega = mat_data['omega']       # Shape: (3, 8, 306)

# Verify shapes
print(f"Time modes shape: {u_time.shape}")
print(f"Freq modes shape: {u_hat_freq.shape}")

Time modes shape: (20, 8, 4000, 306)
Freq modes shape: (20, 8, 4000, 306)


In [22]:
import numpy as np
import torch
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

def create_early_fusion_features(u_time, u_hat_freq):
    """
    Args:
        u_time: Time-domain modes [trials, modes, timepoints, channels]
        u_hat_freq: Frequency-domain modes [trials, modes, freqs, channels]
    
    Returns:
        Combined features [trials, 16 modes, timepoints, 306 channels]
    """
    # 1. Process time-domain features
    time_features = np.abs(u_time)  # Magnitude of time modes
    
    # 2. Process frequency-domain features
    # Average power across frequency bins (keep mode and channel)
    freq_power = np.mean(np.abs(u_hat_freq), axis=2)  # [trials, modes, channels]
    
    # Replicate frequency features across all timepoints
    freq_features = np.repeat(freq_power[:, :, np.newaxis, :], 
                             u_time.shape[2], axis=2)  # [trials, modes, time, channels]
    
    # 3. Concatenate time and frequency modes
    combined = np.concatenate([time_features, freq_features], axis=1)
    
    return combined

# -------------------------------------------------
# Load your data (replace with actual loading code)
# -------------------------------------------------
# Assuming:
# - u_time shape: [20 trials, 8 modes, 4000 timepoints, 306 channels]
# - u_hat_freq shape: [20 trials, 8 modes, 4000 freqs, 306 channels]


# Create fused features
X = create_early_fusion_features(u_time, u_hat_freq)  # [200, 16, 4000, 306]
print(X.shape)
# Create labels (4 tasks, 3 trials each)
y = np.array([1, 1, 1, 1, 2, 4, 4, 4, 3, 1, 1, 2, 2, 2, 4, 2, 3, 4, 4, 2])
y = y-1
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# Channel-wise normalization
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train.reshape(-1, 306)).reshape(X_train.shape)
X_test = scaler.transform(X_test.reshape(-1, 306)).reshape(X_test.shape)

# Convert to PyTorch tensors
train_data = torch.utils.data.TensorDataset(
    torch.FloatTensor(X_train).permute(0, 3, 2, 1),  # [batch, 306, 4000, 8]
    torch.LongTensor(y_train))
test_data = torch.utils.data.TensorDataset(
    torch.FloatTensor(X_test).permute(0, 3, 2, 1),
    torch.LongTensor(y_test))

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size)

(20, 16, 4000, 306)


In [23]:
import torch.nn as nn
import torch.nn.functional as F

class EarlyFusionLTC(nn.Module):
    def __init__(self, input_channels=306, num_modes=16, num_classes=4):
        super().__init__()
        
        # 1. Spatial Attention (for 306 channels)
        self.channel_att = nn.Sequential(
            nn.Conv1d(input_channels, 64, 1),
            nn.ReLU(),
            nn.Conv1d(64, input_channels, 1),
            nn.Sigmoid()
        )
        
        # 2. Liquid Time-Constant Network
        self.ltc = LiquidBlock(
            input_dim=num_modes,
            hidden_dim=128,
            time_steps=4000
        )
        
        # 3. Temporal Attention
        self.temp_att = nn.MultiheadAttention(128, num_heads=4, batch_first=True)
        
        # 4. Classifier
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        # x shape: [batch, 306 channels, 4000 time, 8 modes]
        batch_size = x.size(0)
        
        # 1. Spatial Attention
        spatial_weights = self.channel_att(x.mean(dim=2))  # [batch, 306, 8]
        x = x * spatial_weights.unsqueeze(2)  # [batch, 306, 4000, 8]
        
        # 2. Liquid Processing
        x = x.permute(0, 3, 2, 1)  # [batch, 8, 4000, 306]
        ltc_out = self.ltc(x)  # [batch, 4000, 128]
        
        # 3. Temporal Attention
        attn_out, _ = self.temp_att(
            ltc_out, ltc_out, ltc_out
        )  # [batch, 4000, 128]
        
        # 4. Pooling & Classification
        pooled = attn_out.mean(dim=1)  # [batch, 128]
        return self.classifier(pooled)

class LiquidBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, time_steps):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.time_steps = time_steps
        
        # Liquid parameters
        self.W = nn.Linear(input_dim, hidden_dim)
        self.A = nn.Linear(hidden_dim, hidden_dim)
        self.time_emb = nn.Embedding(time_steps, hidden_dim)
        
        # Adaptive time constants
        self.tau = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        print(f"\n[Input to LiquidBlock] shape: {x.shape}")
        
        # x shape: [batch, modes, time, channels]
        batch_size, num_modes, seq_len, num_channels = x.size()
        print(f"Unpacked dimensions - batch: {batch_size}, modes: {num_modes}, time: {seq_len}, channels: {num_channels}")
        
        # Average across channels and reshape
        x = x.mean(dim=3)  # [batch, modes, time]
        print(f"\nAfter mean(dim=3): {x.shape}")
        
        x = x.permute(0, 2, 1)  # [batch, time, modes]
        print(f"After permute(0,2,1): {x.shape}")
        
        # Check if the input dimension matches
        if x.size(-1) != self.input_dim:
            raise ValueError(f"Input dimension mismatch. Expected {self.input_dim}, got {x.size(-1)}")
        
        # Process all time steps at once
        x_flat = x.reshape(-1, self.input_dim)  # [batch*time, modes]
        print(f"\nAfter reshape to [batch*time, modes]: {x_flat.shape}")
        
        x_t = self.W(x_flat)  # [batch*time, hidden_dim]
        print(f"After linear transformation W: {x_t.shape}")
        
        try:
            x_t = x_t.view(batch_size, seq_len, self.hidden_dim)  # [batch, time, hidden]
            print(f"After reshape to [batch, time, hidden]: {x_t.shape}")
        except RuntimeError as e:
            print(f"\nERROR in reshape!")
            print(f"Attempting to reshape {x_t.shape} to [{batch_size}, {seq_len}, {self.hidden_dim}]")
            print(f"Total elements before: {x_t.numel()}")
            print(f"Total elements after: {batch_size * seq_len * self.hidden_dim}")
            raise e
        
        h = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        outputs = []
        
        for t in range(seq_len):
            # Time embedding
            time_idx = torch.tensor([t]*batch_size).to(x.device)
            t_emb = self.time_emb(time_idx)
            
            # Liquid dynamics
            dh = torch.sigmoid(
                self.A(h) + x_t[:, t, :] + t_emb
            )
            tau = self.tau(h)
            h = h * (1 - tau) + dh * tau
            
            outputs.append(h)
        
        output_tensor = torch.stack(outputs, dim=1)  # [batch, time, hidden]
        print(f"\nFinal output shape: {output_tensor.shape}")
        return output_tensor

In [None]:
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
    
    acc = 100. * correct / len(loader.dataset)
    return total_loss / len(loader), acc


def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            total_loss += criterion(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    acc = 100. * correct / len(loader.dataset)
    return total_loss / len(loader), acc


# Initialize
device = torch.device('cpu')
model = EarlyFusionLTC().to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=3)

# Training loop
for epoch in range(100):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, test_loader, criterion)
    
    scheduler.step(val_acc)
    
    print(f"Epoch {epoch+1:03d} | "
          f"Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | "
          f"Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")


[Input to LiquidBlock] shape: torch.Size([16, 16, 4000, 306])
Unpacked dimensions - batch: 16, modes: 16, time: 4000, channels: 306

After mean(dim=3): torch.Size([16, 16, 4000])
After permute(0,2,1): torch.Size([16, 4000, 16])

After reshape to [batch*time, modes]: torch.Size([64000, 16])
After linear transformation W: torch.Size([64000, 128])
After reshape to [batch, time, hidden]: torch.Size([16, 4000, 128])

Final output shape: torch.Size([16, 4000, 128])

[Input to LiquidBlock] shape: torch.Size([4, 16, 4000, 306])
Unpacked dimensions - batch: 4, modes: 16, time: 4000, channels: 306

After mean(dim=3): torch.Size([4, 16, 4000])
After permute(0,2,1): torch.Size([4, 4000, 16])

After reshape to [batch*time, modes]: torch.Size([16000, 16])
After linear transformation W: torch.Size([16000, 128])
After reshape to [batch, time, hidden]: torch.Size([4, 4000, 128])

Final output shape: torch.Size([4, 4000, 128])
Epoch 001 | Train Loss: 1.4104 Acc: 18.75% | Val Loss: 1.4016 Acc: 25.00%

[