In [1]:
import sys
sys.path.insert(0, '..')

from torchkge.models import TransEModel
from torchkge.sampling import BernoulliNegativeSampler
from torchkge.utils.datasets import load_fb15k
from torchkge.utils import MarginLoss, DataLoader

In [2]:
kg_train, _, kg_test = load_fb15k('../data')
dataloader = DataLoader(kg_train, batch_size=5)
data = next(iter(dataloader))
data

(tensor([ 3920,   839, 10094,  2587, 11748]),
 tensor([9220, 9523,  775, 7238, 4833]),
 tensor([ 791, 1273,  846,   95, 1100]))

## 1. Negative Sampling

    False Triplet 을 만들어 주기 위해서 Negative Sampling 을 한다.
    Head / Tail 중 하나를 임의의 Entity로 바꿔줌으로써, False Triplet 을 만든다.
        Head / Tail 중 하나를 랜덤하게 고르고 (P(head) = 0.5), 이를 다른 Entity로 바꿔준다. => Uniform Netagive Sampling
    
    하지만 False Triple 이 실제로 존재하는 Triple일 경우가 발생한다. (False Negative)
    False Negative Sampling 을 줄이기 위해서, Head / Tail 중 하나를 고를때 Relation의 특징을 이용한다.
    
        1. 1-to-N relation 을 가지는 triple의 경우, head 를 바꿔준다.
        2. N-to-1 relation 을 가지는 triple의 경우, tail 을 바꿔준다.
        
        모든 triple을 검토해 relation 에 해당되는 확률을 계산하고, 이를 sampling 에 반영한다.
    

In [3]:
from torchkge.utils.operations import get_bernoulli_probs
from torch import tensor, bernoulli, cat, randint
from pandas import DataFrame

In [4]:
t = cat((kg_train.head_idx.view(-1,1), 
         kg_train.tail_idx.view(-1,1),
         kg_train.relations.view(-1,1)), dim=1)
t

tensor([[ 3920,  9220,   791],
        [  839,  9523,  1273],
        [10094,   775,   846],
        ...,
        [ 8170, 13876,   100],
        [13723,  8121,  1030],
        [ 6647,  6817,   851]])

### 1.1 Compute bernoulli distribution for each relation

In [5]:
# get average number of head per tail for each relation

df = DataFrame(t.numpy(), columns=['from','to','rel'])
df = df.groupby(['rel','to']).count().groupby('rel').mean()
df.reset_index(inplace=True)
hpt = {i:v for i,v in df.values}
hpt

