## DeepInf_GCN ##

In [None]:
import torch 
from torch.utils.data import Dataset
class InfluenceDataSet(Dataset):
    def __init__(self, file_dir, embedding_dim, seed, shuffle, model):
        self.graphs = np.load(os.path.join(file_dir, "adjacency_matrix.npy")).astype(np.float32)

        # self-loop trick, the input graphs should have no self-loop
        identity = np.identity(self.graphs.shape[1])
        self.graphs += identity
        self.graphs[self.graphs != 0] = 1.0
        if model == "gat" or model == "pscn":
            self.graphs = self.graphs.astype(np.dtype('B'))
        elif model == "gcn":
            # normalized graph laplacian for GCN: D^{-1/2}AD^{-1/2}
            for i in range(len(self.graphs)):
                graph = self.graphs[i]
                d_root_inv = 1. / np.sqrt(np.sum(graph, axis=1))
                graph = (graph.T * d_root_inv).T * d_root_inv
                self.graphs[i] = graph
        else:
            raise NotImplementedError
        logger.info("graphs loaded!")

        # wheather a user has been influenced
        # wheather he/she is the ego user
        self.influence_features = np.load(
                os.path.join(file_dir, "influence_feature.npy")).astype(np.float32)
        logger.info("influence features loaded!")

        self.labels = np.load(os.path.join(file_dir, "label.npy"))
        logger.info("labels loaded!")

        self.vertices = np.load(os.path.join(file_dir, "vertex_id.npy"))
        logger.info("vertex ids loaded!")

        if shuffle:
            self.graphs, self.influence_features, self.labels, self.vertices = \
                    sklearn.utils.shuffle(
                        self.graphs, self.influence_features,
                        self.labels, self.vertices,
                        random_state=seed
                    )

        vertex_features = np.load(os.path.join(file_dir, "vertex_feature.npy"))
        vertex_features = preprocessing.scale(vertex_features)
        self.vertex_features = torch.FloatTensor(vertex_features)
        logger.info("global vertex features loaded!")

        embedding_path = os.path.join(file_dir, "deepwalk.emb_%d" % embedding_dim)
        max_vertex_idx = np.max(self.vertices)
        embedding = load_w2v_feature(embedding_path, max_vertex_idx)
        self.embedding = torch.FloatTensor(embedding)
        logger.info("%d-dim embedding loaded!", embedding_dim)

        self.N = self.graphs.shape[0]
        logger.info("%d ego networks loaded, each with size %d" % (self.N, self.graphs.shape[1]))

        n_classes = self.get_num_class()
        class_weight = self.N / (n_classes * np.bincount(self.labels))
        self.class_weight = torch.FloatTensor(class_weight)

    def get_embedding(self):
        return self.embedding

    def get_vertex_features(self):
        return self.vertex_features

    def get_feature_dimension(self):
        return self.influence_features.shape[-1]

    def get_num_class(self):
        return np.unique(self.labels).shape[0]

    def get_class_weight(self):
        return self.class_weight

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.graphs[idx], self.influence_features[idx], self.labels[idx], self.vertices[idx]


In [67]:
def load_w2v_feature(file, max_idx=0):
    with open(file, "rb") as f:
        nu = 0
        for line in f:
            content = line.strip().split()
            nu += 1
            if nu == 1:
                n, d = int(content[0]), int(content[1])
                feature = [[0.] * d for i in range(max(n, max_idx + 1))]
                continue
            index = int(content[0])
            while len(feature) <= index:
                feature.append([0.] * d)
            for i, x in enumerate(content[1:]):
                feature[index][i] = float(x)
    for item in feature:
        assert len(item) == d
    return np.array(feature, dtype=np.float32)


In [68]:
import logging
import numpy as np
import os
import sklearn
from sklearn import preprocessing
logger = logging.getLogger(__name__)
influence_dataset = InfluenceDataSet(
            "weibo/", 64, 42, False, "gcn")

