In [None]:
import pandas as pd
import torch
import numpy as np
import os
from sklearn.preprocessing import StandardScaler, LabelEncoder
from tab_transformer_pytorch import TabTransformer
import logging


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s"
)

def extract_tab_features(csv_path, feature_dim=256):
    if not os.path.exists(csv_path):
        raise FileNotFoundError(f"path error: {csv_path}")
    if not isinstance(feature_dim, int) or feature_dim <= 0:
        raise ValueError("feature_dim error")



    file_name = os.path.splitext(os.path.basename(csv_path))[0]


    try:
        df = pd.read_csv(csv_path)
        logging.info(f"datashape: {df.shape}")

        cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()
        num_cols = df.select_dtypes(include=['number']).columns.tolist()




        encoders = {}
        encoded_df = df.copy()
        for col in cat_cols:
            le = LabelEncoder()
            encoded_df[col] = le.fit_transform(df[col].astype(str))
            encoders[col] = le

        scaler = StandardScaler()
        if num_cols:
            if encoded_df[num_cols].isnull().values.any():
                encoded_df[num_cols] = encoded_df[num_cols].fillna(encoded_df[num_cols].mean())
            encoded_df[num_cols] = scaler.fit_transform(encoded_df[num_cols])


        X_cat = encoded_df[cat_cols].values if cat_cols else np.zeros((len(encoded_df), 0))
        X_num = encoded_df[num_cols].values if num_cols else np.zeros((len(encoded_df), 0))


        categories = [encoded_df[col].nunique() for col in cat_cols]
        if not categories:
            categories = [1]
            X_cat = np.zeros((len(encoded_df), 1), dtype=int)


        dim = max(feature_dim // 4, 8)
        model = TabTransformer(
            categories=categories,
            num_continuous=X_num.shape[1],
            dim=dim,
            depth=6,
            heads=8,
            attn_dropout=0.1,
            ff_dropout=0.1,
            mlp_hidden_mults=(4, 2),
            mlp_act=None,
            dim_out=feature_dim
        )


        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = model.to(device)
        logging.info(f"model loaded in : {device}")


        X_cat_tensor = torch.LongTensor(X_cat).to(device)
        X_num_tensor = torch.FloatTensor(X_num).to(device)


        model.eval()
        batch_size = 64
        num_samples = len(encoded_df)
        all_features = []

        with torch.no_grad():
            for i in range(0, num_samples, batch_size):
                end_idx = min(i + batch_size, num_samples)
                cat_batch = X_cat_tensor[i:end_idx]
                num_batch = X_num_tensor[i:end_idx]


                features = model(cat_batch, num_batch).cpu()
                all_features.append(features.numpy())

                if (i // batch_size) % 10 == 0:
                    logging.info(f"process: {end_idx}/{num_samples}")


        all_features = np.vstack(all_features)

        logging.info(f"features.shape: {all_features.shape}")


        print("features：")
        print(all_features)

        return all_features

    except Exception as e:
        logging.error(f"error: {str(e)}")
        import traceback
        traceback.print_exc()
        return None


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
csv_path_mutation = "/content/drive/MyDrive/proj72_data/multiomics/mutation.csv"
csv_path_clinical = "/content/drive/MyDrive/proj72_data/multiomics/clinical.csv"
feature_dim = 256


features_mutation = extract_tab_features(
    csv_path=csv_path_mutation,
    feature_dim=feature_dim
)

features_clinical = extract_tab_features(
    csv_path=csv_path_clinical,
    feature_dim=feature_dim
)

features：
[[ 0.11898579  0.18346208  0.058261   ... -0.02295681 -0.19146286
   0.05482033]
 [ 0.1495759   0.14784636  0.1287416  ...  0.05690724 -0.0748592
  -0.19670159]
 [ 0.22605298  0.14434478  0.14035615 ... -0.09968987 -0.24286707
  -0.05302662]
 ...
 [-0.04441394  0.12801169  0.04246295 ...  0.01180192 -0.14561115
  -0.11234135]
 [-0.01789065  0.17041911  0.00188292 ... -0.02076292 -0.14894941
  -0.04929658]
 [ 0.0753727   0.23229182  0.05958675 ... -0.0898457  -0.1872366
  -0.12321194]]
features：
[[ 0.30354282  0.5335673   0.77190095 ... -0.31558722 -0.35376155
   0.17969918]
 [ 0.10730834  0.28623575  0.2597753  ... -0.12229361 -0.15587553
   0.17283902]
 [ 0.10702241  0.12587878  0.52760005 ... -0.25977504 -0.5123722
   0.16054098]
 ...
 [-0.44602314  0.42073148  0.02070296 ... -0.34694073  0.20849368
   0.05358921]
 [ 0.05102543  0.39582673  0.20700967 ... -0.558087   -0.04459652
   0.20244083]
 [ 0.07341093  0.49012336  0.05591625 ... -0.6342342  -0.31341708
   0.10049491]]

In [None]:
import torch
import torch.nn as nn
from typing import List
from typing import Dict
import torch.nn.functional as F
import pandas as pd
import numpy as np
import itertools

In [None]:


class FowardNetwork(nn.Module):
    def __init__(self, embed_dim):
        super(FowardNetwork, self).__init__()
        self.Fc1 = nn.Linear(embed_dim, embed_dim, bias=True)
        self.Fc2 = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        x = F.silu(self.Fc1(x))
        x = F.silu(self.Fc2(x))
        return x


class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, batch_size):
        super(CrossAttention, self).__init__()
        self.dropout = 0.2
        self.batch_size = batch_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        self.O_layer = nn.Linear(embed_dim, embed_dim, bias=False)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop1 = nn.Dropout(self.dropout)
        self.drop2 = nn.Dropout(self.dropout)
        self.alpha = nn.Parameter(torch.tensor(0.2))
        self.belta = nn.Parameter(torch.ones(num_heads))
        self.fowNet = FowardNetwork(self.embed_dim)

    def split_heads(self, x):
        x = x.view(self.batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)

    def scaled_dot_product_attention(self, Q, K, V):
        scores = (torch.matmul(Q, K.transpose(-1, -2)) /
                  torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float64)))
        attn_weights = F.softmax(scores, dim=-1)
        original_mask = torch.zeros_like(attn_weights)
        mask_indices = (attn_weights >= self.alpha).float()
        natural_index = torch.arange(0, attn_weights.size(3))
        natural_index = natural_index[None,None,None,:].expand(self.batch_size,
                                                self.num_heads,attn_weights.size(2), -1)
        original_mask.scatter_(-1, natural_index, src=mask_indices)
        attn_weights = attn_weights * original_mask
        attn_weights_adjusted = F.softmax(attn_weights, dim=-1)

        attn_output = torch.matmul(attn_weights_adjusted, V)
        return attn_output, attn_weights

    def forward(self, query, key, value):

        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))
        attn_output, atten_maps = self.scaled_dot_product_attention(Q, K, V)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(self.batch_size, query.size(1), self.embed_dim)
        attn_output = self.O_layer(attn_output)
        attn_output = self.norm1(query + self.drop1(attn_output))
        inter_output = self.fowNet(attn_output)
        final_output = self.norm2(attn_output + self.drop2(inter_output))

        return final_output, atten_maps

