In [None]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [None]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATConv
import pandas as pd

In [None]:
df_G = pd.read_csv('network.csv', header = None)
df_G.columns = ['source', 'target', 'weight']
df_G.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 834421 entries, 0 to 834420
Data columns (total 3 columns):
 #   Column  Non-Null Count   Dtype  
---  ------  --------------   -----  
 0   source  834421 non-null  int64  
 1   target  834421 non-null  int64  
 2   weight  834421 non-null  float64
dtypes: float64(1), int64(2)
memory usage: 19.1 MB


In [None]:
df = df_G.head()
df

Unnamed: 0,source,target,weight
0,39364684,21061006,0.45
1,39364684,18513522,0.85
2,39364684,38251731,1.15
3,39364684,22369434,1.2
4,39364684,98928660,1.4


In [None]:
unique_nodes = pd.unique(df[['source', 'target']].values.ravel())  # Get unique node IDs
node_map = {node: idx for idx, node in enumerate(unique_nodes)}
num_nodes = len(node_map)

In [None]:
edge_index = torch.tensor([[node_map[src] for src in df["source"]],
                           [node_map[tgt] for tgt in df["target"]]], dtype=torch.long)
edge_index

tensor([[0, 0, 0, 0, 0],
        [1, 2, 3, 4, 5]])

In [None]:
edge_weight = torch.tensor(df["weight"].values, dtype=torch.float)
edge_weight

tensor([0.4500, 0.8500, 1.1500, 1.2000, 1.4000])

In [None]:
data = Data(x = torch.ones((num_nodes, 2)), edge_index=edge_index, edge_attr=edge_weight)

In [None]:
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, heads=1):
        super(GAT, self).__init__()
        self.gat = GATConv(in_channels, out_channels, heads=heads, concat=False)

    def forward(self, x, edge_index, edge_attr):
        x, attention_weights = self.gat(x, edge_index, return_attention_weights=True)
        return x, attention_weights

In [None]:
model = GAT(in_channels=2, out_channels=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
for epoch in range(20):
    model.train()
    optimizer.zero_grad()
    node_embeddings, attention_weights = model(data.x, data.edge_index, data.edge_attr)
    loss = torch.mean(node_embeddings)
    loss.backward()
    optimizer.step()

    print(f"Epoch {epoch + 1}, Loss: {loss.item()}")

print("\nNode embeddings:")
print(node_embeddings)
print("\nAttention weights:")
print(attention_weights)

Epoch 1, Loss: -0.3092934787273407
Epoch 2, Loss: -0.33929339051246643
Epoch 3, Loss: -0.3692934215068817
Epoch 4, Loss: -0.3992933928966522
Epoch 5, Loss: -0.42929336428642273
Epoch 6, Loss: -0.4592933654785156
Epoch 7, Loss: -0.48929330706596375
Epoch 8, Loss: -0.5192933082580566
Epoch 9, Loss: -0.5492933392524719
Epoch 10, Loss: -0.5792933106422424
Epoch 11, Loss: -0.6092932820320129
Epoch 12, Loss: -0.6392932534217834
Epoch 13, Loss: -0.669293224811554
Epoch 14, Loss: -0.6992931365966797
Epoch 15, Loss: -0.729293167591095
Epoch 16, Loss: -0.759293258190155
Epoch 17, Loss: -0.789293110370636
Epoch 18, Loss: -0.8192930817604065
Epoch 19, Loss: -0.8492931723594666
Epoch 20, Loss: -0.8792931437492371

Node embeddings:
tensor([[-0.2604, -1.4982],
        [-0.2604, -1.4982],
        [-0.2604, -1.4982],
        [-0.2604, -1.4982],
        [-0.2604, -1.4982],
        [-0.2604, -1.4982]], grad_fn=<AddBackward0>)

Attention weights:
(tensor([[0, 0, 0, 0, 0, 0, 1, 2, 3, 4, 5],
        [1, 2, 

In [None]:
df_user = pd.read_csv('user_dataset.csv')
df_user.head()

Unnamed: 0,user_id,user_rt,num_post,user_time_rt
0,100000075,1,0,116.82
1,100001275,1,0,20.67
2,1000027712,1,0,65.05
3,100003573,1,0,210.23
4,100003814,2,0,470.455