In [82]:
from tensorboard_logger import tensorboard_logger
import shutil
tensorboard_log_dir = 'log/%s' % ('gcn')
os.makedirs(tensorboard_log_dir, exist_ok=True)
shutil.rmtree(tensorboard_log_dir)
tensorboard_logger.configure(tensorboard_log_dir)
logger.info('tensorboard logging to %s', tensorboard_log_dir)

2020-07-18 10:58:50,071 tensorboard logging to log/gcn


In [70]:
from torch.utils.data.sampler import Sampler
class ChunkSampler(Sampler):
    """
    Samples elements sequentially from some offset.
    Arguments:
        num_samples: # of desired datapoints
        start: offset where we should start selecting from
    """
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples
    
N = len(influence_dataset)   
train_start,  valid_start, test_start = \
        0, int(N * 0.75 / 100), int(N * (0.75 + 0.125))
    
from torch.utils.data import DataLoader
train_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(valid_start - train_start, 0))

In [71]:
from torch.utils.data import DataLoader

In [72]:
valid_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(test_start - valid_start, valid_start))
test_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(N - test_start, test_start))

In [73]:
N = len(influence_dataset)
n_classes = 2
hidden_units = "128,128"
class_weight = influence_dataset.get_class_weight() \
        if False else torch.ones(n_classes)
logger.info("class_weight=%.2f:%.2f", class_weight[0], class_weight[1])

feature_dim = influence_dataset.get_feature_dimension()
n_units = [feature_dim] + [int(x) for x in hidden_units.strip().split(",")] + [n_classes]
logger.info("feature dimension=%d", feature_dim)
logger.info("number of classes=%d", n_classes)
dropout = 0.2