In [None]:
class DynamicSequentialMultiheadCrossAttention(nn.Module):
    def __init__(self, d_model: int, total_modalities: int, embed_dim: int, num_heads: int, batch_size: int):
        super().__init__()
        self.d_model = d_model
        self.total_modalities = total_modalities  # Total number of modality
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.batch_size = batch_size

        # Initial multihead cross attention models
        self.cross_attentions = nn.ModuleList([
            CrossAttention(embed_dim=self.embed_dim, num_heads=self.num_heads, batch_size=self.batch_size)
            for _ in range(self.total_modalities - 1)
        ])

    def forward(self, *modalities: List[torch.Tensor]):
        """
        Args:
            modalities: List of tensors where:
                - modalities[0]: Main modality A [B, N, d_model] (Query)
                - modalities[1:]: Auxiliary modalities [B, M_i, d_model] (Keys/Values)
        Returns:
            a_enhanced: [B, N, d_model] (Enhanced main modality)
        """
        assert len(modalities) == self.total_modalities, \
            f"Expected {self.total_modalities} modalities (including A), got {len(modalities)}"

        a = modalities[0]  # Main modality A

        # Sequentially fuse each auxiliary modality
        for i in range(1, self.total_modalities):
            # Current auxiliary modality (B, C, etc.)
            current_modality = modalities[i]
            # Cross-Attention: A as Query, current modality as Key/Value
            a, attn_weights = self.cross_attentions[i - 1](a, current_modality, current_modality)

        return a, attn_weights

