### Step 1. prepare Dataset + DataLoader

In [1]:
import torch
from torch.utils.data import Dataset

# Define a custom Dataset for the Twitter15 graph-based data
class Twitter15Dataset(Dataset):
    def __init__(self, graph_data_list):
        """
        Args:
            graph_data_list (list): List of graph dictionaries.
                                    Each graph is a dict with keys 'x', 'edge_index', and 'y'.
        """
        self.graphs = graph_data_list  # Store the list of graphs

    def __len__(self):
        # Return the total number of graph samples
        return len(self.graphs)

    def __getitem__(self, idx):
        # Get the graph at index `idx`
        graph = self.graphs[idx]
        x = graph['x']  # Input features tensor of shape (seq_len, feature_dim)
        y = graph['y']  # Label as an integer (e.g., 0~3)

        return x, y  # Return features and label


# Custom collate function for dynamic batching and padding
def collate_fn(batch):
    # Unpack the batch into features (xs) and labels (ys)
    xs, ys = zip(*batch)

    # Determine the maximum sequence length in this batch
    max_len = max(x.shape[0] for x in xs)
    feature_dim = xs[0].shape[1]  # Assume all sequences have the same feature dimension

    padded_xs = []  # To store padded sequences
    masks = []      # To store attention masks (1 for real token, 0 for padding)

    for x in xs:
        seq_len = x.shape[0]
        pad_len = max_len - seq_len

        # Pad the sequence with zeros if it's shorter than max_len
        if pad_len > 0:
            pad = torch.zeros((pad_len, feature_dim), dtype=x.dtype)
            x_padded = torch.cat([x, pad], dim=0)  # Pad at the end
        else:
            x_padded = x  # No padding needed

        # Create a mask: 1s for original tokens, 0s for padding
        mask = torch.cat([torch.ones(seq_len), torch.zeros(pad_len)]).bool()

        padded_xs.append(x_padded)
        masks.append(mask)

    # Stack all sequences and masks into batched tensors
    padded_xs = torch.stack(padded_xs)    # Shape: (batch_size, max_len, feature_dim)
    masks = torch.stack(masks)            # Shape: (batch_size, max_len)
    ys = torch.tensor(ys)                 # Shape: (batch_size,)

    return padded_xs, masks, ys  # Return padded inputs, masks, and labels


In [2]:
from sklearn.model_selection import train_test_split

# Load the preprocessed list of graphs (each graph is a dictionary)
graph_data_list = torch.load("../processed/twitter15_graph_data_clean.pt", weights_only=False)

# Split the dataset into Train / Validation / Test sets with a 7:1:1 ratio
# First split: 70% training, 30% temp (to be further split)
train_graphs, temp_graphs = train_test_split(graph_data_list, test_size=0.3, random_state=42)

# Second split: 15% validation, 15% test from the 30% temp
val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)

# Print the number of samples in each split
print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")


Train: 1043, Val: 223, Test: 224


In [3]:
from torch.utils.data import DataLoader

batch_size = 16  # Set the batch size for training and evaluation

# Build datasets using the Twitter15Dataset class
train_dataset = Twitter15Dataset(train_graphs)
val_dataset = Twitter15Dataset(val_graphs)
test_dataset = Twitter15Dataset(test_graphs)

# Build DataLoaders for each dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,               # Shuffle for training to improve generalization
    collate_fn=collate_fn       # Custom function to pad sequences and create masks
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,              # No shuffle for validation
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,              # No shuffle for testing
    collate_fn=collate_fn
)

print("DataLoaders created successfully!")


DataLoaders created successfully!


### Step 2: MambaEncoder + Pooling + MLP

In [4]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba  # Import the Mamba sequential state model

# A simple classification head using two linear layers and ReLU activation
class ClassifierHead(nn.Module):
    def __init__(self, hidden_dim=256, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),  # First linear layer
            nn.ReLU(),                          # Non-linear activation
            nn.Linear(hidden_dim, num_classes)  # Output layer for class logits
        )

    def forward(self, x):
        return self.fc(x)  # Returns logits (batch_size, num_classes)