In [74]:
class BatchGraphConvolution(Module):

    def __init__(self, in_features, out_features, bias=True):
        super(BatchGraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)
        init.xavier_uniform_(self.weight)

    def forward(self, x, lap):
        expand_weight = self.weight.expand(x.shape[0], -1, -1)
        support = torch.bmm(x, expand_weight)
        output = torch.bmm(lap, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

In [75]:
class BatchGCN(nn.Module):
    def __init__(self, n_units, dropout, pretrained_emb, vertex_feature,
            use_vertex_feature, fine_tune=False, instance_normalization=False):
        super(BatchGCN, self).__init__()
        self.num_layer = len(n_units) - 1
        self.dropout = dropout
        self.inst_norm = instance_normalization
        if self.inst_norm:
            self.norm = nn.InstanceNorm1d(pretrained_emb.size(1), momentum=0.0, affine=True)

        # https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222/2
        self.embedding = nn.Embedding(pretrained_emb.shape[0], pretrained_emb.shape[1])
        self.embedding.weight = nn.Parameter(pretrained_emb)
        self.embedding.weight.requires_grad = fine_tune
        n_units[0] += pretrained_emb.shape[1]

        self.use_vertex_feature = use_vertex_feature
        if self.use_vertex_feature:
            self.vertex_feature = nn.Embedding(vertex_feature.shape[0], vertex_feature.shape[1])
            self.vertex_feature.weight = nn.Parameter(vertex_feature)
            self.vertex_feature.weight.requires_grad = False
            n_units[0] += vertex_feature.shape[1]

        self.layer_stack = nn.ModuleList()

        for i in range(self.num_layer):
            self.layer_stack.append(
                    BatchGraphConvolution(n_units[i], n_units[i + 1])
                    )

    def forward(self, x, vertices, lap):
        emb = self.embedding(vertices)
        if self.inst_norm:
            emb = self.norm(emb.transpose(1, 2)).transpose(1, 2)
        x = torch.cat((x, emb), dim=2)
        if self.use_vertex_feature:
            vfeature = self.vertex_feature(vertices)
            x = torch.cat((x, vfeature), dim=2)
        for i, gcn_layer in enumerate(self.layer_stack):
            x = gcn_layer(x, lap)
            if i + 1 < self.num_layer:
                x = F.elu(x)
                x = F.dropout(x, self.dropout, training=self.training)
        return F.log_softmax(x, dim=-1)

In [76]:
model = BatchGCN(pretrained_emb=influence_dataset.get_embedding(),
                vertex_feature=influence_dataset.get_vertex_features(),
                use_vertex_feature=False,
                n_units=n_units,
                dropout=dropout,
                instance_normalization=False)

In [77]:
import torch.optim as optim
params = [{'params': filter(lambda p: p.requires_grad, model.parameters())
   }]
optimizer = optim.Adagrad(params, lr=0.1, weight_decay=5e-4)
optimizer.zero_grad()

In [89]:
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
def evaluate(epoch, loader, thr=None, return_best_thr=False, log_desc='valid_'):
    model.eval()
    total = 0.
    loss, prec, rec, f1 = 0., 0., 0., 0.
    y_true, y_pred, y_score = [], [], []
    for i_batch, batch in enumerate(loader):
        graph, features, labels, vertices = batch
        bs = graph.shape[0]

        output = model(features, vertices, graph)
        if True:
            output = output[:, -1, :]
        loss_batch = F.nll_loss(output, labels, class_weight)
        loss += bs * loss_batch.item()

        y_true += labels.data.tolist()
        y_pred += output.max(1)[1].data.tolist()
        y_score += output[:, 1].data.tolist()
        total += bs

    model.train()

    if thr is not None:
        logger.info("using threshold %.4f", thr)
        y_score = np.array(y_score)
        y_pred = np.zeros_like(y_score)
        y_pred[y_score > thr] = 1

    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary")
    auc = roc_auc_score(y_true, y_score)
    logger.info("%sloss: %.4f AUC: %.4f Prec: %.4f Rec: %.4f F1: %.4f",
            log_desc, loss / total, auc, prec, rec, f1)

    tensorboard_logger.log_value(log_desc + 'loss', loss / total, epoch + 1)
    tensorboard_logger.log_value(log_desc + 'auc', auc, epoch + 1)
    tensorboard_logger.log_value(log_desc + 'prec', prec, epoch + 1)
    tensorboard_logger.log_value(log_desc + 'rec', rec, epoch + 1)
    tensorboard_logger.log_value(log_desc + 'f1', f1, epoch + 1)

    if return_best_thr:
        precs, recs, thrs = precision_recall_curve(y_true, y_score)
        f1s = 2 * precs * recs / (precs + recs)
        f1s = f1s[:-1]
        thrs = thrs[~np.isnan(f1s)]
        f1s = f1s[~np.isnan(f1s)]
        best_thr = thrs[np.argmax(f1s)]
        logger.info("best threshold=%4f, f1=%.4f", best_thr, np.max(f1s))
        return best_thr
    else:
        return None

In [86]:
def train(epoch, train_loader, valid_loader, test_loader, log_desc='train_'):
    model.train()
    loss = 0.
    total = 0.
    for i_batch, batch in enumerate(train_loader):
        graph, features, labels, vertices = batch
        bs = graph.shape[0]
        optimizer.zero_grad()
        output = model(features, vertices, graph)
        if True:
            output = output[:, -1, :]
        loss_train = F.nll_loss(output, labels, class_weight)
        loss += bs * loss_train.item()
        total += bs
        loss_train.backward()
        optimizer.step()
    logger.info("train loss in this epoch %f", loss / total)
    tensorboard_logger.log_value('train_loss', loss / total, epoch + 1)
    if (epoch + 1) % 10 == 0:
        logger.info("epoch %d, checkpoint!", epoch)
        best_thr = evaluate(epoch, valid_loader, return_best_thr=True, log_desc='valid_')
        evaluate(epoch, test_loader, thr=best_thr, log_desc='test_')

In [90]:
import time
t_total = time.time()
logger.info("training...")
for epoch in range(500):
    train(epoch, train_loader, valid_loader, test_loader)
logger.info("optimization Finished!")
logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

logger.info("retrieve best threshold...")
best_thr = evaluate(args.epochs, valid_loader, return_best_thr=True, log_desc='valid_')

# Testing
logger.info("testing...")
evaluate(500, test_loader, thr=best_thr, log_desc='test_')

2020-07-18 11:02:25,877 training...
2020-07-18 11:02:26,755 train loss in this epoch 0.537798
2020-07-18 11:02:27,635 train loss in this epoch 0.536948
2020-07-18 11:02:28,501 train loss in this epoch 0.536713
2020-07-18 11:02:29,385 train loss in this epoch 0.533474
2020-07-18 11:02:30,263 train loss in this epoch 0.528936
2020-07-18 11:02:31,154 train loss in this epoch 0.528339
2020-07-18 11:02:32,044 train loss in this epoch 0.528182
2020-07-18 11:02:32,920 train loss in this epoch 0.526966
2020-07-18 11:02:33,808 train loss in this epoch 0.528220
2020-07-18 11:02:34,670 train loss in this epoch 0.527923
2020-07-18 11:02:34,672 epoch 9, checkpoint!
2020-07-18 11:03:41,257 valid_loss: 0.5310 AUC: 0.6747 Prec: 0.4142 Rec: 0.0087 F1: 0.0170
2020-07-18 11:03:41,824 best threshold=-1.671974, f1=0.4654
2020-07-18 11:03:50,752 using threshold -1.6720
2020-07-18 11:03:51,033 test_loss: 0.5321 AUC: 0.6733 Prec: 0.3336 Rec: 0.7600 F1: 0.4637
2020-07-18 11:03:51,879 train loss in this epoch 0

2020-07-18 11:15:00,281 best threshold=-1.554631, f1=0.4740
2020-07-18 11:15:09,347 using threshold -1.5546
2020-07-18 11:15:09,705 test_loss: 0.5199 AUC: 0.6931 Prec: 0.3494 Rec: 0.7357 F1: 0.4738
2020-07-18 11:15:10,579 train loss in this epoch 0.505811
2020-07-18 11:15:11,689 train loss in this epoch 0.505841
2020-07-18 11:15:12,601 train loss in this epoch 0.506164
2020-07-18 11:15:13,455 train loss in this epoch 0.506660
2020-07-18 11:15:14,305 train loss in this epoch 0.506348
2020-07-18 11:15:15,157 train loss in this epoch 0.504176
2020-07-18 11:15:16,031 train loss in this epoch 0.504163
2020-07-18 11:15:16,938 train loss in this epoch 0.505757
2020-07-18 11:15:17,817 train loss in this epoch 0.505464
2020-07-18 11:15:18,686 train loss in this epoch 0.504610
2020-07-18 11:15:18,688 epoch 99, checkpoint!
2020-07-18 11:16:24,185 valid_loss: 0.5185 AUC: 0.6948 Prec: 0.5340 Rec: 0.0819 F1: 0.1421
2020-07-18 11:16:24,755 best threshold=-1.514119, f1=0.4748
2020-07-18 11:16:34,179 u

2020-07-18 11:27:45,185 best threshold=-1.528112, f1=0.4802
2020-07-18 11:27:54,233 using threshold -1.5281
2020-07-18 11:27:54,700 test_loss: 0.5156 AUC: 0.7020 Prec: 0.3582 Rec: 0.7260 F1: 0.4797
2020-07-18 11:27:55,572 train loss in this epoch 0.496845
2020-07-18 11:27:56,463 train loss in this epoch 0.497710
2020-07-18 11:27:57,371 train loss in this epoch 0.495640
2020-07-18 11:27:58,257 train loss in this epoch 0.496198
2020-07-18 11:27:59,131 train loss in this epoch 0.497632
2020-07-18 11:28:00,039 train loss in this epoch 0.497720
2020-07-18 11:28:00,971 train loss in this epoch 0.496301
2020-07-18 11:28:01,841 train loss in this epoch 0.495709
2020-07-18 11:28:02,700 train loss in this epoch 0.496304
2020-07-18 11:28:03,559 train loss in this epoch 0.495730
2020-07-18 11:28:03,561 epoch 189, checkpoint!
2020-07-18 11:29:10,703 valid_loss: 0.5151 AUC: 0.7038 Prec: 0.5511 Rec: 0.0984 F1: 0.1669
2020-07-18 11:29:11,268 best threshold=-1.529049, f1=0.4809
2020-07-18 11:29:20,680 

2020-07-18 11:40:36,130 best threshold=-1.525539, f1=0.4842
2020-07-18 11:40:45,377 using threshold -1.5255
2020-07-18 11:40:45,713 test_loss: 0.5134 AUC: 0.7071 Prec: 0.3636 Rec: 0.7202 F1: 0.4832
2020-07-18 11:40:46,597 train loss in this epoch 0.490449
2020-07-18 11:40:47,470 train loss in this epoch 0.491087
2020-07-18 11:40:48,336 train loss in this epoch 0.490537
2020-07-18 11:40:49,245 train loss in this epoch 0.490913
2020-07-18 11:40:50,132 train loss in this epoch 0.491602
2020-07-18 11:40:50,987 train loss in this epoch 0.491120
2020-07-18 11:40:51,855 train loss in this epoch 0.491020
2020-07-18 11:40:52,732 train loss in this epoch 0.490658
2020-07-18 11:40:53,596 train loss in this epoch 0.488805
2020-07-18 11:40:54,486 train loss in this epoch 0.489443
2020-07-18 11:40:54,487 epoch 279, checkpoint!
2020-07-18 11:42:00,619 valid_loss: 0.5125 AUC: 0.7092 Prec: 0.5599 Rec: 0.1176 F1: 0.1944
2020-07-18 11:42:01,194 best threshold=-1.531608, f1=0.4846
2020-07-18 11:42:10,265 

2020-07-18 11:53:26,181 best threshold=-1.528522, f1=0.4866
2020-07-18 11:53:34,594 using threshold -1.5285
2020-07-18 11:53:34,845 test_loss: 0.5125 AUC: 0.7104 Prec: 0.3677 Rec: 0.7118 F1: 0.4850
2020-07-18 11:53:35,748 train loss in this epoch 0.487773
2020-07-18 11:53:36,876 train loss in this epoch 0.486620
2020-07-18 11:53:37,775 train loss in this epoch 0.487647
2020-07-18 11:53:38,651 train loss in this epoch 0.487005
2020-07-18 11:53:39,545 train loss in this epoch 0.487421
2020-07-18 11:53:40,561 train loss in this epoch 0.488204
2020-07-18 11:53:41,525 train loss in this epoch 0.486492
2020-07-18 11:53:42,417 train loss in this epoch 0.488322
2020-07-18 11:53:43,325 train loss in this epoch 0.485830
2020-07-18 11:53:44,183 train loss in this epoch 0.486597
2020-07-18 11:53:44,185 epoch 369, checkpoint!
2020-07-18 11:54:52,103 valid_loss: 0.5114 AUC: 0.7122 Prec: 0.5616 Rec: 0.1235 F1: 0.2025
2020-07-18 11:54:52,672 best threshold=-1.533478, f1=0.4867
2020-07-18 11:55:01,203 

2020-07-18 12:06:20,787 best threshold=-1.509849, f1=0.4887
2020-07-18 12:06:29,625 using threshold -1.5098
2020-07-18 12:06:30,061 test_loss: 0.5111 AUC: 0.7128 Prec: 0.3713 Rec: 0.7081 F1: 0.4872
2020-07-18 12:06:30,981 train loss in this epoch 0.483897
2020-07-18 12:06:31,862 train loss in this epoch 0.486121
2020-07-18 12:06:32,768 train loss in this epoch 0.484037
2020-07-18 12:06:33,691 train loss in this epoch 0.483770
2020-07-18 12:06:34,591 train loss in this epoch 0.483289
2020-07-18 12:06:35,487 train loss in this epoch 0.483109
2020-07-18 12:06:36,386 train loss in this epoch 0.485035
2020-07-18 12:06:37,301 train loss in this epoch 0.482771
2020-07-18 12:06:38,195 train loss in this epoch 0.484361
2020-07-18 12:06:39,089 train loss in this epoch 0.483431
2020-07-18 12:06:39,091 epoch 459, checkpoint!
2020-07-18 12:07:47,234 valid_loss: 0.5094 AUC: 0.7148 Prec: 0.5596 Rec: 0.1458 F1: 0.2313
2020-07-18 12:07:47,798 best threshold=-1.494785, f1=0.4885
2020-07-18 12:07:56,913 

NameError: name 'args' is not defined

In [93]:
best_thr = evaluate(500, valid_loader, return_best_thr=True, log_desc='valid_')
evaluate(500, test_loader, thr=best_thr, log_desc='test_')

2020-07-18 12:19:23,076 valid_loss: 0.5092 AUC: 0.7156 Prec: 0.5581 Rec: 0.1456 F1: 0.2310
2020-07-18 12:19:23,652 best threshold=-1.468486, f1=0.4892
2020-07-18 12:19:32,389 using threshold -1.4685
2020-07-18 12:19:32,754 test_loss: 0.5105 AUC: 0.7137 Prec: 0.3767 Rec: 0.6921 F1: 0.4879


## DeepInf_GAT ##

In [115]:
influence_dataset = InfluenceDataSet(
            "weibo/", 64, 42, False, "gat")
train_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(valid_start - train_start, 0))
valid_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(test_start - valid_start, valid_start))
test_loader = DataLoader(influence_dataset, batch_size=1024,
                        sampler=ChunkSampler(N - test_start, test_start))

