## Second Challenge

The second challenge would be learning the node embeddings of CORA dataset in an **unsupervised** way, i.e. the node labels would not be available during training time.

We measure the performance of the embedding matrix by training a simple softmax classifier on the learned item embeddings on the training labels, and compute the accuracy on the test labels.  However, remember that both the training, validation, and test labels are **unavailable** during training; you MUST NOT use them.  Instead, please treat the evaluation routine as a black box, and only run the routine at test time.

In [3]:
##### DO NOT CHANGE THIS CELL
import torch
import torch.nn.functional as F
import numpy as np
import scipy.sparse as ssp
import dgl
import dgl.data
import dgl.nn.pytorch as dglnn
from collections import namedtuple

Args = namedtuple('Args', ['dataset'])
dataset = dgl.data.load_data(Args('cora'))

G = dgl.DGLGraph(dataset.graph)
X = torch.FloatTensor(dataset.features)

def evaluate(emb):
    """
    Evaluate the performance of the learned embedding.  The greater returned
    value the better.
    
    It trains a softmax regression model on the training set from the given
    embeddings, and return the accuracy on the test set.
    
    Parameters
    ----------
    emb : numpy.ndarray
        An N-by-M matrix where N is the number of nodes in CORA and M is
        the size of node embedding (can be of any value).
    """
    from sklearn.linear_model import LogisticRegressionCV
    global dataset
    C = LogisticRegressionCV(
        Cs=[1e-3, 1e-2, 1e-1, 1, 10, 100, 1000, 10000, 100000, 1e+6, 1e+7],
        multi_class='multinomial', solver='lbfgs', max_iter=10000)
    train_mask = (dataset.train_mask != 0)
    test_mask = (dataset.test_mask != 0)
    labels = dataset.labels
    C.fit(emb[train_mask], labels[train_mask])
    print('Best model found with C =', C.C_[0])
    return C.score(emb[test_mask], labels[test_mask])

We expect you to learn the node embeddings only from the given graph `G` and the node features `X`.  The following cell is an example solution which does nothing.  Please implement your model and report the number when you are done.

In [4]:
embedding = X
print('Baseline performance using raw features:', evaluate(X))
print('Baseline performance using my embedding:', evaluate(embedding))



Best model found with C = 100000.0
Baseline performance using raw features: 0.578




Best model found with C = 100000.0
Baseline performance using my embedding: 0.578


Example solution.  Remove all the cells below (including this one) before publishing.

In [5]:
class SAGENet(torch.nn.Module):
    def __init__(self, input_size, feature_size1, feature_size2, feature_size3):
        super().__init__()
        
        self.W1 = torch.nn.Linear(input_size, feature_size1)
        self.W2 = torch.nn.Linear(2 * feature_size1, feature_size2)
        self.W3 = torch.nn.Linear(2 * feature_size2, feature_size3)
        
    def forward(self, g, x, emask):
        with g.local_scope():
            import pdb
            g.edata['mask'] = emask
            g.update_all(dgl.function.copy_e('mask', 'm'), dgl.function.sum('m', 'deg'))
            g.ndata['deg'] = g.ndata['deg'].clamp(min=1)
            g.ndata['h1'] = self.W1(x)
            g.update_all(dgl.function.u_mul_e('h1', 'mask', 'm'), dgl.function.sum('m', 'm1'))
            g.ndata['m1'] = g.ndata['m1'] / g.ndata['deg'][:, None]
            g.ndata['h2'] = self.W2(F.leaky_relu(torch.cat([g.ndata['h1'], g.ndata['m1']], 1)))
            g.update_all(dgl.function.u_mul_e('h2', 'mask', 'm'), dgl.function.sum('m', 'm2'))
            g.ndata['m2'] = g.ndata['m2'] / g.ndata['deg'][:, None]
            return self.W3(F.leaky_relu(torch.cat([g.ndata['h2'], g.ndata['m2']], 1)))
        
    def emask(self, g, head, tail, tail_neg):
        n_negs = tail_neg.shape[1]
        has_edges = g.has_edges_between(head, tail)
        head_neg = head[:, None].expand_as(tail_neg)
        head_neg = head_neg.flatten()
        tail_neg = tail_neg.flatten()
        has_edges_neg = g.has_edges_between(head_neg, tail_neg)
        emask = torch.ones(g.number_of_edges())
        if has_edges.sum().item() > 0:
            eids = g.edge_ids(head[has_edges], tail[has_edges])
            emask[eids] = 0
        if has_edges_neg.sum().item() > 0:
            eids = g.edge_ids(head[has_edges_neg], tail[has_edges_neg])
            emask[eids] = 0
        return emask, has_edges, has_edges_neg.view(-1, n_negs)
    
    
