In [1]:
import os
import sys
import typing
import pickle
import functools
import networkx as nx
from tqdm.auto import tqdm

import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

from sklearn.model_selection import train_test_split

from preprocess import preprocess_dict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
X_train = pd.read_pickle('data/X_train.pickle')
y_train = pd.read_pickle('data/y_train.pickle')

file_path = 'data/X_train_processed.npz'
if os.path.exists(file_path):
    data = np.load(file_path, allow_pickle=True)
    X_train_processed = {key: data[key] for key in data}
else:
    X_train_processed = preprocess_dict(X_train, n_workers=32)
    np.savez_compressed(file_path, **X_train_processed)

print(len(X_train), len(X_train_processed), len(y_train))

23500 23500 23500


In [3]:
print(X_train_processed['00000'].shape)
edge_features = torch.tensor(X_train_processed['00000'], dtype=torch.float32)
print(edge_features.shape)

(90, 3, 1000)
torch.Size([90, 3, 1000])


In [4]:
X_sample = pd.read_csv('data/A.csv', index_col=0)
y_sample = pd.read_csv('data/B.csv', index_col=0)
print(y_sample)

        0  Y  2  3  4  5  6  7  8  X
parent                              
0       0  0  0  0  0  0  0  0  0  0
Y       1  0  0  0  0  0  0  0  0  0
2       0  0  0  0  0  1  0  0  0  0
3       0  1  0  0  0  1  0  0  0  0
4       0  1  0  0  0  1  0  0  0  0
5       1  1  0  0  0  0  0  0  0  0
6       0  1  0  0  0  1  0  0  0  0
7       0  1  0  0  0  1  0  0  0  0
8       0  0  1  1  0  0  0  1  0  1
X       0  1  0  0  0  1  0  0  0  0


In [3]:
# Utils for DAG
def graph_nodes_representation(graph, nodelist):
    """
    Create an alternative representation of a graph which is hashable
    and equivalent graphs have the same hash.

    Python cannot PROPERLY use nx.Graph/DiGraph as key for
    dictionaries, because two equivalent graphs with just different
    order of the nodes would result in different keys. This is
    undesirable here.

    So here we transform the graph into an equivalent form that is
    based on a specific nodelist and that is hashable. In this way,
    two equivalent graphs, once transformed, will result in identical
    keys.

    So we use the following trick: extract the adjacency matrix
    (with nodes in a fixed order) and then make a hashable thing out
    of it, through tuple(array.flatten()):
    """

    # This get the adjacency matrix with nodes in a given order, as
    # numpy array (which is not hashable):
    adjacency_matrix = nx.adjacency_matrix(graph, nodelist=nodelist).todense()

    # This transforms the numpy array into a hashable object:
    hashable = tuple(adjacency_matrix.flatten())

    return hashable

def create_graph_label():
    """
    Create a dictionary from graphs to labels, in two formats.
    """
    graph_label = {
        nx.DiGraph([("X", "Y"), ("v", "X"), ("v", "Y")]): "Confounder",
        nx.DiGraph([("X", "Y"), ("X", "v"), ("Y", "v")]): "Collider",
        nx.DiGraph([("X", "Y"), ("X", "v"), ("v", "Y")]): "Mediator",
        nx.DiGraph([("X", "Y"), ("v", "X")]):             "Cause of X",
        nx.DiGraph([("X", "Y"), ("v", "Y")]):             "Cause of Y",
        nx.DiGraph([("X", "Y"), ("X", "v")]):             "Consequence of X",
        nx.DiGraph([("X", "Y"), ("Y", "v")]):             "Consequence of Y",
        nx.DiGraph({"X": ["Y"], "v": []}):                "Independent",
    }

    nodelist = ["v", "X", "Y"]

    # This is an equivalent alternative to graph_label but in a form
    # for which two equivalent graphs have the same key:
    adjacency_label = {
        graph_nodes_representation(graph, nodelist): label
        for graph, label in graph_label.items()
    }

    return graph_label, adjacency_label

def get_labels(adjacency_matrix, adjacency_label):
    """
    Transform an adjacency_matrix (as pd.DataFrame) into a dictionary of variable:label
    """

    result = {}
    for variable in adjacency_matrix.columns.drop(["X", "Y"]):
        submatrix = adjacency_matrix.loc[[variable, "X", "Y"], [variable, "X", "Y"]]  # this is not hashable
        key = tuple(submatrix.values.flatten())  # this is hashable and a compatible with adjacency_label
    
        result[variable] = adjacency_label[key]

    return result