2020-07-18 13:11:44,075 graphs loaded!
2020-07-18 13:11:44,320 influence features loaded!
2020-07-18 13:11:44,324 labels loaded!
2020-07-18 13:11:44,438 vertex ids loaded!
2020-07-18 13:11:45,238 global vertex features loaded!
2020-07-18 13:13:04,557 64-dim embedding loaded!
2020-07-18 13:13:04,560 779164 ego networks loaded, each with size 50


In [None]:
from tensorboard_logger import tensorboard_logger
import shutil
logger = logging.getLogger(__name__)
tensorboard_log_dir = 'log/%s' % ('gat')
os.makedirs(tensorboard_log_dir, exist_ok=True)
shutil.rmtree(tensorboard_log_dir)
tensorboard_logger.configure(tensorboard_log_dir)
logger.info('tensorboard logging to %s', tensorboard_log_dir)

In [118]:
class MultiHeadGraphAttention(nn.Module):
    def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
        super(MultiHeadGraphAttention, self).__init__()
        self.n_head = n_head
        self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
        self.a_src = Parameter(torch.Tensor(n_head, f_out, 1))
        self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(attn_dropout)

        if bias:
            self.bias = Parameter(torch.Tensor(f_out))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)

        init.xavier_uniform_(self.w)
        init.xavier_uniform_(self.a_src)
        init.xavier_uniform_(self.a_dst)

    def forward(self, h, adj):
        n = h.size(0) # h is of size n x f_in
        h_prime = torch.matmul(h.unsqueeze(0), self.w) #  n_head x n x f_out
        attn_src = torch.bmm(h_prime, self.a_src) # n_head x n x 1
        attn_dst = torch.bmm(h_prime, self.a_dst) # n_head x n x 1
        attn = attn_src.expand(-1, -1, n) + attn_dst.expand(-1, -1, n).permute(0, 2, 1) # n_head x n x n

        attn = self.leaky_relu(attn)
        attn.data.masked_fill_(1 - adj, float("-inf"))
        attn = self.softmax(attn) # n_head x n x n
        attn = self.dropout(attn)
        output = torch.bmm(attn, h_prime) # n_head x n x f_out

        if self.bias is not None:
            return output + self.bias
        else:
            return output


