### Step 1. 准备 Dataset + DataLoader

In [12]:
import torch
from torch.utils.data import Dataset

class Twitter15Dataset(Dataset):
    def __init__(self, graph_data_list):
        """
        Args:
            graph_data_list (list): list of graphs, each graph is a dict {'x', 'edge_index', 'y'}
        """
        self.graphs = graph_data_list

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        graph = self.graphs[idx]
        x = graph['x']  # (seq_len, feature_dim)
        y = graph['y']  # int64 label (0~3)

        return x, y

def collate_fn(batch):
    xs, ys = zip(*batch)

    max_len = max(x.shape[0] for x in xs)  # find max sequence length in batch
    feature_dim = xs[0].shape[1]

    padded_xs = []
    masks = []

    for x in xs:
        seq_len = x.shape[0]
        pad_len = max_len - seq_len

        if pad_len > 0:
            pad = torch.zeros((pad_len, feature_dim), dtype=x.dtype)
            x_padded = torch.cat([x, pad], dim=0)
        else:
            x_padded = x

        mask = torch.cat([torch.ones(seq_len), torch.zeros(pad_len)]).bool()

        padded_xs.append(x_padded)
        masks.append(mask)

    padded_xs = torch.stack(padded_xs)    # (batch_size, max_len, feature_dim)
    masks = torch.stack(masks)             # (batch_size, max_len)
    ys = torch.tensor(ys)                  # (batch_size,)

    return padded_xs, masks, ys

In [13]:
from sklearn.model_selection import train_test_split

# 加载你的清理版数据
graph_data_list = torch.load("../processed/twitter15_graph_data_clean.pt", weights_only=False)

# 划分Train/Val/Test (7:1.5:1.5)
train_graphs, temp_graphs = train_test_split(graph_data_list, test_size=0.3, random_state=42)
val_graphs, test_graphs = train_test_split(temp_graphs, test_size=0.5, random_state=42)

print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}, Test: {len(test_graphs)}")

Train: 1043, Val: 223, Test: 224


In [14]:
from torch.utils.data import DataLoader

batch_size = 16

# 建Dataset
train_dataset = Twitter15Dataset(train_graphs)
val_dataset = Twitter15Dataset(val_graphs)
test_dataset = Twitter15Dataset(test_graphs)

# 建Dataloader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

print("DataLoaders created successfully!")

DataLoaders created successfully!


### Step 2: Transformer

In [18]:
import torch.nn as nn

class TransformerClassifier(nn.Module):
    def __init__(self, feature_dim, hidden_dim, num_classes, num_heads=8, num_layers=2, dropout=0.1):
        super(TransformerClassifier, self).__init__()
        
        self.input_projection = nn.Linear(feature_dim, 768)  # 映射到768维
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=768,
            nhead=num_heads,
            dim_feedforward=hidden_dim,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers
        )
        
        self.fc = nn.Linear(768, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        src_key_padding_mask = ~mask
        x = self.input_projection(x)  # 加上这一行
        transformer_out = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)
        pooled_output = transformer_out.mean(dim=1)

        output = self.dropout(pooled_output)
        logits = self.fc(output)

        return logits, transformer_out



### Step 3. 写训练和验证代码（Trainer）

In [21]:
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score

def train_one_epoch(model, train_loader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    for x, mask, y in train_loader:
        x = x.to(device)
        mask = mask.to(device)
        y = y.to(device)

        x = torch.nan_to_num(x, nan=0.0)

        # 不要 sum mask！直接传 mask！
        logits, _ = model(x, mask)

        loss = loss_fn(logits, y)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * x.size(0)

        preds = logits.argmax(dim=-1)
        all_preds.extend(preds.detach().cpu().tolist())
        all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1

def evaluate_one_epoch(model, val_loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, mask, y in val_loader:
            x = x.to(device)
            mask = mask.to(device)
            y = y.to(device)

            x = torch.nan_to_num(x, nan=0.0)

            logits, _ = model(x, mask)  # 注意这里直接传mask
            logits = torch.clamp(logits, min=-10, max=10)

            loss = loss_fn(logits, y)

            running_loss += loss.item() * x.size(0)

            preds = logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    epoch_loss = running_loss / len(val_loader.dataset)
    epoch_acc = accuracy_score(all_labels, all_preds)
    epoch_f1 = f1_score(all_labels, all_preds, average='macro')

    return epoch_loss, epoch_acc, epoch_f1


class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0.0):
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.best_f1 = -float('inf')

    def __call__(self, val_f1, model, save_path):
        score = val_f1

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_f1, model, save_path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_f1, model, save_path)
            self.counter = 0

    def save_checkpoint(self, val_f1, model, save_path):
        torch.save(model.state_dict(), save_path)
        self.best_f1 = val_f1

