In [1]:
import os
import typing
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
import networkx as nx
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load data from Crunch
import crunch

crunch = crunch.load_notebook()

class CausalDataset(Dataset):
    """Custom Dataset for causal modeling."""
    def __init__(self, X: typing.List[pd.DataFrame], y: typing.List[pd.DataFrame]) -> None:
        self.X = np.zeros([len(X), 1000, 10], dtype=np.float32)
        self.y = np.zeros([len(X), 10, 10], dtype=np.float32)
        self.target_mask = np.zeros([len(X), 10, 10], dtype=bool)

        for i in range(len(X)):
            self.X[i, :X[i].shape[0], :X[i].shape[1]] = X[i].values
            self.y[i, :y[i].shape[0], :y[i].shape[1]] = y[i].values
            self.target_mask[i, :y[i].shape[0], :y[i].shape[1]] = True

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, idx: int) -> dict:
        return {
            'X': self.X[idx],
            'y': self.y[idx],
            'target_mask': self.target_mask[idx]
        }

def preprocessing(X: pd.DataFrame) -> torch.Tensor:
    """Preprocess input DataFrame for the model."""
    return torch.Tensor(X.values).unsqueeze(0)


loaded inline runner with module: <module '__main__'>


In [3]:
class CausalModel(nn.Module):
    """Causal model architecture using feedforward neural networks."""

    def __init__(self, d_model=64):
        super().__init__()
        self.input_layer = nn.Sequential(
            nn.Linear(1, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 2 * d_model)
        )
        self.final = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q, k = self.input_layer(x.unsqueeze(-1)).chunk(2, dim=-1)
        x = torch.einsum('b s i d, b s j d -> b i j d', q, k) * (x.shape[1] ** -0.5)
        return self.final(x).squeeze(-1)


In [8]:


class ModelWrapper(pl.LightningModule):
    """PyTorch Lightning wrapper for the causal model."""

    def __init__(self, d_model=64, lr=1e-3, pos_weight=5.0):
        super().__init__()
        self.model = CausalModel(d_model)
        self.train_criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
        return [optimizer], [scheduler]

    def training_step(self, train_batch: dict, batch_idx: int):
        x = train_batch['X']
        y = train_batch['y']
        target_mask = train_batch['target_mask']
        preds = self(x)
        loss = self.train_criterion(preds[target_mask], y[target_mask])
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, val_batch: dict, batch_idx: int):
        x = val_batch['X']
        y = val_batch['y']
        target_mask = val_batch['target_mask']
        preds = self(x)
        loss = self.train_criterion(preds[target_mask], y[target_mask])
        self.log("val_loss", loss, on_epoch=True)

def transform_proba_to_DAG(nodes: typing.List[str], pred: np.ndarray) -> np.ndarray:
    G = nx.DiGraph()
    G.add_nodes_from(nodes)
    x_index, y_index = np.unravel_index(np.argsort(pred.ravel())[::-1], pred.shape)
    for i, j in zip(x_index, y_index):
        n1 = nodes[i]
        n2 = nodes[j]
        if i == j or (n1 == 'X' and n2 == 'Y') or (n1 == 'Y' and n2 == 'X'):
            continue
        if pred[i, j] > 0.5:
            G.add_edge(n1, n2)
            if not nx.is_directed_acyclic_graph(G):
                G.remove_edge(n1, n2)
    return nx.to_numpy_array(G)

def create_graph_label() -> typing.Tuple[dict, dict]:
    graph_label = {
        nx.DiGraph([("X", "Y"), ("v", "X"), ("v", "Y")]): "Confounder",
        nx.DiGraph([("X", "Y"), ("X", "v"), ("Y", "v")]): "Collider",
        nx.DiGraph([("X", "Y"), ("X", "v"), ("v", "Y")]): "Mediator",
        nx.DiGraph([("X", "Y"), ("v", "X")]): "Cause of X",
        nx.DiGraph([("X", "Y"), ("v", "Y")]): "Cause of Y",
        nx.DiGraph([("X", "Y"), ("X", "v")]): "Consequence of X",
        nx.DiGraph([("X", "Y"), ("Y", "v")]): "Consequence of Y",
        nx.DiGraph({"X": ["Y"], "v": []}): "Independent",
    }
    nodelist = ["v", "X", "Y"]
    adjacency_label = {
        graph_nodes_representation(graph, nodelist): label  # noqa: F821
        for graph, label in graph_label.items()
    }
    return graph_label, adjacency_label

def get_labels(adjacency_matrix: pd.DataFrame, adjacency_label: dict) -> dict:
    result = {}
    for variable in adjacency_matrix.columns.drop(["X", "Y"]):
        submatrix = adjacency_matrix.loc[[variable, "X", "Y"], [variable, "X", "Y"]]
        key = tuple(submatrix.values.flatten())
        result[variable] = adjacency_label[key]
    return result

def train(
    X_train: typing.Dict[str, pd.DataFrame],
    y_train: typing.Dict[str, pd.DataFrame],
    model_directory_path: str,
    batch_size: int = 64,
    max_epochs: int = 10,
    learning_rate: float = 1e-3,
) -> None:
    X = [X_train[dataset_id] for dataset_id in X_train]
    y = [y_train[dataset_id] for dataset_id in y_train]
    dataset = CausalDataset(X, y)
    train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    model = ModelWrapper(lr=learning_rate)
    trainer = pl.Trainer(accelerator="auto", max_epochs=max_epochs, logger=False)
    trainer.fit(model, train_dataloader)

    model_path_file = os.path.join(model_directory_path, "model.pt")
    torch.save(model.state_dict(), model_path_file)

def infer(
    X_test: typing.Dict[str, pd.DataFrame],
    model_directory_path: str,
    id_column_name: str,
    prediction_column_name: str,
) -> pd.DataFrame:
    model_path_file = os.path.join(model_directory_path, "model.pt")
    model = CausalModel(d_model=64)
    model.load_state_dict(torch.load(model_path_file, map_location='cpu'))
    model.eval()

    submission_file = {}
    for name in X_test:
        X = X_test[name]
        x = preprocessing(X)

        with torch.no_grad():
            pred = model(x)[0]
            pred = torch.sigmoid(pred)
            pred = pred.cpu().numpy()

        nodes = list(X.columns)
        pred = transform_proba_to_DAG(nodes, pred).astype(int)
        G = pd.DataFrame(pred, columns=nodes, index=nodes)

        for i in nodes:
            for j in nodes:
                submission_file[f'{name}_{i}_{j}'] = int(G.loc[i, j])

    submission_file = pd.Series(submission_file).reset_index()
    submission_file.columns = [id_column_name, prediction_column_name]
    crunch.test(
    no_determinism_check=True
)