{0.0: 1.0,
 1.0: 1.0,
 2.0: 1.0,
 3.0: 1.0,
 4.0: 1.0,
 5.0: 1.0,
 6.0: 1.0,
 7.0: 1.5,
 8.0: 1.48,
 9.0: 1.8888888888888888,
 10.0: 1.4482758620689655,
 11.0: 2.0,
 12.0: 1.0,
 13.0: 2.5,
 14.0: 1.0,
 15.0: 2.0,
 16.0: 1.0,
 17.0: 1.0,
 18.0: 1.5,
 19.0: 1.0,
 20.0: 1.4285714285714286,
 21.0: 1.0,
 22.0: 40.6,
 23.0: 1.0,
 24.0: 35.8125,
 25.0: 1.0,
 26.0: 28.0,
 27.0: 2.6,
 28.0: 32.05,
 29.0: 2.642857142857143,
 30.0: 31.35,
 31.0: 1.5,
 32.0: 1.0,
 33.0: 1.0,
 34.0: 1.4,
 35.0: 1.625,
 36.0: 1.0,
 37.0: 1.0,
 38.0: 1.1428571428571428,
 39.0: 1.25,
 40.0: 3.0,
 41.0: 1.25,
 42.0: 1.5,
 43.0: 1.0,
 44.0: 1.5555555555555556,
 45.0: 1.0,
 46.0: 1.0,
 47.0: 1.0,
 48.0: 1.0,
 49.0: 1.0,
 50.0: 1.0,
 51.0: 4.0,
 52.0: 1.0,
 53.0: 4.0,
 54.0: 1.0,
 55.0: 2.0,
 56.0: 1.0,
 57.0: 3.0,
 58.0: 1.0,
 59.0: 2.6666666666666665,
 60.0: 1.0416666666666667,
 61.0: 1.6666666666666667,
 62.0: 3.5,
 63.0: 1.6666666666666667,
 64.0: 1.0,
 65.0: 1.0,
 66.0: 2.2857142857142856,
 67.0: 1.0,
 68.0: 2.0,
 69

In [6]:
# get average number of tail per head for each relation

df = DataFrame(t.numpy(), columns=['from','to','rel'])
df = df.groupby(['from','rel']).count().groupby('rel').mean()
df.reset_index(inplace=True)
tph = {i:v for i,v in df.values}
tph

{0.0: 1.0,
 1.0: 1.0,
 2.0: 1.0,
 3.0: 38.0,
 4.0: 1.0,
 5.0: 1.0,
 6.0: 1.0,
 7.0: 1.25,
 8.0: 2.3125,
 9.0: 1.3076923076923077,
 10.0: 2.625,
 11.0: 1.3333333333333333,
 12.0: 1.6666666666666667,
 13.0: 1.6666666666666667,
 14.0: 1.0,
 15.0: 1.4285714285714286,
 16.0: 1.0,
 17.0: 2.0,
 18.0: 1.0,
 19.0: 1.6666666666666667,
 20.0: 2.3076923076923075,
 21.0: 1.0,
 22.0: 9.36923076923077,
 23.0: 1.0,
 24.0: 8.815384615384616,
 25.0: 1.0,
 26.0: 1.0,
 27.0: 1.3928571428571428,
 28.0: 7.202247191011236,
 29.0: 1.48,
 30.0: 7.125,
 31.0: 1.5,
 32.0: 1.0,
 33.0: 1.0,
 34.0: 1.4,
 35.0: 1.4444444444444444,
 36.0: 1.0,
 37.0: 1.0,
 38.0: 1.0,
 39.0: 1.0,
 40.0: 1.0,
 41.0: 1.0,
 42.0: 1.0,
 43.0: 1.0,
 44.0: 1.4,
 45.0: 1.0,
 46.0: 1.0,
 47.0: 7.0,
 48.0: 1.0,
 49.0: 1.3333333333333333,
 50.0: 1.5,
 51.0: 1.0,
 52.0: 4.0,
 53.0: 1.0,
 54.0: 1.0,
 55.0: 1.0,
 56.0: 2.0,
 57.0: 1.2,
 58.0: 1.0,
 59.0: 2.2857142857142856,
 60.0: 5.0,
 61.0: 1.0,
 62.0: 1.75,
 63.0: 1.6666666666666667,
 64.0: 2.0

    relation   3: hpt = 1    , tph = 38   (1-to-N)
    relation 110: hpt = 47.86, tph = 1.23 (N-to-1)

In [7]:
bern_prob = [0.5]*kg_train.n_rel
for r in tph.keys():
    bern_prob[int(r)] = tph[r] / (tph[r] + hpt[r])

bern_prob = tensor(bern_prob).float()    
bern_prob

tensor([0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.1600, 0.4783])

    bern_prob[r] ~= 1 : tph[r] >> hpt[r] => 1-to-N
    bern_prob[r] ~= 0 : tph[r] << hpt[r] => N-to-1

### 1.2 Create corrupted triplets

In [8]:
n_neg = 1 # the number of invalid triple per triple 

heads, tails, relations = data[0], data[1], data[2] # Batch triplets
batch_size = heads.shape[0]
neg_heads = heads.repeat(n_neg)
neg_tails = tails.repeat(n_neg)

In [9]:
mask = bernoulli(bern_prob[relations].repeat(n_neg)).double()
mask

tensor([0., 1., 1., 0., 0.], dtype=torch.float64)

    [0,1,1,1, 0] : 0, 4 번째는 tail을 바꾸고, 1,2,3 번째는 head를 바꿔라

In [10]:
n_h_cor = int(mask.sum().item()) # the number of corrupted head

neg_heads[mask==1] = randint(1, kg_train.n_ent, (n_h_cor,))
neg_tails[mask==0] = randint(1, kg_train.n_ent, (batch_size*n_neg-n_h_cor,))

In [11]:
print(heads)
print(neg_heads)

tensor([ 3920,   839, 10094,  2587, 11748])
tensor([ 3920,  6023,  3639,  2587, 11748])


In [12]:
print(tails)
print(neg_tails)

tensor([9220, 9523,  775, 7238, 4833])
tensor([1010, 9523,  775, 4898, 4395])


## 2. Loss Function
    Knowledge Graph Embedding model is Energy-based model.
    Triplet is converted into a scalar energy.
    Based on the model, we use Margin Ranking Loss.

### 2.1 Get scalar energy using both triplets

In [13]:
print('Valid triplets')
print(f'heads    : {heads}')
print(f'relations: {relations}')
print(f'tails    : {tails}')
print('\n')
print('corrupted triplets')
print(f'heads    : {neg_heads}')
print(f'relations: {relations}')
print(f'tails    : {neg_tails}')

Valid triplets
heads    : tensor([ 3920,   839, 10094,  2587, 11748])
relations: tensor([ 791, 1273,  846,   95, 1100])
tails    : tensor([9220, 9523,  775, 7238, 4833])


corrupted triplets
heads    : tensor([ 3920,  6023,  3639,  2587, 11748])
relations: tensor([ 791, 1273,  846,   95, 1100])
tails    : tensor([1010, 9523,  775, 4898, 4395])


In [14]:
# Create Translation model

from torchkge.models import TransEModel

ent_emb_dim = 4

model = TransEModel(ent_emb_dim, kg_train.n_ent, kg_train.n_rel, dissimilarity_type='L2')

entity_embeddings = model.ent_emb
relation_embeddings = model.rel_emb

In [15]:
print(f'Entity Embedding Matrix  : {entity_embeddings}')
print(f'Relation Embedding Matrix: {relation_embeddings}')

Entity Embedding Matrix  : Embedding(14951, 4)
Relation Embedding Matrix: Embedding(1345, 4)


In [16]:
# positive triplet score

from torch.nn.functional import normalize

h = normalize(entity_embeddings(heads), p=2, dim=1) # layer normalization
t = normalize(entity_embeddings(tails), p=2, dim=1)
r = normalize(relation_embeddings(relations))

In [17]:
print(f'head embedding:    \n {h} \n')
print(f'tail embedding: \n {t} \n')
print(f'relation embedding: \n {r} \n')

head embedding:    
 tensor([[ 0.3858, -0.3258,  0.7736, -0.3827],
        [ 0.5567,  0.4163, -0.5309,  0.4847],
        [-0.5825, -0.1948, -0.4425,  0.6533],
        [-0.2206, -0.6302, -0.5828,  0.4632],
        [-0.8772, -0.4091, -0.2392,  0.0774]], grad_fn=<DivBackward0>) 

tail embedding: 
 tensor([[-0.0930, -0.5260,  0.7916, -0.2969],
        [ 0.1684, -0.6696, -0.6655, -0.2835],
        [ 0.5902,  0.7118,  0.0859, -0.3710],
        [-0.7117,  0.2081,  0.6239, -0.2468],
        [-0.7226, -0.1678, -0.0711, -0.6668]], grad_fn=<DivBackward0>) 

relation embedding: 
 tensor([[-0.7382,  0.3838,  0.3793, -0.4049],
        [-0.4717,  0.1630,  0.6574, -0.5645],
        [-0.2535,  0.9480, -0.0694,  0.1794],
        [ 0.6144,  0.4255,  0.3521, -0.5635],
        [-0.3249, -0.7234, -0.1065,  0.5999]], grad_fn=<DivBackward0>) 



In [18]:
pos_score = ((h+r)-t).norm(p=2, dim=-1)**2
pos_score

tensor([0.7797, 2.2353, 3.8423, 2.1443, 3.0423], grad_fn=<PowBackward0>)

In [19]:
nh = normalize(entity_embeddings(neg_heads), p=2, dim=1) # layer normalization
nt = normalize(entity_embeddings(neg_tails), p=2, dim=1)


print(f'negative head embedding:    \n {nh} \n')
print(f'negative tail embedding: \n {nt} \n')

negative head embedding:    
 tensor([[ 0.3858, -0.3258,  0.7736, -0.3827],
        [ 0.4108,  0.6540, -0.1436,  0.6188],
        [-0.6901, -0.3529,  0.4213, -0.4709],
        [-0.2206, -0.6302, -0.5828,  0.4632],
        [-0.8772, -0.4091, -0.2392,  0.0774]], grad_fn=<DivBackward0>) 

negative tail embedding: 
 tensor([[ 0.3841,  0.1421,  0.9045, -0.1192],
        [ 0.1684, -0.6696, -0.6655, -0.2835],
        [ 0.5902,  0.7118,  0.0859, -0.3710],
        [-0.5234, -0.0324,  0.5798,  0.6235],
        [-0.2394, -0.4171, -0.8747,  0.0605]], grad_fn=<DivBackward0>) 



In [20]:
neg_score = ((nh+r)-nt).norm(p=2, dim=-1)**2
neg_score

tensor([1.0579, 3.7672, 2.4435, 2.0518, 2.0987], grad_fn=<PowBackward0>)

### 2.2 Compute loss using both scores

In [22]:
from torch import max

loss = max(tensor([0]*len(pos_score)), (pos_score - neg_score) + 0.5)   # 0.5 : margin
loss

tensor([0.2218, 0.0000, 1.8988, 0.5925, 1.4436], grad_fn=<MaximumBackward>)

    실제 코드에서는
    loss_function = torch.nn.MarginRankingLoss(margin=0.5, reduction='sum')
    loss = loss_function(-pos_score, -neg_score, target=ones_like(postive_triplets))
    
    https://pytorch.org/docs/master/generated/torch.nn.MarginRankingLoss.html#torch.nn.MarginRankingLoss

## 3. Link Prediction

    5개의 triplet을 이용해 link prediction 을 하는 과정
    link prediction 을 측정하기 위해 2개의 criteria 를 정한다.
        1) mean rank
        2) Hit @ k