class BatchMultiHeadGraphAttention(nn.Module):
    def __init__(self, n_head, f_in, f_out, attn_dropout, bias=True):
        super(BatchMultiHeadGraphAttention, self).__init__()
        self.n_head = n_head
        self.w = Parameter(torch.Tensor(n_head, f_in, f_out))
        self.a_src = Parameter(torch.Tensor(n_head, f_out, 1))
        self.a_dst = Parameter(torch.Tensor(n_head, f_out, 1))

        self.leaky_relu = nn.LeakyReLU(negative_slope=0.2)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(attn_dropout)
        if bias:
            self.bias = Parameter(torch.Tensor(f_out))
            init.constant_(self.bias, 0)
        else:
            self.register_parameter('bias', None)

        init.xavier_uniform_(self.w)
        init.xavier_uniform_(self.a_src)
        init.xavier_uniform_(self.a_dst)

    def forward(self, h, adj):
        bs, n = h.shape[:2] # h is of size bs x n x f_in
        h_prime = torch.matmul(h.unsqueeze(1), self.w) # bs x n_head x n x f_out
        attn_src = torch.matmul(F.tanh(h_prime), self.a_src) # bs x n_head x n x 1
        attn_dst = torch.matmul(F.tanh(h_prime), self.a_dst) # bs x n_head x n x 1
        attn = attn_src.expand(-1, -1, -1, n) + attn_dst.expand(-1, -1, -1, n).permute(0, 1, 3, 2) # bs x n_head x n x n

        attn = self.leaky_relu(attn)
        mask = 1 - adj.unsqueeze(1) # bs x 1 x n x n
        attn.data.masked_fill_(mask, float("-inf"))
        attn = self.softmax(attn) # bs x n_head x n x n
        attn = self.dropout(attn)
        output = torch.matmul(attn, h_prime) # bs x n_head x n x f_out
        if self.bias is not None:
            return output + self.bias
        else:
            return output