# Mamba-based encoder that applies multiple stacked Mamba layers
class MambaEncoder(nn.Module):
    def __init__(self, input_dim=833, hidden_dim=128, num_layers=2, dropout_rate=0.2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)  # Project input to hidden dimension
        self.layers = nn.ModuleList([
            nn.Sequential(
                Mamba(d_model=hidden_dim),       # Mamba layer
                nn.Dropout(dropout_rate)         # Regularization
            ) for _ in range(num_layers)         # Repeat for specified number of layers
        ])
        self.norm = nn.LayerNorm(hidden_dim)     # Final layer normalization
        
    def forward(self, x, mask):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, input_dim)
            mask: Bool tensor of shape (batch_size, seq_len) indicating valid (non-padded) tokens
        """
        x = self.input_proj(x)  # Project input to hidden_dim

        for layer in self.layers:
            x = layer(x)  # Apply Mamba + Dropout

        x = self.norm(x)  # Normalize across hidden_dim
        return x  # Shape: (batch_size, seq_len, hidden_dim)


# Applies masked mean pooling across sequence length
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, mask):
        """
        Args:
            x: Tensor of shape (batch_size, seq_len, hidden_dim)
            mask: Bool tensor (batch_size, seq_len), where 1 means valid token
        Returns:
            pooled: Tensor of shape (batch_size, hidden_dim)
        """
        mask = mask.unsqueeze(-1)  # (batch_size, seq_len, 1) for broadcasting
        x = x * mask  # Zero out padded positions

        sum_x = x.sum(dim=1)             # Sum over sequence length
        lengths = mask.sum(dim=1)        # Number of valid tokens per sample

        pooled = sum_x / lengths.clamp(min=1e-6)  # Avoid divide-by-zero
        return pooled  # Shape: (batch_size, hidden_dim)


# Full classification model: normalization → encoder → pooling → classification
class MambaClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)                 # Normalize input features
        self.encoder = MambaEncoder(input_dim, hidden_dim, num_layers)
        self.pooling = MeanPooling()                        # Masked average pooling
        self.classifier = ClassifierHead(hidden_dim, num_classes)

    def forward(self, x, mask):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
            mask: Bool tensor of shape (batch_size, seq_len)
        Returns:
            logits: Tensor of shape (batch_size, num_classes)
        """
        x = self.norm(x)               # Normalize input
        h = self.encoder(x, mask)      # Apply Mamba encoder
        pooled = self.pooling(h, mask) # Mean-pool over sequence
        logits = self.classifier(pooled)  # Classify
        logits = torch.clamp(logits, min=-10, max=10)  # Clip logits to avoid extreme values
        return logits


### Step 3. Trainer

In [11]:
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

