In [1]:
import math, random, torch, torch.nn as nn, torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm


def mobius_add(x, y, eps=1e-5):
    xy = (x * y).sum(dim=-1, keepdim=True)
    x2 = (x * x).sum(dim=-1, keepdim=True)
    y2 = (y * y).sum(dim=-1, keepdim=True)
    num = (1 + 2 * xy + y2) * x + (1 - x2) * y
    denom = 1 + 2 * xy + x2 * y2
    return num / (denom.clamp_min(eps))

def hyp_distance(x, y, eps=1e-5):
    x2 = (x * x).sum(dim=-1)
    y2 = (y * y).sum(dim=-1)
    diff2 = ((x - y) ** 2).sum(dim=-1)
    denom = (1 - x2) * (1 - y2)
    z = 1 + 2 * diff2 / denom.clamp_min(eps)
    return torch.acosh(z.clamp_min(1 + eps))

class MuRP(nn.Module):
    def __init__(self, num_ent, num_rel, dim):
        super().__init__()
        self.ent = nn.Embedding(num_ent, dim)
        self.rel = nn.Embedding(num_rel, dim)
        nn.init.uniform_(self.ent.weight, -0.001, 0.001)
        nn.init.uniform_(self.rel.weight, -0.001, 0.001)

    def forward(self, h, r, t):
        h_e = self.ent(h)
        r_e = self.rel(r).tanh()         
        t_e = self.ent(t)
        # translate: h ⊕ r
        h_r = mobius_add(h_e, r_e)
        dist = hyp_distance(h_r, t_e)
        return -dist                    

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_triples(path):
    triples = []
    with open(path) as f:
        for line in f:
            h,r,t = line.strip().split()[:3]
            triples.append((h,r,t))
    return triples

train_triples = load_triples("../graph_data/train.tsv")
dev_triples =  load_triples("../graph_data/dev.tsv")
test_triples =  load_triples("../graph_data/test.tsv")

entities = sorted({e for h,_,t in train_triples+dev_triples+test_triples for e in (h,t)})
relations = sorted({r for _,r,_ in train_triples+dev_triples+test_triples})
ent2id = {e:i for i,e in enumerate(entities)}
rel2id = {r:i for i,r in enumerate(relations)}


def encode(triples):
    return torch.tensor([[ent2id[h], rel2id[r], ent2id[t]] for h,r,t in triples], dtype=torch.long)

train_data = encode(train_triples)
dev_data   = encode(dev_triples)
test_data  = encode(test_triples)

In [3]:

batch_size = 512
train_loader = DataLoader(TensorDataset(train_data), batch_size=batch_size, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MuRP(len(entities), len(relations), dim=100).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)


def neg_sample(batch, n_ent):
    neg = batch.clone()
    mask = torch.rand(batch.size(0), device=batch.device) < 0.5
    neg[mask,0] = torch.randint(0,n_ent,(mask.sum().item(),),device=batch.device)
    neg[~mask,2] = torch.randint(0,n_ent,((~mask).sum().item(),),device=batch.device)
    return neg


In [4]:
def eval_rank(data_tensor, n_neg=100, ks=(1,3,10)):
    model.eval()
    ranks = []
    with torch.no_grad():
        for h,r,t in tqdm(data_tensor, desc='Eval', leave=False):
            h = torch.tensor([h], device=device)
            r = torch.tensor([r], device=device)
            t = torch.tensor([t], device=device)
            pos = model(h, r, t).item()
            scores = [pos]
            for _ in range(n_neg):
                if random.random()<0.5:
                    h2, t2 = random.randrange(len(entities)), t.item()
                else:
                    h2, t2 = h.item(), random.randrange(len(entities))
                scores.append(model(
                    torch.tensor([h2], device=device),
                    r,
                    torch.tensor([t2], device=device)
                ).item())
            rank = 1 + sum(s>pos for s in scores[1:])
            ranks.append(rank)
    ranks = torch.tensor(ranks, dtype=torch.float)
    mrr = (1.0/ranks).mean().item()
    hits = {f"Hits@{k}": (ranks<=k).float().mean().item() for k in ks}
    return mrr, hits

#

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MuRP(len(entities), len(relations), dim=100).to(device)
optimizer = AdamW(model.parameters(), lr=5e-4, weight_decay=1e-4)
best_dev_mrr = 0.0