In [4]:
class CausalDataset(Dataset):
    def __init__(
        self, 
        X_dict=typing.Dict[str, pd.DataFrame], 
        X_processed_dict=typing.Dict[str, pd.DataFrame], # Preprocessed data
        y_dict=typing.Dict[str, pd.DataFrame], 
        x_var='X', 
        y_var='Y'
    ):
        self.X_dict = X_dict
        self.X_processed_dict = X_processed_dict
        self.y_dict = y_dict
        self.ids = list(X_dict.keys())
        self.x_var = x_var
        self.y_var = y_var
        self.adjacency_graph, self.adjacency_label = create_graph_label()
        self.node_labels = [
            'Confounder', 'Collider', 'Mediator', 'Independent',
            'Cause of X', 'Consequence of X', 'Cause of Y', 'Consequence of Y',
        ]

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        sample_id = self.ids[idx]
        X_sample = self.X_dict[sample_id]                      # DataFrame (data: 1000 * num_nodes)
        X_processed_sample = self.X_processed_dict[sample_id]  # numpy array (data: num_edges * 3 * 1000)
        y_sample = self.y_dict[sample_id]                      # DataFrame (adjacency matrix: num_nodes * num_nodes)

        variables = X_sample.columns.tolist()
        edge_features, edge_types = self._process_edges(X_processed_sample, variables)
        node_labels = self._process_node_labels(y_sample, variables)
        edge_labels = self._process_edge_labels(y_sample, variables)

        return {
            'edge_features': edge_features,  # (num_edges, 3, 1000)
            'edge_types': edge_types,        # (num_edges,)
            'node_labels': node_labels,      # (num_nodes - 2, 8)    # For Node CLF
            'edge_labels': edge_labels,      # (num_edges, 2)        # For Edge CLF
            'variables': variables           # List[str]
        }
    
    def _process_edges(self, X_processed_sample, variables):
        edge_features = []
        edge_types = []
        for u in variables:
            for v in variables:
                if u == v:
                    continue
                edge_types.append(self._get_edge_type(u, v))

        edge_features = torch.tensor(X_processed_sample, dtype=torch.float32)
        edge_types = torch.tensor(edge_types, dtype=torch.long)
        return edge_features, edge_types

    def _get_edge_type(self, u, v):
        x, y = self.x_var, self.y_var
        if u == x and v != y:      # u is X but v is not Y
            return 0
        elif u == y and v != x:    # u is Y but v is not X
            return 1
        elif u != y and v == x:    # u is not Y but v is X 
            return 2
        elif u != x and v == y:    # u is not X but v is Y 
            return 3
        elif u == x and v == y:    # u is X and v is Y
            return 4
        elif u == y and v == x:    # u is Y and v is X
            return 5
        else:                      # none of the above
            return 6
        
    def _process_node_labels(self, y_sample, variables):
        node_label_dict = get_labels(y_sample, self.adjacency_label)
        
        # 剔除 x_var 和 y_var，得到 nodes
        nodes = [var for var in variables if var not in {self.x_var, self.y_var}]

        # 初始化 0/1 矩阵，大小为 len(nodes) * len(self.node_labels)
        node_label_matrix = np.zeros((len(nodes), len(self.node_labels)), dtype=int)

        # 遍历 nodes，并根据 node_label_dict 填充矩阵
        for i, node in enumerate(nodes):
            if node in node_label_dict:
                label = node_label_dict[node]
                if label in self.node_labels:
                    j = self.node_labels.index(label)  # 获取标签对应的索引
                    node_label_matrix[i, j] = 1    # 设置为 1
        return torch.tensor(node_label_matrix, dtype=torch.long)
    
    def _process_edge_labels(self, y_sample, variables):
        edge_label_matrix = []
        for u in variables:
            for v in variables:
                if u == v:
                    continue
                label_vector = np.zeros(2, dtype=int)
                if y_sample.loc[u, v] == 1:
                    label_vector[1] = 1
                else:
                    label_vector[0] = 1
                edge_label_matrix.append(label_vector)
        edge_label_matrix = np.array(edge_label_matrix)
        return torch.tensor(edge_label_matrix, dtype=torch.long)

In [6]:
class ConvBlock(nn.Module):
    def __init__(self, channels=64, kernel_size=3):
        super().__init__()
        self.conv = nn.Conv1d(
            channels, channels, 
            kernel_size=kernel_size, 
            padding=kernel_size//2
        )
        self.norm = nn.GroupNorm(8, channels)  # 8 groups for 64 channels
        self.activation = nn.GELU()

    def forward(self, x):
        identity = x
        x = self.conv(x)
        x = self.norm(x)
        x = self.activation(x)
        return x + identity

