In [1]:
import torch
from torch_geometric.data import Data, InMemoryDataset
import transformers
from transformers import Trainer, TrainingArguments, HfArgumentParser

from script.build_model import build_model, build_tokenizer_model, build_tokenizer
from src.trainer.metric import ROUGE, metric_fn
from src.trainer.trainer import KGLLMTrainer
from config.config import Config
from src.data.datasets import FB15k237Inductive
from src.data.types import CustomData
from src.ultra import tasks, util
from src.ultra.models import Ultra


def parse_args(config_path: str) -> Config:
    parser = HfArgumentParser(Config)
    cfg: Config = parser.parse_yaml_file(config_path)[0]
    cfg.train = cfg.train.set_dataloader(train_batch_size=cfg.train.batch_size, eval_batch_size=cfg.train.batch_size)

    # get_logger().
    return cfg


def get_data(cfg: Config) -> tuple[InMemoryDataset, CustomData, CustomData, CustomData]:
    dataset = util.build_dataset(cfg)
    return dataset, dataset[0], dataset[1], dataset[2]

  from .autonotebook import tqdm as notebook_tqdm
  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [7]:
from src.data.special_tokens import SpecialToken


cfg = parse_args("config/pretrain/pretrain_0.yaml")
transformers.set_seed(cfg.train.seed)

task_name = cfg.task.name

# data sampler, loader, collator -> custom trainer
dataset, train_data, valid_data, test_data = get_data(cfg=cfg)

# tokenizer, model = build_tokenizer_model(cfg)
tokenizer = build_tokenizer(cfg)
SpecialToken.add_tokens(tokenizer)

4

In [8]:
from src.data.pretrain import PretrainDataset


data = PretrainDataset(train_data, tokenizer, cfg)

In [9]:
from src.data.evaluate import EvaluateDataset

eval_data = EvaluateDataset(valid_data, tokenizer, cfg)

In [10]:
cfg.train.fast_test = -1

In [12]:
from tqdm import tqdm
_d = data
for i in tqdm(range(len(_d))):
    try:
        _d[i]
    except Exception as e:
        print(e)
        break

100%|██████████| 2000/2000 [04:04<00:00,  8.18it/s]


In [10]:
idx = i % data.data.target_edge_index.shape[1]
i, idx

(1163, 1163)

In [26]:
triple = (
            torch.cat(
                [
                    data.data.target_edge_index[:, idx],
                    data.data.target_edge_type[idx].unsqueeze(0),
                ]
            )
            .t()
            .view(-1, 3)
        )
entities = torch.cat([triple[:, 0], triple[:, 1]]).unique()
subg = data.sample_from_edge_index(entities)

# 采样子图中要预测的 triples，以及对应的负样本
# cfg task num_mask
edge_mask = torch.randperm(subg.target_edge_index.shape[1])[:data.cfg.task.num_mask]
# mask_triples: tris x 3
mask_triples = (
    torch.cat(
        [
            subg.target_edge_index[:, edge_mask],
            subg.target_edge_type[edge_mask].unsqueeze(0),
        ]
    )
    .t()
    .view(-1, 3)
)

In [29]:
from random import random
_data = data.data
batch = mask_triples.view(-1, 3)
num_negative = data.cfg.task.num_negative
strict = data.cfg.task.strict_negative
limit_nodes = subg.n_id

In [40]:
batch_size = len(batch)
_bs = batch_size
if _bs == 1:
    batch_size = 2


pos_h_index, pos_t_index, pos_r_index = batch.t()

# strict negative sampling vs random negative sampling
if strict:
    t_mask, h_mask = tasks.strict_negative_mask(_data, batch)
    t_mask = t_mask[:batch_size // 2]
    if limit_nodes is not None:
        t_mask = t_mask[:, limit_nodes]
    neg_t_candidate = t_mask.nonzero()[:, 1]
    num_t_candidate = t_mask.sum(dim=-1)
    # draw samples for negative tails
    rand = torch.rand(len(t_mask), num_negative, device=batch.device)
    index = (rand * num_t_candidate.unsqueeze(-1)).long()
    index = index + (num_t_candidate.cumsum(0) - num_t_candidate).unsqueeze(-1)
    neg_t_index = neg_t_candidate[index]
    if limit_nodes is not None:
        neg_t_index = limit_nodes[neg_t_index]

    h_mask = h_mask[batch_size // 2:]
    if limit_nodes is not None:
        h_mask = h_mask[:, limit_nodes]
    neg_h_candidate = h_mask.nonzero()[:, 1]
    num_h_candidate = h_mask.sum(dim=-1)
    # draw samples for negative heads
    rand = torch.rand(len(h_mask), num_negative, device=batch.device)
    index = (rand * num_h_candidate.unsqueeze(-1)).long()
    index = index + (num_h_candidate.cumsum(0) - num_h_candidate).unsqueeze(-1)
    neg_h_index = neg_h_candidate[index]
    if limit_nodes is not None:
        neg_h_index = limit_nodes[neg_h_index]
else:
    neg_index = torch.randint(_data.num_nodes, (batch_size, num_negative), device=batch.device)
    neg_t_index, neg_h_index = neg_index[:batch_size // 2], neg_index[batch_size // 2:]

h_index = pos_h_index.unsqueeze(-1).repeat(1, num_negative + 1)
t_index = pos_t_index.unsqueeze(-1).repeat(1, num_negative + 1)
r_index = pos_r_index.unsqueeze(-1).repeat(1, num_negative + 1)

# TODO FIXME 为了应对 bs 为1的情况，随机替换 h / t
if _bs == 1:
    # t_index[0, 1:] = neg_t_index
    h_index[0, 1:] = neg_h_index
    # if random() > 0.5:
    #     # 替换 t    (预测 t)
    #     t_index[0, 1:] = neg_t_index
    # else:
    #     # 替换 h    (预测 h)
    #     h_index[0, 1:] = neg_h_index
else:
    t_index[:batch_size // 2, 1:] = neg_t_index
    h_index[batch_size // 2:, 1:] = neg_h_index

RuntimeError: expand(torch.LongTensor{[0, 8]}, size=[8]): the number of sizes provided (1) must be greater or equal to the number of dimensions in the tensor (2)

In [44]:
neg_h_candidate = torch.tensor([1])

In [52]:
_1, _2 = tasks.strict_negative_mask(_data, batch)
_2, _2.shape, h_mask[batch_size // 2:], neg_t_index

(tensor([[True, True, True,  ..., True, True, True]]),
 torch.Size([1, 1594]),
 tensor([], size=(0, 2), dtype=torch.bool),
 tensor([[952, 952, 952, 952, 952, 952, 952, 952]]))

In [42]:
h_index, h_index.shape, neg_h_index, neg_h_index.shape

(tensor([[952, 952, 952, 952, 952, 952, 952, 952, 952]]),
 torch.Size([1, 9]),
 tensor([], size=(0, 8), dtype=torch.int64),
 torch.Size([0, 8]))