In [1]:
import logging
import os
import sys
import pdb
import subprocess
import torch.nn as nn
import torch
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from transformers.modeling_outputs import TokenClassifierOutput, SequenceClassifierOutput
import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score
from torch import nn
from typing import Union
import ipdb
from transformers import (
    AutoConfig,
    AutoModelForTokenClassification,
    AutoModel,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    PreTrainedModel,
    BertLayer
    
)
from utils_ner import NerDataset, Split, get_labels
from torchcrf import CRF

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch_geometric.nn import GATConv


class PSumGraph(PreTrainedModel):
    def __init__(self, model_name, from_tf, config, cache_dir, num_labels):
        super().__init__(config)
        self.num_labels = num_labels
        self.config = config
        self.biobert = AutoModelForTokenClassification.from_pretrained(
                            model_name,
                            from_tf=from_tf,
                            config=config)
        self.pre_layers = nn.ModuleList()
        self.num_psum_layers = 4
        self.loss_fct = loss_fct = nn.CrossEntropyLoss()
        self.gat_1 = GATConv(self.config.hidden_size, self.config.hidden_size // 2, heads=2, concat=False, dropout=0.2)

        self.classify = nn.Linear(self.config.hidden_size // 2, self.num_labels)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, graph_data=None, return_dict=None):
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.biobert(input_ids=input_ids, 
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
                output_hidden_states=True,
                output_attentions=True)
        
        all_logits = []
        losses = []
        for i in range(self.num_psum_layers):
            out = self.pre_layers[-1-i](hidden_states=outputs.hidden_states[-i-1], 
                    attention_mask=outputs.attentions[-i-1])[0]

            out = self.gat_1(out, graph_data.edge_index)
            logits = self.classify(out)
            all_logits.append(logits)  
            
        avg_logits = torch.mean(torch.stack(all_logits, dim=0), dim=0)
        
        loss = None
        if labels is not None:
            loss = self.loss_fct(avg_logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (avg_logits,)
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=avg_logits
        )

In [3]:
# Prepare CONLL-2003 task
DATA_DIR = '../datasets/NER'
ENTITY='sdNER/task10_data'

labels_dir = f'{DATA_DIR}/{ENTITY}/labels.txt'

labels = get_labels(labels_dir)
label_map: Dict[int, str] = {i: label for i, label in enumerate(labels)}
num_labels = len(labels)

In [4]:
config = AutoConfig.from_pretrained(
    'dmis-lab/biobert-base-cased-v1.2',
    num_labels=num_labels,
    id2label=label_map,
    label2id={label: i for i, label in enumerate(labels)},
    cache_dir=None,
    )

Downloading: 100%|██████████| 1.08k/1.08k [00:00<00:00, 582kB/s]


In [5]:
tokenizer = AutoTokenizer.from_pretrained(
    'dmis-lab/biobert-base-cased-v1.2',
    cache_dir=None,
    use_fast=False,
)

In [None]:
# while len(src) < MAX_SIZE:
    #     src.append(0)
    #     trg.append(1)
        
    # edge_index = torch.tensor([src, trg], dtype=torch.long).t().contiguous()  
    

In [6]:
# Get datasets
train_dataset = (
    NerDataset(
        data_dir='../datasets/NER/sdNER/task10_data',
        tokenizer=tokenizer,
        labels=labels,
        model_type=config.model_type,
        max_seq_length=256,
        overwrite_cache=True,
        mode=Split.train,
    )
)
print(len(train_dataset)) 
eval_dataset = (
    NerDataset(
        data_dir='../datasets/NER/sdNER/task10_data',
        tokenizer=tokenizer,
        labels=labels,
        model_type=config.model_type,
        max_seq_length=256,
        overwrite_cache=True,
        mode=Split.dev,
    )
)

89963


In [None]:
def align_predictions(predictions: np.ndarray, label_ids: np.ndarray) -> Tuple[List[int], List[int]]:
    preds = np.argmax(predictions, axis=2)

    batch_size, seq_len = preds.shape

    out_label_list = [[] for _ in range(batch_size)]
    preds_list = [[] for _ in range(batch_size)]
    
    for i in range(batch_size):
        for j in range(seq_len):
            if label_ids[i, j] != nn.CrossEntropyLoss().ignore_index:
                out_label_list[i].append(label_map[label_ids[i][j]])
                preds_list[i].append(label_map[preds[i][j]])

    return preds_list, out_label_list

def compute_metrics(p: EvalPrediction) -> Dict:
    preds_list, out_label_list = align_predictions(p.predictions, p.label_ids)
    
    return {
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
    }


In [9]:
for d in train_dataset:
    print(d)
    break

InputFeatures(input_ids=[101, 225, 8362, 1161, 107, 1425, 22313, 107, 1231, 25019, 4487, 27466, 21359, 6870, 2225, 175, 23105, 136, 117, 11902, 1389, 3263, 8468, 1185, 20347, 15027, 1995, 2758, 8362, 172, 1874, 2093, 185, 19773, 15027, 186, 6592, 10771, 1231, 25019, 26852, 119, 1508, 1394, 131, 187, 1361, 1465, 5871, 1231, 25019, 4487, 2572, 2495, 5748, 1611, 191, 7409, 9291, 14255, 4487, 2495, 1884, 18312, 118, 1627, 4035, 8468, 182, 6775, 1186, 18630, 131, 120, 120, 189, 119, 1884, 120, 124, 1665, 1179, 2246, 4426, 2087, 1643, 1559, 1964, 1181, 170, 189, 22116, 1116, 1260, 137, 174, 13488, 14426, 7346, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [26]:
#n = d.input_ids.index(0)
n = 256
src = []
trg = []
for i in range(n+1):
    for j in range(i+1, n+1):
        src.append(i)
        trg.append(j)

In [27]:
len(src)

32896

In [22]:
from torch_geometric.data import Data
edge_index = torch.tensor([src, trg], dtype=torch.long)
data = Data(x=None, edge_index=edge_index.t().contiguous())