class SelfAttentionBlock(nn.Module):
    def __init__(self, embed_dim=64, num_heads=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x.unsqueeze(1)  # add sequence dimension
        attn_output, _ = self.attention(x, x, x)
        attn_output = attn_output.squeeze(1)
        return self.norm(x.squeeze(1) + attn_output)
    
class MergeBlock(nn.Module):
    def __init__(self, input_dim=256, output_dim=64):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.norm = nn.LayerNorm(output_dim)
        self.activation = nn.GELU()

    def forward(self, embeddings):
        concatenated = torch.cat(embeddings, dim=-1)
        return self.activation(self.norm(self.linear(concatenated)))

class CausalModel(nn.Module):
    def __init__(self, x_var='X', y_var='Y', hidden_dim=64):
        super().__init__()
        self.x_var = x_var
        self.y_var = y_var
        
        # Stem layer
        self.stem = nn.Conv1d(3, hidden_dim, kernel_size=1)
        
        # Convolutional blocks
        self.conv_blocks = nn.Sequential(*[
            ConvBlock(hidden_dim) for _ in range(5)
        ])
        
        # Pooling
        self.pool = nn.AdaptiveAvgPool1d(1)
        
        # Edge type embedding
        self.edge_type_embed = nn.Embedding(7, hidden_dim)
        
        # Self-attention
        self.self_attns = nn.Sequential(*[
            SelfAttentionBlock(hidden_dim) for _ in range(2)
        ])
        
        # Classification heads
        self.edge_cls = nn.Linear(hidden_dim, 2)
        self.node_merge = MergeBlock(4*hidden_dim)
        self.node_cls = nn.Linear(hidden_dim, 8)

    def forward(self, batch):
        # Unpack batch
        edge_features = batch['edge_features'].squeeze(0)  # [E, 3, 1000]
        edge_types = batch['edge_types'].squeeze(0)        # [E]
        variables = batch['variables']                     # list of var names
        variables = [item for sublist in variables for item in sublist]  # Flatten list
        # print(variables)
        # print(self.x_var, self.y_var)
        # print(variables.index(self.x_var), variables.index(self.y_var))
        
        # Feature extraction
        x = self.stem(edge_features)            # [E, 64, 1000]
        # print('After stem:', x.shape)
        x = self.conv_blocks(x)                 # [E, 64, 1000]
        # print('After conv blocks: ', x.shape)
        x = self.pool(x).squeeze(-1)            # [E, 64]
        # print('After pool: ', x.shape)
        
        # Add edge type embeddings
        x = x + self.edge_type_embed(edge_types)
        # print('After edge type embed: ', x.shape)
        
        # Self-attention
        x = self.self_attns(x)                  # [E, 64]
        # print('After self attns: ', x.shape)
        
        # Edge classification
        edge_logits = self.edge_cls(x)          # [E, 2]
        # print('Edge logits: ', edge_logits)
        
        # Node classification
        p = len(variables)
        edges = [(u, v) for u in range(p) for v in range(p) if u != v]
        edge_indices = {(u, v): idx for idx, (u, v) in enumerate(edges)}
        node_embs = []
        try:
            x_idx = variables.index(self.x_var)
            y_idx = variables.index(self.y_var)
        except ValueError:
            return edge_logits, torch.tensor([])
        
        for u_idx, u in enumerate(variables):
            if u in {self.x_var, self.y_var}:
                continue
            
            # Calculate edge indices
            def get_edge_idx(src_idx, tgt_var):
                nonlocal variables, p
                tgt_idx = variables.index(tgt_var)
                idx = edge_indices[(src_idx, tgt_idx)]
                return idx

            edges = [
                get_edge_idx(u_idx, self.x_var),  # u->X
                get_edge_idx(u_idx, self.y_var),  # u->Y
                get_edge_idx(x_idx, u),           # X->u
                get_edge_idx(y_idx, u)            # Y->u
            ]

            if None in edges:
                raise ValueError("None edge indices")
                
            # Merge embeddings
            merged = self.node_merge([x[e] for e in edges])   # 4 * [1, 64] -> [1, 4 * 64] -> [1, 64]
            node_embs.append(merged)
        
        node_logits = self.node_cls(torch.stack(node_embs))  # [p-2, 64] -> [p-2, 8]
        # print('Node logits: ', node_logits)
        return edge_logits, node_logits

In [13]:
class CausalDataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, valid_dataset, batch_size=1):
        super().__init__()
        self.train_dataset = train_dataset
        self.valid_dataset = valid_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8,
            pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.valid_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=8,
            pin_memory=True
        )

