In [11]:
#Twitter tennis
#Dynamic Graph Temporal Signals
import torch
from torch_geometric_temporal.dataset import TwitterTennisDatasetLoader
from torch_geometric_temporal.signal import temporal_signal_split

loader=TwitterTennisDatasetLoader()
dataset=loader.get_dataset()

In [12]:
#Split
train,test=temporal_signal_split(dataset,train_ratio=0.5)

In [13]:
#Model
#Gatted Recurrent Unit
from torch_geometric_temporal.nn.recurrent import GConvGRU
from torch.nn import Module,Linear

class DynamicTemporalNetwork(Module):
    def __init__(self,features,filters):
        super(DynamicTemporalNetwork,self).__init__()

        self.conv_gru=GConvGRU(in_channels=features,out_channels=filters,K=3,normalization='sym')
        self.linear=Linear(in_features=filters,out_features=1)
    
    def forward(self,x,edge_index,edge_weight):
        x=self.conv_gru(x,edge_index,edge_weight).relu()
        x=self.linear(x)

        return x


In [14]:
from torch.nn import MSELoss
from tqdm import tqdm

model=DynamicTemporalNetwork(features=16,filters=32)
optimizer=torch.optim.Adam(params=model.parameters(),lr=0.001,betas=(0.9,0.999))
loss=MSELoss()

In [15]:
#Training loop
def training_loop():
    model.train(True)

    NUM_EPOCHS=50

    for epoch in tqdm(range(NUM_EPOCHS)):
        train_loss=0
        test_loss=0
        #Train
        for t,snap in enumerate(train):
            op=model(snap.x,snap.edge_index,snap.edge_weight)
            snap_loss=loss(snap.y,op)

            #training
            model.zero_grad()

            snap_loss.backward()
            optimizer.step()

            #aggregate
            train_loss+=snap_loss.item()

        #Validation
        model.eval()
        
        with torch.no_grad():
            for t,snap in enumerate(test):
                op=model(snap.x,snap.edge_index,snap.edge_weight)
                snap_loss=loss(snap.y,op)

                test_loss+=snap_loss.item()
        
            print("Epoch: {ep}".format(ep=epoch+1))
            print("Train Loss: {l}".format(l=train_loss/(t+1)))
            print("Validation Loss: {l}".format(l=test_loss/(t+1)))


In [16]:
training_loop()

  0%|          | 0/50 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)
  2%|▏         | 1/50 [00:00<00:25,  1.91it/s]

Epoch: 1
Train Loss: 0.4590473581105471
Validation Loss: 0.42696451768279076


  4%|▍         | 2/50 [00:01<00:24,  1.99it/s]

Epoch: 2
Train Loss: 0.4497845354179541
Validation Loss: 0.42692563260595


  6%|▌         | 3/50 [00:01<00:22,  2.05it/s]

Epoch: 3
Train Loss: 0.4498294701178869
Validation Loss: 0.4269285942117373


  8%|▊         | 4/50 [00:01<00:22,  2.01it/s]

Epoch: 4
Train Loss: 0.44981456771492956
Validation Loss: 0.42692486693461734


 10%|█         | 5/50 [00:02<00:22,  2.01it/s]

Epoch: 5
Train Loss: 0.4498084420959155
Validation Loss: 0.42692410623033844


 12%|█▏        | 6/50 [00:02<00:21,  2.04it/s]

Epoch: 6
Train Loss: 0.4498031030098597
Validation Loss: 0.42692358965675037


 14%|█▍        | 7/50 [00:03<00:21,  2.05it/s]

Epoch: 7
Train Loss: 0.44979801326990126
Validation Loss: 0.4269236216942469


 16%|█▌        | 8/50 [00:03<00:20,  2.06it/s]

Epoch: 8
Train Loss: 0.4497935595611731
Validation Loss: 0.42692298789819083


 18%|█▊        | 9/50 [00:04<00:19,  2.06it/s]

Epoch: 9
Train Loss: 0.44978960951169333
Validation Loss: 0.4269228878120581


 20%|██        | 10/50 [00:04<00:20,  2.00it/s]

Epoch: 10
Train Loss: 0.4497856025894483
Validation Loss: 0.426922952880462


 22%|██▏       | 11/50 [00:05<00:19,  2.04it/s]

Epoch: 11
Train Loss: 0.44978179037570953
Validation Loss: 0.4269229936103026


 24%|██▍       | 12/50 [00:05<00:18,  2.07it/s]

Epoch: 12
Train Loss: 0.4497780196368694
Validation Loss: 0.42692462330063186


 26%|██▌       | 13/50 [00:06<00:18,  2.00it/s]

Epoch: 13
Train Loss: 0.44977316732207934
Validation Loss: 0.4269250974059105


 28%|██▊       | 14/50 [00:06<00:18,  1.95it/s]

Epoch: 14
Train Loss: 0.4497693292796612
Validation Loss: 0.4269255911310514


 30%|███       | 15/50 [00:07<00:18,  1.89it/s]

Epoch: 15
Train Loss: 0.4497656079630057
Validation Loss: 0.4269257163008054


 32%|███▏      | 16/50 [00:08<00:19,  1.79it/s]

Epoch: 16
Train Loss: 0.4497621198495229
Validation Loss: 0.42692580819129944


 34%|███▍      | 17/50 [00:08<00:18,  1.77it/s]

Epoch: 17
Train Loss: 0.449758863200744
Validation Loss: 0.4269264017542203


 36%|███▌      | 18/50 [00:09<00:18,  1.73it/s]

Epoch: 18
Train Loss: 0.44975540563464167
Validation Loss: 0.42692679117123283


 38%|███▊      | 19/50 [00:09<00:18,  1.70it/s]

