# Few-shot Relation Extraction Tutorial

> Tutorial作者: 黎洲波（zhoubo.li@zju.edu.cn）

In this tutorial, we use [KnowPrompt](https://arxiv.org/abs/2104.07650v5) to extract relational triples after being trained on few-shot datasets. We hope this tutorial can help you understand the process of few-shot relation extraction.

This tutorial uses `Python3`.

## RE
**Relation extraction** (RE), a key task in information extraction, predicts semantic relations between pairs of entities from unstructured
texts.

## Few-shot RE
Few-shot relation extraction in DeepKE is based on the *pre-train, prompt, and predict* paradigm, feeds Prompt parameters into attention and cross-attention in BERT and fine-tunes them on few-shot datasets, which obtains excellent performance on the low-resource scenario. The prompt-tuning method is shown in the following picture:
![关系抽取中的Prompt-tuning](img/img1.png)

## Dataset
There are some few-shot RE datasets including RETACRED, SEMEVAL, TACREV, WIKI80, .etc. The tutorial uses [SEMEVAL](https://semeval2.fbk.eu/semeval2.php?location=tasks#T11), which is from [SemEval-2010 Task 8: Multi-Way Classification of Semantic Relations between Pairs of Nominals](https://arxiv.org/abs/1911.10422). The structure of the dataset folder `./data/` is as follow:

```
.
├── rel2id.json                     # Relation Label - ID Map
├── temp.txt                        # Relation Label
├── test.txt                        # Test Set
├── train.txt                       # Training Set
└── val.txt                         # Validation Set
```

The data formats of SEMEVAL are described as follow:

```
Data Format:
{
    'token': [tokens in a sentence],
    "h": {
        "name": mention_name,
        "pos" : [postion of mention in a sentence]
    },
    "t": {
        "name": mention_name,
        "pos" : [postion of mention in a sentence]
    },
    "relation": relation
}
```
There are 9+1 relation types in SEMEVAL and their proportions are shown in the following table:

![数据集数据占比](img/img2.png)

## KnowPrompt
In DeepKE, we use the Prompt method that can parse relational labels semantically, which is called Knowledge-aware Prompt-tuning (KnowPrompt). The frameworks of Fine-tuning (Fig. a), Prompt-tuning (Fig. b) and KnowPrompt (Fig. c) we use are in the following picture. The answer words in Prompt refer to virtual answer words.

![低资源关系抽取架构图](img/img3.png)

## Prepare the runtime environment

In [None]:
!pip install deepke
!wget 120.27.214.45/Data/re/few_shot/data.tar.gz
!tar -xzvf data.tar.gz

## Import modules

In [None]:
import os
import json
import csv
import time
import pickle
import logging
import shutil
import numpy as np
from tqdm import tqdm
from functools import partial
from collections import Counter
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, TensorDataset
from transformers import AutoConfig, AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
from transformers import BertTokenizer, AutoModelForMaskedLM

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s', 
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

## Config parameters

In [None]:
class Config(object):
    accelerator = None
    accumulate_grad_batches = 1
    amp_backend = 'native'
    amp_level = 'O2'
    auto_lr_find = False
    auto_scale_batch_size = False
    auto_select_gpus = False
    batch_size = 16
    benchmark = False
    check_val_every_n_epoch = '3'
    checkpoint_callback = True
    data_class = 'REDataset'
    data_dir = 'data/'
    default_root_dir = None
    deterministic = False
    devices = None
    distributed_backend = None
    fast_dev_run = False
    flush_logs_every_n_steps = 100
    gpus = None
    gradient_accumulation_steps = 1
    gradient_clip_algorithm = 'norm'
    gradient_clip_val = 0.0
    ipus = None
    limit_predict_batches = 1.0
    limit_test_batches = 1.0
    limit_train_batches = 1.0
    limit_val_batches = 1.0
    litmodel_class = 'BertLitModel'
    load_checkpoint = None
    log_dir = './model_bert.log'
    log_every_n_steps = 50
    log_gpu_memory = None
    logger = True
    lr = 3e-05
    lr_2 = 3e-05
    max_epochs = '30'
    max_seq_length = 256
    max_steps = None
    max_time = None
    min_epochs = None
    min_steps = None
    model_class = 'BertForMaskedLM'
    model_name_or_path = 'bert-base-uncased'
    move_metrics_to_cpu = False
    multiple_trainloader_mode = 'max_size_cycle'
    num_nodes = 1
    num_processes = 1
    num_sanity_val_steps = 2
    num_train_epochs = 30
    num_workers = 8
    optimizer = 'AdamW'
    overfit_batches = 0.0
    plugins = None
    precision = 32
    prepare_data_per_node = True
    process_position = 0
    profiler = None
    progress_bar_refresh_rate = None
    ptune_k = 7
    reload_dataloaders_every_epoch = False
    reload_dataloaders_every_n_epochs = 0
    replace_sampler_ddp = True
    resume_from_checkpoint = None
    save_path = './model_bert.pt'
    seed = 666
    stochastic_weight_avg = False
    sync_batchnorm = False
    t_lambda = 0.001
    task_name = 'wiki80'
    terminate_on_nan = False
    tpu_cores = None
    track_grad_norm = -1
    train_from_saved_model = ''
    truncated_bptt_steps = None
    two_steps = False
    use_prompt = True
    val_check_interval = 1.0
    wandb = False
    weight_decay = 0.01
    weights_save_path = None
    weights_summary = 'top'
    load_path = './model_bert.pt'
    
cfg = Config()

## Preprocess Dataset

In [None]:
class InputExampleWiki80(object):
    """A single training/test example for span pair classification."""

    def __init__(self, guid, sentence, span1, span2, ner1, ner2, label):
        self.guid = guid
        self.sentence = sentence
        self.span1 = span1
        self.span2 = span2
        self.ner1 = ner1
        self.ner2 = ner2
        self.label = label

class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                lines.append(line)
            return lines

class wiki80Processor(DataProcessor):
    """Processor for the TACRED data set."""
    def __init__(self, tokenizer, data_path, use_prompt):
        super().__init__()
        self.data_dir = data_path

    @classmethod
    def _read_json(cls, input_file):
        data = []
        with open(input_file, "r", encoding='utf-8') as reader:
            all_lines = reader.readlines()
            for line in all_lines:
                ins = eval(line)
                data.append(ins)
        return data

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_json(os.path.join(data_dir, "train.txt")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_json(os.path.join(data_dir, "val.txt")), "dev")

    def get_test_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_json(os.path.join(data_dir, "test.txt")), "test")

    def get_labels(self, negative_label="no_relation"):
        data_dir = self.data_dir
        """See base class."""
        # if 'k-shot' in self.data_dir:
        #     data_dir = os.path.abspath(os.path.join(self.data_dir, "../.."))
        # else:
        #     data_dir = self.data_dir
        with open(os.path.join(data_dir,'rel2id.json'), "r", encoding='utf-8') as reader:
            re2id = json.load(reader)
        return re2id


    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for example in dataset:
            sentence = example['token']
            examples.append(InputExampleWiki80(guid=None,
                            sentence=sentence,
                            # maybe some bugs here, I don't -1
                            span1=(example['h']['pos'][0], example['h']['pos'][1]),
                            span2=(example['t']['pos'][0], example['t']['pos'][1]),
                            ner1=None,
                            ner2=None,
                            label=example['relation']))
        return examples


    def _create_examples(self, dataset, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for example in dataset:
            sentence = example['token']
            examples.append(InputExampleWiki80(guid=None,
                            sentence=sentence,
                            # maybe some bugs here, I don't -1
                            span1=(example['h']['pos'][0], example['h']['pos'][1]),
                            span2=(example['t']['pos'][0], example['t']['pos'][1]),
                            ner1=None,
                            ner2=None,
                            label=example['relation']))
        return examples

class BaseDataModule(nn.Module):
    """
    Base DataModule.
    """

    def __init__(self, cfg) -> None:
        super().__init__()
        self.cfg = cfg if cfg is not None else {}
        self.batch_size = self.cfg.batch_size
        self.num_workers = self.cfg.num_workers

    def get_data_config(self):
        """Return important settings of the dataset, which will be passed to instantiate models."""
        return { "num_labels": self.num_labels}

    def prepare_data(self):
        """
        Use this method to do things that might write to disk or that need to be done only from a single GPU in distributed settings (so don't set state `self.x = y`).
        """
        pass

    def setup(self, stage=None):
        """
        Split into train, val, test, and set dims.
        Should assign `torch Dataset` objects to self.data_train, self.data_val, and optionally self.data_test.
        """
        self.data_train = None
        self.data_val = None
        self.data_test = None

    def train_dataloader(self):
        return DataLoader(self.data_train, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.data_val, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)

    def test_dataloader(self):
        return DataLoader(self.data_test, shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, drop_last=True)

def convert_examples_to_features(examples, max_seq_length, tokenizer, cfg, rel2id):
    """Loads a data file into a list of `InputBatch`s."""

    save_file = "data/cached_wiki80.pkl"
    mode = "text"

    num_tokens = 0
    num_fit_examples = 0
    num_shown_examples = 0
    instances = []
    
    use_bert = "BertTokenizer" in tokenizer.__class__.__name__
    use_gpt = "GPT" in tokenizer.__class__.__name__
    
    assert not (use_bert and use_gpt), "model cannot be gpt and bert together"

    if False:
        with open(file=save_file, mode='rb') as fr:
            instances = pickle.load(fr)
        print('load preprocessed data from {}.'.format(save_file))

    else:
        print('loading..')
        for (ex_index, example) in enumerate(examples):
            

            """
                the relation between SUBJECT and OBJECT is .
                
            """

            if ex_index % 10000 == 0:
                logger.info("Writing example %d of %d" % (ex_index, len(examples)))

            tokens = []
            SUBJECT_START = "[subject_start]"
            SUBJECT_END = "[subject_end]"
            OBJECT_START = "[object_start]"
            OBJECT_END = "[object_end]"
            

            if mode.startswith("text"):
                for i, token in enumerate(example.sentence):
                    if i == example.span1[0]:
                        tokens.append(SUBJECT_START)
                    if i == example.span2[0]:
                        tokens.append(OBJECT_START)
                    # for sub_token in tokenizer.tokenize(token):
                    #     tokens.append(sub_token)
                    if i == example.span1[1]:
                        tokens.append(SUBJECT_END)
                    if i == example.span2[1]:
                        tokens.append(OBJECT_END)

                    tokens.append(token)

            SUBJECT = " ".join(example.sentence[example.span1[0]: example.span1[1]])
            OBJECT = " ".join(example.sentence[example.span2[0]: example.span2[1]])
            SUBJECT_ids = tokenizer(" "+SUBJECT, add_special_tokens=False)['input_ids']
            OBJECT_ids = tokenizer(" "+OBJECT, add_special_tokens=False)['input_ids']
            
            if use_gpt:
                if cfg.CT_CL:
                    prompt = f"[T1] [T2] [T3] [sub] {OBJECT} [sub] [T4] [obj] {SUBJECT} [obj] [T5] {tokenizer.cls_token}"
                else:
                    prompt = f"The relation between [sub] {SUBJECT} [sub] and [obj] {OBJECT} [obj] is {tokenizer.cls_token} ."
            else:
                # add prompt [T_n] and entity marker [obj] to enrich the context.
                prompt = f"[sub] {SUBJECT} [sub] {tokenizer.mask_token} [obj] {OBJECT} [obj] ."
            
            if ex_index == 0:
                input_text = " ".join(tokens)
                logger.info(f"input text : {input_text}")
                logger.info(f"prompt : {prompt}")
                logger.info(f"label : {example.label}")
            inputs = tokenizer(
                prompt,
                " ".join(tokens),
                truncation="longest_first",
                max_length=max_seq_length,
                padding="max_length",
                add_special_tokens=True
            )
            if use_gpt: cls_token_location = inputs['input_ids'].index(tokenizer.cls_token_id) 
            
            # find the subject and object tokens, choose the first ones
            sub_st = sub_ed = obj_st = obj_ed = -1
            for i in range(len(inputs['input_ids'])):
                if sub_st == -1 and inputs['input_ids'][i:i+len(SUBJECT_ids)] == SUBJECT_ids:
                    sub_st = i
                    sub_ed = i + len(SUBJECT_ids)
                if obj_st == -1 and inputs['input_ids'][i:i+len(OBJECT_ids)] == OBJECT_ids:
                    obj_st = i
                    obj_ed = i + len(OBJECT_ids)
            
            assert sub_st != -1 and obj_st != -1


            num_tokens += sum(inputs['attention_mask'])


            if sum(inputs['attention_mask']) > max_seq_length:
                pass
                # tokens = tokens[:max_seq_length]
            else:
                num_fit_examples += 1

            x = OrderedDict()
            x['input_ids'] = inputs['input_ids']
            if use_bert: x['token_type_ids'] = inputs['token_type_ids']
            x['attention_mask'] = inputs['attention_mask']
            x['label'] = rel2id[example.label]
            if use_gpt: x['cls_token_location'] = cls_token_location
            x['so'] =[sub_st, sub_ed, obj_st, obj_ed]

            instances.append(x)


        with open(file=save_file, mode='wb') as fw:
            pickle.dump(instances, fw)
        print('Finish save preprocessed data to {}.'.format( save_file))

    input_ids = [o['input_ids'] for o in instances]
    attention_mask = [o['attention_mask'] for o in instances]
    if use_bert: token_type_ids = [o['token_type_ids'] for o in instances]
    if use_gpt: cls_idx = [o['cls_token_location'] for o in instances]
    labels = [o['label'] for o in instances]
    so = torch.tensor([o['so'] for o in instances])


    input_ids = torch.tensor(input_ids)
    attention_mask = torch.tensor(attention_mask)
    if use_gpt: cls_idx = torch.tensor(cls_idx)
    if use_bert: token_type_ids = torch.tensor(token_type_ids)
    labels = torch.tensor(labels)

    logger.info("Average #tokens: %.2f" % (num_tokens * 1.0 / len(examples)))
    logger.info("%d (%.2f %%) examples can fit max_seq_length = %d" % (num_fit_examples,
                num_fit_examples * 100.0 / len(examples), max_seq_length))

    if use_gpt:
        dataset = TensorDataset(input_ids, attention_mask, cls_idx, labels)
    elif use_bert:
        dataset = TensorDataset(input_ids, attention_mask, token_type_ids, labels, so)
    else:
        dataset = TensorDataset(input_ids, attention_mask, labels, so)
    
    return dataset

def get_dataset(mode, cfg, tokenizer, processor):

    if mode == "train":
        examples = processor.get_train_examples(cfg.data_dir)
    elif mode == "dev":
        examples = processor.get_dev_examples(cfg.data_dir)
    elif mode == "test":
        examples = processor.get_test_examples(cfg.data_dir)
    else:
        raise Exception("mode must be in choice [trian, dev, test]")
    gpt_mode = "wiki80" in cfg.task_name
    # normal relation extraction task
    dataset = convert_examples_to_features(
        examples, cfg.max_seq_length, tokenizer, cfg, processor.get_labels()
    )
    return dataset

class REDataset(BaseDataModule):
    def __init__(self, cfg) -> None:
        super().__init__(cfg)
        
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name_or_path)
        self.processor = wiki80Processor(self.tokenizer, self.cfg.data_dir, self.cfg.use_prompt)
        
        use_gpt = "gpt" in cfg.model_name_or_path

        rel2id = self.processor.get_labels()
        self.num_labels = len(rel2id)

        entity_list = ["[object_start]", "[object_end]", "[subject_start]", "[subject_end]"]
        class_list = [f"[class{i}]" for i in range(1, self.num_labels+1)]

        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': entity_list})
        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': class_list})
        if use_gpt:
            self.tokenizer.add_special_tokens({'cls_token': "[CLS]"})
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        so_list = ["[sub]", "[obj]"]
        num_added_tokens = self.tokenizer.add_special_tokens({'additional_special_tokens': so_list})

        prompt_tokens = [f"[T{i}]" for i in range(1,6)]
        self.tokenizer.add_special_tokens({'additional_special_tokens': prompt_tokens})




    def setup(self, stage=None):
        self.data_train = get_dataset("train", self.cfg, self.tokenizer, self.processor)
        self.data_val = get_dataset("dev", self.cfg, self.tokenizer, self.processor)
        self.data_test = get_dataset("test", self.cfg, self.tokenizer, self.processor)


    def prepare_data(self):
        pass

    def get_tokenizer(self):
        return self.tokenizer

## Metric Functions

In [None]:
def dialog_f1_eval(logits, labels):
    def getpred(result, T1=0.5, T2=0.4):
        # 使用阈值得到preds, result = logits
        # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的
        ret = []
        for i in range(len(result)):
            r = []
            maxl, maxj = -1, -1
            for j in range(len(result[i])):
                if result[i][j] > T1:
                    r += [j]
                if result[i][j] > maxl:
                    maxl = result[i][j]
                    maxj = j
            if len(r) == 0:
                if maxl <= T2:
                    r = [36]
                else:
                    r += [maxj]
            ret.append(r)
        return ret

    def geteval(devp, data):
        correct_sys, all_sys = 0, 0
        correct_gt = 0

        for i in range(len(data)):
            # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1， 如果没有就是[36]
            for id in data[i]:
                if id != 36:
                    # 标签中 1 的个数
                    correct_gt += 1
                    if id in devp[i]:
                        # 预测正确
                        correct_sys += 1

            for id in devp[i]:
                if id != 36:
                    all_sys += 1

        precision = 1 if all_sys == 0 else correct_sys / all_sys
        recall = 0 if correct_gt == 0 else correct_sys / correct_gt
        f_1 = 2 * precision * recall / (precision + recall) if precision + recall != 0 else 0
        return f_1

    logits = np.asarray(logits)
    logits = list(1 / (1 + np.exp(-logits)))

    temp_labels = []
    for l in labels:
        t = []
        for i in range(36):
            if l[i] == 1:
                t += [i]
        if len(t) == 0:
            t = [36]
        temp_labels.append(t)
    assert (len(labels) == len(logits))
    labels = temp_labels

    bestT2 = bestf_1 = 0
    for T2 in range(51):
        devp = getpred(logits, T2=T2 / 100.)
        f_1 = geteval(devp, labels)
        if f_1 > bestf_1:
            bestf_1 = f_1
            bestT2 = T2 / 100.

    return dict(f1=bestf_1, T2=bestT2)



def f1_eval(logits, labels):
    def getpred(result, T1 = 0.5, T2 = 0.4) :
        # 使用阈值得到preds, result = logits
        # T2 表示如果都低于T2 那么就是 no relation, 否则选取一个最大的
        ret = []
        for i in range(len(result)):
            r = []
            maxl, maxj = -1, -1
            for j in range(len(result[i])):
                if result[i][j] > T1:
                    r += [j]
                if result[i][j] > maxl:
                    maxl = result[i][j]
                    maxj = j
            if len(r) == 0:
                if maxl <= T2:
                    r = [36]
                else:
                    r += [maxj]
            ret.append(r)
        return ret

    def geteval(devp, data):
        correct_sys, all_sys = 0, 0
        correct_gt = 0
        
        for i in range(len(data)):
            # 每一个样本 都是[1,4,...,20] 表示有1,4,20 是1， 如果没有就是[36]
            for id in data[i]:
                if id != 36:
                    # 标签中 1 的个数
                    correct_gt += 1
                    if id in devp[i]:
                        # 预测正确
                        correct_sys += 1

            for id in devp[i]:
                if id != 36:
                    all_sys += 1

        precision = 1 if all_sys == 0 else correct_sys/all_sys
        recall = 0 if correct_gt == 0 else correct_sys/correct_gt
        f_1 = 2*precision*recall/(precision+recall) if precision+recall != 0 else 0
        return f_1

    logits = np.asarray(logits)
    logits = list(1 / (1 + np.exp(-logits)))

    temp_labels = []
    for l in labels:
        t = []
        for i in range(36):
            if l[i] == 1:
                t += [i]
        if len(t) == 0:
            t = [36]
        temp_labels.append(t)
    assert(len(labels) == len(logits))
    labels = temp_labels
    
    bestT2 = bestf_1 = 0
    for T2 in range(51):
        devp = getpred(logits, T2=T2/100.)
        f_1 = geteval(devp, labels)
        if f_1 > bestf_1:
            bestf_1 = f_1
            bestT2 = T2/100.

    return bestf_1, bestT2


def f1_score(output, label, rel_num=42, na_num=13):
    correct_by_relation = Counter()
    guess_by_relation = Counter()
    gold_by_relation = Counter()
    output = np.argmax(output, axis=-1)

    for i in range(len(output)):
        guess = output[i]
        gold = label[i]

        if guess == na_num:
            guess = 0
        elif guess < na_num:
            guess += 1

        if gold == na_num:
            gold = 0
        elif gold < na_num:
            gold += 1

        if gold == 0 and guess == 0:
            continue
        if gold == 0 and guess != 0:
            guess_by_relation[guess] += 1
        if gold != 0 and guess == 0:
            gold_by_relation[gold] += 1
        if gold != 0 and guess != 0:
            guess_by_relation[guess] += 1
            gold_by_relation[gold] += 1
            if gold == guess:
                correct_by_relation[gold] += 1
    
    f1_by_relation = Counter()
    recall_by_relation = Counter()
    prec_by_relation = Counter()
    for i in range(1, rel_num):
        recall = 0
        if gold_by_relation[i] > 0:
            recall = correct_by_relation[i] / gold_by_relation[i]
        precision = 0
        if guess_by_relation[i] > 0:
            precision = correct_by_relation[i] / guess_by_relation[i]
        if recall + precision > 0 :
            f1_by_relation[i] = 2 * recall * precision / (recall + precision)
        recall_by_relation[i] = recall
        prec_by_relation[i] = precision

    micro_f1 = 0
    if sum(guess_by_relation.values()) != 0 and sum(correct_by_relation.values()) != 0:
        recall = sum(correct_by_relation.values()) / sum(gold_by_relation.values())
        prec = sum(correct_by_relation.values()) / sum(guess_by_relation.values())    
        micro_f1 = 2 * recall * prec / (recall+prec)

    return dict(f1=micro_f1)

## Model Construction

### Base Model Class

In [None]:
OPTIMIZER = "AdamW"
LR = 5e-5
LOSS = "cross_entropy"
ONE_CYCLE_TOTAL_STEPS = 100

class BaseLitModel(nn.Module):
    """
    Generic PyTorch-Lightning class that must be initialized with a PyTorch module.
    """

    def __init__(self, model, cfg, device: str = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')):
        super().__init__()
        self.model = model
        self.cur_model = model.module if hasattr(model, 'module') else model
        self.device = device
        self.cfg = cfg if cfg is not None else {}

        optimizer = self.cfg.optimizer
        self.optimizer_class = getattr(torch.optim, optimizer)
        self.lr = self.cfg.lr

    def configure_optimizers(self):
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        if self.one_cycle_max_lr is None:
            return optimizer
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, max_lr=self.one_cycle_max_lr, total_steps=self.one_cycle_total_steps)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"}

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        x, y = batch
        x.to(self.device)
        logits = x
        loss = (logits - y) ** 2
        print("train_loss: ", loss)
        #self.train_acc(logits, y)
        #self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        x, y = batch
        x.to(self.device)
        logits = x
        loss = (logits - y) ** 2
        print("val_loss: ", loss)

    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        x, y = batch
        x.to(self.device)
        logits = x
        loss = (logits - y) ** 2
        print("test_loss: ", loss)

    def configure_optimizers(self):
        no_decay_param = ["bias", "LayerNorm.weight"]

        optimizer_group_parameters = [
            {"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay_param)], "weight_decay": self.cfg.weight_decay},
            {"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay_param)], "weight_decay": 0}
        ]

        
        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
        #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.num_training_steps * 0.1, num_training_steps=self.num_training_steps)
        return optimizer

### Model Subclass

In [None]:
def multilabel_categorical_crossentropy(y_pred, y_true):
    y_pred = (1 - 2 * y_true) * y_pred
    y_pred_neg = y_pred - y_true * 1e12
    y_pred_pos = y_pred - (1 - y_true) * 1e12
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return (neg_loss + pos_loss).mean()

class BertLitModel(BaseLitModel):
    """
    use AutoModelForMaskedLM, and select the output by another layer in the lit model
    """
    def __init__(self, model, cfg, tokenizer):
        super().__init__(model, cfg)
        self.tokenizer = tokenizer
        with open(f"{cfg.data_dir}/rel2id.json","r") as file:
            rel2id = json.load(file)
        
        Na_num = 0
        for k, v in rel2id.items():
            if k == "NA" or k == "no_relation" or k == "Other":
                Na_num = v
                break
        num_relation = len(rel2id)
        # init loss function
        self.loss_fn = multilabel_categorical_crossentropy if "dialogue" in cfg.data_dir else nn.CrossEntropyLoss()
        # ignore the no_relation class to compute the f1 score
        self.eval_fn = f1_eval if "dialogue" in cfg.data_dir else partial(f1_score, rel_num=num_relation, na_num=Na_num)
        self.best_f1 = 0
        self.t_lambda = cfg.t_lambda
        
        self.label_st_id = tokenizer("[class1]", add_special_tokens=False)['input_ids'][0]
    
        self._init_label_word()

    def _init_label_word(self):
        cfg = self.cfg
        # ./dataset/dataset_name
        dataset_name = cfg.data_dir.split("/")[1]
        model_name_or_path = cfg.model_name_or_path.split("/")[-1]
        label_path = f"data/{model_name_or_path}.pt"
        # [num_labels, num_tokens], ignore the unanswerable
        if "dialogue" in cfg.data_dir:
            label_word_idx = torch.load(label_path)[:-1]
        else:
            label_word_idx = torch.load(label_path)
        
        num_labels = len(label_word_idx)
        
        self.cur_model.resize_token_embeddings(len(self.tokenizer))
        with torch.no_grad():
            word_embeddings = self.cur_model.get_input_embeddings()
            continous_label_word = [a[0] for a in self.tokenizer([f"[class{i}]" for i in range(1, num_labels+1)], add_special_tokens=False)['input_ids']]
            for i, idx in enumerate(label_word_idx):
                word_embeddings.weight[continous_label_word[i]] = torch.mean(word_embeddings.weight[idx], dim=0)
                # word_embeddings.weight[continous_label_word[i]] = self.relation_embedding[i]
            so_word = [a[0] for a in self.tokenizer(["[obj]","[sub]"], add_special_tokens=False)['input_ids']]
            meaning_word = [a[0] for a in self.tokenizer(["person","organization", "location", "date", "country"], add_special_tokens=False)['input_ids']]
            
            for i, idx in enumerate(so_word):
                word_embeddings.weight[so_word[i]] = torch.mean(word_embeddings.weight[meaning_word], dim=0)
            assert torch.equal(self.cur_model.get_input_embeddings().weight, word_embeddings.weight)
            assert torch.equal(self.cur_model.get_input_embeddings().weight, self.cur_model.get_output_embeddings().weight)
        
        self.word2label = continous_label_word # a continous list
            
                
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, token_type_ids , labels, so = batch
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        token_type_ids = token_type_ids.to(self.device)
        labels = labels.to(self.device)
        so = so.to(self.device)
        result = self.model(input_ids, attention_mask, token_type_ids, return_dict=True, output_hidden_states=True)
        logits = result.logits
        output_embedding = result.hidden_states[-1]
        logits = self.pvp(logits, input_ids)
        loss = self.loss_fn(logits, labels) + self.t_lambda * self.ke_loss(output_embedding, labels, so)
        #print("Train/loss: ", loss)
        return loss

    def validation_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, token_type_ids , labels, _ = batch
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        token_type_ids = token_type_ids.to(self.device)
        labels = labels.to(self.device)
        logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits
        logits = self.pvp(logits, input_ids)
        loss = self.loss_fn(logits, labels)
        #print("Eval/loss: ", loss)
        return {"loss": loss, "eval_logits": logits.detach().cpu().numpy(), "eval_labels": labels.detach().cpu().numpy()}
    
    def validation_epoch_end(self, outputs):
        logits = np.concatenate([o["eval_logits"] for o in outputs])
        labels = np.concatenate([o["eval_labels"] for o in outputs])

        f1 = self.eval_fn(logits, labels)['f1']
        #print("Eval/f1: ", f1)
        best_f1 = -1
        if f1 > self.best_f1:
            self.best_f1 = f1
            best_f1 = self.best_f1
        #print("Eval/best_f1: ", self.best_f1)
        return f1, best_f1, self.best_f1

    def test_step(self, batch, batch_idx):  # pylint: disable=unused-argument
        input_ids, attention_mask, token_type_ids , labels, _ = batch
        input_ids = input_ids.to(self.device)
        attention_mask = attention_mask.to(self.device)
        token_type_ids = token_type_ids.to(self.device)
        labels = labels.to(self.device)
        logits = self.model(input_ids, attention_mask, token_type_ids, return_dict=True).logits
        logits = self.pvp(logits, input_ids)
        return {"test_logits": logits.detach().cpu().numpy(), "test_labels": labels.detach().cpu().numpy()}

    def test_epoch_end(self, outputs):
        logits = np.concatenate([o["test_logits"] for o in outputs])
        labels = np.concatenate([o["test_labels"] for o in outputs])

        f1 = self.eval_fn(logits, labels)['f1']
        #print("Test/f1: ", f1)
        return f1
        
    def pvp(self, logits, input_ids):
        # convert the [batch_size, seq_len, vocab_size] => [batch_size, num_labels]
        mask_id = self.tokenizer(self.tokenizer.mask_token, add_special_tokens = False)['input_ids'][0]
        _, mask_idx = (input_ids == mask_id).nonzero(as_tuple=True)
        bs = input_ids.shape[0]
        mask_output = logits[torch.arange(bs), mask_idx]
        assert mask_idx.shape[0] == bs, "only one mask in sequence!"
        final_output = mask_output[:,self.word2label]
        
        return final_output
        
    def ke_loss(self, logits, labels, so):
        subject_embedding = []
        object_embedding = []
        bsz = logits.shape[0]
        for i in range(bsz):
            subject_embedding.append(torch.mean(logits[i, so[i][0]:so[i][1]], dim=0))
            object_embedding.append(torch.mean(logits[i, so[i][2]:so[i][3]], dim=0))
            
        subject_embedding = torch.stack(subject_embedding)
        object_embedding = torch.stack(object_embedding)
        # trick , the relation ids is concated, 
        relation_embedding = self.cur_model.get_output_embeddings().weight[labels+self.label_st_id]
        
        loss = torch.norm(subject_embedding + relation_embedding - object_embedding, p=2)
        
        return loss

    def configure_optimizers(self):
        no_decay_param = ["bias", "LayerNorm.weight"]

        if not self.cfg.two_steps: 
            parameters = self.cur_model.named_parameters()
        else:
            # cur_model.bert.embeddings.weight
            parameters = [next(self.cur_model.named_parameters())]
        # only optimize the embedding parameters
        optimizer_group_parameters = [
            {"params": [p for n, p in parameters if not any(nd in n for nd in no_decay_param)], "weight_decay": self.cfg.weight_decay},
            {"params": [p for n, p in parameters if any(nd in n for nd in no_decay_param)], "weight_decay": 0}
        ]

        
        optimizer = self.optimizer_class(optimizer_group_parameters, lr=self.lr, eps=1e-8)
        return optimizer

## Preprocess the inputs

### Few-shot sampling

In [None]:
Seed = [1, 2, 3, 4, 5]
mode = 'k-shot'
data_file = 'train.txt'

def get_labels(path, name,  negative_label="no_relation"):
    """See base class."""

    count = Counter()
    with open(path + "/" + name, "r") as f:
        features = []
        for line in f.readlines():
            line = line.rstrip()
            if len(line) > 0:
                # count[line['relation']] += 1
                features.append(eval(line))

    # logger.info("label distribution as list: %d labels" % len(count))
    # # Make sure the negative label is alwyas 0
    # labels = []
    # for label, count in count.most_common():
    #     logger.info("%s: %d 个 %.2f%%" % (label, count,  count * 100.0 / len(dataset)))
    #     if label not in labels:
    #         labels.append(label)
    return features

path = 'data'

output_dir = os.path.join(path, mode)
dataset = get_labels(path, data_file)

for seed in Seed:

    # Other datasets
    np.random.seed(seed)
    np.random.shuffle(dataset)

    # Set up dir
    k = 8
    setting_dir = os.path.join(output_dir, f"{k}-{seed}")
    os.makedirs(setting_dir, exist_ok=True)

    label_list = {}
    for line in dataset:
        label = line['relation']
        if label not in label_list:
            label_list[label] = [line]
        else:
            label_list[label].append(line)

    with open(os.path.join(setting_dir, "train.txt"), "w") as f:
        file_list = []
        for label in label_list:
            for line in label_list[label][:k]:  # train中每一类取前k个数据
                f.writelines(json.dumps(line))
                f.write('\n')

        f.close()

shutil.copyfile('data/rel2id.json','data/k-shot/8-1/rel2id.json')
shutil.copyfile('data/val.txt','data/k-shot/8-1/val.txt')
shutil.copyfile('data/test.txt','data/k-shot/8-1/test.txt')

### Obtain labels

In [None]:
def split_label_words(tokenizer, label_list):
    label_word_list = []
    for label in label_list:
        if label == 'no_relation':
            label_word_id = tokenizer.encode('None', add_special_tokens=False)
            label_word_list.append(torch.tensor(label_word_id))
        else:
            tmps = label
            label = label.lower()
            label = label.split("(")[0]
            label = label.replace(":"," ").replace("_"," ").replace("per","person").replace("org","organization")
            label_word_id = tokenizer(label, add_special_tokens=False)['input_ids']
            print(label, label_word_id)
            label_word_list.append(torch.tensor(label_word_id))
    padded_label_word_list = pad_sequence([x for x in label_word_list], batch_first=True, padding_value=0)
    return padded_label_word_list


model_name_or_path = cfg.model_name_or_path

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
with open("data/rel2id.json", "r") as file:
    t = json.load(file)
    label_list = list(t)

t = split_label_words(tokenizer, label_list)

with open(f"data/{model_name_or_path}.pt", "wb") as file:
    torch.save(t, file)

## Auxiliary functions

In [None]:
def set_seed(cfg):
    torch.cuda.manual_seed_all(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)
    torch.cuda.manual_seed_all(cfg.seed)

def logging(log_dir, s, print_=True, log_=True):
    if print_:
        print(s)
    if log_dir != '' and log_:
        with open(log_dir, 'a+') as f_log:
            f_log.write(s + '\n')

## Train the model
### Model training

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

data = REDataset(cfg)
data_config = data.get_data_config()

config = AutoConfig.from_pretrained(cfg.model_name_or_path)
config.num_labels = data_config["num_labels"]

model = AutoModelForMaskedLM.from_pretrained(cfg.model_name_or_path, config=config)


    
# if torch.cuda.device_count() > 1:
#     print("Let's use", torch.cuda.device_count(), "GPUs!")
#     model = torch.nn.DataParallel(model, device_ids = list(range(torch.cuda.device_count())))

model.to(device)

lit_model = BertLitModel(model, cfg, data.tokenizer)
data.setup()

if cfg.train_from_saved_model != '':
    model.load_state_dict(torch.load(cfg.train_from_saved_model)["checkpoint"])
    print("load saved model from {}.".format(cfg.train_from_saved_model))
    lit_model.best_f1 = torch.load(cfg.train_from_saved_model)["best_f1"]
#data.tokenizer.save_pretrained('test')


optimizer = lit_model.configure_optimizers()
if cfg.train_from_saved_model != '':
    optimizer.load_state_dict(torch.load(cfg.train_from_saved_model)["optimizer"])
    print("load saved optimizer from {}.".format(cfg.train_from_saved_model))

num_training_steps = len(data.train_dataloader()) // cfg.gradient_accumulation_steps * cfg.num_train_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_training_steps * 0.1, num_training_steps=num_training_steps)
log_step = 20


logging(cfg.log_dir,'-' * 89, print_=True)
logging(cfg.log_dir, time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + ' INFO : START TO TRAIN ', print_=True)
logging(cfg.log_dir,'-' * 89, print_=True)

for epoch in range(cfg.num_train_epochs):
    model.train()
    num_batch = len(data.train_dataloader())
    total_loss = 0
    log_loss = 0
    for index, train_batch in enumerate(data.train_dataloader()):
        loss = lit_model.training_step(train_batch, index)
        total_loss += loss.item()
        log_loss += loss.item()
        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        if log_step > 0 and (index+1) % log_step == 0:
            cur_loss = log_loss / log_step
            logging(cfg.log_dir, 
                '| epoch {:2d} | step {:4d} | lr {} | train loss {:5.3f}'.format(
                    epoch, (index+1), scheduler.get_last_lr(), cur_loss)
                , print_=True)
            log_loss = 0
    avrg_loss = total_loss / num_batch
    logging(cfg.log_dir,
        '| epoch {:2d} | train loss {:5.3f}'.format(
            epoch, avrg_loss))
        
    model.eval()
    with torch.no_grad():
        val_loss = []
        for val_index, val_batch in enumerate(tqdm(data.val_dataloader())):
            loss = lit_model.validation_step(val_batch, val_index)
            val_loss.append(loss)
        f1, best, best_f1 = lit_model.validation_epoch_end(val_loss)
        logging(cfg.log_dir,'-' * 89)
        logging(cfg.log_dir,
            '| epoch {:2d} | dev_result: {}'.format(epoch, f1))
        logging(cfg.log_dir,'-' * 89)
        logging(cfg.log_dir,
            '| best_f1: {}'.format(best_f1))
        logging(cfg.log_dir,'-' * 89)
        if cfg.save_path != "" and best != -1:
            save_path = cfg.save_path
            torch.save({
                'epoch': epoch,
                'checkpoint': model.state_dict(),
                'best_f1': best_f1,
                'optimizer': optimizer.state_dict()
            }, save_path
            , _use_new_zipfile_serialization=False)
            logging(cfg.log_dir,
                '| successfully save model at: {}'.format(save_path))
            logging(cfg.log_dir,'-' * 89)

### Model prediction

In [None]:
def test(cfg, model, lit_model, data):
    model.eval()
    with torch.no_grad():
        test_loss = []
        for test_index, test_batch in enumerate(tqdm(data.test_dataloader())):
            loss = lit_model.test_step(test_batch, test_index)
            test_loss.append(loss)
        f1 = lit_model.test_epoch_end(test_loss)
        logging(cfg.log_dir,
            '| test_result: {}'.format(f1))
        logging(cfg.log_dir,'-' * 89)

test(cfg, model, lit_model, data)