# Установка

In [1]:
!pip install transformers
!pip install datasets

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 12.5 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 6.4 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 42.2 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 43.5 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 30.9 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found e

## Датасеты

In [2]:
from datasets import load_dataset

datasets = load_dataset("snli")

Downloading builder script:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/938 [00:00<?, ?B/s]

Downloading and preparing dataset snli/plain_text (download: 90.17 MiB, generated: 65.51 MiB, post-processed: Unknown size, total: 155.68 MiB) to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b...


Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Dataset snli downloaded and prepared to /root/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
datasets

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [4]:
check_dataset = datasets["validation"]
check_dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 10000
})

## Токенизатор

In [5]:
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
inputs = tokenizer("Hello, my dog is cute", "kek lol", return_tensors="pt")
inputs

Downloading:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

{'input_ids': tensor([[    0, 31414,     6,   127,  2335,    16, 11962,     2,     2,  1071,
           330, 29784,     2]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

## Граф

In [6]:
%reload_ext autoreload

In [7]:
%autoreload 2
from knowledge_graph import CauseNet

In [8]:
cause_net = CauseNet()
cause_net.get_tokens('vegan diet leads to poor nutrition', 'vegan diet leads to diabetes')

Downloading CauseNet...
Resulting size of causenet.jsonl.bz2:
Removing old file...
Unpacking bz2 file...
Resulting size of causenet.jsonl:
Extracting connections...
Amount of edges: 197806
Building graph...
Amount of vertexes: 51863


({'anxiety',
  'cancer',
  'cardiovascular_disease',
  'changes',
  'childhood_obesity',
  'condition',
  'conditions',
  'depression',
  'diabetes',
  'disease',
  'diseases',
  'health_problems',
  'heart_disease',
  'illness',
  'illnesses',
  'imbalances',
  'increase',
  'inflammation',
  'injury',
  'insomnia',
  'nutrition',
  'obesity',
  'overweight',
  'poor_health',
  'poor_nutrition',
  'problems',
  'stress',
  'symptoms',
  'type_2_diabetes',
  'weight_gain'},
 {('anxiety', 'diabetes'),
  ('cancer', 'diabetes'),
  ('cardiovascular_disease', 'diabetes'),
  ('changes', 'diabetes'),
  ('childhood_obesity', 'diabetes'),
  ('condition', 'diabetes'),
  ('conditions', 'diabetes'),
  ('depression', 'diabetes'),
  ('diabetes', 'diabetes'),
  ('disease', 'diabetes'),
  ('diseases', 'diabetes'),
  ('health_problems', 'diabetes'),
  ('heart_disease', 'diabetes'),
  ('illness', 'diabetes'),
  ('illnesses', 'diabetes'),
  ('imbalances', 'diabetes'),
  ('increase', 'diabetes'),
  ('infl

# Код классов

## Датасет

In [9]:
MAX_OBJECTS = 20

In [64]:
import numpy as np
import tqdm

class KGraphDataset(torch.utils.data.Dataset):
    def __init__(self, nli_dataset, kg, tokenizer, debug=False):
      self.tokens = []
      self.kg_kg_mask = []
      self.full_mask = []
      self.entity_tokens = []
      
      dataset_len = len(nli_dataset)
      if debug:
        print("1 sample:\n")

      for data_ind in tqdm.tqdm(range(dataset_len)):
        # tokenize
        s_prem = nli_dataset[data_ind]['premise']
        s_hyp = nli_dataset[data_ind]['hypothesis']
        #s_prem = "poor nutrition"
        #s_hyp = ""
        if debug and (data_ind == 0):
          print("premise:", s_prem)
          print("hypothesis:", s_hyp)
        inputs = tokenizer(s_prem, s_hyp)
        
        # split to words
        word_ids = inputs.word_ids()
        seq_ids = inputs.sequence_ids()
        word_cnts = [0, 0]
        for pos, elem in enumerate(seq_ids):
          if elem is not None:
            word_cnts[elem] = max(word_cnts[elem], word_ids[pos] + 1)

        bad_syms = ",.!?- "
        word_lst = []
        for s_ind in range(2):
          if (s_ind == 0):
            s  = s_prem
          else:
            s = s_hyp
          for num in range(0, word_cnts[s_ind]):
            inter = inputs.word_to_chars(batch_or_word_index=num, sequence_index=s_ind)
            word = s[inter.start:inter.end]
            for sym in bad_syms:
              word = word.replace(sym, "")
            #if len(word) <= 3:
            #  word = ""
            word_lst.append(word.lower())
        
        if debug and (data_ind == 0):
          print("Words:", word_lst)
        
        # create mask matrices
        # example https://github.com/microsoft/CodeBERT/blob/master/GraphCodeBERT/codesearch/run.py#L206

        all_objects, all_edges = kg.get_tokens(s_prem, s_hyp)
        if debug and (data_ind == 0):
          print("From KG")
          print("Objects:", all_objects)
          print("Edges:", all_edges)
        object_to_id = {}
        for id, obj in enumerate(all_objects):
          object_to_id[obj] = id

        # kg x kg
        obj_cnt = len(all_objects)
        kg_kg_mask = np.eye(MAX_OBJECTS, dtype=bool)
        for edge in all_edges:
          obj_from, obj_to = edge
          if ((obj_from in all_objects) and (obj_to in all_objects)):
            id_from = object_to_id[obj_from]
            id_to = object_to_id[obj_to]
            kg_kg_mask[id_from][id_to] = True
        
        if debug and (data_ind == 0):
          print()
          print("kg2kg mask:")
          print(kg_kg_mask * 1)

        # kg x text
        tokens_cnt = len(inputs['input_ids'])
        kg_text_mask = np.eye(tokens_cnt + MAX_OBJECTS, dtype=bool)

        for token_pos in range(tokens_cnt):
          word_id = word_ids[token_pos]
          seq_id = seq_ids[token_pos]
          if word_id is None:
            # special token (False row)
            kg_text_mask[token_pos][token_pos] = False
          else:
            correct_id = word_id + seq_id * word_cnts[0]
            word = word_lst[correct_id]
            for id, obj in enumerate(all_objects):
              if (word in obj):
                kg_text_mask[token_pos][tokens_cnt + id] = True

        if debug and (data_ind == 0):
          print()
          print("Full mask")
          print(kg_text_mask * 1)

        obj_tokens = tokenizer(list(all_objects))['input_ids']
        if debug and (data_ind == 0):
          print()
          print("objects tokens:")
          print(obj_tokens)
        obj_cnt = len(obj_tokens)
        obj_dict = {}
        for obj_ind in range(MAX_OBJECTS):
          if obj_ind < obj_cnt:
            obj_dict["e" + str(obj_ind)] = torch.tensor(obj_tokens[obj_ind])
          else:
            obj_dict["e" + str(obj_ind)] = torch.tensor([0])
        
        if debug and (data_ind == 0):
          print()
          print("objects dics:")
          print(obj_dict)

        self.tokens.append(torch.tensor(inputs['input_ids']))
        self.kg_kg_mask.append(torch.tensor(kg_kg_mask))
        self.full_mask.append(torch.tensor(kg_text_mask))
        self.entity_tokens.append(obj_dict)

        if data_ind > 20:
          break

    def __getitem__(self, idx):
        return {
            "input_ids": self.tokens[idx],
            "kg_kg_mask": self.kg_kg_mask[idx],
            "full_mask": self.full_mask[idx],
            "e_input_ids": self.entity_tokens[idx]
        }

    def __len__(self):
        return len(self.tokens)

In [65]:
check_kgdata = KGraphDataset(check_dataset, cause_net, tokenizer, debug=False)

  0%|          | 21/10000 [00:00<04:28, 37.23it/s]


In [66]:
check_kgdata[0]

{'e_input_ids': {'e0': tensor([    0, 22197,     2]),
  'e1': tensor([    0, 48324,     2]),
  'e10': tensor([0]),
  'e11': tensor([0]),
  'e12': tensor([0]),
  'e13': tensor([0]),
  'e14': tensor([0]),
  'e15': tensor([0]),
  'e16': tensor([0]),
  'e17': tensor([0]),
  'e18': tensor([0]),
  'e19': tensor([0]),
  'e2': tensor([   0,   90, 8508, 5225,    2]),
  'e3': tensor([    0, 25785,  1825,     2]),
  'e4': tensor([   0, 9738, 2629,    2]),
  'e5': tensor([    0,  4892, 42216,     2]),
  'e6': tensor([0]),
  'e7': tensor([0]),
  'e8': tensor([0]),
  'e9': tensor([0])},
 'full_mask': tensor([[False, False, False,  ..., False, False, False],
         [False,  True, False,  ..., False, False, False],
         [False, False,  True,  ..., False, False, False],
         ...,
         [False, False, False,  ...,  True, False, False],
         [False, False, False,  ..., False,  True, False],
         [False, False, False,  ..., False, False,  True]]),
 'input_ids': tensor([    0,  9058,  

In [67]:
len(check_kgdata)

22

## Collator

https://huggingface.co/course/chapter7/3?fw=pt#finetuning-distilbert-with-the-trainer-api

In [69]:
import collections
import numpy as np

#from transformers import default_data_collator
from transformers import DataCollatorForLanguageModeling
from transformers import DataCollatorWithPadding

lm_data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
pad_data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

def lm_kg_data_collator(features):
  input_ids_list = []
  obj_size = [0 for ind in range(MAX_OBJECTS)]
  mask1_size = 0
  mask2_size = 0
  for feature in features:
    input_ids_list.append({"input_ids": feature["input_ids"]})
    mask1_size = max(mask1_size, feature["kg_kg_mask"].shape[0])
    mask2_size = max(mask2_size, feature["full_mask"].shape[0])
    for obj_ind in range(MAX_OBJECTS):
      obj_size[obj_ind] = max(obj_size[obj_ind], feature["e_input_ids"]["e" + str(obj_ind)].shape[0]) 

  for ind, feature in enumerate(features):
    mask1 = torch.zeros((mask1_size, mask1_size))
    cur_size = feature["kg_kg_mask"].shape[0]
    mask1[:cur_size, :cur_size] = feature["kg_kg_mask"]
    input_ids_list[ind]["kg_kg_mask"] = mask1

    mask2 = torch.zeros((mask2_size, mask2_size))
    cur_size = feature["full_mask"].shape[0]
    mask2[:cur_size, :cur_size] = feature["full_mask"]
    input_ids_list[ind]["full_mask"] = mask2

    for obj_ind in range(MAX_OBJECTS):
      new_tokens = torch.ones(obj_size[obj_ind], dtype=int)
      name = "e" + str(obj_ind)
      cur_size = feature["e_input_ids"][name].shape[0]
      new_tokens[:cur_size] = feature["e_input_ids"][name]
      input_ids_list[ind][name] = new_tokens

  batch_data = lm_data_collator(input_ids_list)
  _ = batch_data.pop("attention_mask")
  return batch_data

In [70]:
samples = [check_kgdata[i] for i in range(2)]
batch = lm_kg_data_collator(samples)
print(batch)

for chunk in batch["input_ids"]:
  print(f"\n'>>> {tokenizer.decode(chunk)}'")

{'input_ids': tensor([[    0, 50264,   390,    32, 16105,   150,  1826,     7,   213,  8368,
             4,     2,     2, 50264,  7502,    32, 31164, 15364, 50264,  1826,
             7,   213,  8368,    71,    95,  4441,  4592,     4,     2],
        [    0, 50264, 50264,    32, 16105,   150,  1826,     7,   213,  8368,
             4,     2,     2, 50264,   693,    32,  1826,  8368,     4,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1]]), 'kg_kg_mask': tensor([[[1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
   

In [71]:
from torch.utils.data import DataLoader


check_loader = DataLoader(check_kgdata, batch_size=2, shuffle=False, collate_fn=lm_kg_data_collator)

In [72]:
for batch in check_loader:
  #print(batch)
  for chunk in batch["input_ids"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
  for chunk in batch["e0"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
  for chunk in batch["e1"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
  for chunk in batch["e2"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
  for chunk in batch["e8"]:
    print(f"\n'>>> {tokenizer.decode(chunk)}'")
  break


'>>> <s>Two women<mask><mask> while holding to go packages.</s></s>The sisters are hugging goodbye while holding to go packages after just eating<mask>.</s>'

'>>> <s>Two women<mask> embracing while holding to<mask> packagesizo</s></s>Two woman are holding packages.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad>'

'>>> <s>women</s>'

'>>> <s>women</s>'

'>>> <s>packages</s>'

'>>> <s>packages</s>'

'>>> <s>trouble</s>'

'>>> <s>characteristics</s><pad>'

'>>> <s>'

'>>> <s>'


## Модель

### KENLI

In [73]:
import torch.nn as nn
import torch  

from transformers.models.bert.modeling_bert import BertAttention

class KE_cell(nn.Module):   
  def __init__(self, config):
    super(KE_cell, self).__init__()
    self.TR1 = BertAttention(config)
    self.TR2 = BertAttention(config)
    
  def forward(self, last_hidden_state_orig, last_hidden_state_kg, extended_attention_mask1, extended_attention_mask2): 
    # TR1
    new_hidden_state_kg = self.TR1(hidden_states=last_hidden_state_kg, attention_mask=extended_attention_mask1)[0]
    # TR2
    new_hidden_state_all = torch.cat((last_hidden_state_orig, new_hidden_state_kg), 1)
    out_hidden_state_all = self.TR2(hidden_states=new_hidden_state_all, attention_mask=extended_attention_mask2)[0]
    # make output
    out_hidden_state_orig = out_hidden_state_all[:, 0:last_hidden_state_orig.shape[1], :]
    out_hidden_state_kg = new_hidden_state_kg # no update in TR2
    return out_hidden_state_orig, out_hidden_state_kg

In [74]:
class KE_net(nn.Module):   
  def __init__(self, config, num_cells=2):
    super(KE_net, self).__init__()
    self.num_cells = num_cells
    self.KE_cells = nn.ModuleList([KE_cell(config) for i in range(self.num_cells)])
    
  def get_extended_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        # https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L250
        extended_attention_mask = attention_mask[:, None, :, :]
        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask
    
  def forward(self, last_hidden_state_orig, last_hidden_state_kg, attention_mask1, attention_mask2):
    hidden_state_orig = last_hidden_state_orig
    hidden_state_kg = last_hidden_state_kg
    extended_attention_mask1 = self.get_extended_attention_mask(attention_mask1)
    extended_attention_mask2 = self.get_extended_attention_mask(attention_mask2)

    for i, cell in enumerate(self.KE_cells):
      hidden_state_orig, hidden_state_kg = cell(hidden_state_orig, hidden_state_kg, extended_attention_mask1, extended_attention_mask2)

    return hidden_state_orig

### Base model

In [87]:
from transformers import RobertaTokenizer, RobertaModel

base_model = RobertaModel.from_pretrained("roberta-base")
inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
outputs = base_model(**inputs)
last_hidden_states = outputs.last_hidden_state

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Check

https://github.com/huggingface/transformers/blob/v4.18.0/src/transformers/models/roberta/modeling_roberta.py#L844

In [88]:
for batch in check_loader:
  save_batch = batch
  break
print(save_batch)

{'input_ids': tensor([[    0,  9058,   390,    32, 16105,   150,  1826,     7,   213,  8368,
             4,     2,     2,   133,  7502,    32, 31164, 15364,   150,  1826,
             7,   213,  8368,    71,    95,  4441,  4592,     4,     2],
        [    0,  9058,   390, 50264, 16105,   150,  1826,     7,   213, 50264,
             4,     2,     2,  9058,   693,    32,  1826, 50264,     4,     2,
             1,     1,     1,     1,     1,     1,     1,     1,     1]]), 'kg_kg_mask': tensor([[[1., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.],
   

In [91]:
base_model.encoder

RobertaEncoder(
  (layer): ModuleList(
    (0): RobertaLayer(
      (attention): RobertaAttention(
        (self): RobertaSelfAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): RobertaSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): RobertaIntermediate(
        (dense): Linear(in_features=768, out_features=3072, bias=True)
        (intermediate_act_fn): GELUActivation()
      )
      (output): RobertaOutput(
        (dense): Linear(in_features=3072, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise

In [89]:
base_model.eval()
base_model(input_ids=save_batch['input_ids'])

BaseModelOutputWithPoolingAndCrossAttentions([('last_hidden_state',
                                               tensor([[[-0.0660,  0.0997,  0.0027,  ..., -0.0710, -0.0715, -0.0303],
                                                        [ 0.0174,  0.1783, -0.3063,  ..., -0.1834, -0.1642,  0.0852],
                                                        [-0.0605,  0.2268, -0.0690,  ..., -0.2827, -0.1948,  0.2758],
                                                        ...,
                                                        [ 0.1107, -0.0557,  0.0218,  ...,  0.1439, -0.1075,  0.1116],
                                                        [-0.0554,  0.0943, -0.0205,  ..., -0.1013, -0.0826, -0.0552],
                                                        [-0.0902,  0.0641, -0.0218,  ...,  0.1019, -0.1305,  0.0712]],
                                               
                                                       [[-0.0797,  0.0936,  0.0028,  ..., -0.0699, -0.0713,  0.021

In [90]:
embedding_output = base_model.embeddings(input_ids=save_batch['input_ids'])

#encoder_outputs = base_model.encoder(embedding_output)
hidden_states = embedding_output
for i in range(12):
  layer_module = base_model.encoder.layer[i]
  layer_outputs = layer_module(hidden_states)
  hidden_states = layer_outputs[0]
encoder_outputs = layer_outputs

sequence_output = encoder_outputs[0]
pooled_output = base_model.pooler(sequence_output)
(sequence_output, pooled_output)

(tensor([[[-0.0660,  0.0997,  0.0027,  ..., -0.0710, -0.0715, -0.0303],
          [ 0.0174,  0.1783, -0.3063,  ..., -0.1834, -0.1642,  0.0852],
          [-0.0605,  0.2268, -0.0690,  ..., -0.2827, -0.1948,  0.2758],
          ...,
          [ 0.1107, -0.0557,  0.0218,  ...,  0.1439, -0.1075,  0.1116],
          [-0.0554,  0.0943, -0.0205,  ..., -0.1013, -0.0826, -0.0552],
          [-0.0902,  0.0641, -0.0218,  ...,  0.1019, -0.1305,  0.0712]],
 
         [[-0.0797,  0.0936,  0.0028,  ..., -0.0699, -0.0713,  0.0216],
          [-0.0295,  0.2340, -0.3292,  ..., -0.2086, -0.2004,  0.0538],
          [-0.1984,  0.3206, -0.1063,  ..., -0.4021, -0.2245,  0.2310],
          ...,
          [ 0.0672,  0.0668,  0.0516,  ..., -0.2326, -0.1463,  0.2416],
          [ 0.0672,  0.0668,  0.0516,  ..., -0.2326, -0.1463,  0.2416],
          [ 0.0672,  0.0668,  0.0516,  ..., -0.2326, -0.1463,  0.2416]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[-0.0197, -0.2081, -0.2334,  ..., -0.1573, -0.0

### Full model

In [92]:
class Model_KG(nn.Module):   
  def __init__(self, config, base_bert):
    super(Model_KG, self).__init__()
    self.base_bert = base_bert
    self.ke_net = KE_net(config)
    
  def forward(self, input_ids, e_input_ids_list, kg_kg_mask, full_mask, labels=None):
    # get vectors for KG objects
    with torch.no_grad():
      kg_features = []
      for e_input_ids in e_input_ids_list:
        kg_embedding_output = self.base_bert.embeddings(input_ids=e_input_ids)
        kg_feature = self.base_bert.encoder(kg_embedding_output)[0]
        kg_feature = torch.mean(kg_feature, axis=1, keepdim=True)
        kg_features.append(kg_feature)
      kg_hidden_states = torch.cat(kg_features, dim=1)

    # embeddings
    embedding_output = self.base_bert.embeddings(input_ids=input_ids)

    # encoder first part
    hidden_states = embedding_output
    for i in range(6):
      layer_module = self.base_bert.encoder.layer[i]
      layer_outputs = layer_module(hidden_states)
      hidden_states = layer_outputs[0]

    # KE net
    hidden_states = self.ke_net(hidden_states, kg_hidden_states, kg_kg_mask, full_mask)

    # encoder second part
    for i in range(6, 12):
      layer_module = self.base_bert.encoder.layer[i]
      layer_outputs = layer_module(hidden_states)
      hidden_states = layer_outputs[0]
    encoder_outputs = layer_outputs

    # pooler
    sequence_output = encoder_outputs[0]
    pooled_output = self.base_bert.pooler(sequence_output)
    return (sequence_output, pooled_output)

In [94]:
model = Model_KG(base_model.config, base_model)

In [96]:
model.eval()
input_ids = save_batch["input_ids"]
e_input_ids_list = [save_batch["e" + str(i)] for i in range(MAX_OBJECTS)]
kg_kg_mask = save_batch["kg_kg_mask"]
full_mask = save_batch["full_mask"]
model(input_ids, e_input_ids_list, kg_kg_mask, full_mask)

(tensor([[[-0.0205,  0.0995,  0.0070,  ..., -0.1532, -0.1575,  0.0663],
          [-0.0193,  0.1826,  0.0542,  ..., -0.2380, -0.1368,  0.0270],
          [-0.0897,  0.2068,  0.0977,  ..., -0.3018, -0.2212,  0.0691],
          ...,
          [ 0.1066,  0.1178,  0.1226,  ..., -0.1081, -0.2112,  0.0888],
          [-0.0582,  0.0982, -0.0502,  ..., -0.1726, -0.1558,  0.0279],
          [-0.0246,  0.0586,  0.0756,  ..., -0.1726, -0.1744,  0.1800]],
 
         [[-0.0707,  0.0405, -0.0061,  ..., -0.1496, -0.1303,  0.0394],
          [-0.0382,  0.0903, -0.0107,  ..., -0.2419, -0.1378,  0.0144],
          [-0.1394,  0.1180,  0.0828,  ..., -0.2581, -0.1193,  0.0287],
          ...,
          [ 0.0606, -0.0157,  0.1719,  ..., -0.2911, -0.2711,  0.0307],
          [ 0.0606, -0.0157,  0.1719,  ..., -0.2911, -0.2711,  0.0307],
          [ 0.0606, -0.0157,  0.1719,  ..., -0.2911, -0.2711,  0.0307]]],
        grad_fn=<NativeLayerNormBackward0>),
 tensor([[-0.0149, -0.2043, -0.1803,  ..., -0.1007, -0.0

# Обучение

In [99]:
model

Model_KG(
  (base_bert): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0): RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,)

In [97]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
num_epochs = 3
num_training_steps = num_epochs * len(check_loader)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
print(device)

cpu


In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))

model.train()

for epoch in range(num_epochs):
  for batch in check_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(batch['input_ids'])
    print(outputs)
    #loss = outputs.loss
    #loss.backward()
    #optimizer.step()
    #lr_scheduler.step()
    #optimizer.zero_grad()
    #progress_bar.update(1)
    break
  break

  0%|          | 0/33 [00:00<?, ?it/s]

BaseModelOutputWithPoolingAndCrossAttentions(last_hidden_state=tensor([[[-5.8904e-02,  1.0552e-01,  1.2006e-02,  ..., -8.2651e-02,
          -7.2221e-02, -2.6057e-02],
         [ 5.0650e-02,  3.5228e-02, -6.3365e-02,  ..., -4.1882e-01,
          -2.8699e-02,  2.9377e-02],
         [-1.4988e-01,  7.8505e-02, -6.8389e-02,  ..., -2.2490e-01,
          -3.0423e-01,  1.6624e-01],
         ...,
         [ 1.2868e-01, -9.4526e-02, -1.3775e-02,  ...,  2.1351e-01,
          -1.4605e-01, -3.7178e-02],
         [-6.5021e-02,  1.1034e-01, -1.0636e-02,  ..., -1.0035e-01,
          -1.2613e-01, -4.3427e-02],
         [-1.0618e-01,  2.5286e-02, -7.3273e-02,  ...,  5.5727e-02,
          -1.6850e-01,  6.0278e-02]],

        [[-8.6268e-02,  1.2173e-01, -3.9041e-02,  ..., -5.8905e-02,
          -7.9802e-02, -1.8740e-02],
         [-7.1007e-03,  2.2275e-01, -2.1930e-01,  ..., -2.2412e-02,
          -2.0509e-01,  9.2653e-02],
         [-1.6151e-02,  1.7875e-01, -4.8446e-02,  ..., -8.7615e-02,
          -7.