In [1]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import numpy as np
import dgl
#import torchcde
import torch.nn.functional as F
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from sklearn.metrics import roc_auc_score
import networkx as nx
#import random
import os
import pandas as pd
import torchcde

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz

In [3]:
graph_file = open('./cora/cora.cites', 'r')

In [4]:
graph_file.seek(0)
cora_edgelist = []
for line in graph_file.readlines():
    i, j = line.split()
    cora_edgelist.append((int(j),int(i)))  # Correct direction of links

In [5]:
cora = nx.DiGraph(cora_edgelist)

In [6]:
lookup = {}
for new_ids, ids in enumerate(cora.nodes()):
    lookup[ids] = new_ids
# Create new graph with new node ids
new_cora = nx.DiGraph()
for i, j in cora.edges():
    new_cora.add_edge(lookup[i], lookup[j])

In [7]:
content = './cora/cora.content'
labels = {'Case_Based': 0, 'Genetic_Algorithms': 1, 'Neural_Networks': 2, 
          'Probabilistic_Methods': 3, 'Reinforcement_Learning':4, 
          'Rule_Learning': 5, 'Theory': 6}
cora_labels = np.ndarray(shape=len(new_cora), dtype=int)
cora_features = np.ndarray(shape=(len(new_cora), 1433), dtype=int)
with open(content, 'r') as f:
    for lines in f.readlines():
        idx, *data, label = lines.strip().split()
        idx = int(idx)
        cora_labels[lookup[idx]] = labels[label]
        for i, val in enumerate(map(int, data)):
            cora_features[lookup[idx]][i] = val

In [8]:
src = [i[0] for i in new_cora.edges()]
dst = [i[1] for i in new_cora.edges()]

In [9]:
dag = dgl.from_networkx(new_cora)
cora_features = torch.tensor(cora_features)

In [10]:
g=dgl.add_reverse_edges(dag)

In [None]:
p=(g.in_degrees()/g.num_edges())**(3/4)
p=p.numpy()/torch.sum(p).numpy()

In [None]:
@torch.no_grad()
def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).cpu().numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