for epoch in range(1, 201):
    model.train(); total_loss=0; total=0
    for batch in train_loader:
        pos = batch[0].to(device)
        neg = neg_sample(pos, len(entities))
        pos_s = model(pos[:,0], pos[:,1], pos[:,2])
        neg_s = model(neg[:,0], neg[:,1], neg[:,2])
        loss = - (F.logsigmoid(pos_s).mean() + F.logsigmoid(-neg_s).mean())
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        total_loss += loss.item()*pos.size(0); total+=pos.size(0)
    if epoch % 1 == 0:
        dev_mrr, _ = eval_rank(dev_data, n_neg=100)
        if dev_mrr > best_dev_mrr:
            best_dev_mrr = dev_mrr
            torch.save(model.state_dict(), 'best_murp.pt')
        print(f"Epoch {epoch:02d} | Train Loss: {total_loss/total:.4f} | Dev MRR: {dev_mrr:.4f}")


                                                     

Epoch 01 | Train Loss: 1.3863 | Dev MRR: 0.0758


                                                     

Epoch 02 | Train Loss: 1.3812 | Dev MRR: 0.2235


                                                     

Epoch 03 | Train Loss: 1.3718 | Dev MRR: 0.4292


                                                     

Epoch 04 | Train Loss: 1.3612 | Dev MRR: 0.5247


                                                     

Epoch 05 | Train Loss: 1.3497 | Dev MRR: 0.5825


                                                     

Epoch 06 | Train Loss: 1.3382 | Dev MRR: 0.5994


                                                     

Epoch 07 | Train Loss: 1.3265 | Dev MRR: 0.5950


                                                     

Epoch 08 | Train Loss: 1.3151 | Dev MRR: 0.6520


                                                     

Epoch 09 | Train Loss: 1.3036 | Dev MRR: 0.6407


                                                     

Epoch 10 | Train Loss: 1.2926 | Dev MRR: 0.6297


                                                     

Epoch 11 | Train Loss: 1.2818 | Dev MRR: 0.6566


                                                     

Epoch 12 | Train Loss: 1.2709 | Dev MRR: 0.6446


                                                     

Epoch 13 | Train Loss: 1.2606 | Dev MRR: 0.6811


                                                     

Epoch 14 | Train Loss: 1.2501 | Dev MRR: 0.6386


                                                     

Epoch 15 | Train Loss: 1.2405 | Dev MRR: 0.6595


                                                     

Epoch 16 | Train Loss: 1.2304 | Dev MRR: 0.6701


                                                     

Epoch 17 | Train Loss: 1.2209 | Dev MRR: 0.6699


                                                     

Epoch 18 | Train Loss: 1.2114 | Dev MRR: 0.6701


                                                     

Epoch 19 | Train Loss: 1.2011 | Dev MRR: 0.6700


                                                     

Epoch 20 | Train Loss: 1.1930 | Dev MRR: 0.6705


                                                     

Epoch 21 | Train Loss: 1.1836 | Dev MRR: 0.6598


                                                     

Epoch 22 | Train Loss: 1.1743 | Dev MRR: 0.6816


                                                     

Epoch 23 | Train Loss: 1.1657 | Dev MRR: 0.6810


                                                     

Epoch 24 | Train Loss: 1.1572 | Dev MRR: 0.6558


                                                     

Epoch 25 | Train Loss: 1.1487 | Dev MRR: 0.6594


                                                     

Epoch 26 | Train Loss: 1.1401 | Dev MRR: 0.6813


                                                     

Epoch 27 | Train Loss: 1.1328 | Dev MRR: 0.6700


                                                     

Epoch 28 | Train Loss: 1.1242 | Dev MRR: 0.6697


                                                     

Epoch 29 | Train Loss: 1.1166 | Dev MRR: 0.6487


                                                     

Epoch 30 | Train Loss: 1.1090 | Dev MRR: 0.6818


                                                     

Epoch 31 | Train Loss: 1.1012 | Dev MRR: 0.6698


                                                     

Epoch 32 | Train Loss: 1.0927 | Dev MRR: 0.6811


                                                     

Epoch 33 | Train Loss: 1.0861 | Dev MRR: 0.6487


                                                     