### Step 4: 配置超参数 + 启动训练循环 (Runner)

In [22]:
import os
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


save_path = os.path.abspath("../checkpoints/best_transformer_model.pt")
hidden_dim = 128
num_classes = 4
learning_rate = 5e-5
weight_decay = 1e-2
max_epochs = 1000
patience = 10

# 模型
feature_dim = 833 
model = TransformerClassifier(
    feature_dim=feature_dim,
    hidden_dim=256,     # 可以调大一点，transformer内部FFN隐藏层
    num_classes=4,
    num_heads=4,
    num_layers=2,
    dropout=0.1
).to(device)


# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# 损失函数
def smooth_cross_entropy(preds, targets, smoothing=0.1):
    confidence = 1.0 - smoothing
    logprobs = F.log_softmax(preds, dim=-1)
    nll_loss = -logprobs.gather(dim=-1, index=targets.unsqueeze(1)).squeeze(1)
    smooth_loss = -logprobs.mean(dim=-1)
    loss = confidence * nll_loss + smoothing * smooth_loss
    return loss.mean()

loss_fn = smooth_cross_entropy

# Early Stopping
early_stopper = EarlyStopping(patience=patience, verbose=True)

# 开始训练
for epoch in range(1, max_epochs + 1):
    train_loss, train_acc, train_f1 = train_one_epoch(model, train_loader, optimizer, loss_fn, device)
    val_loss, val_acc, val_f1 = evaluate_one_epoch(model, val_loader, loss_fn, device)

    print(f"Epoch {epoch}:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train F1: {train_f1:.4f}")
    print(f"  Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f} | Val   F1: {val_f1:.4f}")

    early_stopper(val_f1, model, save_path)

    if early_stopper.early_stop:
        print("Early stopping triggered!")
        break

print("Training completed.")


  output = torch._nested_tensor_from_mask(


Epoch 1:
  Train Loss: 1.3873 | Train Acc: 0.3557 | Train F1: 0.3545
  Val   Loss: 1.3492 | Val   Acc: 0.2960 | Val   F1: 0.1640
Epoch 2:
  Train Loss: 1.1785 | Train Acc: 0.5014 | Train F1: 0.4958
  Val   Loss: 1.3182 | Val   Acc: 0.3004 | Val   F1: 0.1620
EarlyStopping counter: 1 out of 10
Epoch 3:
  Train Loss: 1.0932 | Train Acc: 0.5781 | Train F1: 0.5721
  Val   Loss: 1.2857 | Val   Acc: 0.4933 | Val   F1: 0.4708
Epoch 4:
  Train Loss: 0.9971 | Train Acc: 0.6520 | Train F1: 0.6484
  Val   Loss: 1.2808 | Val   Acc: 0.4170 | Val   F1: 0.3521
EarlyStopping counter: 1 out of 10
Epoch 5:
  Train Loss: 0.9067 | Train Acc: 0.6855 | Train F1: 0.6833
  Val   Loss: 1.3070 | Val   Acc: 0.3363 | Val   F1: 0.2311
EarlyStopping counter: 2 out of 10
Epoch 6:
  Train Loss: 0.8266 | Train Acc: 0.7555 | Train F1: 0.7541
  Val   Loss: 1.2557 | Val   Acc: 0.4126 | Val   F1: 0.3503
EarlyStopping counter: 3 out of 10
Epoch 7:
  Train Loss: 0.7364 | Train Acc: 0.8102 | Train F1: 0.8087
  Val   Loss: 1.2

In [23]:
# 加载训练好的最优模型
model.load_state_dict(torch.load(save_path))
model.to(device)
model.eval()

# Test阶段
def test_model(model, test_loader, loss_fn, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x, mask, y in test_loader:
            x = x.to(device)
            mask = mask.to(device)
            y = y.to(device)

            x = torch.nan_to_num(x, nan=0.0)

            lengths = mask.sum(dim=1).cpu()
            logits, _ = model(x, lengths)
            logits = torch.clamp(logits, min=-10, max=10)

            loss = loss_fn(logits, y)
            running_loss += loss.item() * x.size(0)

            preds = logits.argmax(dim=-1)
            all_preds.extend(preds.cpu().tolist())
            all_labels.extend(y.cpu().tolist())

    test_loss = running_loss / len(test_loader.dataset)
    test_acc = accuracy_score(all_labels, all_preds)
    test_f1 = f1_score(all_labels, all_preds, average='macro')

    return test_loss, test_acc, test_f1

# 调用测试函数
test_loss, test_acc, test_f1 = test_model(model, test_loader, loss_fn, device)

print("=== Final Test Results ===")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")


AssertionError: only bool and floating types of src_key_padding_mask are supported