def seed(seed=43):
    """
    Fix random process by a seed.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    #torch.cuda.manual_seed(seed)
    #torch.cuda.manual_seed_all(seed)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = False
    dgl.random.seed(seed)

#seed()
print(torch.__version__)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
g_edges=torch.stack(g.edges())

In [None]:
ids = np.arange(len(g_edges[0]))
np.random.shuffle(ids)
trainnum = int(len(ids)*0.8)
train_id = ids[:trainnum]
val_id = ids[trainnum:]

In [None]:
seed_size=128
lr = 0.01
def collate_fn(data):
    bs = len(data)
    edges=torch.cat([g_edges[:,data],g_edges[[1,0],:][:,data]],dim=1)
    src = edges[0,:].repeat(15)
    edges=set([i for i in edges.permute((1,0))])
    
    dst = torch.tensor(np.random.choice(np.arange(g.num_nodes()),p=p,size=30*bs))
    
    neg_data = torch.stack([src,dst])
    neg_data = torch.stack([i for i in neg_data.permute((1,0)) if i not in edges])
    
    data = torch.stack(list(edges))
    
    return data, neg_data

In [None]:
class skipgram(pl.LightningModule):

    def __init__(self, dim_emb=64, word_count=0, batch_size=seed_size):
        super(skipgram, self).__init__()
        self.v = nn.Embedding(word_count, dim_emb)
        self.u = nn.Embedding(word_count, dim_emb)
        self.init_weight
        self.batch_size = batch_size
        self.relu = nn.ReLU()

    def init_weight(self):
        gain = 1 #0.0036
        nn.init.xavier_uniform_(self.v.weight, gain=gain)
        nn.init.xavier_uniform_(self.u.weight, gain=gain)

    def forward(self, x):
        return self.u(x), self.v(x)

    def training_step(self, batch, batch_idx):

        data, neg_data = batch[0], batch[1]
        
        u_emb = self.u(data[:,0]).unsqueeze(1)
        v_emb = self.v(data[:,1]).unsqueeze(-1)
        neg_u_emb = self.u(neg_data[:,0]).unsqueeze(1)
        neg_v_emb = self.v(neg_data[:,1]).unsqueeze(-1)
        
        score = torch.bmm(u_emb, v_emb).squeeze(1)
        neg_score = torch.bmm(neg_u_emb, neg_v_emb).squeeze(1)

        score = torch.cat((F.logsigmoid(score),F.logsigmoid(-neg_score)))
        loss = -torch.mean(score)

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):

        data, neg_data = batch[0], batch[1]
        
        u_emb = self.u(data[:,0]).unsqueeze(1)
        v_emb = self.v(data[:,1]).unsqueeze(-1)
        neg_u_emb = self.u(neg_data[:,0]).unsqueeze(1)
        neg_v_emb = self.v(neg_data[:,1]).unsqueeze(-1)
        
        score = torch.bmm(u_emb, v_emb).squeeze(1)
        neg_score = torch.bmm(neg_u_emb, neg_v_emb).squeeze(1)

        score = torch.cat((F.logsigmoid(score),F.logsigmoid(-neg_score)))
        loss = -torch.mean(score)
        auc=compute_auc(score,neg_score)
        
        print('epoch: ', self.current_epoch, 'val_loss: ', loss.detach().numpy(),'auc:',auc)
        self.log('val_loss', loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):

        data, neg_data = batch[0], batch[1]
        
        u_emb = self.u(data[:,0]).unsqueeze(1)
        v_emb = self.v(data[:,1]).unsqueeze(-1)
        neg_u_emb = self.u(neg_data[:,0]).unsqueeze(1)
        neg_v_emb = self.v(neg_data[:,1]).unsqueeze(-1)
        
        score = torch.bmm(u_emb, v_emb).squeeze(1)
        neg_score = torch.bmm(neg_u_emb, neg_v_emb).squeeze(1)

        score = torch.cat((F.logsigmoid(score),F.logsigmoid(-neg_score)))
        loss = -torch.mean(score)

        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=1e-6)
        return optimizer

    def train_dataloader(self):
        return DataLoader(ids, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0, drop_last=True, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(val_id, batch_size=len(val_id), shuffle=False, collate_fn=collate_fn, num_workers=0, drop_last=True, pin_memory=True)

    def predict_dataloader(self):
        return DataLoader(ids, batch_size=g.num_nodes(), shuffle=False)

In [None]:
model = skipgram(dim_emb=16, word_count=g.num_nodes())
trainer = Trainer(min_epochs=0,
                  max_epochs=60,
                  check_val_every_n_epoch=10,
                  progress_bar_refresh_rate=1,
                  gpus=0,
                  reload_dataloaders_every_n_epochs=5,
                  profiler='simple')

In [None]:
#trainer.fit(model)

In [11]:
#output = model.v(torch.arange(g.num_nodes())).detach()
#np.save('cora_embedding.npy', output)
output=torch.tensor(np.load('cora_embedding.npy'))

In [12]:
_,_,V = torch.pca_lowrank(cora_features.float(),q=15,center=True, niter=100)
cora_features = torch.matmul(cora_features.float(), V)

In [13]:
def fn(data):
    dst = data
    trials=20
    length=7
    paths=[]
    ids=data.repeat(trials)
    rw = dgl.sampling.random_walk(dag, ids, length=length)[0]
    visited=set()
    for i in range(length+1):
        for j in range(trials):
            if rw[j,i].item()>=0:
                if rw[j,i].item() in visited:
                    rw[j,i:]=-1
                    break
        for j in range(trials):
            if rw[j,i].item()>=0:
                visited.add(rw[j,i].item())
    for i in range(trials):
        for j in range(length+1):
            if rw[i,j]==-1 and j>1:
                if list(rw[i,:j].numpy()) not in paths:
                    paths.append(list(rw[i,:j].numpy()))
                break
    length = max([len(p) for p in paths])
    bs = len(paths)
    
    for i in range(len(paths)):
        paths[i] = [dst.item()]*(length-len(paths[i]))+paths[i]
        
    paths = torch.tensor(paths).long()
    cora_feat = cora_features[torch.flip(paths,dims=(-1,)).long()]
    paths     =        output[torch.flip(paths,dims=(-1,)).long()]
    
    time=torch.arange(paths.shape[1]).repeat(paths.shape[0],1).unsqueeze(-1)/8
    X = torch.cat([time, paths,cora_feat], dim=-1)
    train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(X)
    return cora_labels[dst], train_coeffs

In [14]:
def collate_fn(data):
    def fn(data):
        dst = data
        trials=20
        length=7
        paths=[]
        ids=data.repeat(trials)
        rw = dgl.sampling.random_walk(dag, ids, length=length)[0]
        visited=set()
        for i in range(length+1):
            for j in range(trials):
                if rw[j,i].item()>=0:
                    if rw[j,i].item() in visited:
                        rw[j,i:]=-1
                        break
            for j in range(trials):
                if rw[j,i].item()>=0:
                    visited.add(rw[j,i].item())
        for i in range(trials):
            for j in range(length+1):
                if rw[i,j]==-1 and j>1:
                    if list(rw[i,:j].numpy()) not in paths:
                        paths.append(list(rw[i,:j].numpy()))
                    break
        length = max([len(p) for p in paths])
        bs = len(paths)
    
        for i in range(len(paths)):
            paths[i] = [dst.item()]*(length-len(paths[i]))+paths[i]
        
        paths = torch.tensor(paths).long()
        cora_feat = cora_features[torch.flip(paths,dims=(-1,)).long()]
        paths     =        output[torch.flip(paths,dims=(-1,)).long()]
    
        time=torch.arange(paths.shape[1]).repeat(paths.shape[0],1).unsqueeze(-1)/8
        X = torch.cat([time, paths,cora_feat], dim=-1)
        train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(X)
        return cora_labels[dst], train_coeffs
    batch = {i:fn(i) for i in data}
    return batch

In [15]:
ids = np.arange(dag.num_nodes())
np.random.shuffle(ids)
trainnum = int(len(ids)*0.8)
train_id = ids[:trainnum]
val_id = ids[trainnum:]

In [16]:
class CDEFunc(torch.nn.Module):
    def __init__(self, input_channels=32, hidden_channels=32):
        ######################
        # input_channels is the number of input channels in the data X. (Determined by the data.)
        # hidden_channels is the number of channels for z_t. (Determined by you!)
        ######################
        super(CDEFunc, self).__init__()
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels

        self.linear1 = torch.nn.Linear(hidden_channels, 64)
        self.linear2 = torch.nn.Linear(64, input_channels * hidden_channels)
        
        gain = 1 #0.0036
        nn.init.xavier_uniform_(self.linear1.weight, gain=gain)
        nn.init.xavier_uniform_(self.linear2.weight, gain=gain)

    ######################
    # For most purposes the t argument can probably be ignored; unless you want your CDE to behave differently at
    # different times, which would be unusual. But it's there if you need it!
    ######################
    def forward(self, t, z):
        # z has shape (batch, hidden_channels)
        #print(z.shape)
        z = self.linear1(z)
        z = z.relu()
        z = self.linear2(z)
        ######################
        # Easy-to-forget gotcha: Best results tend to be obtained by adding a final tanh nonlinearity.
        ######################
        z = z.tanh()
        ######################
        # Ignoring the batch dimension, the shape of the output tensor must be a matrix,
        # because we need it to represent a linear map from R^input_channels to R^hidden_channels.
        ######################
        z = z.view(z.size(0), self.hidden_channels, self.input_channels)
        return z

In [51]:
lr = 0.001
crit = nn.CrossEntropyLoss()
class CDE(pl.LightningModule):

    def __init__(self, input_dim=32, h_dim=32, output_dim=8, cora_features=cora_features, output=output):
        super(CDE, self).__init__()
        self.cdef=CDEFunc()
        self.initial = torch.nn.Linear(input_dim, h_dim)
        self.transform = torch.nn.Linear(input_dim, h_dim)
        self.k = torch.nn.Linear(h_dim, h_dim)
        self.q = torch.nn.Linear(h_dim, h_dim)
        self.readout = torch.nn.Linear(h_dim, output_dim)
        self.init_weight
        self.prelu = nn.PReLU()
        self.cora_features, self.output=cora_features, output
        self.sm = nn.Softmax(dim=0)
        #self.eps = nn.Parameter(torch.FloatTensor([1.0]))

    def init_weight(self):
        gain = 1 #0.0036
        nn.init.xavier_uniform_(self.initial.weight, gain=gain)
        nn.init.xavier_uniform_(self.readout.weight, gain=gain)
        nn.init.xavier_uniform_(self.k.weight, gain=gain)
        nn.init.xavier_uniform_(self.q.weight, gain=gain)
        nn.init.xavier_uniform_(self.transform.weight, gain=gain)
        

    def forward(self, data, batch_idx):
        cora_features, output=self.cora_features, self.output
        pred = []
        Y = []
        for i in data:
            y, X = data[i]
            Y.append(y)
            #print(X.shape)
            X = torchcde.CubicSpline(X)
            X0 = X.evaluate(X.interval[0])
            #print('X0:',X0.shape)
            z0 = self.initial(X0)
            z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.cdef,
                              t=X.interval)
            z_T = z_T[:, 1]
            #print(z_T.shape)
            dst_feat = self.transform(torch.cat([z_T[0,0].unsqueeze(0),output[i],cora_features[i]]))
            #print(dst_feat.shape)
            z_T = torch.cat([dst_feat.unsqueeze(0),z_T])
            #print(z_T.shape)
            q = self.q(dst_feat)
            k = self.k(z_T)
            h = torch.sum(z_T* self.sm(torch.matmul(q, k.t())).unsqueeze(-1),dim=0)
            pred.append(self.readout(h))
        return data, torch.stack(pred), torch.tensor(Y)
    
    def training_step(self, data, batch_idx):
        cora_features, output=self.cora_features, self.output
        pred = []
        Y = []
        for i in data:
            y, X = data[i]
            Y.append(y)
            X = torchcde.CubicSpline(X)
            X0 = X.evaluate(X.interval[0])
            z0 = self.initial(X0)
            z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.cdef,
                              t=X.interval)
            z_T = z_T[:, 1]
            #print(z_T.shape)
            dst_feat = self.transform(torch.cat([z_T[0,0].unsqueeze(0),output[i],cora_features[i]]))
            #print(dst_feat.shape)
            z_T = torch.cat([dst_feat.unsqueeze(0),z_T])
            #print(z_T.shape)
            q = self.q(dst_feat)
            k = self.k(z_T)
            h = torch.sum(z_T* self.sm(torch.matmul(q, k.t())).unsqueeze(-1),dim=0)
            pred.append(self.readout(h))
        loss = crit(torch.stack(pred), torch.tensor(Y))
        pred = torch.argmax(torch.stack(pred),dim=-1)
        acc = torch.mean((pred==torch.tensor(Y)).float()) 
        return loss
    def validation_step(self, data, batch_idx):
        cora_features, output=self.cora_features, self.output
        pred = []
        Y = []
        for i in data:
            y, X = data[i]
            Y.append(y)
            X = torchcde.CubicSpline(X)
            X0 = X.evaluate(X.interval[0])
            z0 = self.initial(X0)
            z_T = torchcde.cdeint(X=X,
                              z0=z0,
                              func=self.cdef,
                              t=X.interval)
            z_T = z_T[:, 1]
            dst_feat = self.transform(torch.cat([z_T[0,0].unsqueeze(0),output[i],cora_features[i]]))
            #print(dst_feat.shape)
            z_T = torch.cat([dst_feat.unsqueeze(0),z_T])
            #print(z_T.shape)
            q = self.q(dst_feat)
            k = self.k(z_T)
            h = torch.sum(z_T* self.sm(torch.matmul(q, k.t())).unsqueeze(-1),dim=0)
            pred.append(self.readout(h))
        loss = crit(torch.stack(pred), torch.tensor(Y))
        pred = torch.argmax(torch.stack(pred),dim=-1)
        acc = torch.mean((pred==torch.tensor(Y)).float())
        
        print('epoch: ', self.current_epoch, 'val_loss: ', loss.detach().numpy(),'acc:', acc.item())
        return loss
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=1e-6)
        return optimizer
    
    def train_dataloader(self):
        return DataLoader(train_id, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(val_id, batch_size=64, shuffle=False, collate_fn=collate_fn, num_workers=0, pin_memory=True)

    def predict_dataloader(self):
        return DataLoader(ids, batch_size=g.num_nodes(), shuffle=False)
        

In [52]:
model=CDE()
trainer = Trainer(min_epochs=0,
                  max_epochs=80,
                  check_val_every_n_epoch=1,
                  progress_bar_refresh_rate=0,
                  gpus=0,
                  reload_dataloaders_every_n_epochs=5,
                  profiler='simple')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model)


  | Name      | Type    | Params
--------------------------------------
0 | cdef      | CDEFunc | 68.7 K
1 | initial   | Linear  | 1.1 K 
2 | transform | Linear  | 1.1 K 
3 | k         | Linear  | 1.1 K 
4 | q         | Linear  | 1.1 K 
5 | readout   | Linear  | 264   
6 | prelu     | PReLU   | 1     
7 | sm        | Softmax | 0     
--------------------------------------
73.2 K    Trainable params
0         Non-trainable params
73.2 K    Total params
0.293     Total estimated model params size (MB)


epoch:  0 val_loss:  2.2166483 acc: 0.0625
epoch:  0 val_loss:  2.2360587 acc: 0.125


In [None]:
loader=DataLoader(train_id, batch_size=16, shuffle=True, collate_fn=collate_fn, num_workers=0, pin_memory=True)