In [13]:
import torch.nn as nn
import torch
from torch_geometric.utils import to_dense_batch
import einops

class VanillaAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout):
        super().__init__()
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.n_heads = n_heads
        self.scaling = (d_model//n_heads)**(-0.5)
        self.final_linear = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index=None, edge_attr=None, parenthood=None, batch=None):
        x, mask = to_dense_batch(x, batch) # x has shape [batch_size, num_nodes, d_model]
        attn_mask = mask.unsqueeze(1) * mask.unsqueeze(2)
        float_attn_mask = ((~attn_mask)*(-1.0e9)).unsqueeze(1)
        keys = self.W_K(x)
        queries = self.W_Q(x)
        values = self.W_V(x)
        keys = einops.rearrange(keys, "b n (h x) -> b n h x", h = self.n_heads)
        queries = einops.rearrange(queries, "b n (h x) -> b n h x", h = self.n_heads)
        values = einops.rearrange(values, "b n (h x) -> b n h x", h = self.n_heads)
        attn_coefficients = torch.einsum("bihx, bjhx  -> bhij",keys, queries)*self.scaling
        attn_coefficients += float_attn_mask
        softmaxed_attn_coefficients = torch.softmax(attn_coefficients, dim = -1)
        softmaxed_attn_coefficients = self.attn_dropout(softmaxed_attn_coefficients)
        computed_values = torch.einsum("bhij, bjhx-> bihx", softmaxed_attn_coefficients, values)
        concatenated_values = einops.rearrange(computed_values, "b i h x ->b i (h x)").contiguous()
        out_dense = self.final_linear(concatenated_values)
        out = out_dense[mask]
        return self.resid_dropout(out)
    
att = VanillaAttention(d_model=8, n_heads=2, dropout=0.0)
x = torch.zeros(5,8)
x[0] = 1
x[1] = 2
x[2] = 3
x[3] = 4
x[4] = 5
x = torch.randn_like(x)
batch = torch.tensor([0,0,0,1,1])
test_1 = att(x=x, batch=batch)
x[0] = 1
test_2 = att(x=x, batch=batch)

In [2]:
from transformer_gcn import TransformerGCN

In [3]:
import pickle
from torch_geometric.loader import DataLoader
with open("/media/enzo/Stockage/Output_general/dataset_3/dataset.pkl", "rb") as f:
    full_dataset = pickle.load(f)

dataloader = DataLoader(full_dataset, batch_size=2, shuffle=False)
item = next(iter(dataloader))

In [4]:
transformer = TransformerGCN(node_in_features=6,
                             d_model=32,
                             n_heads=2,
                             mlp_expansion_factor=2,
                             n_blocks=2,
                             aggr="add")

In [5]:
import torch
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.001)

In [6]:
x = item["node"].x
edge_index = item["node", "sends_gene_to", "node"].edge_index
edge_attr = item["node", "sends_gene_to", "node"].edge_attr
batch = item["node"].batch
parenthood = item["node", "is_parent_of", "node"].edge_index
result = transformer(x, edge_index, edge_attr, parenthood, batch)
l = result.norm()

In [7]:
with open("/home/enzo/Documents/git/WP1/DeepGhosts/experiments/batch.pkl", "rb") as f:
    batch = pickle.load(f)

In [8]:
with open("/home/enzo/Documents/git/WP1/DeepGhosts/experiments/batch_node.pkl", "rb") as f:
    batch_node = pickle.load(f)

In [9]:
with open("/home/enzo/Documents/git/WP1/DeepGhosts/experiments/batch_sgt.pkl", "rb") as f:
    batch_sgt = pickle.load(f)

In [10]:
with open("/home/enzo/Documents/git/WP1/DeepGhosts/experiments/batch_isp.pkl", "rb") as f:
    batch_isp = pickle.load(f)

In [11]:
x = batch_node.x
edge_index = batch_sgt.edge_index
edge_attr = batch_sgt.edge_attr
batch = batch_node.batch
parenthood = batch_isp.edge_index
result = transformer(x, edge_index, edge_attr, parenthood, batch)

In [12]:
emb = transformer.embedding(x, edge_index, edge_attr, parenthood, batch)

In [13]:
transformer.transformer_blocks[0]