In [119]:
class BatchGAT(nn.Module):
    def __init__(self, pretrained_emb, vertex_feature, use_vertex_feature,
            n_units=[1433, 8, 7], n_heads=[8, 1],
            dropout=0.1, attn_dropout=0.0, fine_tune=False,
            instance_normalization=False):
        super(BatchGAT, self).__init__()
        self.n_layer = len(n_units) - 1
        self.dropout = dropout
        self.inst_norm = instance_normalization
        if self.inst_norm:
            self.norm = nn.InstanceNorm1d(pretrained_emb.size(1), momentum=0.0, affine=True)

        # https://discuss.pytorch.org/t/can-we-use-pre-trained-word-embeddings-for-weight-initialization-in-nn-embedding/1222/2
        self.embedding = nn.Embedding(pretrained_emb.size(0), pretrained_emb.size(1))
        self.embedding.weight = nn.Parameter(pretrained_emb)
        self.embedding.weight.requires_grad = fine_tune
        n_units[0] += pretrained_emb.size(1)

        self.use_vertex_feature = use_vertex_feature
        if self.use_vertex_feature:
            self.vertex_feature = nn.Embedding(vertex_feature.size(0), vertex_feature.size(1))
            self.vertex_feature.weight = nn.Parameter(vertex_feature)
            self.vertex_feature.weight.requires_grad = False
            n_units[0] += vertex_feature.size(1)

        self.layer_stack = nn.ModuleList()
        for i in range(self.n_layer):
            # consider multi head from last layer
            f_in = n_units[i] * n_heads[i - 1] if i else n_units[i]
            self.layer_stack.append(
                    BatchMultiHeadGraphAttention(n_heads[i], f_in=f_in,
                        f_out=n_units[i + 1], attn_dropout=attn_dropout)
                    )

    def forward(self, x, vertices, adj):
        emb = self.embedding(vertices)
        if self.inst_norm:
            emb = self.norm(emb.transpose(1, 2)).transpose(1, 2)
        x = torch.cat((x, emb), dim=2)
        if self.use_vertex_feature:
            vfeature = self.vertex_feature(vertices)
            x = torch.cat((x, vfeature), dim=2)
        bs, n = adj.size()[:2]
        for i, gat_layer in enumerate(self.layer_stack):
            x = gat_layer(x, adj) # bs x n_head x n x f_out
            if i + 1 == self.n_layer:
                x = x.mean(dim=1)
            else:
                x = F.elu(x.transpose(1, 2).contiguous().view(bs, n, -1))
                x = F.dropout(x, self.dropout, training=self.training)
        return F.log_softmax(x, dim=-1)