In [None]:
#batch_size = 64
#feature_dim = 256
#seq_length_c = features_clinical.shape[0] // batch_size
#seq_length_mut = features_mutation.shape[0] // batch_size

#features_mutation = features_mutation.reshape(batch_size, seq_length_mut, feature_dim)
#features_clinical = features_clinical.reshape(batch_size, seq_length_c, feature_dim)


In [None]:
def reshape_with_padding(features, batch_size=64, feature_dim=256):
    total_samples = features.shape[0]
    seq_length = int(np.ceil(total_samples / batch_size))
    target_total = batch_size * seq_length
    padding_needed = target_total - total_samples

    if padding_needed > 0:
        pad = np.zeros((padding_needed, feature_dim))
        features = np.vstack([features, pad])

    features = features.reshape(batch_size, seq_length, feature_dim)
    return features


In [None]:
batch_size = 64
feature_dim = 256

features_mutation = reshape_with_padding(features_mutation, batch_size, feature_dim)
features_clinical = reshape_with_padding(features_clinical, batch_size, feature_dim)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

features_clinical = torch.tensor(features_clinical, dtype=torch.float32).to(device)
features_mutation = torch.tensor(features_mutation, dtype=torch.float32).to(device)

model = DynamicSequentialMultiheadCrossAttention(d_model=256, total_modalities=2, embed_dim=256, num_heads=8, batch_size=64).to(device)

output, attn_weights = model(features_clinical, features_mutation)
print(output.shape)


torch.Size([64, 21, 256])


In [None]:
model = DynamicSequentialMultiheadCrossAttention(d_model=256, total_modalities=2, embed_dim=256, num_heads=8, batch_size=64)
output, attn_weights = model(features_clinical, features_mutation)
print(output.shape)

torch.Size([64, 21, 256])


In [None]:
output

tensor([[[ 0.9468,  1.3098,  2.2585,  ..., -0.7706, -0.7170,  0.8301],
         [ 0.4157,  0.4868,  0.6769,  ..., -0.2153, -0.3273,  0.5964],
         [ 0.3138, -0.1144,  1.2133,  ..., -0.5826, -1.0286,  0.3885],
         ...,
         [ 0.1701, -0.7020,  1.0758,  ...,  0.3533,  0.6397,  0.7606],
         [ 0.3512,  0.1266,  1.2712,  ..., -0.5711,  0.4935,  0.8700],
         [-0.1589,  1.0831,  1.0249,  ..., -1.7126, -1.8269,  0.7927]],

        [[ 0.6620, -0.0300,  0.5703,  ..., -0.1709, -0.9496,  0.6046],
         [ 0.3956,  0.5757, -0.2135,  ..., -1.4176, -0.2242,  0.2286],
         [ 0.6683,  0.1578,  0.7205,  ..., -1.2043, -0.8021,  0.5351],
         ...,
         [ 0.3216,  0.9024,  0.6244,  ..., -0.5191,  0.3778,  0.4298],
         [-1.1257,  0.4531, -0.5790,  ..., -0.4395,  0.0533,  1.4077],
         [-0.7021,  1.6748,  1.1322,  ..., -0.7760, -1.0045,  1.0170]],

        [[-0.4882,  1.4057,  0.6133,  ..., -0.4880,  0.4032,  0.9318],
         [ 0.3269,  1.1047,  1.3261,  ..., -0