In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.dataset import WikiMathsDatasetLoader
from torch_geometric_temporal.nn.recurrent import EvolveGCNH
from tqdm import tqdm

In [2]:
dataset = WikiMathsDatasetLoader().get_dataset()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.75)

In [3]:
train_dataset[0]

Data(x=[1068, 8], edge_index=[2, 27079], edge_attr=[27079], y=[1068])

In [4]:
train_dataset[0].x

tensor([[-0.4323, -0.4739,  0.2659,  ...,  0.6412,  0.2179, -0.7617],
        [-0.4041, -0.4165, -0.0751,  ...,  0.4464, -0.3916, -0.8137],
        [-0.3892,  0.0634,  0.5913,  ...,  0.2776, -0.0724, -0.8116],
        ...,
        [-0.0232,  0.0580,  0.8699,  ...,  0.6263, -0.2668, -0.9975],
        [-0.9213, -0.3829,  0.7834,  ...,  0.6040,  0.4245, -0.2035],
        [-0.3530, -0.4638,  0.0311,  ...,  0.1012, -0.0982, -0.7666]])

In [5]:
train_dataset[0].y

tensor([-0.4067, -0.1620, -0.4043,  ..., -0.9163, -1.0110, -0.5007])

In [6]:
train_dataset[1]

Data(x=[1068, 8], edge_index=[2, 27079], edge_attr=[27079], y=[1068])

In [7]:
train_dataset[1].x

tensor([[-0.4739,  0.2659,  0.4844,  ...,  0.2179, -0.7617, -0.4067],
        [-0.4165, -0.0751,  0.1484,  ..., -0.3916, -0.8137, -0.1620],
        [ 0.0634,  0.5913,  0.5370,  ..., -0.0724, -0.8116, -0.4043],
        ...,
        [ 0.0580,  0.8699,  2.9809,  ..., -0.2668, -0.9975, -0.9163],
        [-0.3829,  0.7834,  0.7834,  ...,  0.4245, -0.2035, -1.0110],
        [-0.4638,  0.0311,  0.1012,  ..., -0.0982, -0.7666, -0.5007]])

In [8]:
train_dataset[1].y

tensor([0.3064, 0.3470, 0.7482,  ..., 2.2502, 1.8601, 0.1492])

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm


In [10]:

class TemporalGNN(nn.Module):
    def __init__(self, node_count, dim_in):
        super(TemporalGNN, self).__init__()
        self.recurrent = EvolveGCNH(node_count, dim_in)
        self.linear = nn.Linear(dim_in, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight).relu()
        h = self.linear(h)
        return h

In [11]:
# Instantiate the forward model
model = TemporalGNN(node_count=dataset[0].x.shape[0], dim_in=dataset[0].x.shape[1])
optimizer = optim.Adam(model.parameters(), lr=0.01)
model.train()

TemporalGNN(
  (recurrent): EvolveGCNH(
    (pooling_layer): TopKPooling(8, ratio=0.00749063670411985, multiplier=1.0)
    (recurrent_layer): GRU(8, 8)
    (conv_layer): GCNConv_Fixed_W(8, 8)
  )
  (linear): Linear(in_features=8, out_features=1, bias=True)
)

In [12]:
# Train the forward model
for epoch in tqdm(range(50)):
    for snapshot in train_dataset:
        y_pred = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        loss = torch.mean((y_pred - snapshot.y) ** 2)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

100%|██████████| 50/50 [02:23<00:00,  2.88s/it]


In [13]:
# Evaluate the forward model
model.eval()
test_loss = 0
for i, snapshot in enumerate(test_dataset):
    y_pred = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
    mse = torch.mean((y_pred - snapshot.y) ** 2)
    test_loss += mse
test_loss = test_loss / (i + 1)
print(f'Test MSE = {test_loss.item():.4f}')

Test MSE = 0.9543


In [19]:
# Define the InverseTemporalGNN model
class InverseTemporalGNN:
    def __init__(self, forward_model, lr=0.01):
        self.forward_model = forward_model
        self.lr = lr

    def optimize_inputs(self, initial_inputs, edge_index, edge_weight, target_outputs, epochs=100):
        inputs = initial_inputs.clone().detach().requires_grad_(True)
        optimizer = optim.Adam([inputs], lr=self.lr)

        # Reshape the target_outputs to match the predictions' shape
        target_outputs = target_outputs.view(-1, 1)

        for epoch in tqdm(range(epochs)):
            optimizer.zero_grad()
            predictions = self.forward_model(inputs, edge_index, edge_weight)
            loss = nn.MSELoss()(predictions, target_outputs)
            loss.backward()
            optimizer.step()

        return inputs.detach()


