In [2]:
from pathlib import Path
import torch
import torch.nn as nn
from tqdm import tqdm
_train_path = Path("../WN18RR/train.txt")
_test_path = Path("../WN18RR/test.txt")
_valid_path = Path("../WN18RR/valid.txt")

def load_dataset(path:Path) -> list[tuple]:
    """
    parses dataset path into list of tuples.
    """
    datalist = []
    with open(path, "r") as f:
        for line in f:
            head, relation,tail = line.strip().split("\t")
            datalist.append((head,relation,tail))

    return datalist

train_dataset = load_dataset(_train_path)
valid_dataset = load_dataset(_valid_path)

# Dataset Preparation

1. Collpase Relations to singular to see pure GAT Perf
2. adding self loops as part of information flow
3. send to Cuda

In [None]:
# Build ID Maps
# We collapse relations in the encoder, but we still keep IDs around if we want to play around
entities = set()
relations = set()
for h, r, t in (train_dataset + valid_dataset):
    entities.add(h); entities.add(t); relations.add(r)

ent2id = {e:i for i,e in enumerate(sorted(entities))}
rel2id = {r:i for i,r in enumerate(sorted(relations))}

num_entities  = len(ent2id)
num_relations = len(rel2id)
print(f"#entities={num_entities}, #relations={num_relations}")

def triples_to_tensor(triples_list):
    # returns LongTensor [N, 3] with (h_id, r_id, t_id)
    arr = np.array([(ent2id[h], rel2id[r], ent2id[t]) for h,r,t in triples_list], dtype=np.int64)
    return torch.from_numpy(arr)

train_triples = triples_to_tensor(train_dataset)  # [N_train, 3]
valid_triples = triples_to_tensor(valid_dataset)  # [N_valid, 3]

# --------- Build collapsed edge_index from TRAIN triples ----------
# Use only (h,t), add reverse edges for better message flow.
edges = []
for h, r, t in train_triples.tolist():
    edges.append((h, t))
    edges.append((t, h))
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()  # [2, E]
print("edge_index:", edge_index.shape)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from torch_geometric.utils import add_self_loops

# when building edge_index:
edge_index, _ = add_self_loops(edge_index, num_nodes=num_entities)
edge_index = edge_index.to(device)
train_triples = train_triples.to(device)
valid_triples = valid_triples.to(device)

# Model Definition and Decoder

GNN Models need decoder because output is a vector

take this vector -> guess the tail/head (decoder)

since we collpase into uni relation, there is no need for specialised decoders, a dot product is sufficient for closest neighbour.