def train_one_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for x, mask, y in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        # make sure there is no nan in dataset
        x = torch.nan_to_num(x, nan=0.0)

        optimizer.zero_grad()
        logits = model(x, mask)

        # ===== 检查 logits 是否正常 =====
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            print("⚠️ Problematic batch detected!")
            print(f"x shape: {x.shape}")
            print(f"mask sum: {mask.sum(dim=1)}")  # 每条链有效节点数量
            print(f"y: {y}")
            print(f"logits max: {torch.nanmax(logits)}, min: {torch.nanmin(logits)}")
            continue  # 跳过这个batch
        
        loss = loss_fn(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * x.size(0)

        preds = logits.argmax(dim=-1)
        all_preds.extend(preds.detach().cpu().tolist())
        all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1

def evaluate_one_epoch(model, val_loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, mask, y in val_loader:
            x = x.to(device)
            mask = mask.to(device)
            y = y.to(device)

            logits = model(x, mask)
            logits = torch.clamp(logits, min=-10, max=10)

            if torch.isnan(logits).any() or torch.isinf(logits).any():
                continue  # 验证时也保护
            
            loss = loss_fn(logits, y)
            
            running_loss += loss.item() * x.size(0)

            preds = logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1

class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0):
        """
        Args:
            patience (int): How many epochs to wait after last improvement
            verbose (bool): Print message when early stopping
            delta (float): Minimum change to qualify as improvement
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_f1 = -float('inf')

    def __call__(self, val_f1, model):
        score = val_f1
    
        if self.best_score is None:
            self.best_score = score
            self.best_f1 = val_f1
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_f1 = val_f1
            self.counter = 0
    

    # def __call__(self, val_f1, model, save_path):
    #     score = val_f1

    #     if self.best_score is None:
    #         self.best_score = score
    #         self.save_checkpoint(val_f1, model, save_path)
    #     elif score < self.best_score + self.delta:
    #         self.counter += 1
    #         if self.verbose:
    #             print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
    #         if self.counter >= self.patience:
    #             self.early_stop = True
    #     else:
    #         self.best_score = score
    #         self.save_checkpoint(val_f1, model, save_path)
    #         self.counter = 0

    # def save_checkpoint(self, val_f1, model, save_path):
    #     """Save model when val_f1 improves."""
    #     torch.save(model.state_dict(), save_path)
    #     self.best_f1 = val_f1

### Step 4: Runner

In [13]:
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

# Train model for one epoch
def train_one_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()  # Set model to training mode
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for x, mask, y in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        # Replace any NaNs in the input with 0
        x = torch.nan_to_num(x, nan=0.0)

        optimizer.zero_grad()           # Clear gradients
        logits = model(x, mask)         # Forward pass

        # ===== Check for invalid logits (NaN or Inf) =====
        if torch.isnan(logits).any() or torch.isinf(logits).any():
            print("⚠️ Problematic batch detected!")
            print(f"x shape: {x.shape}")
            print(f"mask sum: {mask.sum(dim=1)}")  # Number of valid nodes per graph
            print(f"y: {y}")
            print(f"logits max: {torch.nanmax(logits)}, min: {torch.nanmin(logits)}")
            continue  # Skip this batch

        loss = loss_fn(logits, y)       # Compute loss
        loss.backward()                 # Backpropagation
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
        optimizer.step()                # Update weights

        running_loss += loss.item() * x.size(0)  # Track total loss

        preds = logits.argmax(dim=-1)  # Get predicted classes
        all_preds.extend(preds.detach().cpu().tolist())
        all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1


# Evaluate model on validation/test set
def evaluate_one_epoch(model, val_loader, loss_fn, device):
    model.eval()  # Set model to evaluation mode
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():  # Disable gradient computation
        for x, mask, y in val_loader:
            x = x.to(device)
            mask = mask.to(device)
            y = y.to(device)

            logits = model(x, mask)  # Forward pass
            logits = torch.clamp(logits, min=-10, max=10)  # Clip to prevent instability

            if torch.isnan(logits).any() or torch.isinf(logits).any():
                continue  # Skip bad batches during validation

            loss = loss_fn(logits, y)
            running_loss += loss.item() * x.size(0)

            preds = logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1


# Simple early stopping implementation based on validation F1
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0):
        """
        Args:
            patience (int): Number of epochs to wait after last improvement.
            verbose (bool): Whether to print messages.
            delta (float): Minimum improvement to reset patience.
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_f1 = -float('inf')

    def __call__(self, val_f1, model):
        score = val_f1

        if self.best_score is None:
            self.best_score = score
            self.best_f1 = val_f1
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_f1 = val_f1
            self.counter = 0



==== Trying lr=0.0001, weight_decay=0.0 ====
Epoch 1:
  Train Loss: 1.3818 | Train Acc: 0.2953 | Train F1: 0.1998
  Val   Loss: nan | Val   Acc: 0.2960 | Val   F1: 0.2252
Epoch 2:
  Train Loss: 1.3063 | Train Acc: 0.4228 | Train F1: 0.4094
  Val   Loss: nan | Val   Acc: 0.3946 | Val   F1: 0.3946
Epoch 3:
  Train Loss: 1.1614 | Train Acc: 0.5158 | Train F1: 0.5096
  Val   Loss: nan | Val   Acc: 0.4753 | Val   F1: 0.4767
Epoch 4:
  Train Loss: 1.0618 | Train Acc: 0.6079 | Train F1: 0.6014
  Val   Loss: nan | Val   Acc: 0.5157 | Val   F1: 0.4969
Epoch 5:
  Train Loss: 0.9758 | Train Acc: 0.6673 | Train F1: 0.6620
  Val   Loss: nan | Val   Acc: 0.4843 | Val   F1: 0.4588
Epoch 6:
  Train Loss: 0.9151 | Train Acc: 0.6961 | Train F1: 0.6899
  Val   Loss: nan | Val   Acc: 0.5695 | Val   F1: 0.5672
Epoch 7:
  Train Loss: 0.8432 | Train Acc: 0.7536 | Train F1: 0.7503
  Val   Loss: nan | Val   Acc: 0.5291 | Val   F1: 0.5307
Epoch 8:
  Train Loss: 0.7967 | Train Acc: 0.7632 | Train F1: 0.7604
  V

### Step 5 Test Evaluation

In [14]:
import torch
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

# 1. Load the best saved model
model.load_state_dict(torch.load("../checkpoints/best_model.pt"))
model.to(device)
model.eval()

# 2. Evaluate on the Test Set
all_preds = []
all_labels = []

with torch.no_grad():  # Disable gradient calculation for evaluation
    for x, mask, y in test_loader:
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        logits = model(x, mask)            # Get model predictions (logits)
        preds = logits.argmax(dim=-1)      # Convert logits to predicted class indices

        all_preds.extend(preds.cpu().tolist())  # Collect predictions
        all_labels.extend(y.cpu().tolist())     # Collect ground truth labels

# 3. Compute accuracy and macro F1-score
test_acc = accuracy_score(all_labels, all_preds)
test_f1 = f1_score(all_labels, all_preds, average='macro')

print(f" Test Accuracy: {test_acc:.4f}")
print(f" Test Macro-F1: {test_f1:.4f}")

# Print confusion matrix
cm = confusion_matrix(all_labels, all_preds)
print("Confusion Matrix:")
print(cm)


 Test Accuracy: 0.5938
 Test Macro-F1: 0.5902
Confusion Matrix:
[[43  3  7  4]
 [12 25 11  8]
 [ 9  4 31  9]
 [18  2  4 34]]