Epoch 34 | Train Loss: 1.0784 | Dev MRR: 0.6701


                                                     

Epoch 35 | Train Loss: 1.0707 | Dev MRR: 0.6591


                                                     

Epoch 36 | Train Loss: 1.0645 | Dev MRR: 0.6270


                                                     

Epoch 37 | Train Loss: 1.0577 | Dev MRR: 0.6592


                                                     

Epoch 38 | Train Loss: 1.0505 | Dev MRR: 0.6594


                                                     

Epoch 39 | Train Loss: 1.0435 | Dev MRR: 0.6700


                                                     

Epoch 40 | Train Loss: 1.0365 | Dev MRR: 0.6702


                                                     

Epoch 41 | Train Loss: 1.0298 | Dev MRR: 0.6451


                                                     

Epoch 42 | Train Loss: 1.0231 | Dev MRR: 0.6812


                                                     

Epoch 43 | Train Loss: 1.0175 | Dev MRR: 0.6707


                                                     

Epoch 44 | Train Loss: 1.0112 | Dev MRR: 0.6813


                                                     

Epoch 45 | Train Loss: 1.0056 | Dev MRR: 0.6705


                                                     

Epoch 46 | Train Loss: 0.9992 | Dev MRR: 0.6702


                                                     

Epoch 47 | Train Loss: 0.9936 | Dev MRR: 0.6703


                                                     

Epoch 48 | Train Loss: 0.9875 | Dev MRR: 0.6701


                                                     

Epoch 49 | Train Loss: 0.9820 | Dev MRR: 0.6814


                                                     

Epoch 50 | Train Loss: 0.9759 | Dev MRR: 0.6813


                                                     

Epoch 51 | Train Loss: 0.9707 | Dev MRR: 0.6703


                                                     

Epoch 52 | Train Loss: 0.9650 | Dev MRR: 0.6813


                                                     

Epoch 53 | Train Loss: 0.9602 | Dev MRR: 0.6819


                                                     

Epoch 54 | Train Loss: 0.9547 | Dev MRR: 0.6704


                                                     

Epoch 55 | Train Loss: 0.9495 | Dev MRR: 0.6812


                                                     

Epoch 56 | Train Loss: 0.9448 | Dev MRR: 0.6703


                                                     

Epoch 57 | Train Loss: 0.9400 | Dev MRR: 0.6704


                                                     

Epoch 58 | Train Loss: 0.9344 | Dev MRR: 0.6808


                                                     

Epoch 59 | Train Loss: 0.9303 | Dev MRR: 0.6813


                                                     

Epoch 60 | Train Loss: 0.9254 | Dev MRR: 0.6810


                                                     

Epoch 61 | Train Loss: 0.9216 | Dev MRR: 0.6592


                                                     

Epoch 62 | Train Loss: 0.9158 | Dev MRR: 0.6700


                                                     

Epoch 63 | Train Loss: 0.9123 | Dev MRR: 0.6811


                                                     

Epoch 64 | Train Loss: 0.9076 | Dev MRR: 0.6805


                                                     

Epoch 65 | Train Loss: 0.9040 | Dev MRR: 0.6594


                                                     

Epoch 66 | Train Loss: 0.8992 | Dev MRR: 0.6810


                                                     

Epoch 67 | Train Loss: 0.8950 | Dev MRR: 0.6808


                                                     

Epoch 68 | Train Loss: 0.8912 | Dev MRR: 0.6702


                                                     

Epoch 69 | Train Loss: 0.8873 | Dev MRR: 0.6551


                                                     

Epoch 70 | Train Loss: 0.8846 | Dev MRR: 0.6806


                                                     

Epoch 71 | Train Loss: 0.8806 | Dev MRR: 0.6706


                                                     

Epoch 72 | Train Loss: 0.8767 | Dev MRR: 0.6810


                                                     

Epoch 73 | Train Loss: 0.8735 | Dev MRR: 0.6483


                                                     

Epoch 74 | Train Loss: 0.8693 | Dev MRR: 0.6805


                                                     

Epoch 75 | Train Loss: 0.8666 | Dev MRR: 0.6696


                                                     

Epoch 76 | Train Loss: 0.8631 | Dev MRR: 0.6591


                                                     