class MLP(torch.nn.Module):
    def __init__(self, input_size, feature_size1, feature_size2, feature_size3):
        super().__init__()
        
        self.W1 = torch.nn.Linear(input_size, feature_size1)
        self.W2 = torch.nn.Linear(2 * feature_size1, feature_size2)
        self.W3 = torch.nn.Linear(2 * feature_size2, feature_size3)
        
    def forward(self, g, x, emask):
        h1 = self.W1(x)
        m1 = torch.cat([h1, torch.zeros_like(h1)], 1)
        h2 = self.W2(F.leaky_relu(m1))
        m2 = torch.cat([h2, torch.zeros_like(h2)], 1)
        h3 = self.W3(F.leaky_relu(m2))
        return h3
    
    def emask(self, g, head, tail, tail_neg):
        n_negs = tail_neg.shape[1]
        has_edges = g.has_edges_between(head, tail)
        head_neg = head[:, None].expand_as(tail_neg)
        head_neg = head_neg.flatten()
        tail_neg = tail_neg.flatten()
        has_edges_neg = g.has_edges_between(head_neg, tail_neg)
        emask = torch.ones(g.number_of_edges())
        if has_edges.sum().item() > 0:
            eids = g.edge_ids(head[has_edges], tail[has_edges])
            emask[eids] = 0
        if has_edges_neg.sum().item() > 0:
            eids = g.edge_ids(head[has_edges_neg], tail[has_edges_neg])
            emask[eids] = 0
        return emask, has_edges, has_edges_neg.view(-1, n_negs)

In [None]:
import tqdm

train_edges = int(G.number_of_edges() * 0.8)
valid_edges = G.number_of_edges() - train_edges

model = SAGENet(X.shape[1], 500, 200, 100)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 2048
n_negs = 5
best = None
l2_reg = 1e-6
with tqdm.tnrange(400) as tq:
    for _ in tq:
        batches = torch.randperm(train_edges).split(batch_size)
        with tqdm.tqdm_notebook(batches, leave=False) as tqb:
            for batch in tqb:
                head, tail = G.find_edges(batch)
                tail_neg = torch.LongTensor(np.random.choice(G.number_of_nodes(), (len(head), n_negs)))

                emask, label_pos, label_neg = model.emask(G, head, tail, tail_neg)
                #emask.zero_()
                emb = model(G, X, emask)
                sign_pos = label_pos.float() * 2 - 1
                sign_neg = label_neg.float() * 2 - 1

                emb_head = emb[head]
                emb_tail = emb[tail]
                emb_tail_neg = emb[tail_neg]

                score_pos = (emb_head * emb_tail).sum(1)
                score_neg = (emb_head.unsqueeze(1) * emb_tail_neg).sum(2)
                loss = (-F.logsigmoid(sign_pos * score_pos) - F.logsigmoid(sign_neg * score_neg).sum(1)).mean()
                #loss = loss + emb.norm() ** 2 * l2_reg
                #loss = -F.logsigmoid((score_pos[:, None] - score_neg) * (sign_pos[:, None] - sign_neg)).mean()

                opt.zero_grad()
                loss.backward()
                gn = 0
                for p in model.parameters():
                    gn += p.grad.norm() ** 2
                gn = torch.sqrt(gn)
                opt.step()
                
                tqb.set_postfix({'loss': loss.item(), 'gn': gn.item()}, refresh=True)
                
        batches = (torch.randperm(valid_edges) + train_edges).split(batch_size)
        with tqdm.tqdm_notebook(batches, leave=False) as tqb, torch.no_grad():
            loss = 0
            for batch in tqb:
                head, tail = G.find_edges(batch)
                tail_neg = torch.LongTensor(np.random.choice(G.number_of_nodes(), (len(head), n_negs)))

                emask, label_pos, label_neg = model.emask(G, head, tail, tail_neg)
                #emask.zero_()
                emb = model(G, X, emask)
                sign_pos = label_pos.float() * 2 - 1
                sign_neg = label_neg.float() * 2 - 1

                emb_head = emb[head]
                emb_tail = emb[tail]
                emb_tail_neg = emb[tail_neg]

                score_pos = (emb_head * emb_tail).sum(1)
                score_neg = (emb_head.unsqueeze(1) * emb_tail_neg).sum(2)
                loss += (-F.logsigmoid(sign_pos * score_pos) - F.logsigmoid(sign_neg * score_neg).sum(1)).sum().item()
                #loss += -F.logsigmoid((score_pos[:, None] - score_neg) * (sign_pos[:, None] - sign_neg)).sum().item()
            loss /= valid_edges
            #loss /= valid_edges * n_negs
            if best is None or best > loss:
                best = loss
                best_emb = emb.numpy()
            tq.set_postfix({'best': best, 'loss': loss}, refresh=True)

HBox(children=(IntProgress(value=0, max=400), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

In [23]:
print(evaluate(emb.numpy()))



Best model found with C = 10.0
0.7