In [24]:
h_emb = entity_embeddings(heads)
t_emb = entity_embeddings(tails)
r_emb = relation_embeddings(relations)
candidates = entity_embeddings.weight.data.view(1, kg_test.n_ent, 4)
candidates = candidates.expand(5, kg_test.n_ent, 4)  # (batch_size, # of entity, emb_dim)

In [25]:
print(f'head     : {h_emb.size()}')
print(f'tail     : {t_emb.size()}')
print(f'relation : {r_emb.size()}')
print(f'candidate: {candidates.size()}') # Entity Embedding matrix 

head     : torch.Size([5, 4])
tail     : torch.Size([5, 4])
relation : torch.Size([5, 4])
candidate: torch.Size([5, 14951, 4])


### LP Compute Ranks

In [30]:
# Dictionary to check the valid triplet
# (head_idx, relation_idx) : [tail_idx1, tail_idx2, ... ]

dict_of_tails = kg_test.dict_of_tails
dict_of_tails

defaultdict(set,
            {(3920, 791): {1546, 3799, 9220},
             (839, 1273): {7489, 9523},
             (10094, 846): {57,
              58,
              61,
              62,
              64,
              67,
              68,
              69,
              72,
              73,
              95,
              322,
              357,
              496,
              541,
              684,
              687,
              697,
              739,
              745,
              747,
              752,
              755,
              757,
              758,
              766,
              771,
              773,
              774,
              775,
              778,
              780,
              837,
              1035,
              1051,
              1127,
              1201,
              1260,
              1328,
              1343,
              1355,
              1486,
              1520,
              1536,
              1618,
              1629,
       

In [50]:
print(f"({heads[0]},{relations[0]},{tails[0]})")
print(f"({heads[0].item()},{relations[0].item()}) : {dict_of_tails[heads[0].item(), relations[0].item()]}")

(3920,791,9220)
(3920,791) : {1546, 9220, 3799}


    전체 fact 에 대해서, (heat, relation) 이 주어질때 존재하는 tail의 집합이다.
    head_idx = 3920, relation_idx = 791 일때 존재하는 tail_idx는 [1546, 3799, 9220] 이다.

In [26]:
# The scores of candidate entities to predict tail entity

hr = (h_emb + r_emb).view(5, 1, r_emb.shape[1])
tail_candidate_scores = -((hr-candidates).norm(p=2, dim=-1)**2) # Equation for TransE
tail_candidate_scores.size()

torch.Size([5, 14951])

    head, relation 을 고정한채 모든 tail entity candidate 에 대해서 score 를 계산해준다.
    5개의 triple에 대해 14951개에 해당되는 entity들의 score가 계산됨을 볼 수 있다.

    우리의 목표는 각 triplet 마다 정답 tail이 높은 값을 가지고, 정답이 아닌 경우에는 낮은 값을 가지는것을 원한다.
    (논문에서는 낮은 값을 가지는것으로 표현했지만, torchkge는 score 값에 - 를 곱해줌으로써 높은 값을 가지는것으로 문제를 변경한다.)
    
    (head, relation) 마다 여러개의 정답을 가질 수 있기 때문에, triplet에 해당되는 않는 정답들에 대해서 filtering 을 해줘야 한다.

In [39]:
filtered_tail_candidate_scores = tail_candidate_scores.clone()

for i in range(5): # batch_size = 5
    true_target = dict_of_tails[heads[i].item(), relations[i].item()].copy()
    if len(true_target) == 1:
        true_target = None
    else:
        true_target.remove(tails[i].item())  # Delete true tail index (existing tail idx of the triplet)
        true_target = tensor(list(true_target)).long()

    if true_target is None:
        continue
    filtered_tail_candidate_scores[i][true_target] = -float('Inf')

In [62]:
# print(filtered_tail_candidate_scores[0])
print(tail_candidate_scores[0][list(dict_of_tails[heads[0].item(), relations[0].item()])])
print(filtered_tail_candidate_scores[0][list(dict_of_tails[heads[0].item(), relations[0].item()])])

tensor([-4.9000, -0.7797, -2.7698], grad_fn=<IndexBackward>)
tensor([   -inf, -0.7797,    -inf], grad_fn=<IndexBackward>)


    torchkge의 저자는 전체 entity의 랭킹을 구하는것이 아니라, 정답 tail entity의 ranking 만을 계산한다.
    
    i.e.
    correct tail entity score = 0.7
    total entity score = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9]
    correct tail entity score <= total entity score = [False, False, False, True, True]
    sum( '' ) = 2
    
    즉, correct tail 의 rank 는 2가 된다.

In [53]:
true_data = filtered_tail_candidate_scores.gather(1, tails.long().view(-1,1))
true_data

tensor([[-0.7797],
        [-2.2353],
        [-3.8423],
        [-2.1443],
        [-3.0423]], grad_fn=<GatherBackward>)

In [59]:
rank_true_data = (filtered_tail_candidate_scores >= true_data).sum(dim=1)
rank_true_data   

tensor([  815, 13889,  9259, 14427,  4124])

    mean rank 를 계산할 때는 위 rank 의 평균값을 이용하고,
    hit @ k 의 경우는 rank_true_data < k 를 이용해 값을 구한다.
    
    성능 측정이 아닌 실제 link prediction 을 하기 위해서는 score값을 기준으로 높은것을 찾아야 하기 때문에 score 값을 기록하고 있거나, 매번 새로 연산을 해줘야 한다.
    (torchkge의 경우 정답에 대한 순위만을 기록하지, 전체 entity에 대한 score는 따로 기록하지 않는다.) => 구현 필요