Epoch 77 | Train Loss: 0.8601 | Dev MRR: 0.6700


                                                     

Epoch 78 | Train Loss: 0.8574 | Dev MRR: 0.6553


                                                     

Epoch 79 | Train Loss: 0.8539 | Dev MRR: 0.6698


                                                     

Epoch 80 | Train Loss: 0.8506 | Dev MRR: 0.6809


                                                     

Epoch 81 | Train Loss: 0.8472 | Dev MRR: 0.6593


                                                     

Epoch 82 | Train Loss: 0.8451 | Dev MRR: 0.6809


                                                     

Epoch 83 | Train Loss: 0.8428 | Dev MRR: 0.6811


                                                     

Epoch 84 | Train Loss: 0.8393 | Dev MRR: 0.6807


                                                     

Epoch 85 | Train Loss: 0.8364 | Dev MRR: 0.6803


                                                     

Epoch 86 | Train Loss: 0.8345 | Dev MRR: 0.6695


                                                     

Epoch 87 | Train Loss: 0.8333 | Dev MRR: 0.6812


                                                     

Epoch 88 | Train Loss: 0.8307 | Dev MRR: 0.6702


                                                     

Epoch 89 | Train Loss: 0.8279 | Dev MRR: 0.6809


                                                     

Epoch 90 | Train Loss: 0.8252 | Dev MRR: 0.6589


                                                     

Epoch 91 | Train Loss: 0.8235 | Dev MRR: 0.6594


                                                     

Epoch 92 | Train Loss: 0.8213 | Dev MRR: 0.6703


                                                     

Epoch 93 | Train Loss: 0.8188 | Dev MRR: 0.6812


                                                     

Epoch 94 | Train Loss: 0.8176 | Dev MRR: 0.6590


                                                     

Epoch 95 | Train Loss: 0.8152 | Dev MRR: 0.6590


                                                     

Epoch 96 | Train Loss: 0.8132 | Dev MRR: 0.6812


                                                     

Epoch 97 | Train Loss: 0.8110 | Dev MRR: 0.6701


                                                     

Epoch 98 | Train Loss: 0.8100 | Dev MRR: 0.6701


                                                     

Epoch 99 | Train Loss: 0.8076 | Dev MRR: 0.6592


                                                     

Epoch 100 | Train Loss: 0.8065 | Dev MRR: 0.6445


                                                     

Epoch 101 | Train Loss: 0.8057 | Dev MRR: 0.6701


                                                     

Epoch 102 | Train Loss: 0.8038 | Dev MRR: 0.6588


                                                     

Epoch 103 | Train Loss: 0.8021 | Dev MRR: 0.6701


                                                     

Epoch 104 | Train Loss: 0.8009 | Dev MRR: 0.6810


                                                     

Epoch 105 | Train Loss: 0.7990 | Dev MRR: 0.6700


                                                     

Epoch 106 | Train Loss: 0.7984 | Dev MRR: 0.6371


                                                     

Epoch 107 | Train Loss: 0.7966 | Dev MRR: 0.6807


                                                     

Epoch 108 | Train Loss: 0.7950 | Dev MRR: 0.6803


                                                     

Epoch 109 | Train Loss: 0.7943 | Dev MRR: 0.6699


                                                     

Epoch 110 | Train Loss: 0.7933 | Dev MRR: 0.6590


                                                     

Epoch 111 | Train Loss: 0.7913 | Dev MRR: 0.6697


                                                     

Epoch 112 | Train Loss: 0.7909 | Dev MRR: 0.6697


                                                     

Epoch 113 | Train Loss: 0.7898 | Dev MRR: 0.6585


                                                     

Epoch 114 | Train Loss: 0.7887 | Dev MRR: 0.6809


                                                     

Epoch 115 | Train Loss: 0.7873 | Dev MRR: 0.6698


                                                     

Epoch 116 | Train Loss: 0.7867 | Dev MRR: 0.6805


                                                     

Epoch 117 | Train Loss: 0.7857 | Dev MRR: 0.6591


                                                     

Epoch 118 | Train Loss: 0.7849 | Dev MRR: 0.6588


                                                     

Epoch 119 | Train Loss: 0.7840 | Dev MRR: 0.6698


                                                     