class CausalLightningModule(pl.LightningModule):
    def __init__(self, model, edge_weights, node_weights, lr=1e-3):
        super().__init__()
        self.model = model
        self.edge_weights = edge_weights.to(self.device)
        self.node_weights = node_weights.to(self.device)
        self.lr = lr

    def forward(self, batch):
        return self.model(batch)

    def _compute_loss(self, batch):
        edge_logits, node_logits = self.forward(batch)
        
        # Edge loss
        edge_labels = batch['edge_labels'].squeeze(0).to(edge_logits.device)
        edge_labels_idx = torch.argmax(edge_labels, dim=1)
        edge_loss = F.cross_entropy(
            edge_logits, edge_labels_idx,
            weight=self.edge_weights.to(edge_logits.device)
        )
        
        # Node loss
        node_labels = batch['node_labels'].squeeze(0).to(node_logits.device)
        node_labels_idx = torch.argmax(node_labels, dim=1)
        node_loss = F.cross_entropy(
            node_logits, node_labels_idx,
            weight=self.node_weights.to(node_logits.device)
        )

        total_loss = edge_loss + node_loss
        
        return total_loss, edge_loss, node_loss
    
    def _compute_metrics(self, batch):
        edge_logits, node_logits = self.forward(batch)

        # Edge metrics: Balanced Accuracy
        edge_labels = batch['edge_labels'].squeeze(0).to(edge_logits.device)
        edge_labels_idx = torch.argmax(edge_labels, dim=1)
        edge_preds = torch.argmax(edge_logits, dim=1)
        edge_acc = (edge_preds == edge_labels_idx).float().mean()

        # Node metrics: Balanced Accuracy
        node_labels = batch['node_labels'].squeeze(0).to(node_logits.device)
        node_labels_idx = torch.argmax(node_labels, dim=1)
        node_preds = torch.argmax(node_logits, dim=1)
        node_acc = (node_preds == node_labels_idx).float().mean()

        return edge_acc, node_acc

    def training_step(self, batch, batch_idx):
        total_loss, loss_edge, loss_node = self._compute_loss(batch)
        self.log_dict({
            'train_loss': total_loss,
            'train_edge_loss': loss_edge,
            'train_node_loss': loss_node
        }, prog_bar=True)
        return total_loss

    def validation_step(self, batch, batch_idx):
        total_loss, loss_edge, loss_node = self._compute_loss(batch)
        edge_acc, node_acc = self._compute_metrics(batch)
        self.log_dict({
            'val_loss': total_loss,
            'val_edge_acc': edge_acc,
            'val_node_acc': node_acc
        }, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=10,  # Adjust based on total epochs
            eta_min=1e-5
        )
        return [optimizer], [scheduler]

def compute_class_weights(dataset):
    edge_labels = []
    node_labels = []
    
    for i in tqdm(range(len(dataset))):
        sample = dataset[i]
        edge_labels.append(sample['edge_labels'])
        node_labels.append(sample['node_labels'])
    
    # Process edge weights
    edge_labels = torch.cat(edge_labels)
    edge_counts = torch.sum(edge_labels, dim=0)
    edge_weights = 1.0 / (edge_counts + 1e-5)  # Add epsilon to avoid division by zero
    edge_weights = edge_weights / edge_weights.sum() * len(edge_counts)
    
    # Process node weights
    node_labels = torch.cat(node_labels)
    node_labels_idx = torch.argmax(node_labels, dim=1)
    node_counts = torch.bincount(node_labels_idx)
    node_weights = 1.0 / (node_counts + 1e-5)
    node_weights = node_weights / node_weights.sum() * len(node_counts)
    
    return edge_weights, node_weights

In [8]:
train_keys, test_keys = train_test_split(list(X_train_processed.keys()), test_size=0.2, random_state=42)

print("Train datasets (top 5):", train_keys[:5])
print("Test datasets (top 5):", test_keys[:5])

X_train_split = {key: X_train[key] for key in train_keys}
X_train_processed_split = {key: X_train_processed[key] for key in train_keys}
y_train_split = {key: y_train[key] for key in train_keys}
X_test_split = {key: X_train[key] for key in test_keys}
X_test_processed_split = {key: X_train_processed[key] for key in test_keys}
y_test_split = {key: y_train[key] for key in test_keys}

train_dataset = CausalDataset(X_train_split, X_train_processed_split, y_train_split)
test_dataset = CausalDataset(X_test_split, X_test_processed_split, y_test_split)

Train datasets (top 5): ['09981', '08138', '30965', '01606', '00812']
Test datasets (top 5): ['04552', '03154', '07222', '14344', '14242']