In [20]:
# Instantiate the inverse model
inverse_model = InverseTemporalGNN(model, lr=0.01)

In [16]:
m = 0
for i in train_dataset:
    m+=1
print(m)

542


In [17]:
n = 0
for i in test_dataset:
    n+=1
print(n)

181


In [33]:
y_true = train_dataset[0].y
y_true

tensor([-0.4067, -0.1620, -0.4043,  ..., -0.9163, -1.0110, -0.5007])

In [31]:
y_pred = model(train_dataset[0].x, train_dataset[0].edge_index, train_dataset[0].edge_attr)
y_pred

tensor([[-0.6051],
        [-0.6694],
        [-1.1557],
        ...,
        [-0.2152],
        [-0.2614],
        [-0.3901]], grad_fn=<AddmmBackward0>)

In [34]:
torch.mean((y_pred-y_true)**2)

tensor(0.3063, grad_fn=<MeanBackward0>)

In [42]:
y_true_test = test_dataset[2].y
y_true_test

tensor([ 0.5772, -0.7393,  1.1736,  ..., -0.5104,  0.6040, -0.1093])

In [45]:
test_dataset[2].x

tensor([[ 0.9130,  0.6742, -0.2639,  ...,  0.8885,  1.2307,  1.2051],
        [-0.1123,  0.5643, -0.7331,  ..., -0.3047, -0.4537, -0.0751],
        [ 1.5990,  1.3818, -0.0030,  ...,  2.1964,  2.5675,  2.2205],
        ...,
        [ 0.0580, -0.1044, -0.7539,  ..., -0.2668, -0.1044, -0.3480],
        [ 0.1554,  0.7834,  0.6040,  ...,  2.3087,  1.2320,  1.0526],
        [ 0.1419, -0.0723, -0.5746,  ...,  0.3339,  0.4631,  0.2120]])

In [41]:
y_pred_test = model(test_dataset[2].x,test_dataset[2].edge_index,test_dataset[2].edge_attr)
y_pred_test

tensor([[ 0.1176],
        [ 0.0798],
        [ 0.3456],
        ...,
        [-0.2158],
        [-0.1412],
        [-0.1183]], grad_fn=<AddmmBackward0>)

In [43]:
torch.mean((y_true_test-y_pred_test)**2)

tensor(0.6232, grad_fn=<MeanBackward0>)

In [79]:
target_outputs = test_dataset[2].y
initial_inputs = torch.rand_like(test_dataset[2].x)
edge_index = test_dataset[2].edge_index
edge_weight = test_dataset[2].edge_weight

optimized_inputs = inverse_model.optimize_inputs(initial_inputs, edge_index, edge_weight, target_outputs, epochs=2000) 
print(f'initial_inputs :  {initial_inputs}')
print(f'Optimized Inputs: {optimized_inputs}')

100%|██████████| 2000/2000 [00:04<00:00, 489.97it/s]

initial_inputs :  tensor([[0.5913, 0.2369, 0.9110,  ..., 0.5659, 0.9896, 0.4009],
        [0.0859, 0.5484, 0.7577,  ..., 0.8816, 0.0766, 0.3686],
        [0.9923, 0.5065, 0.4042,  ..., 0.9693, 0.3701, 0.4639],
        ...,
        [0.7581, 0.4797, 0.0401,  ..., 0.4210, 0.5876, 0.1154],
        [0.4213, 0.7540, 0.5499,  ..., 0.4747, 0.1034, 0.6268],
        [0.7224, 0.5920, 0.5727,  ..., 0.5704, 0.6778, 0.7637]])
Optimized Inputs: tensor([[ 1.0794,  1.1733,  0.6859,  ..., -0.5140,  0.6223, -0.4099],
        [-0.4826, -0.1800,  1.0287,  ...,  1.1560,  0.7842,  0.0693],
        [-0.8906,  0.5577,  4.7070,  ...,  3.5420,  3.3562,  0.1525],
        ...,
        [-0.5538,  0.7734, -0.6664,  ...,  0.5363, -0.8809,  0.3537],
        [-0.2862,  1.7210,  0.7613,  ...,  0.2813, -0.9586, -0.1862],
        [-0.3406,  0.0762,  0.3883,  ..., -0.2808, -0.2929,  0.4687]])





In [80]:
y_true

tensor([-0.4067, -0.1620, -0.4043,  ..., -0.9163, -1.0110, -0.5007])

In [81]:
y_pred_initial_test = model(initial_inputs,test_dataset[2].edge_index,test_dataset[2].edge_attr)
y_pred_initial_test

tensor([[-0.0254],
        [ 0.0108],
        [ 0.1660],
        ...,
        [-0.1626],
        [-0.0573],
        [-0.0436]], grad_fn=<AddmmBackward0>)

In [82]:
y_pred_optim_test = model(optimized_inputs,test_dataset[2].edge_index,test_dataset[2].edge_attr)
y_pred_optim_test

tensor([[ 0.6078],
        [ 0.0906],
        [ 0.3956],
        ...,
        [-0.3328],
        [ 0.8006],
        [-0.1194]], grad_fn=<AddmmBackward0>)

In [83]:
torch.mean((y_true - y_pred_optim_test)**2)

tensor(0.9975, grad_fn=<MeanBackward0>)

In [84]:
torch.mean((y_true - y_pred_initial_test)**2)

tensor(0.4574, grad_fn=<MeanBackward0>)

In [22]:
# Example usage of the inverse model
for snapshot in test_dataset:
    target_outputs = snapshot.y
    initial_inputs = torch.randn_like(snapshot.x)  # Initialize with small random values
    edge_index = snapshot.edge_index
    edge_weight = snapshot.edge_attr

    optimized_inputs = inverse_model.optimize_inputs(initial_inputs, edge_index, edge_weight, target_outputs, epochs=100)
    print(f'initial_inputs :  {initial_inputs}')
    print(f'Optimized Inputs: {optimized_inputs}')

100%|██████████| 100/100 [00:00<00:00, 406.70it/s]


initial_inputs :  tensor([[-1.7819e+00, -8.4877e-01, -8.3737e-02,  ..., -1.2228e+00,
         -6.4525e-02, -8.9364e-01],
        [-1.1322e+00,  8.5030e-01,  1.5404e+00,  ..., -1.3122e-01,
          1.1210e+00,  1.5799e-01],
        [-1.3442e-01,  1.2540e+00,  7.0148e-01,  ...,  9.2039e-01,
         -1.0771e+00,  1.6157e-01],
        ...,
        [ 1.6828e-01,  3.1733e-01,  5.3488e-01,  ...,  6.1489e-01,
          9.1285e-02,  6.7807e-01],
        [ 1.0924e+00,  4.8166e-01, -4.7915e-01,  ..., -1.4607e-01,
         -6.3834e-01, -4.4304e-01],
        [-1.4466e+00, -1.2917e-03, -7.8040e-01,  ..., -5.3888e-01,
         -6.8017e-01, -9.6322e-01]])