In [120]:
n_heads = [int(x) for x in [8,8,1]]
hidden_units = "16,16"
feature_dim = influence_dataset.get_feature_dimension()
n_units = [feature_dim] + [int(x) for x in hidden_units.strip().split(",")] + [n_classes]
model = BatchGAT(pretrained_emb=influence_dataset.get_embedding(),
            vertex_feature=influence_dataset.get_vertex_features(),
            use_vertex_feature= False,
            n_units=n_units, n_heads=n_heads,
            dropout=dropout)

In [None]:
from tqdm import tqdm
t_total = time.time()
logger.info("training...")
for epoch in tqdm(range(500)):
    train(epoch, train_loader, valid_loader, test_loader)
logger.info("optimization Finished!")
logger.info("total time elapsed: {:.4f}s".format(time.time() - t_total))

logger.info("retrieve best threshold...")
best_thr = evaluate(500, valid_loader, return_best_thr=True, log_desc='valid_')

# Testing
logger.info("testing...")
evaluate(500, test_loader, thr=best_thr, log_desc='test_')

2020-07-18 13:23:20,028 training...
2020-07-18 13:23:25,542 train loss in this epoch 0.691449
2020-07-18 13:23:31,010 train loss in this epoch 0.691475


2020-07-18 13:23:36,521 train loss in this epoch 0.691477
2020-07-18 13:23:42,014 train loss in this epoch 0.691479