In [9]:
model = CausalModel()
print(model)

CausalModel(
  (stem): Conv1d(3, 64, kernel_size=(1,), stride=(1,))
  (conv_blocks): Sequential(
    (0): ConvBlock(
      (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (activation): GELU(approximate='none')
    )
    (1): ConvBlock(
      (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (activation): GELU(approximate='none')
    )
    (2): ConvBlock(
      (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (activation): GELU(approximate='none')
    )
    (3): ConvBlock(
      (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
      (activation): GELU(approximate='none')
    )
    (4): ConvBlock(
      (conv): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (no

In [10]:
edge_weights, node_weights = compute_class_weights(train_dataset)

datamodule = CausalDataModule(train_dataset, test_dataset, batch_size=1)

100%|██████████| 18800/18800 [01:53<00:00, 165.54it/s]


In [11]:
print(edge_weights.shape, node_weights.shape)

torch.Size([2]) torch.Size([8])


In [30]:
# 获取 DataLoader
train_loader = datamodule.train_dataloader()

# 取一个 batch
sample_batch = next(iter(train_loader))

# 打印 batch 结构
print("Sample batch keys:", sample_batch.keys())

# 详细打印数据
for key, value in sample_batch.items():
    if torch.is_tensor(value):
        print(f"{key}: shape={value.shape}, dtype={value.dtype}")
    else:
        print(f"{key}: {value}")

Sample batch keys: dict_keys(['edge_features', 'edge_types', 'node_labels', 'edge_labels', 'variables'])
edge_features: shape=torch.Size([1, 90, 3, 1000]), dtype=torch.float32
edge_types: shape=torch.Size([1, 90]), dtype=torch.int64
node_labels: shape=torch.Size([1, 8, 8]), dtype=torch.int64
edge_labels: shape=torch.Size([1, 90, 2]), dtype=torch.int64
variables: [['Y'], ['1'], ['2'], ['X'], ['4'], ['5'], ['6'], ['7'], ['8'], ['9']]


In [34]:
# 实例化模型
model = CausalModel()

# 将 batch 输入模型
edge_logits, node_logits = model(sample_batch)

# 打印输出格式
print("Edge logits shape:", edge_logits.shape)
print("Edge logits dtype:", edge_logits.dtype)
print("Node logits shape:", node_logits.shape)
print("Node logits dtype:", node_logits.dtype)

['Y', '1', '2', 'X', '4', '5', '6', '7', '8', '9']
X Y
3 0
After stem: torch.Size([90, 64, 1000])
After conv blocks:  torch.Size([90, 64, 1000])
After pool:  torch.Size([90, 64])
After edge type embed:  torch.Size([90, 64])
After self attns:  torch.Size([90, 64])
Edge logits:  tensor([[-0.3824,  0.0889],
        [-0.7635,  0.1137],
        [-0.2374,  0.5080],
        [-0.8178,  0.0975],
        [-0.8521,  0.0680],
        [-0.4573,  0.0654],
        [-0.8307,  0.0932],
        [-0.7718,  0.0532],
        [-0.7580,  0.1129],
        [-0.8891,  0.1429],
        [-0.5933,  0.1408],
        [-0.4811, -0.0500],
        [-0.6194,  0.1839],
        [-0.6127,  0.1477],
        [-0.5434,  0.2256],
        [-0.6188,  0.1530],
        [-0.5455,  0.2107],
        [-0.5971,  0.1634],
        [-0.9152,  0.2439],
        [-0.3815,  0.1339],
        [-0.4817, -0.0309],
        [-0.5590,  0.1605],
        [-0.5310,  0.2218],
        [-0.4334,  0.1606],
        [-0.4387, -0.0350],
        [ 0.0316,  0.2

In [15]:
trainer = pl.Trainer(
    max_epochs=50,
    accelerator="auto",
    devices="auto",
    precision="16-mixed",
    enable_progress_bar=True,
    log_every_n_steps=10
)

pl_model = CausalLightningModule(
    model,
    edge_weights,
    node_weights,
    lr=1e-5
)

trainer.fit(pl_model, datamodule=datamodule)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type        | Params | Mode
---------------------------------------------
0 | model | CausalModel | 113 K  | eval
---------------------------------------------
113 K     Trainable params
0         Non-trainable params
113 K     Total params
0.455     Total estimated model params size (MB)
0         Modules in train mode
40        Modules in eval mode


Epoch 0:   4%|▍         | 737/18800 [00:11<04:53, 61.46it/s, v_num=9, train_loss=nan.0, train_edge_loss=nan.0, train_node_loss=nan.0]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined