**Run OntoZSL on NELL with RDFS**
---

The parameters in other settings are attached in the end.


**1. Bind your Google Drive**

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


**2. Import Package**

In [3]:
import os
import sys
import json
import random
import shutil
import logging
import argparse
import numpy as np
import os.path as osp
from tqdm import tqdm
from collections import defaultdict
from collections import deque
from sklearn.metrics.pairwise import cosine_similarity

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch import optim
from torch.autograd import Variable
from torch.nn.functional import normalize
from torch.nn.parameter import Parameter

**3. Data Load and Preparation, including data for pre-training Feature Encoder, model training and semantic embedding data**

In [4]:
def random_pick(some_list, probabilities):
    # print("random pick:", seed)
    # random.seed(seed)
    x = random.uniform(0,1)
    cumulative_probability = 0.0
    for item, item_probability in zip(some_list, probabilities):
        cumulative_probability += item_probability
        if x < cumulative_probability:break
    return item

# Using
def Extractor_generate(seed, dataset, train_tasks, batch_size, symbol2id, ent2id, e1rel_e2, few, sub_epoch):
    # print("extract generate:", seed)
    random.seed(seed)
    print('\nLOADING PRETRAIN TRAINING DATA')
    # train_tasks = json.load(open(dataset + '/train_tasks.json'))
    rel2candidates = json.load(open(dataset + '/rel2candidates_all.json'))

    task_pool = train_tasks.keys()

    t_num = list()
    for k in task_pool:
        v = min(len(rel2candidates[k]), 1000)
        t_num.append(v)
    t_sum = sum(t_num)
    probability = [float(item)/t_sum for item in t_num]

    while True:
        support_pairs, query_pairs, false_pairs, support_left, support_right, query_left, query_right, false_left, false_right = \
           list(), list(), list(), list(), list(), list(), list(), list(), list()
        # query = random_pick(seed, task_pool, probability)

        query = random_pick(task_pool, probability)
        candidates = rel2candidates[query]
        if len(candidates) <= 20:
            continue


        for _ in range(sub_epoch):
            candidates = rel2candidates[query]

            if len(candidates) <= 20:
                continue

            train_and_test = train_tasks[query]

            random.shuffle(train_and_test)

            support_triples = train_and_test[:few]

            support_pairs += [[symbol2id[triple[0]], symbol2id[triple[2]]] for triple in support_triples]

            support_left += [ent2id[triple[0]] for triple in support_triples]
            support_right += [ent2id[triple[2]] for triple in support_triples]

            all_test_triples = train_and_test[few:]

            if len(all_test_triples) == 0:
                continue

            if len(all_test_triples) < batch_size:
                query_triples = [random.choice(all_test_triples) for _ in range(batch_size)]
            else:
                query_triples = random.sample(all_test_triples, batch_size)

            query_pairs += [[symbol2id[triple[0]], symbol2id[triple[2]]] for triple in query_triples]

            query_left += [ent2id[triple[0]] for triple in query_triples]
            query_right += [ent2id[triple[2]] for triple in query_triples]

            for triple in query_triples:
                e_h = triple[0]
                rel = triple[1]
                e_t = triple[2]
                while True:
                    noise = random.choice(candidates)
                    if noise in ent2id.keys():#ent2id.has_key(noise):
                        if (noise not in e1rel_e2[e_h+rel]) and noise != e_t:
                            break
                false_pairs.append([symbol2id[e_h], symbol2id[noise]])
                false_left.append(ent2id[e_h])
                false_right.append(ent2id[noise])

        yield support_pairs, query_pairs, false_pairs, support_left, support_right, query_left, query_right, false_left, false_right


def centroid_generate(dataset, relation_name, symbol2id, ent2id, train_tasks, rela2label):

    all_test_triples = train_tasks[relation_name]

    query_triples = all_test_triples

    query_pairs = [[symbol2id[triple[0]], symbol2id[triple[2]]] for triple in query_triples]

    query_left = [ent2id[triple[0]] for triple in query_triples]
    query_right = [ent2id[triple[2]] for triple in query_triples]

    return query_pairs, query_left, query_right, rela2label[relation_name]


def train_generate_decription(dataset, train_tasks, batch_size, symbol2id, ent2id, e1rel_e2, rel2id, args, rela2label, rela_matrix):
    # print("train_generate_description", seed)
    # random.seed(seed)
    print('##LOADING TRAINING DATA')
    # train_tasks = json.load(open(dataset + 'train_tasks.json'))
    print('##LOADING CANDIDATES')
    rel2candidates = json.load(open(dataset + '/rel2candidates_all.json'))
    # task_pool = list(train_tasks.keys())
    task_pool = sorted(train_tasks.keys())  # ensure the readout is the same

    # print(task_pool)

    while True:
        rel_batch, query_pairs, query_left, query_right, false_pairs, false_left, false_right, labels = [], [], [], [], [], [], [], []
        random.shuffle(task_pool)
        if len(rel2candidates[task_pool[0]]) <= 20:
            continue
        if len(rel2candidates[task_pool[1]]) <= 20:
            continue
        for query in task_pool[:args.gan_batch_rela]:
            # print(query)
            relation_id = rel2id[query]
            candidates = rel2candidates[query]

            if args.dataset == 'Wiki':
                if len(candidates) <= 20:
                    # print 'not enough candidates'
                    continue

            train_and_test = train_tasks[query]

            random.shuffle(train_and_test)

            all_test_triples = train_and_test

            if len(all_test_triples) == 0:
                continue

            # print("all test triples num:", len(all_test_triples))

            if len(all_test_triples) < batch_size:
                query_triples = [random.choice(all_test_triples) for _ in range(batch_size)]
            else:
                query_triples = random.sample(all_test_triples, batch_size)

            query_pairs += [[symbol2id[triple[0]], symbol2id[triple[2]]] for triple in query_triples]

            query_left += [ent2id[triple[0]] for triple in query_triples]
            query_right += [ent2id[triple[2]] for triple in query_triples]

            label = rela2label[query]

            # generate negative samples
            false_pairs_ = []
            false_left_ = []
            false_right_ = []
            for triple in query_triples:
                e_h = triple[0]
                rel = triple[1]
                e_t = triple[2]
                while True:
                    noise = random.choice(candidates)
                    if noise in ent2id.keys(): # ent2id.has_key(noise):
                        if (noise not in e1rel_e2[e_h+rel]) and noise != e_t:
                            break
                false_pairs_.append([symbol2id[e_h], symbol2id[noise]])
                false_left_.append(ent2id[e_h])
                false_right_.append(ent2id[noise])

            false_pairs += false_pairs_
            false_left += false_left_
            false_right += false_right_


            rel_batch += [rel2id[query] for _ in range(batch_size)]

            labels += [rela2label[query]] * batch_size

        yield rela_matrix[rel_batch], query_pairs, query_left, query_right, false_pairs, false_left, false_right, labels

def load_semantic_embed(data_path, dataset, type):
    """
    Load Semantic Embeddings.
    """

    file_name = ''
    file_path = os.path.join(data_path, 'semantic_embeddings')
    if dataset == 'NELL':
        if type == 'rdfs':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_55000.npz')
        elif type == 'rdfs_hie':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_hie_60000.npz')
        elif type == 'rdfs_cons':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_cons_60000.npz')
        elif type == 'text':
            file_name = os.path.join(file_path, 'rela_matrix_text.npz')
        elif type == 'rdfs_text':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_55000_text140.npz')
        else:
            print("WARNING: invalid semantic embeddings type")
    elif dataset == 'Wiki':
        if type == 'rdfs':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_65000.npz')
        elif type == 'rdfs_hie':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_hie_60000.npz')
        elif type == 'rdfs_cons':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_cons_60000.npz')
        elif type == 'text':
            file_name = os.path.join(file_path, 'rela_matrix_text.npz')
        elif type == 'rdfs_text':
            file_name = os.path.join(file_path, 'rela_matrix_rdfs_65000_text140.npz')
        else:
            print("WARNING: invalid semantic embeddings type")

    if file_name:
        rela_embeddings = np.load(file_name)['relaM'].astype('float32')


    else:
        print('WARNING: invalid semantic embeddings file path')
    return rela_embeddings

**4. Additional Modules for Feature Encoder**

In [5]:
def gaussian_noise(x, sigma=0.1, mean=0, stddev=1, is_training=True):
    if is_training:
        noise = Variable(x.data.new(x.size()).normal_(mean, stddev))
        noise = sigma * noise
        return x + noise
    return x



class Path(nn.Module):
    """convolution to encode every paths beween an entity pair"""
    def __init__(self, input_dim, num_symbols, use_pretrain=True, embed_path='', dropout=0.5, k_sizes = [3], k_num=100):
        '''
        Parameters:
        input_dim: size of relation/entity embeddings
        num_symbols: total number of entities and relations
        use_pretraIn: use pretrained KB embeddings or not
        '''
        super(Path, self).__init__()
        self.symbol_emb = nn.Embedding(num_symbols + 1, input_dim, padding_idx=num_symbols)
        self.k_sizes = k_sizes
        self.k_num = k_num

        if use_pretrain:
            emb_np = np.loadtxt(embed_path)
            self.symbol_emb.weight.data.copy_(torch.from_numpy(emb_np))
            self.symbol_emb.weight.requires_grad = False

        self.convs = nn.ModuleList([nn.Conv2d(1,self.k_num, (k, input_dim)) for k in self.k_sizes])

        self.dropout = nn.Dropout(dropout)

    def forward(self, path):
        '''
        Inputs:
        path: batch * max_len(7)
        '''
        path = self.symbol_emb(path)
        path = path.unsqueeze(1) # (B, 1, W, D)

        convs = [F.relu(conv(path)).squeeze(3) for conv in self.convs] # every element (B, 100, W-(k-1))
        pools = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in convs]

        path = torch.cat(pools, 1) # (B, num_k * c_out)
        path = self.dropout(path)

        return path

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, d_model, attn_dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.temper = np.power(d_model, 0.5)
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, attn_mask=None):

        attn = torch.bmm(q, k.transpose(1, 2)) / self.temper

        if attn_mask is not None:

            assert attn_mask.size() == attn.size(), \
                    'Attention mask shape {} mismatch ' \
                    'with Attention logit tensor shape ' \
                    '{}.'.format(attn_mask.size(), attn.size())

            attn.data.masked_fill_(attn_mask, -float('inf'))

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn

class LayerNormalization(nn.Module):
    ''' Layer normalization module '''

    def __init__(self, d_hid, eps=1e-3):
        super(LayerNormalization, self).__init__()

        self.eps = eps
        self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
        self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)

    def forward(self, z):
        if z.size(1) == 1:
            return z

        mu = torch.mean(z, keepdim=True, dim=-1)
        sigma = torch.std(z, keepdim=True, dim=-1)
        ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
        ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)

        return ln_out

class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k))
        self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v))

        self.attention = ScaledDotProductAttention(d_model)
        self.layer_norm = LayerNormalization(d_model)

        self.proj = nn.Linear(n_head*d_v, d_model)
        init.xavier_normal_(self.proj.weight)

        self.dropout = nn.Dropout(dropout)

        init.xavier_normal_(self.w_qs)
        init.xavier_normal_(self.w_ks)
        init.xavier_normal_(self.w_vs)

    def forward(self, q, k, v, attn_mask=None):

        d_k, d_v = self.d_k, self.d_v
        n_head = self.n_head

        residual = q

        mb_size, len_q, d_model = q.size()
        mb_size, len_k, d_model = k.size()
        mb_size, len_v, d_model = v.size()

        # treat as a (n_head) size batch
        q_s = q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_q) x d_model
        k_s = k.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_k) x d_model
        v_s = v.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_v) x d_model

        # treat the result as a (n_head * mb_size) size batch
        q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k)   # (n_head*mb_size) x len_q x d_k
        k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k)   # (n_head*mb_size) x len_k x d_k
        v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v)   # (n_head*mb_size) x len_v x d_v

        # perform attention, result size = (n_head * mb_size) x len_q x d_v
        if attn_mask:
            outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask.repeat(n_head, 1, 1))
        else:
            outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=None)

        # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v)
        outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1)

        # project back to residual size
        outputs = self.proj(outputs)
        outputs = self.dropout(outputs)

        return self.layer_norm(outputs + residual), attns

class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_hid, d_inner_hid, dropout=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1) # position-wise
        self.w_2 = nn.Conv1d(d_inner_hid, d_hid, 1) # position-wise
        self.layer_norm = LayerNormalization(d_hid)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        output = self.relu(self.w_1(x.transpose(1, 2)))
        output = self.w_2(output).transpose(2, 1)
        output = self.dropout(output)
        return self.layer_norm(output + residual)

class SupportEncoder(nn.Module):
    """docstring for SupportEncoder"""
    def __init__(self, d_model, d_inner, dropout=0.1):
        super(SupportEncoder, self).__init__()
        self.proj1 = nn.Linear(d_model, d_inner)
        self.proj2 = nn.Linear(d_inner, d_model)
        self.layer_norm = LayerNormalization(d_model)

        init.xavier_normal_(self.proj1.weight)
        init.xavier_normal_(self.proj2.weight)

        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        output = self.relu(self.proj1(x))
        output = self.dropout(self.proj2(output))
        return self.layer_norm(output + residual)


class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner_hid, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(
            n_head, d_model, d_k, d_v, dropout=dropout)
        # self.pos_ffn = PositionwiseFeedForward(d_model, d_inner_hid, dropout=dropout)

    def forward(self, enc_input, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(
            enc_input, enc_input, enc_input, attn_mask=slf_attn_mask)
        # enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


class ContextAwareEncoder(nn.Module):
    """Use self-attention here"""
    def __init__(self, num_layers, d_model, d_inner_hid, n_head, d_k, d_v, dropout = 0.1):
        super(ContextAwareEncoder, self).__init__()
        self.num_layers = num_layers
        #
        self.layer_stack = nn.ModuleList([EncoderLayer(d_model, d_inner_hid, n_head, d_k, d_v, dropout=dropout) for _ in range(self.num_layers)])

    def forward(self, elements, enc_slf_attn_mask=None):
        enc_output = elements
        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(
                enc_output, slf_attn_mask=enc_slf_attn_mask)

        return enc_output

class QueryEncoder(nn.Module):
    """docstring for QueryEncoder"""
    def __init__(self, input_dim, process_step=4):
        super(QueryEncoder, self).__init__()
        self.input_dim = input_dim
        self.process_step = process_step
        # self.batch_size = batch_size
        self.process = nn.LSTMCell(input_dim, 2*input_dim)

        # initialize the hidden states, TODO: try to train the initial state
        # self.h0 = Variable(torch.zeros(self.batch_size, 2*input_dim)).cuda()
        # self.c0 = Variable(torch.zeros(self.batch_size, 2*input_dim)).cuda()

    def forward(self, support, query):
        '''
        support: (few, support_dim)
        query: (batch_size, query_dim)
        support_dim = query_dim

        return:
        (batch_size, query_dim)
        '''
        assert support.size()[1] == query.size()[1]

        if self.process_step == 0:
            return query

        batch_size = query.size()[0]
        h_r = Variable(torch.zeros(batch_size, 2*self.input_dim)).cuda()
        c = Variable(torch.zeros(batch_size, 2*self.input_dim)).cuda()
        for step in range(self.process_step):
            h_r_, c = self.process(query, (h_r, c))
            h = query + h_r_[:,:self.input_dim] # (batch_size, query_dim)
            attn = F.softmax(torch.matmul(h, support.t()), dim=1)
            r = torch.matmul(attn, support) # (batch_size, support_dim)
            h_r = torch.cat((h, r), dim=1)

        # return h_r_[:, :self.input_dim]
        return h

**5. Spectral Normalization from https://arxiv.org/abs/1802.05957**

In [6]:
class SpectralNorm(object):
    # Invariant before and after each forward call:
    #   u = normalize(W @ v)
    # NB: At initialization, this invariant is not enforced

    _version = 1
    # At version 1:
    #   made  `W` not a buffer,
    #   added `v` as a buffer, and
    #   made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.

    def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
        self.name = name
        self.dim = dim
        if n_power_iterations <= 0:
            raise ValueError('Expected n_power_iterations to be positive, but '
                             'got n_power_iterations={}'.format(n_power_iterations))
        self.n_power_iterations = n_power_iterations
        self.eps = eps

    def reshape_weight_to_matrix(self, weight):
        weight_mat = weight
        if self.dim != 0:
            # permute dim to front
            weight_mat = weight_mat.permute(self.dim,
                                            *[d for d in range(weight_mat.dim()) if d != self.dim])
        height = weight_mat.size(0)
        return weight_mat.reshape(height, -1)

    def compute_weight(self, module, do_power_iteration):
        # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
        #     updated in power iteration **in-place**. This is very important
        #     because in `DataParallel` forward, the vectors (being buffers) are
        #     broadcast from the parallelized module to each module replica,
        #     which is a new module object created on the fly. And each replica
        #     runs its own spectral norm power iteration. So simply assigning
        #     the updated vectors to the module this function runs on will cause
        #     the update to be lost forever. And the next time the parallelized
        #     module is replicated, the same randomly initialized vectors are
        #     broadcast and used!
        #
        #     Therefore, to make the change propagate back, we rely on two
        #     important bahaviors (also enforced via tests):
        #       1. `DataParallel` doesn't clone storage if the broadcast tensor
        #          is alreay on correct device; and it makes sure that the
        #          parallelized module is already on `device[0]`.
        #       2. If the out tensor in `out=` kwarg has correct shape, it will
        #          just fill in the values.
        #     Therefore, since the same power iteration is performed on all
        #     devices, simply updating the tensors in-place will make sure that
        #     the module replica on `device[0]` will update the _u vector on the
        #     parallized module (by shared storage).
        #
        #    However, after we update `u` and `v` in-place, we need to **clone**
        #    them before using them to normalize the weight. This is to support
        #    backproping through two forward passes, e.g., the common pattern in
        #    GAN training: loss = D(real) - D(fake). Otherwise, engine will
        #    complain that variables needed to do backward for the first forward
        #    (i.e., the `u` and `v` vectors) are changed in the second forward.
        weight = getattr(module, self.name + '_orig')
        u = getattr(module, self.name + '_u')
        v = getattr(module, self.name + '_v')
        weight_mat = self.reshape_weight_to_matrix(weight)

        if do_power_iteration:
            with torch.no_grad():
                for _ in range(self.n_power_iterations):
                    # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
                    # are the first left and right singular vectors.
                    # This power iteration produces approximations of `u` and `v`.
                    v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v)
                    u = normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps, out=u)
                if self.n_power_iterations > 0:
                    # See above on why we need to clone
                    u = u.clone()
                    v = v.clone()

        sigma = torch.dot(u, torch.mv(weight_mat, v))
        weight = weight / sigma
        return weight

    def remove(self, module):
        with torch.no_grad():
            weight = self.compute_weight(module, do_power_iteration=False)
        delattr(module, self.name)
        delattr(module, self.name + '_u')
        delattr(module, self.name + '_v')
        delattr(module, self.name + '_orig')
        module.register_parameter(self.name, torch.nn.Parameter(weight.detach()))

    def __call__(self, module, inputs):
        setattr(module, self.name, self.compute_weight(module, do_power_iteration=module.training))

    def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
        # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
        # (the invariant at top of this class) and `u @ W @ v = sigma`.
        # This uses pinverse in case W^T W is not invertible.
        v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1)
        return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))

    @staticmethod
    def apply(module, name, n_power_iterations, dim, eps):
        for k, hook in module._forward_pre_hooks.items():
            if isinstance(hook, SpectralNorm) and hook.name == name:
                raise RuntimeError("Cannot register two spectral_norm hooks on "
                                   "the same parameter {}".format(name))

        fn = SpectralNorm(name, n_power_iterations, dim, eps)
        weight = module._parameters[name]

        with torch.no_grad():
            weight_mat = fn.reshape_weight_to_matrix(weight)

            h, w = weight_mat.size()
            # randomly initialize `u` and `v`
            u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
            v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)

        delattr(module, fn.name)
        module.register_parameter(fn.name + "_orig", weight)
        # We still need to assign weight back as fn.name because all sorts of
        # things may assume that it exists, e.g., when initializing weights.
        # However, we can't directly assign as it could be an nn.Parameter and
        # gets added as a parameter. Instead, we register weight.data as a plain
        # attribute.
        setattr(module, fn.name, weight.data)
        module.register_buffer(fn.name + "_u", u)
        module.register_buffer(fn.name + "_v", v)

        module.register_forward_pre_hook(fn)

        module._register_state_dict_hook(SpectralNormStateDictHook(fn))
        module._register_load_state_dict_pre_hook(SpectralNormLoadStateDictPreHook(fn))
        return fn


# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
    # See docstring of SpectralNorm._version on the changes to spectral_norm.
    def __init__(self, fn):
        self.fn = fn

    # For state_dict with version None, (assuming that it has gone through at
    # least one training forward), we have
    #
    #    u = normalize(W_orig @ v)
    #    W = W_orig / sigma, where sigma = u @ W_orig @ v
    #
    # To compute `v`, we solve `W_orig @ x = u`, and let
    #    v = x / (u @ W_orig @ x) * (W / W_orig).
    def __call__(self, state_dict, prefix, local_metadata, strict,
                 missing_keys, unexpected_keys, error_msgs):
        fn = self.fn
        version = local_metadata.get('spectral_norm', {}).get(fn.name + '.version', None)
        if version is None or version < 1:
            with torch.no_grad():
                weight_orig = state_dict[prefix + fn.name + '_orig']
                weight = state_dict.pop(prefix + fn.name)
                sigma = (weight_orig / weight).mean()
                weight_mat = fn.reshape_weight_to_matrix(weight_orig)
                u = state_dict[prefix + fn.name + '_u']
                v = fn._solve_v_and_rescale(weight_mat, u, sigma)
                state_dict[prefix + fn.name + '_v'] = v


# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormStateDictHook(object):
    # See docstring of SpectralNorm._version on the changes to spectral_norm.
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, module, state_dict, prefix, local_metadata):
        if 'spectral_norm' not in local_metadata:
            local_metadata['spectral_norm'] = {}
        key = self.fn.name + '.version'
        if key in local_metadata['spectral_norm']:
            raise RuntimeError("Unexpected key in metadata['spectral_norm']: {}".format(key))
        local_metadata['spectral_norm'][key] = self.fn._version


def spectral_norm(module, name='weight', n_power_iterations=1, eps=1e-12, dim=None):
    r"""Applies spectral normalization to a parameter in the given module.

    .. math::
        \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
        \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}

    Spectral normalization stabilizes the training of discriminators (critics)
    in Generative Adversarial Networks (GANs) by rescaling the weight tensor
    with spectral norm :math:`\sigma` of the weight matrix calculated using
    power iteration method. If the dimension of the weight tensor is greater
    than 2, it is reshaped to 2D in power iteration method to get spectral
    norm. This is implemented via a hook that calculates spectral norm and
    rescales weight before every :meth:`~Module.forward` call.

    See `Spectral Normalization for Generative Adversarial Networks`_ .

    .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter
        n_power_iterations (int, optional): number of power iterations to
            calculate spectral norm
        eps (float, optional): epsilon for numerical stability in
            calculating norms
        dim (int, optional): dimension corresponding to number of outputs,
            the default is 0, except for modules that are instances of
            ConvTranspose1/2/3d, when it is 1

    Returns:
        The original module with the spectral norm hook

    Example::

        >>> m = spectral_norm(nn.Linear(20, 40))
        Linear (20 -> 40)
        >>> m.weight_u.size()
        torch.Size([20])

    """
    if dim is None:
        if isinstance(module, (torch.nn.ConvTranspose1d,
                               torch.nn.ConvTranspose2d,
                               torch.nn.ConvTranspose3d)):
            dim = 1
        else:
            dim = 0
    SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
    return module


def remove_spectral_norm(module, name='weight'):
    r"""Removes the spectral normalization reparameterization from a module.

    Args:
        module (nn.Module): containing module
        name (str, optional): name of weight parameter

    Example:
        >>> m = spectral_norm(nn.Linear(40, 10))
        >>> remove_spectral_norm(m)
    """
    for k, hook in module._forward_pre_hooks.items():
        if isinstance(hook, SpectralNorm) and hook.name == name:
            hook.remove(module)
            del module._forward_pre_hooks[k]
            return module

    raise ValueError("spectral_norm of '{}' not found in {}".format(
        name, module))

**6. Models, including Feature Extractor, Generator and Discriminator**

In [7]:
def weights_init(m):
    classname = m.__class__.__name__
    if 'Linear' in classname:
        init.xavier_normal_(m.weight.data)
        init.constant_(m.bias, 0.0)
class Extractor(nn.Module):
    """
    Matching metric based on KB Embeddings
    """

    def __init__(self, embed_dim, num_symbols, embed=None):
        super(Extractor, self).__init__()
        self.embed_dim = int(embed_dim)
        self.pad_idx = num_symbols
        self.symbol_emb = nn.Embedding(num_symbols + 1, embed_dim, padding_idx=num_symbols)
        self.num_symbols = num_symbols

        self.gcn_w = nn.Linear(self.embed_dim, int(self.embed_dim / 2))
        self.gcn_b = nn.Parameter(torch.FloatTensor(self.embed_dim))

        self.fc1 = nn.Linear(self.embed_dim, int(self.embed_dim / 2))
        self.fc2 = nn.Linear(self.embed_dim, int(self.embed_dim / 2))

        self.dropout = nn.Dropout(0.2)
        self.dropout_e = nn.Dropout(0.2)

        self.symbol_emb.weight.data.copy_(torch.from_numpy(embed))

        self.symbol_emb.weight.requires_grad = False

        d_model = self.embed_dim * 2
        self.support_encoder = SupportEncoder(d_model, 2 * d_model, dropout=0.2)
        # self.query_encoder = QueryEncoder(d_model, process_steps)

    def neighbor_encoder(self, connections, num_neighbors):
        '''
        connections: (batch, 200, 2)
        num_neighbors: (batch,)
        '''
        num_neighbors = num_neighbors.unsqueeze(1)
        entities = connections[:, :, 1].squeeze(-1)
        ent_embeds = self.dropout(self.symbol_emb(entities))  # (batch, 50, embed_dim)
        concat_embeds = ent_embeds

        out = self.gcn_w(concat_embeds)
        out = torch.sum(out, dim=1)  # (batch, embed_dim)
        out = out / num_neighbors
        return out.tanh()

    def entity_encoder(self, entity1, entity2):
        entity1 = self.dropout_e(entity1)
        entity2 = self.dropout_e(entity2)
        entity1 = self.fc1(entity1)
        entity2 = self.fc2(entity2)
        entity = torch.cat((entity1, entity2), dim=-1)
        return entity.tanh()  # (batch, embed_dim)

    def forward(self, query, support, query_meta=None, support_meta=None):
        '''
        query: (batch_size, 2)
        support: (few, 2)
        return: (batch_size, )
        '''
        query_left_connections, query_left_degrees, query_right_connections, query_right_degrees = query_meta
        support_left_connections, support_left_degrees, support_right_connections, support_right_degrees = support_meta

        query_e1 = self.symbol_emb(query[:, 0])  # (batch, embed_dim)
        query_e2 = self.symbol_emb(query[:, 1])  # (batch, embed_dim)
        query_e = self.entity_encoder(query_e1, query_e2)

        support_e1 = self.symbol_emb(support[:, 0])  # (batch, embed_dim)
        support_e2 = self.symbol_emb(support[:, 1])  # (batch, embed_dim)
        support_e = self.entity_encoder(support_e1, support_e2)

        query_left = self.neighbor_encoder(query_left_connections, query_left_degrees)
        query_right = self.neighbor_encoder(query_right_connections, query_right_degrees)

        support_left = self.neighbor_encoder(support_left_connections, support_left_degrees)
        support_right = self.neighbor_encoder(support_right_connections, support_right_degrees)

        query_neighbor = torch.cat((query_left, query_e, query_right), dim=-1)  # tanh
        support_neighbor = torch.cat((support_left, support_e, support_right), dim=-1)  # tanh

        support = support_neighbor
        query = query_neighbor

        support_g = self.support_encoder(support)  # 1 * 100
        query_g = self.support_encoder(query)

        support_g = torch.mean(support_g, dim=0, keepdim=True)

        # cosine similarity
        matching_scores = torch.matmul(query_g, support_g.t()).squeeze()

        return query_g, matching_scores


class Generator(nn.Module):

    def __init__(self, args, dropout=0.5):
        super(Generator, self).__init__()
        input_dim = args.input_dim
        self.noise_dim = args.noise_dim

        self.fc1_dim = args.fc1_dim
        self.ep_dim = args.ep_dim


        fc1 = nn.Linear(input_dim + self.noise_dim, self.fc1_dim)
        self.fc1 = spectral_norm(fc1)

        fc2 = nn.Linear(self.fc1_dim, self.ep_dim)
        self.fc2 = spectral_norm(fc2)

        self.layer_norm = LayerNormalization(self.ep_dim)

    def forward(self, description, noise):
        x_noise = torch.cat([noise, description], 1)
        x_noise = self.fc1(x_noise)
        false = self.fc2(x_noise)
        false = self.layer_norm(false)

        return false


class Discriminator(nn.Module):
    def __init__(self, args, dropout=0.3):
        super(Discriminator, self).__init__()
        fc2_dim = args.ep_dim
        fc_middle = nn.Linear(fc2_dim, fc2_dim)
        self.fc_middle = spectral_norm(fc_middle)

        fc_TF = nn.Linear(fc2_dim, 1)  # True or False
        self.fc_TF = spectral_norm(fc_TF)

        self.layer_norm = LayerNormalization(fc2_dim)


    def forward(self, ep_vec, centroid_matrix):
        middle_vec = F.leaky_relu(self.fc_middle(ep_vec))
        middle_vec = self.layer_norm(middle_vec)

        centroid_matrix = F.leaky_relu(self.fc_middle(centroid_matrix))
        centroid_matrix = self.layer_norm(centroid_matrix)

        # determine True or False
        logit_TF = self.fc_TF(middle_vec)

        # determine label
        class_scores = torch.matmul(middle_vec, centroid_matrix.t())

        return middle_vec, logit_TF, class_scores

**7. Model Training and Evaluation**