Epoch: 19
Train Loss: 0.4497520369788011
Validation Loss: 0.4269268065690994


 40%|████      | 20/50 [00:10<00:17,  1.73it/s]

Epoch: 20
Train Loss: 0.4497487691541513
Validation Loss: 0.4269272704919179


 42%|████▏     | 21/50 [00:11<00:16,  1.73it/s]

Epoch: 21
Train Loss: 0.449745474755764
Validation Loss: 0.42692753449082377


 44%|████▍     | 22/50 [00:11<00:16,  1.73it/s]

Epoch: 22
Train Loss: 0.4497422578434149
Validation Loss: 0.42692767654856045


 46%|████▌     | 23/50 [00:12<00:15,  1.73it/s]

Epoch: 23
Train Loss: 0.4497390153507392
Validation Loss: 0.4269278099139531


 48%|████▊     | 24/50 [00:12<00:14,  1.74it/s]

Epoch: 24
Train Loss: 0.4497359690566858
Validation Loss: 0.42692849338054656


 50%|█████     | 25/50 [00:13<00:14,  1.77it/s]

Epoch: 25
Train Loss: 0.44973279734452565
Validation Loss: 0.42692885051170987


 52%|█████▏    | 26/50 [00:13<00:14,  1.70it/s]

Epoch: 26
Train Loss: 0.4497299772997697
Validation Loss: 0.426929317911466


 54%|█████▍    | 27/50 [00:14<00:13,  1.73it/s]

Epoch: 27
Train Loss: 0.449726922561725
Validation Loss: 0.42692950343092284


 56%|█████▌    | 28/50 [00:15<00:12,  1.77it/s]

Epoch: 28
Train Loss: 0.44972402478257817
Validation Loss: 0.4269308092693488


 58%|█████▊    | 29/50 [00:15<00:12,  1.73it/s]

Epoch: 29
Train Loss: 0.44972059652209284
Validation Loss: 0.4269289284944534


 60%|██████    | 30/50 [00:16<00:11,  1.75it/s]

Epoch: 30
Train Loss: 0.4497181748350461
Validation Loss: 0.42693145101269087


 62%|██████▏   | 31/50 [00:16<00:10,  1.77it/s]

Epoch: 31
Train Loss: 0.44971530164281526
Validation Loss: 0.4269318903485934


 64%|██████▍   | 32/50 [00:17<00:10,  1.78it/s]

Epoch: 32
Train Loss: 0.4497125377257665
Validation Loss: 0.42693183521429695


 66%|██████▌   | 33/50 [00:17<00:09,  1.79it/s]

Epoch: 33
Train Loss: 0.4497096414367358
Validation Loss: 0.42693237761656444


 68%|██████▊   | 34/50 [00:18<00:08,  1.80it/s]

Epoch: 34
Train Loss: 0.44970683952172597
Validation Loss: 0.4269324185947577


 70%|███████   | 35/50 [00:18<00:08,  1.82it/s]

Epoch: 35
Train Loss: 0.44970400954286255
Validation Loss: 0.42693285594383873


 72%|███████▏  | 36/50 [00:19<00:07,  1.81it/s]

Epoch: 36
Train Loss: 0.44970151484012605
Validation Loss: 0.42693426335851353


 74%|███████▍  | 37/50 [00:20<00:07,  1.81it/s]

Epoch: 37
Train Loss: 0.4496985716124376
Validation Loss: 0.42693470045924187


 76%|███████▌  | 38/50 [00:20<00:06,  1.71it/s]

Epoch: 38
Train Loss: 0.4496959278980891
Validation Loss: 0.4269341786702474


 78%|███████▊  | 39/50 [00:21<00:06,  1.64it/s]

Epoch: 39
Train Loss: 0.4496934769054254
Validation Loss: 0.4269348576664925


 80%|████████  | 40/50 [00:22<00:06,  1.60it/s]

Epoch: 40
Train Loss: 0.44969074601928394
Validation Loss: 0.42693425094087917


 82%|████████▏ | 41/50 [00:22<00:05,  1.65it/s]

Epoch: 41
Train Loss: 0.449688575665156
Validation Loss: 0.4269355279703935


 84%|████████▍ | 42/50 [00:23<00:04,  1.64it/s]

Epoch: 42
Train Loss: 0.4496859796345234
Validation Loss: 0.4269357070326805


 86%|████████▌ | 43/50 [00:23<00:04,  1.63it/s]

Epoch: 43
Train Loss: 0.4496836962799231
Validation Loss: 0.4269360880057017


 88%|████████▊ | 44/50 [00:24<00:03,  1.66it/s]

Epoch: 44
Train Loss: 0.44968143353859585
Validation Loss: 0.42693636839588484


 90%|█████████ | 45/50 [00:25<00:02,  1.67it/s]

Epoch: 45
Train Loss: 0.44967881043752034
Validation Loss: 0.4269363301495711


 92%|█████████▏| 46/50 [00:25<00:02,  1.71it/s]

Epoch: 46
Train Loss: 0.44967667708794273
Validation Loss: 0.4269370138645172


 94%|█████████▍| 47/50 [00:26<00:01,  1.70it/s]

Epoch: 47
Train Loss: 0.4496741962929567
Validation Loss: 0.426937056829532


 96%|█████████▌| 48/50 [00:26<00:01,  1.71it/s]

Epoch: 48
Train Loss: 0.44967175846298535
Validation Loss: 0.4269358863433202


 98%|█████████▊| 49/50 [00:27<00:00,  1.72it/s]

Epoch: 49
Train Loss: 0.4496698878705502
Validation Loss: 0.426937185972929


100%|██████████| 50/50 [00:27<00:00,  1.79it/s]

Epoch: 50
Train Loss: 0.44966766610741615
Validation Loss: 0.4269372555116812