Optimized Inputs: tensor([[-2.3742, -0.0892,  0.0226,  ..., -2.0512, -0.7220, -0.3116],
        [-1.7331,  1.7145,  1.6312,  ..., -0.7509,  0.5382,  0.7129],
        [-0.7274,  2.0600,  0.6546,  ...,  0.2650, -1.7136,  0.7555],
        ...,
        [-0.8076,  1.2067,  1.3044,  ...,  1.0535,  0.7309,  1.5071],
        [ 0.0453,  1.4327,  0.3633,  ..

100%|██████████| 100/100 [00:00<00:00, 387.71it/s]


initial_inputs :  tensor([[-1.8203, -0.8891, -1.0472,  ...,  0.7187, -0.8608,  1.4376],
        [ 0.5687, -1.3044, -0.4427,  ..., -0.1589, -0.1065, -1.2774],
        [ 1.6476, -0.9147,  0.9241,  ...,  0.1245, -0.1792, -0.3054],
        ...,
        [-1.1385,  0.5700,  0.8394,  ...,  2.0902, -0.2614,  0.8643],
        [ 0.5293, -0.7613, -0.5302,  ..., -1.9009, -0.3454, -0.4922],
        [ 0.5020,  0.3953,  1.1336,  ..., -1.3184,  0.0086, -0.1347]])
Optimized Inputs: tensor([[-2.3187, -0.3942, -0.6218,  ...,  0.1919, -0.9141,  1.8947],
        [ 0.1112, -0.8441, -0.0205,  ..., -0.5400, -0.3385, -0.8166],
        [ 1.0835, -0.3692,  1.3711,  ..., -0.5142, -0.5288,  0.2205],
        ...,
        [-1.3634,  0.9430,  1.2191,  ...,  1.9542,  0.0170,  1.2883],
        [-0.4686,  0.0762,  0.2985,  ..., -2.4469, -0.2400,  0.3362],
        [ 0.1414,  0.7864,  1.3887,  ..., -0.6159, -0.2550,  0.2761]])


100%|██████████| 100/100 [00:00<00:00, 422.65it/s]


initial_inputs :  tensor([[-1.0069, -0.0998, -0.8662,  ..., -0.0183, -0.3008, -1.1260],
        [ 1.7833,  0.4633, -0.0275,  ...,  0.0586, -0.1113,  0.1466],
        [ 0.8730, -0.9788, -1.6347,  ..., -0.3973,  1.3741,  0.8343],
        ...,
        [-0.1733, -0.2425,  1.9155,  ...,  1.1309, -0.6051,  0.0883],
        [ 0.2414,  0.5669, -0.0931,  ..., -0.2054, -0.1310,  0.2230],
        [ 0.4754, -0.5049,  0.2204,  ...,  1.3215,  1.4774,  0.1721]])
Optimized Inputs: tensor([[-1.5597,  0.4197, -0.5394,  ..., -0.0042, -0.7592, -0.7020],
        [ 1.3516,  0.9154,  0.3482,  ..., -0.2415, -0.5149,  0.4858],
        [ 0.2369, -0.4431, -1.3374,  ..., -0.5240,  0.8258,  1.3222],
        ...,
        [ 0.2684, -0.8221,  1.1492,  ...,  1.5058,  0.2798, -0.6462],
        [-0.6654,  1.5254,  0.8380,  ..., -0.3579,  0.3524,  1.1418],
        [ 0.2470, -0.3107,  0.4600,  ...,  0.9580,  1.2466,  0.4060]])


100%|██████████| 100/100 [00:00<00:00, 427.77it/s]


initial_inputs :  tensor([[-0.5651,  0.6357, -0.2353,  ..., -0.6774, -1.4530,  0.5332],
        [-0.6571,  0.3525,  1.3900,  ..., -0.8270, -0.5448,  0.4703],
        [ 0.2877, -1.4185, -1.2331,  ..., -0.5954, -1.6798, -0.8611],
        ...,
        [ 0.4950,  1.3204, -0.3998,  ...,  0.7332, -0.6343, -1.0405],
        [ 2.5235, -0.6522,  1.1166,  ..., -0.4120, -0.7443, -0.2228],
        [ 1.6995, -0.9441, -0.6766,  ..., -2.6963, -0.5515, -0.4392]])
Optimized Inputs: tensor([[-0.0289,  0.1326, -1.0134,  ..., -0.2616, -0.8996, -0.0822],
        [-0.2305, -0.1244,  0.6554,  ..., -0.2782, -0.0557, -0.0261],
        [ 0.3331, -1.5486, -0.9447,  ..., -0.2328, -1.5046, -0.8269],
        ...,
        [ 1.2452,  0.7184, -1.0068,  ...,  0.1405,  0.0700, -1.7979],
        [ 2.5090, -0.7586,  1.0950,  ..., -0.1040, -0.6834, -0.2346],
        [ 1.6947, -1.0652, -0.8614,  ..., -2.3859, -0.4777, -0.5138]])


100%|██████████| 100/100 [00:00<00:00, 418.20it/s]


initial_inputs :  tensor([[ 0.4372,  0.5690,  0.2983,  ..., -0.3450,  0.4513,  0.1294],
        [-1.4398,  0.8932, -0.0105,  ...,  0.3506,  2.1400,  0.9133],
        [ 0.3330, -0.8129, -1.6310,  ...,  1.8416,  0.1824, -1.3265],
        ...,
        [ 0.7819,  0.4139,  1.8826,  ..., -1.9787,  0.6523, -1.2700],
        [-0.4204, -0.3247, -1.5765,  ..., -0.5057, -0.3565,  0.4269],
        [ 1.1008,  0.7272, -0.7923,  ..., -0.8752, -0.0240, -0.1018]])
Optimized Inputs: tensor([[ 0.7073,  0.3607, -0.0626,  ..., -0.5717,  0.6643, -0.1686],
        [-1.0017,  0.4154, -0.3904,  ..., -0.0894,  2.5624,  0.3539],
        [-0.0571, -0.5506, -0.8546,  ...,  1.8553, -0.2660, -0.9667],
        ...,
        [ 0.8614,  0.1609,  1.6621,  ..., -1.6555,  0.6671, -1.3977],
        [-1.3470,  0.4244, -0.7781,  ..., -0.0852, -1.2740,  1.3781],
        [ 1.6838,  0.1546, -1.2563,  ..., -1.1094,  0.5681, -0.6715]])


100%|██████████| 100/100 [00:00<00:00, 428.93it/s]


initial_inputs :  tensor([[ 0.3405,  0.1450, -0.7867,  ...,  0.5504,  0.8432, -1.8363],
        [-0.6062, -0.0463,  1.4582,  ...,  0.3437, -0.5503, -0.0655],
        [ 1.1803,  1.9108, -1.7447,  ...,  0.7691, -1.0671, -0.0679],
        ...,
        [-0.3618,  0.2769, -0.9969,  ..., -0.4325, -1.2916, -0.6721],
        [-0.4209, -1.3855,  0.8994,  ...,  1.9339,  1.3508, -1.0670],
        [-1.0093,  0.3118, -0.4987,  ..., -0.0705, -1.4551, -1.0362]])
Optimized Inputs: tensor([[-0.3422,  0.7502, -0.1709,  ...,  1.0087,  0.3806, -1.2977],
        [-1.0894,  0.4030,  1.9284,  ...,  0.6691, -0.9126,  0.3442],
        [ 0.6296,  2.3971, -1.1821,  ...,  0.9482, -1.3935,  0.3442],
        ...,
        [-0.6223, -0.2670, -1.3403,  ...,  0.2021, -1.0291, -0.6486],
        [-1.1512, -0.6351,  1.8050,  ...,  1.5240,  0.9775, -0.5237],
        [-1.4655,  0.7535, -0.6587,  ..., -0.2820, -1.7974, -0.7284]])


100%|██████████| 100/100 [00:00<00:00, 424.44it/s]


initial_inputs :  tensor([[ 0.0910,  2.4719, -2.2135,  ...,  0.2131,  0.2249,  0.1740],
        [-1.1900,  0.0799, -0.2907,  ...,  1.3856,  0.8639,  0.7088],
        [ 0.0490, -0.5204,  0.7313,  ..., -0.9454,  1.4577,  0.2477],
        ...,
        [-0.3059,  0.7899,  0.6260,  ...,  1.4656, -1.0139,  1.1480],
        [ 0.3264,  0.2472,  0.7391,  ..., -0.5449,  0.5184, -2.1213],
        [-0.7110, -0.0109, -1.1519,  ..., -0.4448,  0.5476, -0.7143]])
Optimized Inputs: tensor([[-0.1334,  2.6078, -2.6046,  ...,  0.1418, -0.0892,  0.4561],
        [-1.2810,  0.1650, -0.1202,  ...,  1.3501,  0.6711,  0.9094],
        [-0.3221, -0.2510,  0.5844,  ..., -1.1667,  1.0369,  0.6702],
        ...,
        [-0.9093,  1.3947,  1.1780,  ...,  0.7520, -1.3822,  1.7079],
        [-0.4737,  1.0477,  1.4543,  ..., -1.4726,  0.0423, -1.4272],
        [-1.2486,  0.5367, -0.6543,  ..., -0.9398,  0.0177, -0.1895]])


 66%|██████▌   | 66/100 [00:00<00:00, 410.01it/s]


KeyboardInterrupt: 

In [None]:
test_dataset[0].y.shape

In [None]:
test_dataset[0].x.shape

In [None]:
torch.zeros_like(test_dataset[0].x)

In [None]:
torch.zeros_like(test_dataset[0].y)