Epoch 120 | Train Loss: 0.7833 | Dev MRR: 0.6699


                                                     

Epoch 121 | Train Loss: 0.7829 | Dev MRR: 0.6587


                                                     

Epoch 122 | Train Loss: 0.7818 | Dev MRR: 0.6802


                                                     

Epoch 123 | Train Loss: 0.7806 | Dev MRR: 0.6809


                                                     

Epoch 124 | Train Loss: 0.7797 | Dev MRR: 0.6588


                                                     

Epoch 125 | Train Loss: 0.7800 | Dev MRR: 0.6377


                                                     

Epoch 126 | Train Loss: 0.7793 | Dev MRR: 0.6550


                                                     

Epoch 127 | Train Loss: 0.7782 | Dev MRR: 0.6697


                                                     

Epoch 128 | Train Loss: 0.7777 | Dev MRR: 0.6807


                                                     

Epoch 129 | Train Loss: 0.7768 | Dev MRR: 0.6585


                                                     

Epoch 130 | Train Loss: 0.7759 | Dev MRR: 0.6664


                                                     

Epoch 131 | Train Loss: 0.7762 | Dev MRR: 0.6803


                                                     

Epoch 132 | Train Loss: 0.7757 | Dev MRR: 0.6693


                                                     

Epoch 133 | Train Loss: 0.7750 | Dev MRR: 0.6478


                                                     

Epoch 134 | Train Loss: 0.7747 | Dev MRR: 0.6698


                                                     

Epoch 135 | Train Loss: 0.7739 | Dev MRR: 0.6804


                                                     

Epoch 136 | Train Loss: 0.7732 | Dev MRR: 0.6802


                                                     

Epoch 137 | Train Loss: 0.7740 | Dev MRR: 0.6587


                                                     

Epoch 138 | Train Loss: 0.7728 | Dev MRR: 0.6260


                                                     

Epoch 139 | Train Loss: 0.7726 | Dev MRR: 0.6698


                                                     

Epoch 140 | Train Loss: 0.7724 | Dev MRR: 0.6804


                                                     

Epoch 141 | Train Loss: 0.7716 | Dev MRR: 0.6695


                                                     

Epoch 142 | Train Loss: 0.7713 | Dev MRR: 0.6587


                                                     

Epoch 143 | Train Loss: 0.7705 | Dev MRR: 0.6696


                                                     

Epoch 144 | Train Loss: 0.7709 | Dev MRR: 0.6661


                                                     

Epoch 145 | Train Loss: 0.7699 | Dev MRR: 0.6587


                                                     

Epoch 146 | Train Loss: 0.7695 | Dev MRR: 0.6809


                                                     

Epoch 147 | Train Loss: 0.7693 | Dev MRR: 0.6806


                                                     

Epoch 148 | Train Loss: 0.7692 | Dev MRR: 0.6806


                                                     

Epoch 149 | Train Loss: 0.7687 | Dev MRR: 0.6700


                                                     

Epoch 150 | Train Loss: 0.7687 | Dev MRR: 0.6695


                                                     

Epoch 151 | Train Loss: 0.7682 | Dev MRR: 0.6695


                                                     

Epoch 152 | Train Loss: 0.7682 | Dev MRR: 0.6803


                                                     

Epoch 153 | Train Loss: 0.7675 | Dev MRR: 0.6806


                                                     

Epoch 154 | Train Loss: 0.7666 | Dev MRR: 0.6697


                                                     

Epoch 155 | Train Loss: 0.7668 | Dev MRR: 0.6698


                                                     

Epoch 156 | Train Loss: 0.7665 | Dev MRR: 0.6807


                                                     

Epoch 157 | Train Loss: 0.7665 | Dev MRR: 0.6703


                                                     

Epoch 158 | Train Loss: 0.7668 | Dev MRR: 0.6804


                                                     

Epoch 159 | Train Loss: 0.7661 | Dev MRR: 0.6697


                                                     

Epoch 160 | Train Loss: 0.7665 | Dev MRR: 0.6590


                                                     

Epoch 161 | Train Loss: 0.7652 | Dev MRR: 0.6697


                                                     

Epoch 162 | Train Loss: 0.7652 | Dev MRR: 0.6696


                                                     

Epoch 163 | Train Loss: 0.7651 | Dev MRR: 0.6587


                                                     