2020-07-18 13:23:47,565 train loss in this epoch 0.691444
2020-07-18 13:23:53,542 train loss in this epoch 0.691443


2020-07-18 13:23:59,253 train loss in this epoch 0.691459
2020-07-18 13:24:05,018 train loss in this epoch 0.691456


2020-07-18 13:24:10,544 train loss in this epoch 0.691420
2020-07-18 13:24:16,330 train loss in this epoch 0.691432
2020-07-18 13:24:16,333 epoch 9, checkpoint!
































































































2020-07-18 13:29:42,977 valid_loss: 0.6912 AUC: 0.4629 Prec: 0.2159 Rec: 0.2988 F1: 0.2506
2020-07-18 13:29:43,540 best threshold=-0.723414, f1=0.4033














2020-07-18 13:30:31,159 using threshold -0.7234
2020-07-18 13:30:31,435 test_loss: 0.6912 AUC: 0.4637 Prec: 0.2547 Rec: 0.9654 F1: 0.4030
2020-07-18 13:30:37,505 train loss in this epoch 0.691491


2020-07-18 13:30:43,310 train loss in this epoch 0.691429
2020-07-18 13:30:49,192 train loss in this epoch 0.691487
2020-07-18 13:30:55,178 train loss in this epoch 0.691484
2020-07-18 13:31:01,052 train loss in this epoch 0.691450
2020-07-18 13:31:06,579 train loss in this epoch 0.691427


2020-07-18 13:31:12,474 train loss in this epoch 0.691451
2020-07-18 13:31:18,209 train loss in this epoch 0.691475


2020-07-18 13:31:24,050 train loss in this epoch 0.691420
2020-07-18 13:31:30,037 train loss in this epoch 0.691473
2020-07-18 13:31:30,039 epoch 19, checkpoint!




























































