In [8]:
def calc_gradient_penalty(netD, real_data, fake_data, batchsize, centroid_matrix):
    alpha = torch.rand(batchsize, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda()

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = interpolates.cuda()
    interpolates = torch.autograd.Variable(interpolates, requires_grad=True)

    _, disc_interpolates, _ = netD(interpolates, centroid_matrix)

    gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
                              grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
                              create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10 #opt.GP_LAMBDA
    return gradient_penalty

def reset_grad(nets):
    for net in nets:
        net.zero_grad()


class Trainer(object):

    def __init__(self, args):
        super(Trainer, self).__init__()
        for k, v in vars(args).items():setattr(self, k, v)
        self.args = args


        self.train_tasks = json.load(open(os.path.join(args.data_path, 'datasplit', 'train_tasks.json')))
        self.rel2id = json.load(open(os.path.join(args.data_path, 'relation2ids')))

        self.rela_matrix = load_semantic_embed(args.data_path, args.dataset, args.semantic_type)
        self.args.input_dim = self.rela_matrix.shape[1]


        self.ent2id = json.load(open(os.path.join(args.data_path, 'entity2id')))

        print('##LOADING CANDIDATES ENTITIES##')
        self.rel2candidates = json.load(open(os.path.join(args.data_path, 'rel2candidates_all.json')))

        # load answer dict
        self.e1rel_e2 = defaultdict(list)
        self.e1rel_e2 = json.load(open(os.path.join(args.data_path, 'e1rel_e2_all.json')))


        noises = Variable(torch.randn(args.test_sample, args.noise_dim)).cuda()
        self.test_noises = 0.1 * noises

        self.label_num = len(self.train_tasks.keys())

        self.rela2label = dict()
        rela_sorted = sorted(list(self.train_tasks.keys()))
        for i, rela in enumerate(rela_sorted):
            self.rela2label[rela] = int(i)

        print('##LOADING SYMBOL ID AND SYMBOL EMBEDDING')
        if args.load_trained_embed:
            self.load_embed()
        else:
            self.read_embed()


        self.num_symbols = len(self.symbol2id.keys()) - 1
        print("num symbols:", self.num_symbols)
        self.pad_id = self.num_symbols

        print('##DEFINE FEATURE EXTRACTOR')
        self.Extractor = Extractor(args.embed_dim, self.num_symbols, embed=self.symbol2vec)
        self.Extractor.cuda()
        self.Extractor.apply(weights_init)
        self.E_parameters = filter(lambda p: p.requires_grad, self.Extractor.parameters())
        self.optim_E = optim.Adam(self.E_parameters, lr=args.lr_E)

        print('##DEFINE GENERATOR')
        self.Generator = Generator(self.args)
        self.Generator.cuda()
        self.Generator.apply(weights_init)
        self.G_parameters = filter(lambda p: p.requires_grad, self.Generator.parameters())
        self.optim_G = optim.Adam(self.G_parameters, lr=args.lr_G, betas=(0.5, 0.9))
        self.scheduler_G = optim.lr_scheduler.MultiStepLR(self.optim_G, milestones=[4000], gamma=0.2)

        print('##DEFINE DISCRIMINATOR')
        self.Discriminator = Discriminator(self.args)
        self.Discriminator.cuda()
        self.Discriminator.apply(weights_init)
        self.D_parameters = filter(lambda p: p.requires_grad, self.Discriminator.parameters())
        self.optim_D = optim.Adam(self.D_parameters, lr=args.lr_D, betas=(0.5, 0.9))
        self.scheduler_D = optim.lr_scheduler.MultiStepLR(self.optim_D, milestones=[20000], gamma=0.2)

        self.num_ents = len(self.ent2id.keys())

        print('##BUILDING CONNECTION MATRIX')
        degrees = self.build_connection(max_=args.max_neighbor)



    def ensure_path(self, path):
        print(path)
        if osp.exists(path):
            if input('{} exists, remove? ([y]/n)'.format(path)) != 'n':
                shutil.rmtree(path)
                os.mkdir(path)
        else:
            os.mkdir(path)

    def load_embed(self):

        symbol_id = {}

        print('##LOADING PRE-TRAINED EMBEDDING')
        if self.args.embed_model in ['DistMult', 'TransE']:
            embed_all = np.load(os.path.join(self.args.data_path, self.args.embed_model + '_embed.npz'))
            ent_embed = embed_all['eM']
            rel_embed = embed_all['rM']
            print('    ent_embed shape is {}, the number of entity is {}'.format(ent_embed.shape,
                                                                                 len(self.ent2id.keys())))
            print('    rel_embed shape is {}, the number of relation is {}'.format(rel_embed.shape,
                                                                                   len(self.rel2id.keys())))

            i = 0
            embeddings = []
            for key in self.rel2id.keys():
                if key not in ['', 'OOV']:
                    symbol_id[key] = i
                    i += 1
                    embeddings.append(list(rel_embed[self.rel2id[key], :]))

            for key in self.ent2id.keys():
                if key not in ['', 'OOV']:
                    symbol_id[key] = i
                    i += 1
                    embeddings.append(list(ent_embed[self.ent2id[key], :]))

            symbol_id['PAD'] = i
            embeddings.append(list(np.zeros((rel_embed.shape[1],))))
            embeddings = np.array(embeddings)



            np.savez(os.path.join(self.args.data_path, 'Embed_used', self.args.embed_model), embeddings)
            json.dump(symbol_id, open(os.path.join(self.args.data_path, 'Embed_used', self.args.embed_model + '2id'), 'w'))

            self.symbol2id = symbol_id
            self.symbol2vec = embeddings

    def read_embed(self):
        symbol_id = json.load(open(
                os.path.join(self.args.data_path, 'Embed_used', self.args.embed_model + '2id')))
        embeddings = np.load(os.path.join(self.args.data_path, 'Embed_used', self.args.embed_model + '.npz'))['arr_0']

        self.symbol2id = symbol_id
        self.symbol2vec = embeddings

    #  build neighbor connection
    def build_connection(self, max_=100):

        self.connections = (np.ones((self.num_ents, max_, 2)) * self.pad_id).astype(int)
        self.e1_rele2 = defaultdict(list)
        self.e1_degrees = defaultdict(int)
        # rel_list = list()
        with open(os.path.join(self.args.data_path, 'path_graph')) as f:
            lines = f.readlines()
            for line in tqdm(lines):
                e1, rel, e2 = line.rstrip().split()
                self.e1_rele2[e1].append((self.symbol2id[rel], self.symbol2id[e2]))
                self.e1_rele2[e2].append((self.symbol2id[rel], self.symbol2id[e1]))

        # print("path graph relations:", len(set(rel_list)))
        degrees = {}
        for ent, id_ in self.ent2id.items():
            neighbors = self.e1_rele2[ent]
            if len(neighbors) > max_:
                neighbors = neighbors[:max_]
            # degrees.append(len(neighbors))
            degrees[ent] = len(neighbors)
            self.e1_degrees[id_] = len(neighbors)  # add one for self conn
            for idx, _ in enumerate(neighbors):
                self.connections[id_, idx, 0] = _[0]
                self.connections[id_, idx, 1] = _[1]
        # print(self.connections[0])
        # json.dump(degrees, open(self.dataset + '/degrees', 'w'))
        # assert 1==2

        return degrees

    def save_pretrain(self):
        torch.save(self.Extractor.state_dict(), os.path.join(self.args.data_path, 'FE_models_trained', self.args.embed_model + '_Extractor'))


    def load_pretrain(self):
        self.Extractor.load_state_dict(torch.load(os.path.join(self.args.data_path, 'FE_models_trained', self.args.embed_model + '_Extractor'), map_location=lambda storage, loc: storage.cuda(self.args.gpu)))
        self.Extractor.eval()




    def save_model(self):

        path = self.args.save_path

        torch.save(self.Generator.state_dict(), os.path.join(path, self.args.embed_model + '_Generator'))
        torch.save(self.Discriminator.state_dict(), os.path.join(path, self.args.embed_model + '_Discriminator'))

    def load_model(self):
        self.Generator.load_state_dict(torch.load(os.path.join(self.args.save_path, self.args.embed_model + '_Generator')))
        self.Discriminator.load_state_dict(torch.load(os.path.join(self.args.save_path, self.args.embed_model + '_Discriminator')))

    def get_meta(self, left, right):
        left_connections = Variable(
            torch.LongTensor(np.stack([self.connections[_, :, :] for _ in left], axis=0))).cuda()
        left_degrees = Variable(torch.FloatTensor([self.e1_degrees[_] for _ in left])).cuda()
        right_connections = Variable(
            torch.LongTensor(np.stack([self.connections[_, :, :] for _ in right], axis=0))).cuda()
        right_degrees = Variable(torch.FloatTensor([self.e1_degrees[_] for _ in right])).cuda()
        return (left_connections, left_degrees, right_connections, right_degrees)

    def pretrain_Extractor(self):
        print('\n##PRETRAINING FEATURE EXTRACTOR ....')
        # self.ensure_path(self.args.save_path)

        pretrain_losses = deque([], 100)

        i = 0
        for data in Extractor_generate(self.args.manual_seed, self.args.data_path, self.train_tasks, self.pretrain_batch_size, self.symbol2id, self.ent2id,
                                       self.e1rel_e2, self.pretrain_few, self.pretrain_subepoch):
            i += 1

            support, query, false, support_left, support_right, query_left, query_right, false_left, false_right = data

            support_meta = self.get_meta(support_left, support_right)
            query_meta = self.get_meta(query_left, query_right)
            false_meta = self.get_meta(false_left, false_right)

            support = Variable(torch.LongTensor(support)).cuda()
            query = Variable(torch.LongTensor(query)).cuda()
            false = Variable(torch.LongTensor(false)).cuda()

            query_ep, query_scores = self.Extractor(query, support, query_meta, support_meta)
            false_ep, false_scores = self.Extractor(false, support, false_meta, support_meta)

            margin_ = query_scores - false_scores
            pretrain_loss = F.relu(self.args.pretrain_margin - margin_).mean()

            self.optim_E.zero_grad()
            pretrain_loss.backward()
            # self.scheduler.step()
            pretrain_losses.append(pretrain_loss.item())

            if i % self.args.pretrain_loss_every == 0:
                print("Step: %d, Feature Extractor Pretraining loss: %.10f" % (i, np.mean(pretrain_losses)))

            self.optim_E.step()

            if i > self.args.pretrain_times:
                break



        self.save_pretrain()
        print('SAVE FEATURE EXTRACTOR PRETRAINING MODEL!!!')

    def train(self):
        print('\n##START ADVERSARIAL TRAINING...')

        # Pretraining step to obtain reasonable real data embeddings
        if self.args.pretrain_feature_extractor:
            self.pretrain_Extractor()
            print('Finish Pretraining!\n')

        self.load_pretrain()


        self.centroid_matrix = torch.zeros((len(self.train_tasks), self.args.ep_dim))
        self.centroid_matrix = self.centroid_matrix.cuda()

        for relname in self.train_tasks.keys():
            query, query_left, query_right, label_id = centroid_generate(self.args.data_path, relname, self.symbol2id,
                                                                         self.ent2id, self.train_tasks, self.rela2label)
            query_meta = self.get_meta(query_left, query_right)
            query = Variable(torch.LongTensor(query)).cuda()
            query_ep, _ = self.Extractor(query, query, query_meta, query_meta)
            self.centroid_matrix[label_id] = query_ep.data.mean(dim=0)
        self.centroid_matrix = Variable(self.centroid_matrix)

        best_hits10 = 0.0

        D_every = self.args.D_epoch * self.args.loss_every
        D_losses = deque([], D_every)
        D_real_losses, D_real_class_losses, D_fake_losses, D_fake_class_losses \
            = deque([], D_every), deque([], D_every), deque([], D_every), deque([], D_every)

        # loss_G_fake + loss_G_class + loss_VP
        G_every = self.args.G_epoch * self.args.loss_every
        G_losses = deque([], G_every)
        G_fake_losses, G_class_losses, G_VP_losses, G_real_class_losses \
            = deque([], G_every), deque([], G_every), deque([], G_every), deque([], G_every)

        G_data = train_generate_decription(self.args.data_path, self.train_tasks, self.args.G_batch_size, self.symbol2id, self.ent2id,
                                           self.e1rel_e2, self.rel2id, self.args, self.rela2label, self.rela_matrix)

        nets = [self.Generator, self.Discriminator]
        reset_grad(nets)

        for epoch in range(1, (self.args.train_times+1)):

            # train Discriminator
            self.Discriminator.train()
            self.Generator.eval()
            for _ in range(self.args.D_epoch):  # D_epoch = 5
                ### Discriminator real part
                D_descriptions, query, query_left, query_right, D_false, D_false_left, D_false_right, D_labels = G_data.__next__()

                # real part
                query_meta = self.get_meta(query_left, query_right)
                query = Variable(torch.LongTensor(query)).cuda()
                D_real, _ = self.Extractor(query, query, query_meta, query_meta)

                # fake part
                noises = Variable(torch.randn(len(query), self.noise_dim)).cuda()
                D_descriptions = Variable(torch.FloatTensor(D_descriptions)).cuda()
                D_fake = self.Generator(D_descriptions, noises)

                # neg part
                D_false_meta = self.get_meta(D_false_left, D_false_right)
                D_false = Variable(torch.LongTensor(D_false)).cuda()
                D_neg, _ = self.Extractor(D_false, D_false, D_false_meta, D_false_meta)

                # generate Discriminator part vector
                centroid_matrix_ = self.centroid_matrix  # gaussian_noise(self.centroid_matrix)
                _, D_real_decision, D_real_class = self.Discriminator(D_real.detach(), centroid_matrix_)
                _, D_fake_decision, D_fake_class = self.Discriminator(D_fake.detach(), centroid_matrix_)
                _, _, D_neg_class = self.Discriminator(D_neg.detach(), self.centroid_matrix)

                # real adversarial training loss
                loss_D_real = -torch.mean(D_real_decision)

                # adversarial training loss
                loss_D_fake = torch.mean(D_fake_decision)

                # real classification loss
                D_real_scores = D_real_class[range(len(query)), D_labels]
                D_neg_scores = D_neg_class[range(len(query)), D_labels]
                D_margin_real = D_real_scores - D_neg_scores
                loss_rela_class = F.relu(self.args.pretrain_margin - D_margin_real).mean()

                # fake classification loss
                D_fake_scores = D_fake_class[range(len(query)), D_labels]
                D_margin_fake = D_fake_scores - D_neg_scores
                loss_fake_class = F.relu(self.args.pretrain_margin - D_margin_fake).mean()

                grad_penalty = calc_gradient_penalty(self.Discriminator, D_real.data, D_fake.data, len(query),
                                                     self.centroid_matrix)

                loss_D = loss_D_real + 0.5 * loss_rela_class + loss_D_fake + grad_penalty + 0.5 * loss_fake_class

                # D_real_losses, D_real_class_losses, D_fake_losses, D_fake_class_losses
                D_losses.append(loss_D.item())
                D_real_losses.append(loss_D_real.item())
                D_real_class_losses.append(loss_rela_class.item())
                D_fake_losses.append(loss_D_fake.item())
                D_fake_class_losses.append(loss_fake_class.item())

                loss_D.backward()
                self.scheduler_D.step()
                self.optim_D.step()
                reset_grad(nets)

            # train Generator
            self.Discriminator.eval()
            self.Generator.train()
            for _ in range(self.args.G_epoch):  # G_epoch = 1

                G_descriptions, query, query_left, query_right, G_false, G_false_left, G_false_right, G_labels = G_data.__next__()

                # G sample
                noises = Variable(torch.randn(len(query), self.args.noise_dim)).cuda()
                G_descriptions = Variable(torch.FloatTensor(G_descriptions)).cuda()
                G_sample = self.Generator(G_descriptions, noises)  # to train G

                # real data
                query_meta = self.get_meta(query_left, query_right)
                query = Variable(torch.LongTensor(query)).cuda()
                G_real, _ = self.Extractor(query, query, query_meta, query_meta)

                # This negative for classification loss
                G_false_meta = self.get_meta(G_false_left, G_false_right)
                G_false = Variable(torch.LongTensor(G_false)).cuda()
                G_neg, _ = self.Extractor(G_false, G_false, G_false_meta,
                                          G_false_meta)  # just use Extractor to generate ep vector

                # generate Discriminator part vector
                centroid_matrix_ = self.centroid_matrix
                _, G_decision, G_class = self.Discriminator(G_sample, centroid_matrix_)
                _, _, G_real_class = self.Discriminator(G_real.detach(), centroid_matrix_)
                _, _, G_neg_class = self.Discriminator(G_neg.detach(), centroid_matrix_)

                # adversarial training loss
                loss_G_fake = - torch.mean(G_decision)

                # G sample (fake) classification loss
                G_scores = G_class[range(len(query)), G_labels]
                G_neg_scores = G_neg_class[range(len(query)), G_labels]
                G_margin_ = G_scores - G_neg_scores
                loss_G_class = F.relu(self.args.pretrain_margin - G_margin_).mean()

                # real classification loss
                G_real_scores = G_real_class[range(len(query)), G_labels]
                G_margin_real = G_real_scores - G_neg_scores
                loss_rela_class_ = F.relu(self.args.pretrain_margin - G_margin_real).mean()

                # Visual Pivot Regularization
                count = 0
                loss_VP = Variable(torch.Tensor([0.0])).cuda()
                for i in range(len(self.train_tasks.keys())):
                    sample_idx = (np.array(G_labels) == i).nonzero()[0]
                    count += len(sample_idx)
                    if len(sample_idx) == 0:
                        loss_VP += 0.0
                    else:
                        G_sample_cls = G_sample[sample_idx, :]
                        loss_VP += (G_sample_cls.mean(dim=0) - self.centroid_matrix[i]).pow(2).sum().sqrt()
                assert count == len(query)
                loss_VP *= float(1.0 / self.args.gan_batch_rela)

                # ||W||_2 regularization
                reg_loss = Variable(torch.Tensor([0.0])).cuda()
                if self.args.REG_W != 0:
                    for name, p in self.Generator.named_parameters():
                        if 'weight' in name:
                            reg_loss += p.pow(2).sum()
                    reg_loss.mul_(self.args.REG_W)

                # ||W_z||21 regularization, make W_z sparse
                reg_Wz_loss = Variable(torch.Tensor([0.0])).cuda()
                if self.args.REG_Wz != 0:
                    Wz = self.Generator.fc1.weight
                    reg_Wz_loss = Wz.pow(2).sum(dim=0).sqrt().sum().mul(self.args.REG_Wz)

                # Generator loss function
                loss_G = loss_G_fake + loss_G_class + 3.0 * loss_VP  # + reg_Wz_loss + reg_loss

                # G_fake_losses, G_class_losses, G_VP_losses
                G_losses.append(loss_G.item())
                G_fake_losses.append(loss_G_fake.item())
                G_class_losses.append(loss_G_class.item())
                G_real_class_losses.append(loss_rela_class_.item())
                G_VP_losses.append(loss_VP.item())

                loss_G.backward()
                self.scheduler_G.step()
                self.optim_G.step()
                reset_grad(nets)

            if epoch % self.args.loss_every == 0:
                D_screen = [np.mean(D_real_losses), np.mean(D_real_class_losses), np.mean(D_fake_losses),
                            np.mean(D_fake_class_losses)]
                G_screen = [np.mean(G_fake_losses), np.mean(G_class_losses), np.mean(G_real_class_losses),
                            np.mean(G_VP_losses)]
                print("Epoch: %d, D_loss: %.2f [%.2f, %.2f, %.2f, %.2f], G_loss: %.2f [%.2f, %.2f, %.2f, %.2f]" \
                      % (
                      epoch, np.mean(D_losses), D_screen[0], D_screen[1], D_screen[2], D_screen[3], np.mean(G_losses),
                      G_screen[0], G_screen[1], G_screen[2], G_screen[3]))

            # D_screen = [np.mean(D_real_losses), np.mean(D_real_class_losses), np.mean(D_fake_losses),
            #             np.mean(D_fake_class_losses)]
            # G_screen = [np.mean(G_fake_losses), np.mean(G_class_losses), np.mean(G_real_class_losses),
            #             np.mean(G_VP_losses)]
            # print("Epoch: %d, D_loss: %.2f [%.2f, %.2f, %.2f, %.2f], G_loss: %.2f [%.2f, %.2f, %.2f, %.2f]" \
            #       % (
            #           epoch, np.mean(D_losses), D_screen[0], D_screen[1], D_screen[2], D_screen[3], np.mean(G_losses),
            #           G_screen[0], G_screen[1], G_screen[2], G_screen[3]))

            if epoch >= 1000 and epoch % self.args.eval_every == 0:
                self.eval(mode='test', epoch=epoch)
                # self.save_model()



    def eval(self, mode='dev', epoch=0):
        self.Generator.eval()
        self.Discriminator.eval()
        # self.Extractor.eval()
        symbol2id = self.symbol2id

        print('##EVALUATING ON %s DATA' % mode.upper())
        # test_candidates = json.load(open(self.args.data_path + "/test_candidates_sub_10.json"))
        test_candidates = json.load(open(self.args.data_path + "/test_candidates.json"))

        hits10 = []
        hits5 = []
        hits1 = []
        mrr = []


        for query_ in sorted(test_candidates.keys()):


            hits10_ = []
            hits5_ = []
            hits1_ = []
            mrr_ = []

            description = self.rela_matrix[self.rel2id[query_]]
            description = np.expand_dims(description, axis=0)
            descriptions = np.tile(description, (self.args.test_sample, 1))
            descriptions = Variable(torch.FloatTensor(descriptions)).cuda()
            relation_vecs = self.Generator(descriptions, self.test_noises)
            relation_vecs = relation_vecs.data.cpu().numpy()

            for e1_rel, tail_candidates in test_candidates[query_].items():
                if self.args.dataset == "NELL":
                    head, rela, _ = e1_rel.split('\t')
                elif self.args.dataset == "Wiki":
                    head, rela = e1_rel.split('\t')

                true = tail_candidates[0]
                query_pairs = []
                if head not in symbol2id or true not in symbol2id:
                    continue
                query_pairs.append([symbol2id[head], symbol2id[true]])


                query_left = []
                query_right = []
                query_left.append(self.ent2id[head])
                query_right.append(self.ent2id[true])

                for tail in tail_candidates[1:]:
                    if tail not in symbol2id:
                        continue
                    query_pairs.append([symbol2id[head], symbol2id[tail]])

                    query_left.append(self.ent2id[head])
                    query_right.append(self.ent2id[tail])

                query = Variable(torch.LongTensor(query_pairs)).cuda()


                query_meta = self.get_meta(query_left, query_right)
                candidate_vecs, _ = self.Extractor(query, query, query_meta, query_meta)

                candidate_vecs.detach()
                candidate_vecs = candidate_vecs.data.cpu().numpy()

                # dot product
                # scores = candidate_vecs.dot(relation_vecs.transpose())

                # cosine similarity
                scores = cosine_similarity(candidate_vecs, relation_vecs)

                scores = scores.mean(axis=1)

                assert scores.shape == (len(query_pairs),)

                sort = list(np.argsort(scores))[::-1]
                rank = sort.index(0) + 1
                if rank <= 10:
                    hits10.append(1.0)
                    hits10_.append(1.0)
                else:
                    hits10.append(0.0)
                    hits10_.append(0.0)
                if rank <= 5:
                    hits5.append(1.0)
                    hits5_.append(1.0)
                else:
                    hits5.append(0.0)
                    hits5_.append(0.0)
                if rank <= 1:
                    hits1.append(1.0)
                    hits1_.append(1.0)
                else:
                    hits1.append(0.0)
                    hits1_.append(0.0)
                mrr.append(1.0 / rank)
                mrr_.append(1.0 / rank)



        print('\n############   ' + mode + ' ' + str(epoch) + '    #############')
        print('HITS10: {:.3f}, HITS5: {:.3f}, HITS1: {:.3f}, MAP: {:.3f}'.format(np.mean(hits10),
                                                                                 np.mean(hits5),
                                                                                 np.mean(hits1),
                                                                                 np.mean(mrr)))
        print('###################################')

**8. Parameter Settings and Run Model**

In [None]:
if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir", default="/content/drive/MyDrive/ISWC_demo/ZS_KGC/data/", type=str)

    parser.add_argument("--dataset", default="NELL", type=str)
    parser.add_argument("--embed_model", default='TransE', type=str)

    # embedding dimension
    parser.add_argument("--embed_dim", default=100, type=int, help='dimension of triple embedding')
    parser.add_argument("--ep_dim", default=200, type=int, help='dimension of entity pair embedding')
    parser.add_argument("--fc1_dim", default=250, type=int, help='dimension of hidden units in generator')
    parser.add_argument("--noise_dim", default=15, type=int)

    # feature extractor pretraining related
    parser.add_argument("--pretrain_batch_size", default=64, type=int)
    parser.add_argument("--pretrain_few", default=30, type=int)
    parser.add_argument("--pretrain_subepoch", default=20, type=int)
    parser.add_argument("--pretrain_margin", default=10.0, type=float, help='pretraining margin loss')
    parser.add_argument("--pretrain_times", default=16000, type=int, help='total training steps for pretraining')
    parser.add_argument("--pretrain_loss_every", default=500, type=int)

    # adversarial training related
    # batch size
    parser.add_argument("--D_batch_size", default=256, type=int)
    parser.add_argument("--G_batch_size", default=256, type=int)
    parser.add_argument("--gan_batch_rela", default=2, type=int)
    # learning rate
    parser.add_argument("--lr_G", default=0.0001, type=float)
    parser.add_argument("--lr_D", default=0.0001, type=float)
    parser.add_argument("--lr_E", default=0.0005, type=float)
    # training times
    parser.add_argument("--train_times", default=8000, type=int)
    parser.add_argument("--D_epoch", default=5, type=int)
    parser.add_argument("--G_epoch", default=1, type=int)
    # log
    parser.add_argument("--loss_every", default=50, type=int)
    parser.add_argument("--eval_every", default=200, type=int)
    # hyper-parameter
    parser.add_argument("--test_sample", default=20, type=int, help='number of synthesized samples')
    parser.add_argument("--dropout", default=0.5, type=float)
    parser.add_argument('--REG_W', default=0.001, type=float)
    parser.add_argument('--REG_Wz', default=0.0001, type=float)
    parser.add_argument("--max_neighbor", default=50, type=int, help='neighbor number of each entity')
    parser.add_argument("--grad_clip", default=5.0, type=float)
    parser.add_argument("--weight_decay", default=0.0, type=float)

    parser.add_argument("--fine_tune", action='store_true')
    parser.add_argument("--aggregate", default='max', type=str)
    parser.add_argument("--semantic_type", default='rdfs', help='the type of relation embedding to input, options: {text, rdfs, rdfs_hie, rdfs_cons, rdfs_text}')
    # switch
    parser.add_argument("--pretrain_feature_extractor", action='store_true')
    parser.add_argument("--load_trained_embed", action='store_true', help='load well trained kg embeddings, such as TransE')


    parser.add_argument("--manual_seed", type=int, default=6096)
    parser.add_argument('--gpu', type=int, default=0, help='device to use for iterate data, -1 means cpu [default: 0]')

    args = parser.parse_known_args()[0]

    args.data_path = os.path.join(args.data_dir, args.dataset)

    args.save_path = os.path.join(args.data_path, 'expri_data', 'models_train')

    if args.manual_seed is None:
        args.manual_seed = random.randint(1, 10000)

    print("------HYPERPARAMETERS-------")
    for k, v in vars(args).items():
        print(k + ': ' + str(v))
    print("----------------------------")

    np.random.seed(args.manual_seed)
    random.seed(args.manual_seed)
    torch.manual_seed(args.manual_seed)
    if torch.cuda.is_available():
        torch.cuda.set_device(args.gpu)
        print('using gpu {}'.format(args.gpu))
        torch.cuda.manual_seed_all(args.manual_seed)
        torch.backends.cudnn.deterministic = True
    else:
        print("GPU is not available!")



    trainer = Trainer(args)
    trainer.train()
    # trainer.test_()

------HYPERPARAMETERS-------
data_dir: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data/
dataset: NELL
embed_model: TransE
embed_dim: 100
ep_dim: 200
fc1_dim: 250
noise_dim: 15
pretrain_batch_size: 64
pretrain_few: 30
pretrain_subepoch: 20
pretrain_margin: 10.0
pretrain_times: 16000
pretrain_loss_every: 500
D_batch_size: 256
G_batch_size: 256
gan_batch_rela: 2
lr_G: 0.0001
lr_D: 0.0001
lr_E: 0.0005
train_times: 8000
D_epoch: 5
G_epoch: 1
loss_every: 50
eval_every: 200
test_sample: 20
dropout: 0.5
REG_W: 0.001
REG_Wz: 0.0001
max_neighbor: 50
grad_clip: 5.0
weight_decay: 0.0
fine_tune: False
aggregate: max
semantic_type: rdfs
pretrain_feature_extractor: False
load_trained_embed: False
manual_seed: 6096
gpu: 0
data_path: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data/NELL
save_path: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data/NELL/expri_data/models_train
----------------------------
using gpu 0
##LOADING CANDIDATES ENTITIES##
##LOADING SYMBOL ID AND SYMBOL EMBEDDING
num symbols: 65748
##DEFINE

100%|██████████| 181053/181053 [00:00<00:00, 322462.45it/s]



##START ADVERSARIAL TRAINING...
##LOADING TRAINING DATA
##LOADING CANDIDATES




Epoch: 50, D_loss: 1.14 [-4.77, 0.06, 0.77, 8.95], G_loss: 42.98 [-0.71, 8.10, 0.05, 11.86]
Epoch: 100, D_loss: -4.12 [-7.90, 0.06, 2.49, 1.74], G_loss: 27.53 [-2.42, 1.60, 0.04, 9.45]
Epoch: 150, D_loss: -2.46 [-5.99, 0.06, 3.08, 0.44], G_loss: 19.94 [-3.04, 0.42, 0.08, 7.52]
Epoch: 200, D_loss: -1.73 [-2.36, 0.05, 0.40, 0.20], G_loss: 21.48 [-0.31, 0.25, 0.04, 7.18]
Epoch: 250, D_loss: -2.50 [1.80, 0.05, -4.61, 0.22], G_loss: 26.23 [4.68, 0.24, 0.05, 7.10]
Epoch: 300, D_loss: -3.28 [2.71, 0.05, -6.33, 0.15], G_loss: 27.29 [6.34, 0.17, 0.06, 6.93]
Epoch: 350, D_loss: -3.95 [2.92, 0.05, -7.26, 0.14], G_loss: 26.67 [7.27, 0.16, 0.04, 6.41]
Epoch: 400, D_loss: -4.16 [3.34, 0.05, -7.91, 0.12], G_loss: 28.39 [7.92, 0.17, 0.04, 6.77]
Epoch: 450, D_loss: -4.22 [3.72, 0.04, -8.35, 0.11], G_loss: 26.71 [8.41, 0.08, 0.03, 6.07]
Epoch: 500, D_loss: -4.31 [4.06, 0.04, -8.80, 0.11], G_loss: 27.04 [8.82, 0.09, 0.03, 6.04]
Epoch: 550, D_loss: -4.17 [4.61, 0.05, -9.23, 0.10], G_loss: 28.37 [9.28, 0.1

**Parameters in other Settings**


---

*   **run OntoZSL on NELL with "RDFS+literal"**

------HYPERPARAMETERS-------
data_dir: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data;
dataset: NELL;
embed_model: TransE;
embed_dim: 100;
ep_dim: 200;
fc1_dim: 250;
noise_dim: 15;
pretrain_batch_size: 64;
pretrain_few: 30;
pretrain_subepoch: 20;
pretrain_margin: 10.0;
pretrain_times: 16000;
pretrain_loss_every: 500;
D_batch_size: 256;
G_batch_size: 256;
gan_batch_rela: 2;
lr_G: 0.0001;
lr_D: 0.0001;
lr_E: 0.0005;
train_times: 8000;
D_epoch: 5;
G_epoch: 1;
loss_every: 50;
eval_every: 200;
test_sample: 20;
dropout: 0.5;
REG_W: 0.001;
REG_Wz: 0.0001;
max_neighbor: 50;
grad_clip: 5.0;
weight_decay: 0.0;
fine_tune: False;
aggregate: max;
semantic_type: rdfs_text;
pretrain_feature_extractor: False;
load_trained_embed: False;
manual_seed: 6096;


*   **run OntoZSL on Wiki with "RDFS"**

------HYPERPARAMETERS-------
data_dir: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data;
dataset: Wiki;
embed_model: TransE;
embed_dim: 50;
ep_dim: 100;
fc1_dim: 200;
noise_dim: 15;
pretrain_batch_size: 64;
pretrain_few: 30;
pretrain_subepoch: 20;
pretrain_margin: 10.0;
pretrain_times: 16000;
pretrain_loss_every: 500;
D_batch_size: 64;
G_batch_size: 64;
gan_batch_rela: 8;
lr_G: 0.0001;
lr_D: 0.0001;
lr_E: 0.0005;
train_times: 8000;
D_epoch: 5;
G_epoch: 1;
loss_every: 50;
eval_every: 200;
test_sample: 20;
dropout: 0.5;
REG_W: 0.001;
REG_Wz: 0.0001;
max_neighbor: 50;
grad_clip: 5.0;
weight_decay: 0.0;
fine_tune: False;
aggregate: max;
semantic_type: rdfs;
pretrain_feature_extractor: False;
load_trained_embed: False;
manual_seed: 6096;
gpu: 1;


*   **run OntoZSL on Wiki with "RDFS+literal"**

------HYPERPARAMETERS-------
data_dir: /content/drive/MyDrive/ISWC_demo/ZS_KGC/data;
dataset: Wiki;
embed_model: TransE;
embed_dim: 50;
ep_dim: 100;
fc1_dim: 200;
noise_dim: 15;
pretrain_batch_size: 64;
pretrain_few: 30;
pretrain_subepoch: 20;
pretrain_margin: 10.0;
pretrain_times: 16000;
pretrain_loss_every: 500;
D_batch_size: 64;
G_batch_size: 64;
gan_batch_rela: 8;
lr_G: 0.0001;
lr_D: 0.0001;
lr_E: 0.0005;
train_times: 8000;
D_epoch: 5;
G_epoch: 1;
loss_every: 50;
eval_every: 200;
test_sample: 20;
dropout: 0.5;
REG_W: 0.001;
REG_Wz: 0.0001;
max_neighbor: 50;
grad_clip: 5.0;
weight_decay: 0.0;
fine_tune: False;
aggregate: max;
semantic_type: rdfs_text;
pretrain_feature_extractor: False;
load_trained_embed: False;
manual_seed: 6096;