using attention provided: ```GATConv(emb_dim,hidden_dim)```, which is The graph attentional operator from the `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper.

In [None]:
class GATLinkEncoder(nn.Module):
    def __init__(self, num_entities, emb_dim=128, hidden_dim=128, out_dim=256,
                 heads=4, dropout=0.2):
        super().__init__()
        self.entity_emb = nn.Embedding(num_entities, emb_dim)
        nn.init.xavier_uniform_(self.entity_emb.weight)

        self.gat1 = GATConv(emb_dim, hidden_dim, heads=heads, dropout=dropout)
        self.gat2 = GATConv(hidden_dim * heads, out_dim, heads=1, concat=False, dropout=dropout)

        self.res_proj = nn.Linear(emb_dim, out_dim)  # for residual from input emb to output dim
        self.ln = nn.LayerNorm(out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, edge_index):
        x0 = self.entity_emb.weight                    # [N, emb_dim]
        x = F.elu(self.gat1(x0, edge_index))
        x = self.drop(x)
        x = self.gat2(x, edge_index)                  # [N, out_dim]
        x = x + self.res_proj(x0)                     # residual
        x = self.ln(x)
        return x

class DotProductDecoder(nn.Module):
    def forward(self, e_h, e_t):
        return (e_h * e_t).sum(dim=1)  # [B]

class LinkPredictor(nn.Module):
    def __init__(self, num_entities, edge_index, out_dim=200, **gat_kwargs):
        super().__init__()
        self.encoder = GATLinkEncoder(num_entities, out_dim=out_dim, **gat_kwargs)
        self.decoder = DotProductDecoder()
        self.edge_index = edge_index

    def forward(self, triples):
        # triples: [B, 3] but we ignore relation (column 1) for scoring
        h = triples[:, 0]
        t = triples[:, 2]
        ent = self.encoder(self.edge_index)  # [N, d]
        e_h, e_t = ent[h], ent[t]
        return self.decoder(e_h, e_t)  # raw scores (logits)

# Negative Sampling

The model also needs to see examples of false triples to learn what not to predict. Without negatives, the model would just assign high scores to everything.

The datasets dont provide a negative example, so making one is compulsory.

In [None]:
@torch.no_grad()
def sample_negatives(pos_triples, num_entities, corrupt_tail=True):
    """Return negatives by corrupting tail (or head). shape matches pos_triples."""
    neg = pos_triples.clone()
    if corrupt_tail:
        neg[:, 2] = torch.randint(0, num_entities, (pos_triples.size(0),), device=pos_triples.device)
    else:
        neg[:, 0] = torch.randint(0, num_entities, (pos_triples.size(0),), device=pos_triples.device)
    return neg

# Training and Batch Eval

Training code not much to explain here

In [10]:
# --------- Training / Evaluation ----------
def batches(tensor, batch_size, shuffle=True):
    N = tensor.size(0)
    idx = torch.randperm(N, device=tensor.device) if shuffle else torch.arange(N, device=tensor.device)
    for i in range(0, N, batch_size):
        part = tensor[idx[i:i+batch_size]]
        yield part


def sample_negatives_both(pos_triples, num_entities, k_neg=20):
    # Returns two tensors: head-corrupted and tail-corrupted, each [B, k_neg, 3]
    B = pos_triples.size(0)
    device = pos_triples.device

    # tail corruption
    tails = torch.randint(0, num_entities, (B, k_neg), device=device)
    neg_tail = pos_triples.unsqueeze(1).repeat(1, k_neg, 1)
    neg_tail[:, :, 2] = tails

    # head corruption
    heads = torch.randint(0, num_entities, (B, k_neg), device=device)
    neg_head = pos_triples.unsqueeze(1).repeat(1, k_neg, 1)
    neg_head[:, :, 0] = heads

    return neg_head.view(-1, 3), neg_tail.view(-1, 3)  # [B*k,3] each

def train_one_epoch(model, triples, optimizer, batch_size=2048, k_neg=20):
    model.train()
    # pos:neg = 1 : (2*k_neg)
    pos_weight = torch.tensor([(2.0 * k_neg)], device=triples.device)
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    total_loss, seen = 0.0, 0
    for pos in batches(triples, batch_size, shuffle=True):
        neg_h, neg_t = sample_negatives_both(pos, num_entities, k_neg=k_neg)

        all_triples = torch.cat([pos, neg_h, neg_t], dim=0)  # [B + 2Bk, 3]
        labels = torch.cat([
            torch.ones(len(pos), device=triples.device),
            torch.zeros(len(neg_h) + len(neg_t), device=triples.device)
        ])

        scores = model(all_triples)  # [N_all]
        loss = bce(scores, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.numel()
        seen += labels.numel()

    return total_loss / seen

# Evaluation Function
for calculating hit@10 and AUC

In [11]:
@torch.no_grad()
def evaluate_auc_hits(model, triples, batch_size=4096, hits_k=10):
    """Quick sanity metrics (UNFILTERED):
       - ROC-AUC on positive vs randomly corrupted negatives
       - Hits@k on head/tail corruption with 100 negatives per positive
    """
    from sklearn.metrics import roc_auc_score

    model.eval()
    scores_all, labels_all = [], []

    # AUC: one negative per positive
    for pos in batches(triples, batch_size, shuffle=False):
        neg = sample_negatives(pos, num_entities, corrupt_tail=True)
        s_pos = model(pos)
        s_neg = model(neg)
        scores_all.append(torch.cat([s_pos, s_neg], 0).cpu().numpy())
        labels_all.append(np.concatenate([np.ones(len(pos)), np.zeros(len(neg))], 0))
    scores_all = np.concatenate(scores_all, 0)
    labels_all = np.concatenate(labels_all, 0)
    auc = roc_auc_score(labels_all, scores_all)

    # Hits@k (unfiltered): rank true tail among 1 positive + 99 negatives
    k = hits_k
    hits = 0
    n_trials = 0
    for pos in batches(triples, batch_size, shuffle=False):
        # build a set of 100 candidates per positive (1 true + 99 random tails)
        B = pos.size(0)
        true_t = pos[:, 2]
        rand_t = torch.randint(0, num_entities, (B, 99), device=pos.device)
        tails = torch.cat([true_t.unsqueeze(1), rand_t], dim=1)   # [B, 100]

        # score (h, ?) against all 100 tails
        ent = model.encoder(model.edge_index)                     # [N,d]
        e_h = ent[pos[:, 0]]                                      # [B,d]
        e_candidates = ent[tails]                                 # [B,100,d]
        s = (e_h.unsqueeze(1) * e_candidates).sum(dim=2)          # [B,100]
        ranks = (s.argsort(dim=1, descending=True) == 0).nonzero()[:,1] + 1  # position of true tail (1-based)
        hits += (ranks <= k).sum().item()
        n_trials += B
    hits_at_k = hits / n_trials

    return {"auc": float(auc), f"hits@{k}": float(hits_at_k)}

# Actual Training

In [7]:
model = LinkPredictor(num_entities, edge_index, out_dim=200,
                      emb_dim=128, hidden_dim=128, heads=4, dropout=0.3).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

EPOCHS = 20
for epoch in tqdm(range(1, EPOCHS+1), desc="epoch:"):
    loss = train_one_epoch(model, train_triples, opt, batch_size=2048)
    if epoch % 2 == 0 or epoch == 1:
        metrics = evaluate_auc_hits(model, valid_triples, batch_size=4096, hits_k=10)
        print(f"Epoch {epoch:02d} | loss={loss:.4f} | AUC={metrics['auc']:.4f} | Hits@10={metrics['hits@10']:.4f}")

#entities=40757, #relations=11
edge_index: torch.Size([2, 173670])


epoch::   5%|▌         | 1/20 [02:14<42:27, 134.08s/it]

Epoch 01 | loss=54.2459 | AUC=0.8087 | Hits@10=0.5992


epoch::  10%|█         | 2/20 [04:20<38:56, 129.78s/it]

Epoch 02 | loss=12.0317 | AUC=0.8116 | Hits@10=0.6094


epoch::  20%|██        | 4/20 [08:29<33:45, 126.58s/it]

Epoch 04 | loss=8.3247 | AUC=0.8137 | Hits@10=0.6111


epoch::  30%|███       | 6/20 [12:32<28:53, 123.85s/it]

Epoch 06 | loss=6.5370 | AUC=0.8070 | Hits@10=0.6117


epoch::  40%|████      | 8/20 [16:29<24:10, 120.87s/it]

Epoch 08 | loss=5.1960 | AUC=0.8039 | Hits@10=0.6134


epoch::  50%|█████     | 10/20 [20:44<20:37, 123.79s/it]

Epoch 10 | loss=3.9877 | AUC=0.8053 | Hits@10=0.6187


epoch::  60%|██████    | 12/20 [25:12<17:01, 127.72s/it]

Epoch 12 | loss=2.7728 | AUC=0.8033 | Hits@10=0.6177


epoch::  70%|███████   | 14/20 [28:50<11:48, 118.16s/it]

Epoch 14 | loss=1.3322 | AUC=0.7936 | Hits@10=0.6104


epoch::  80%|████████  | 16/20 [32:11<07:14, 108.74s/it]

Epoch 16 | loss=0.8742 | AUC=0.7973 | Hits@10=0.6282


epoch::  90%|█████████ | 18/20 [35:20<03:23, 101.64s/it]

Epoch 18 | loss=0.8306 | AUC=0.8103 | Hits@10=0.6454


epoch:: 100%|██████████| 20/20 [38:34<00:00, 115.72s/it]

Epoch 20 | loss=0.8081 | AUC=0.8153 | Hits@10=0.6543





# Relatioinal GAT (Simplified)

R-GAT Code.
1. Uses the same embedding Matrix
2. Each relation gets its own Attention Feed it Forward.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torch_geometric.utils import add_self_loops
from pathlib import Path
import numpy as np
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from collections import defaultdict


reloading of dataset for sanity

In [None]:
# Data loading

train_list = load_dataset(_train_path)
valid_list = load_dataset(_valid_path)

# Build ID maps
entities, relations = set(), set()
for h, r, t in (train_list + valid_list):
    entities.add(h); entities.add(t); relations.add(r)

ent2id = {e: i for i, e in enumerate(sorted(entities))}
rel2id = {r: i for i, r in enumerate(sorted(relations))}
num_entities, num_relations = len(ent2id), len(rel2id)
print(f"#entities={num_entities}, #relations={num_relations}")

def triples_to_tensor(triples):
    arr = np.array([(ent2id[h], rel2id[r], ent2id[t]) for h, r, t in triples], dtype=np.int64)
    return torch.from_numpy(arr)

train_triples = triples_to_tensor(train_list)  # [N,3]
valid_triples = triples_to_tensor(valid_list)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_triples = train_triples.to(device)
valid_triples = valid_triples.to(device)

# Build relation-aware edge_index dict (add reverse edges + self loops)
rel_edge_index = defaultdict(list)
for h, r, t in train_triples.tolist():
    rel_edge_index[r].append((h, t))
    rel_edge_index[r].append((t, h))  # reverse

for r in range(num_relations):
    if len(rel_edge_index[r]) == 0:
        # ensure key exists even if relation absent in train
        rel_edge_index[r] = torch.empty((2, 0), dtype=torch.long, device=device)
    else:
        eidx = torch.tensor(rel_edge_index[r], dtype=torch.long).t().contiguous()
        eidx, _ = add_self_loops(eidx, num_nodes=num_entities)
        rel_edge_index[r] = eidx.to(device)

In [None]:
class RelationalGATEncoder(nn.Module):
    """
    One GATConv per relation. Messages are summed over relations each layer.
    """
    def __init__(self, num_entities, num_relations,
                 emb_dim=128, hidden_dim=128, out_dim=256,
                 heads=4, dropout=0.2):
        super().__init__()
        self.num_relations = num_relations
        self.entity_emb = nn.Embedding(num_entities, emb_dim)
        nn.init.xavier_uniform_(self.entity_emb.weight)

        self.gat1 = nn.ModuleDict({
            str(r): GATConv(emb_dim, hidden_dim, heads=heads, dropout=dropout)
            for r in range(num_relations)
        })
        self.gat2 = nn.ModuleDict({
            str(r): GATConv(hidden_dim * heads, out_dim, heads=1, concat=False, dropout=dropout)
            for r in range(num_relations)
        })

        self.res_proj = nn.Linear(emb_dim, out_dim)
        self.ln = nn.LayerNorm(out_dim)
        self.drop = nn.Dropout(dropout)

    def forward(self, rel_edge_index: dict[int, torch.Tensor]):
        x0 = self.entity_emb.weight  # [N, emb_dim]

        # layer 1: per-relation attention then sum
        outs = []
        for r in range(self.num_relations):
            eidx = rel_edge_index[r]
            if eidx.numel() == 0:
                continue
            outs.append(F.elu(self.gat1[str(r)](x0, eidx)))
        x = torch.stack(outs).sum(0) if outs else torch.zeros_like(self.res_proj.weight[:x0.size(1)])
        x = self.drop(x)

        # layer 2: per-relation attention then sum
        outs = []
        for r in range(self.num_relations):
            eidx = rel_edge_index[r]
            if eidx.numel() == 0:
                continue
            outs.append(self.gat2[str(r)](x, eidx))
        x = torch.stack(outs).sum(0) if outs else torch.zeros_like(self.res_proj(x0))

        # residual + norm
        x = self.ln(x + self.res_proj(x0))
        return x  # [N, out_dim]

class DistMultDecoder(nn.Module):
    """
    score(h,r,t) = <e_h, w_r, e_t>
    """
    def __init__(self, num_relations, dim):
        super().__init__()
        self.rel_emb = nn.Embedding(num_relations, dim)
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, e_h, r, e_t):
        w_r = self.rel_emb(r)  # [B, d]
        return (e_h * w_r * e_t).sum(dim=1)  # [B]

class RelationalGATLinkPredictor(nn.Module):
    def __init__(self, num_entities, num_relations, rel_edge_index,
                 out_dim=256, **gat_kwargs):
        super().__init__()
        self.encoder = RelationalGATEncoder(num_entities, num_relations,
                                            out_dim=out_dim, **gat_kwargs)
        self.decoder = DistMultDecoder(num_relations, out_dim)
        self.rel_edge_index = rel_edge_index

    def forward(self, triples):  # triples: [B,3] (h,r,t)
        h = triples[:, 0]; r = triples[:, 1]; t = triples[:, 2]
        ent = self.encoder(self.rel_edge_index)  # [N, d]
        return self.decoder(ent[h], r, ent[t])   # logits [B]

# -------------------------
# Utilities
# -------------------------
def batches(tensor, batch_size, shuffle=True):
    N = tensor.size(0)
    idx = torch.randperm(N, device=tensor.device) if shuffle else torch.arange(N, device=tensor.device)
    for i in range(0, N, batch_size):
        yield tensor[idx[i:i+batch_size]]

@torch.no_grad()
def sample_negatives_both(pos_triples, num_entities, k_neg=10):
    """
    Returns flattened head- and tail-corrupted negatives:
      neg_h: [B*k,3], neg_t: [B*k,3]
    """
    B = pos_triples.size(0)
    device = pos_triples.device

    # Tail corruption
    tails = torch.randint(0, num_entities, (B, k_neg), device=device)
    neg_t = pos_triples.unsqueeze(1).repeat(1, k_neg, 1)
    neg_t[:, :, 2] = tails

    # Head corruption
    heads = torch.randint(0, num_entities, (B, k_neg), device=device)
    neg_h = pos_triples.unsqueeze(1).repeat(1, k_neg, 1)
    neg_h[:, :, 0] = heads

    return neg_h.view(-1, 3), neg_t.view(-1, 3)

def train_one_epoch(model, triples, optimizer, batch_size=2048, k_neg=10):
    model.train()
    # Balance: 1 pos : (2*k_neg) neg
    pos_weight = torch.tensor([2.0 * k_neg], device=triples.device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    total_loss, total_items = 0.0, 0
    for pos in batches(triples, batch_size, shuffle=True):
        neg_h, neg_t = sample_negatives_both(pos, num_entities, k_neg=k_neg)
        all_trip = torch.cat([pos, neg_h, neg_t], dim=0)
        labels   = torch.cat([
            torch.ones(len(pos), device=triples.device),
            torch.zeros(len(neg_h) + len(neg_t), device=triples.device)
        ], dim=0)

        scores = model(all_trip)
        loss = criterion(scores, labels)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * labels.numel()
        total_items += labels.numel()

    return total_loss / total_items

@torch.no_grad()
def evaluate_auc_hits(model, triples, batch_size=4096, hits_k=10):
    model.eval()
    # --- AUC: 1 negative per positive (tail corruption) ---
    scores_all, labels_all = [], []
    for pos in batches(triples, batch_size, shuffle=False):
        B = pos.size(0)
        neg = pos.clone()
        neg[:, 2] = torch.randint(0, num_entities, (B,), device=pos.device)

        s_pos = model(pos)
        s_neg = model(neg)

        scores_all.append(torch.cat([s_pos, s_neg], dim=0).cpu().numpy())
        labels_all.append(np.concatenate([np.ones(B), np.zeros(B)], axis=0))

    auc = roc_auc_score(np.concatenate(labels_all), np.concatenate(scores_all))

    # --- Hits@k (unfiltered): rank true tail among 99 random tails + 1 true ---
    hits, trials = 0, 0
    # Precompute entity encs once for speed
    ent = model.encoder(model.rel_edge_index)  # [N, d]
    for pos in batches(triples, batch_size, shuffle=False):
        B = pos.size(0)
        h = pos[:, 0]; r = pos[:, 1]; t_true = pos[:, 2]

        # 99 random negatives
        rand_t = torch.randint(0, num_entities, (B, 99), device=pos.device)
        cand_t = torch.cat([t_true.unsqueeze(1), rand_t], dim=1)  # [B,100]

        e_h = ent[h]                          # [B,d]
        w_r = model.decoder.rel_emb(r)        # [B,d]
        e_c = ent[cand_t]                     # [B,100,d]

        # DistMult score(h,r,?) for all candidates
        s = ((e_h * w_r).unsqueeze(1) * e_c).sum(dim=2)  # [B,100]
        ranks = (s.argsort(dim=1, descending=True) == 0).nonzero()[:, 1] + 1
        hits += (ranks <= hits_k).sum().item()
        trials += B

    return {"auc": float(auc), f"hits@{hits_k}": hits / trials}

In [8]:
num_entities = len(ent2id)
model = RelationalGATLinkPredictor(
    num_entities=num_entities,
    num_relations=num_relations,
    rel_edge_index=rel_edge_index,
    emb_dim=128, hidden_dim=128, out_dim=256,
    heads=4, dropout=0.2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
EPOCHS = 20

for epoch in tqdm(range(1, EPOCHS + 1), desc="epoch"):
    loss = train_one_epoch(model, train_triples, optimizer, batch_size=2048, k_neg=10)
    if epoch % 2 == 0 or epoch == 1:
        metrics = evaluate_auc_hits(model, valid_triples, batch_size=4096, hits_k=10)
        print(f"Epoch {epoch:02d} | loss={loss:.4f} | AUC={metrics['auc']:.4f} | Hits@10={metrics['hits@10']:.4f}")

#entities=40757, #relations=11


epoch:   5%|▌         | 1/20 [07:05<2:14:52, 425.90s/it]

Epoch 01 | loss=1.5246 | AUC=0.7878 | Hits@10=0.4466


epoch:  10%|█         | 2/20 [12:49<1:53:18, 377.72s/it]

Epoch 02 | loss=0.7209 | AUC=0.8515 | Hits@10=0.5969


epoch:  20%|██        | 4/20 [24:58<1:39:28, 373.05s/it]

Epoch 04 | loss=0.2300 | AUC=0.8788 | Hits@10=0.7248


epoch:  30%|███       | 6/20 [37:15<1:26:36, 371.18s/it]

Epoch 06 | loss=0.1423 | AUC=0.8764 | Hits@10=0.7561


epoch:  40%|████      | 8/20 [48:23<1:10:14, 351.21s/it]

Epoch 08 | loss=0.1145 | AUC=0.8755 | Hits@10=0.7663


epoch:  50%|█████     | 10/20 [59:27<57:05, 342.53s/it] 

Epoch 10 | loss=0.0976 | AUC=0.8777 | Hits@10=0.7779


epoch:  60%|██████    | 12/20 [1:10:18<44:33, 334.15s/it]

Epoch 12 | loss=0.0823 | AUC=0.8735 | Hits@10=0.7630


epoch:  70%|███████   | 14/20 [1:21:54<34:21, 343.54s/it]

Epoch 14 | loss=0.0787 | AUC=0.8862 | Hits@10=0.7894


epoch:  80%|████████  | 16/20 [1:32:46<22:16, 334.22s/it]

Epoch 16 | loss=0.0673 | AUC=0.8777 | Hits@10=0.7815


epoch:  90%|█████████ | 18/20 [1:43:39<11:01, 330.55s/it]

Epoch 18 | loss=0.0650 | AUC=0.8856 | Hits@10=0.7844


epoch: 100%|██████████| 20/20 [1:55:02<00:00, 345.11s/it]

Epoch 20 | loss=0.0575 | AUC=0.8779 | Hits@10=0.7963





In [9]:
torch.save(model, "model/r-gat.pth")