TransformerMPNN(
  (transformer_block): VanillaTransformerBlock(
    (attention): VanillaAttention(
      (W_K): Linear(in_features=32, out_features=32, bias=False)
      (W_Q): Linear(in_features=32, out_features=32, bias=False)
      (W_V): Linear(in_features=32, out_features=32, bias=False)
      (final_linear): Linear(in_features=32, out_features=32, bias=True)
    )
    (mlp): Sequential(
      (0): Linear(in_features=32, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=32, bias=True)
    )
    (layer_norm_1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (layer_norm_2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  )
  (message_passing): GCN(
    (layer): GCNConv(32, 32)
  )
  (aggregation): Linear(in_features=64, out_features=32, bias=True)
)

In [14]:
transformer.transformer_blocks[0].transformer_block.attention(emb, batch=batch).norm()

tensor(5.6674, grad_fn=<LinalgVectorNormBackward0>)

In [15]:
transformer.transformer_blocks[0].transformer_block(emb, batch=batch)

tensor([[ 1.8480,  1.1060, -1.2063,  ..., -1.1965,  0.4992,  0.5279],
        [ 1.2813,  0.7959, -0.1550,  ..., -1.0034, -0.2451, -0.8826],
        [ 1.3536,  1.1594, -0.0518,  ..., -1.2289, -0.2262, -0.7620],
        ...,
        [ 0.5683, -0.4899, -0.4450,  ...,  0.4125,  0.4441, -0.3414],
        [ 0.5261, -0.4859, -0.3067,  ...,  0.3840,  0.3767, -0.4991],
        [ 0.5552, -0.4558, -0.3919,  ...,  0.3860,  0.4381, -0.4166]],
       grad_fn=<AddBackward0>)

In [16]:
gcnconv = transformer.transformer_blocks[0].message_passing.layer

In [17]:
edge_index_inv = edge_index[[1,0]]

In [18]:
edge_index_inv

tensor([[  0,   0,   0,  ..., 108, 108, 118],
        [ 15,  24,  25,  ...,  84, 105,  40]])

In [19]:
import torch_geometric
import torch
print(torch.__version__)
torch_geometric.__version__

2.5.1


'2.6.1'

In [20]:
gcnconv

GCNConv(32, 32)

In [22]:
gcnconv(emb, edge_index_inv, edge_attr).shape

torch.Size([119, 32])

In [30]:
transformer.transformer_blocks[0].message_passing(emb, edge_index, edge_attr).norm()

tensor(nan, grad_fn=<LinalgVectorNormBackward0>)

In [23]:
edge_attr

tensor([[-1.0646],
        [-1.0530],
        [-1.0879],
        ...,
        [-0.1701],
        [-0.1701],
        [-0.1584]])

In [24]:
transformer.transformer_blocks[0](emb, edge_index, edge_attr, parenthood, batch)

tensor([[    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan],
        ...,
        [-1.3901,  0.0060,  0.4786,  ...,  0.0300,  0.6529,  0.2032],
        [-1.2956,  0.0104,  0.4279,  ..., -0.1605,  0.4879,  0.1784],
        [    nan,     nan,     nan,  ...,     nan,     nan,     nan]],
       grad_fn=<AddBackward0>)

In [None]:
result = model(x, edge_index, edge_attr, parenthood, batch)


In [10]:
l.backward()

In [12]:
optimizer.step()

In [14]:
# see parameter gradients
for name, param in transformer.named_parameters():
    print(name, param.grad.norm(), param.norm())

embedding.node_embedding.weight tensor(616.2884) tensor(3.2843, grad_fn=<LinalgVectorNormBackward0>)
embedding.node_embedding.bias tensor(23.7653) tensor(1.5283, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.aggregation.weight tensor(1162.4742) tensor(3.2305, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.aggregation.bias tensor(19.6875) tensor(0.4217, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.transformer_block.attention.W_K.weight tensor(0.9841) tensor(3.2997, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.transformer_block.attention.W_Q.weight tensor(0.8953) tensor(3.3082, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.transformer_block.attention.W_V.weight tensor(27.8781) tensor(3.2212, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.transformer_block.attention.final_linear.weight tensor(25.6120) tensor(3.2518, grad_fn=<LinalgVectorNormBackward0>)
transformer_blocks.0.transformer_block.attention.final_linear.bias