Epoch 164 | Train Loss: 0.7643 | Dev MRR: 0.6478


                                                     

Epoch 165 | Train Loss: 0.7645 | Dev MRR: 0.6589


                                                     

Epoch 166 | Train Loss: 0.7644 | Dev MRR: 0.6695


                                                     

Epoch 167 | Train Loss: 0.7645 | Dev MRR: 0.6805


                                                     

Epoch 168 | Train Loss: 0.7645 | Dev MRR: 0.6701


                                                     

Epoch 169 | Train Loss: 0.7642 | Dev MRR: 0.6480


                                                     

Epoch 170 | Train Loss: 0.7636 | Dev MRR: 0.6588


                                                     

Epoch 171 | Train Loss: 0.7635 | Dev MRR: 0.6696


                                                     

Epoch 172 | Train Loss: 0.7633 | Dev MRR: 0.6805


                                                     

Epoch 173 | Train Loss: 0.7632 | Dev MRR: 0.6480


                                                     

Epoch 174 | Train Loss: 0.7636 | Dev MRR: 0.6703


                                                     

Epoch 175 | Train Loss: 0.7631 | Dev MRR: 0.6486


                                                     

Epoch 176 | Train Loss: 0.7627 | Dev MRR: 0.6803


                                                     

Epoch 177 | Train Loss: 0.7631 | Dev MRR: 0.6696


                                                     

Epoch 178 | Train Loss: 0.7623 | Dev MRR: 0.6370


                                                     

Epoch 179 | Train Loss: 0.7619 | Dev MRR: 0.6587


                                                     

Epoch 180 | Train Loss: 0.7621 | Dev MRR: 0.6552


                                                     

Epoch 181 | Train Loss: 0.7616 | Dev MRR: 0.6589


                                                     

Epoch 182 | Train Loss: 0.7612 | Dev MRR: 0.6589


                                                     

Epoch 183 | Train Loss: 0.7613 | Dev MRR: 0.6695


                                                     

Epoch 184 | Train Loss: 0.7620 | Dev MRR: 0.6812


                                                     

Epoch 185 | Train Loss: 0.7614 | Dev MRR: 0.6809


                                                     

Epoch 186 | Train Loss: 0.7610 | Dev MRR: 0.6701


                                                     

Epoch 187 | Train Loss: 0.7614 | Dev MRR: 0.6479


                                                     

Epoch 188 | Train Loss: 0.7614 | Dev MRR: 0.6696


                                                     

Epoch 189 | Train Loss: 0.7600 | Dev MRR: 0.6696


                                                     

Epoch 190 | Train Loss: 0.7603 | Dev MRR: 0.6696


                                                     

Epoch 191 | Train Loss: 0.7604 | Dev MRR: 0.6809


                                                     

Epoch 192 | Train Loss: 0.7598 | Dev MRR: 0.6803


                                                     

Epoch 193 | Train Loss: 0.7599 | Dev MRR: 0.6804


                                                     

Epoch 194 | Train Loss: 0.7593 | Dev MRR: 0.6813


                                                     

Epoch 195 | Train Loss: 0.7591 | Dev MRR: 0.6804


                                                     

Epoch 196 | Train Loss: 0.7593 | Dev MRR: 0.6695


                                                     

Epoch 197 | Train Loss: 0.7594 | Dev MRR: 0.6801


                                                     

Epoch 198 | Train Loss: 0.7596 | Dev MRR: 0.6482


                                                     

Epoch 199 | Train Loss: 0.7585 | Dev MRR: 0.6807


                                                     

Epoch 200 | Train Loss: 0.7593 | Dev MRR: 0.6701




In [None]:

model.load_state_dict(torch.load('best_murp.pt'))
test_mrr, test_hits = eval_rank(test_data.to(device), n_neg=100)
print(f"Test MRR: {test_mrr:.4f} | Hits@1: {test_hits['Hits@1']:.4f} | Hits@3: {test_hits['Hits@3']:.4f} | Hits@10: {test_hits['Hits@10']:.4f}")

  model.load_state_dict(torch.load('best_murp.pt'))
                                                     

Test MRR: 0.6024 | Hits@1: 0.5870 | Hits@3: 0.5870 | Hits@10: 0.6304


