In [46]:
import pandas as pd
import numpy as np
import string
import random
import torch
import collections
import dgl
import tqdm
import os
os.environ["DATASET_DIR"] = "/home/rustambaku13/Documents/Warwick/d3-gnn/datasets"

In [13]:
class TwoLayerGraphSAGE(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super().__init__()
        self.conv1 = dgl.nn.SAGEConv(in_features, hidden_features, aggregator_type="mean")
        self.conv2 = dgl.nn.SAGEConv(hidden_features, out_features, aggregator_type="mean")
        self.ff = torch.nn.Sequential(
            torch.nn.Linear(out_features, out_features),
            torch.nn.ReLU(),
            torch.nn.Linear(out_features, out_features),
            torch.nn.ReLU(),
            torch.nn.Linear(out_features, out_features)
        )
    def forward(self, g, x):
        x = self.conv1(g[0], x)
        x = self.conv2(g[1], x)
        return torch.nn.functional.softmax(self.ff(x), dim=1)


# Get OGB-Products

In [37]:
def get_ogb(nedge=20000000):
    topology = pd.read_csv(os.path.join(os.environ["DATASET_DIR"], "ogb-products","edges.csv"), header=None, nrows=nedge)
    features = pd.read_csv(os.path.join(os.environ["DATASET_DIR"], "ogb-products","node_features.csv"), header=None, index_col=0)
    labels = pd.read_csv(os.path.join(os.environ["DATASET_DIR"], "ogb-products","node_labels.csv"), header=None, index_col=0)
    graph = dgl.graph((topology[0].values, topology[1].values))
    graph.ndata['f'] = torch.tensor(features.values, dtype=torch.float32)
    graph.ndata['l'] = torch.tensor(labels.values[:,0], dtype=torch.int64)
    return graph

def save_ogb_model(model):
    os.mkdir(os.path.join(os.environ["DATASET_DIR"], "ogb-products","graphSage"))
    for i, weight in enumerate(model.parameters()):
        np.save(os.path.join(os.environ["DATASET_DIR"], "ogb-products","graphSage",str(i)), weight.data)
    
def load_ogb_model(model):
    for i, j in enumerate(model.parameters()):
        d = np.load(os.path.join(os.environ["DATASET_DIR"], "ogb-products","graphSage",str(i) + '.npy'))
        j.data = torch.tensor(np.load(os.path.join(os.environ["DATASET_DIR"], "ogb-products","graphSage",str(i) + '.npy')), requires_grad=True)
        
    

In [38]:
g = get_ogb()

## Train

In [39]:
# model = TwoLayerGraphSAGE(100, 64, 47)
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
dataloader = dgl.dataloading.DataLoader(
    g, g.nodes(), sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=4)
opt = torch.optim.Adam(model.parameters())
loss = torch.nn.CrossEntropyLoss()

In [40]:
for epoch in tqdm.tqdm(range(60)):
    corrects = 0
    total = 0
    for input_nodes, output_nodes, blocks in dataloader:
        input_features = blocks[0].srcdata['f']
        output_labels = blocks[-1].dstdata['l']
        output_predictions = model(blocks, input_features)
        loss_value = loss(output_predictions, output_labels)
        corrects += (output_predictions.argmax(dim=1) == output_labels).sum()
        total += output_predictions.shape[0]
        print(corrects / total)
        continue
        opt.zero_grad()
        loss_value.backward()
        opt.step()
        
    print(corrects / total)

  0%|                                                                                                                                                                                       | 0/60 [00:00<?, ?it/s]

tensor(0.5625)
tensor(0.5781)
tensor(0.5785)
tensor(0.5791)
tensor(0.5762)
tensor(0.5708)
tensor(0.5721)
tensor(0.5693)
tensor(0.5675)
tensor(0.5679)
tensor(0.5661)
tensor(0.5640)
tensor(0.5620)
tensor(0.5618)
tensor(0.5622)
tensor(0.5610)
tensor(0.5596)
tensor(0.5607)
tensor(0.5602)
tensor(0.5606)
tensor(0.5606)
tensor(0.5620)
tensor(0.5624)
tensor(0.5628)
tensor(0.5613)
tensor(0.5618)
tensor(0.5618)
tensor(0.5617)
tensor(0.5614)
tensor(0.5611)
tensor(0.5622)
tensor(0.5621)
tensor(0.5617)
tensor(0.5622)
tensor(0.5624)
tensor(0.5621)


  0%|                                                                                                                                                                                       | 0/60 [00:09<?, ?it/s]


KeyboardInterrupt: 

In [45]:
model.state_dict()


OrderedDict([('conv1.bias',
              tensor([ 4.3216e-02, -2.8648e-01, -9.6218e-02, -9.6030e-02,  1.2274e-01,
                      -1.5185e-01, -5.3370e-01, -1.6791e-01, -5.1596e-02,  2.2380e-01,
                       1.5711e-01, -2.1449e-01,  2.7400e-01, -4.1046e-02, -1.6027e-01,
                       2.4559e-01,  1.4261e-01,  1.5711e-01, -4.4723e-02,  9.4001e-02,
                      -2.5334e-01, -1.4514e-01, -1.1348e-01,  3.2410e-01,  4.6058e-01,
                       2.8597e-02, -2.6869e-01, -1.8483e-01,  2.9422e-01, -2.3243e-01,
                      -1.0819e-01, -2.3820e-01, -2.3009e-02, -2.5044e-01,  3.4067e-01,
                       2.4581e-01,  7.6851e-05,  6.1633e-02, -4.6752e-01,  1.0411e-01,
                       1.1758e-01,  6.3716e-01, -2.5277e-01, -4.6154e-02,  4.4517e-01,
                      -2.1197e-01,  2.2786e-02, -3.3997e-02,  2.9581e-01, -3.7883e-01,
                       5.5132e-02,  2.8451e-01, -4.8603e-01, -3.6035e-01,  1.2573e-01,
               

In [42]:
for i in model.parameters():
    print(i)

Parameter containing:
tensor([ 4.3216e-02, -2.8648e-01, -9.6218e-02, -9.6030e-02,  1.2274e-01,
        -1.5185e-01, -5.3370e-01, -1.6791e-01, -5.1596e-02,  2.2380e-01,
         1.5711e-01, -2.1449e-01,  2.7400e-01, -4.1046e-02, -1.6027e-01,
         2.4559e-01,  1.4261e-01,  1.5711e-01, -4.4723e-02,  9.4001e-02,
        -2.5334e-01, -1.4514e-01, -1.1348e-01,  3.2410e-01,  4.6058e-01,
         2.8597e-02, -2.6869e-01, -1.8483e-01,  2.9422e-01, -2.3243e-01,
        -1.0819e-01, -2.3820e-01, -2.3009e-02, -2.5044e-01,  3.4067e-01,
         2.4581e-01,  7.6851e-05,  6.1633e-02, -4.6752e-01,  1.0411e-01,
         1.1758e-01,  6.3716e-01, -2.5277e-01, -4.6154e-02,  4.4517e-01,
        -2.1197e-01,  2.2786e-02, -3.3997e-02,  2.9581e-01, -3.7883e-01,
         5.5132e-02,  2.8451e-01, -4.8603e-01, -3.6035e-01,  1.2573e-01,
         1.0709e-01, -2.9244e-01, -2.3544e-01, -5.5924e-02,  3.0020e-02,
         4.2227e-01,  5.0927e-01, -2.4297e-01,  4.3790e-01],
       requires_grad=True)
Parameter cont

In [55]:
model.state_dict()[

OrderedDict([('conv1.bias',
              tensor([ 0.1161,  0.0524,  0.2654,  0.0838, -0.0017,  0.1636,  0.0817, -0.0172,
                      -0.1624,  0.2094,  0.0560,  0.0652, -0.0107,  0.3436,  0.1472,  0.1637,
                       0.2237,  0.1753,  0.0309,  0.1348,  0.1438,  0.3050,  0.0547,  0.0032,
                       0.2663,  0.2699,  0.1136,  0.0642,  0.0247,  0.0941,  0.0411, -0.0110,
                       0.2892, -0.0658,  0.1327, -0.0470,  0.0634,  0.1909, -0.0432,  0.0614,
                       0.0699,  0.4091,  0.0823,  0.0251,  0.0457,  0.0996,  0.0385,  0.3054,
                       0.0164,  0.0716, -0.0883,  0.0648,  0.0920,  0.0597,  0.0294,  0.0409,
                       0.1581, -0.1074,  0.0634,  0.0674, -0.0006,  0.3371,  0.0852,  0.1685])),
             ('conv1.fc_self.weight',
              tensor([[ 0.0323,  0.1473, -0.2212,  ...,  0.0339, -0.3376, -0.2017],
                      [-0.1576, -0.1143,  0.1937,  ..., -0.0010,  0.0160,  0.2377],
           