In [1]:
import copy
import os
import json
import random

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler


class BaseData:
    def __init__(self, args):
        self.args = args
        self.label_list = self._read_labels()
        self.id2label, self.label2id = [], {}
        self.label2task_id = {}
        self.train_data, self.val_data, self.test_data = None, None, None

    def _read_labels(self):
        """
        :return: only return the label name, in order to set label index from 0 more conveniently.
        """
        id2label = json.load(open(os.path.join(self.args.data_path, self.args.dataset_name, 'id2label.json')))
        return id2label

    def read_and_preprocess(self, **kwargs):
        raise NotImplementedError

    def add_labels(self, cur_labels, task_id):
        for c in cur_labels:
            if c not in self.id2label:
                self.id2label.append(c)
                self.label2id[c] = len(self.label2id)
                self.label2task_id[self.label2id[c]] = task_id

    def filter(self, labels, split='train'):
        if not isinstance(labels, list):
            labels = [labels]
        split = split.lower()
        res = []
        for label in labels:
            if split == 'train':
                if self.args.debug:
                    res += copy.deepcopy(self.train_data[label])[:10]
                else:
                    res += copy.deepcopy(self.train_data[label])
            elif split in ['dev', 'val']:
                if self.args.debug:
                    res += copy.deepcopy(self.val_data[label])[:10]
                else:
                    res += copy.deepcopy(self.val_data[label])
            elif split == 'test':
                if self.args.debug:
                    res += copy.deepcopy(self.test_data[label])[:10]
                else:
                    res += copy.deepcopy(self.test_data[label])
        for idx in range(len(res)):
            res[idx]["labels"] = self.label2id[res[idx]["labels"]]
        return res


class BaseDataset(Dataset):
    def __init__(self, data):
        if isinstance(data, dict):
            res = []
            for key in data.keys():
                res += data[key]
            data = res
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # cur_data = self.data[idx]
        # cur_data["idx"] = idx
        # mask_head = True if random.random() > 0.5 else False
        # input_ids, attention_mask, subject_start_pos, object_start_pos = mask_entity(cur_data["input_ids"], mask_head)
        # augment_data = {
        #     "input_ids": input_ids,
        #     "attention_mask": attention_mask,
        #     "subject_start_pos": subject_start_pos,
        #     "object_start_pos": object_start_pos,
        #     "labels": cur_data["labels"],
        #     "idx": idx
        # }
        # return [cur_data, augment_data]
        return self.data[idx]

In [2]:
import copy
import json
import os
import random

from tqdm import tqdm

class FewRelData(BaseData):
    def __init__(self, args):
        super().__init__(args)
        self.entity_markers = ["[E11]", "[E12]", "[E21]", "[E22]"]

    def remove_entity_markers(self, input_ids):
        ans = []
        entity_pos = {}
        for c in input_ids:
            if c not in [30522, 30523, 30524, 30525]:
                ans.append(c)
            else:
                if c % 2 == 0:
                    entity_pos[c] = len(ans)
                else:
                    entity_pos[c] = len(ans) - 1
        return ans, entity_pos[30522], entity_pos[30523], entity_pos[30524], entity_pos[30525]

    def preprocess(self, raw_data, tokenizer):
        subject_start_marker = tokenizer.convert_tokens_to_ids(self.entity_markers[0])
        object_start_marker = tokenizer.convert_tokens_to_ids(self.entity_markers[2])
        subject_end_marker = tokenizer.convert_tokens_to_ids(self.entity_markers[1])
        object_end_marker = tokenizer.convert_tokens_to_ids(self.entity_markers[3])
        res = []
        result = tokenizer(raw_data['sentence'])
        for idx in range(len(raw_data['sentence'])):
            subject_marker_st = result['input_ids'][idx].index(subject_start_marker)
            object_marker_st = result['input_ids'][idx].index(object_start_marker)
            subject_marker_ed = result['input_ids'][idx].index(subject_end_marker)
            object_marker_ed = result['input_ids'][idx].index(object_end_marker)
            input_ids = result['input_ids'][idx]
            sentence = copy.deepcopy(raw_data['sentence'][idx])
            for c in self.entity_markers:
                sentence = sentence.replace(c, '')
            sentence = sentence.replace('  ', ' ')
            # prompt_input_ids, mask_pos = self.get_prompt_input_ids(input_ids)
            input_ids_without_marker, subject_st, subject_ed, object_st, object_ed = \
                self.remove_entity_markers(input_ids)
            ins = {
                'sentence': sentence,
                'input_ids': input_ids,  # default: add marker to the head entity and tail entity
                'subject_marker_st': subject_marker_st,
                'object_marker_st': object_marker_st,
                'labels': raw_data['labels'][idx],
                'input_ids_without_marker': input_ids_without_marker,
                'subject_st': subject_st,
                'subject_ed': subject_ed,
                'object_st': object_st,
                'object_ed': object_ed,
            }
            if hasattr(self.args, 'columns'):
                columns = self.args.columns
                ins = {k: v for k, v in ins.items() if k in columns}
            res.append(ins)
        return res

    def read_and_preprocess(self, tokenizer, seed=None):
        raw_data = json.load(open(os.path.join(self.args.data_path, self.args.dataset_name, 'data_with_marker.json')))

        train_data = {}
        val_data = {}
        test_data = {}

        if seed is not None:
            random.seed(seed)

        for label in tqdm(raw_data.keys(), desc="Load FewRel data"):
            cur_data = raw_data[label]
            random.shuffle(cur_data)
            train_raw_data = {"sentence": [], "labels": []}
            val_raw_data = {"sentence": [], "labels": []}
            test_raw_data = {"sentence": [], "labels": []}
            for idx, sample in enumerate(cur_data):
                sample["tokens"] = ' '.join(sample["tokens"])
                sample["relation"] = sample["relation"]
                if idx < 420:
                    train_raw_data["sentence"].append(sample["tokens"])
                    train_raw_data["labels"].append(sample["relation"])
                elif idx < 420 + 140:
                    val_raw_data["sentence"].append(sample["tokens"])
                    val_raw_data["labels"].append(sample["relation"])
                else:
                    test_raw_data["sentence"].append(sample["tokens"])
                    test_raw_data["labels"].append(sample["relation"])

            train_data[label] = self.preprocess(train_raw_data, tokenizer)
            val_data[label] = self.preprocess(val_raw_data, tokenizer)
            test_data[label] = self.preprocess(test_raw_data, tokenizer)

        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data



In [11]:
from transformers import AutoTokenizer, set_seed

tokenizer = AutoTokenizer.from_pretrained(
        'bert-base-uncased',
        use_fast=True,
        additional_special_tokens=["[E11]", "[E12]", "[E21]", "[E22]"],
    )

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [12]:
class Args:
    def __init__(self):
        self.data_path = "datasets"  # đường dẫn đến dữ liệu của bạn
        self.dataset_name = "FewRel"  # tên dataset của bạn

# Khởi tạo đối tượng args
args = Args()

data = FewRelData(args=args)
data.read_and_preprocess(tokenizer)

Load FewRel data: 100%|██████████| 80/80 [00:02<00:00, 31.30it/s]


In [None]:
for label, samples in data.train_data.items():
    print(f"Label: {label}")
    for sample in samples:
        print(sample)
        print("\n")

In [9]:
from transformers import BertModel

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
import logging
from peft import get_peft_model, LoraConfig, TaskType, PeftModel

logger = logging.getLogger(__name__)


class PeftFeatureExtractor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.device = config.device
        self.dataset = config.dataset_name

        self.bert = BertModel.from_pretrained(config.model_name_or_path)
        self.bert.resize_token_embeddings(
            self.bert.config.vocab_size + config.additional_special_tokens_len
        )

        self.origin_bert = None
        self.peft_bert = None
        self.peft_type = config.peft_type if hasattr(config, "peft_type") else None
        self.peft_init = config.peft_init if hasattr(config, "peft_init") else None

        self.prompts = nn.ParameterList()
        self.pre_seq_len = config.pre_seq_len if hasattr(config, "pre_seq_len") else None
        self.n_layer = self.bert.config.num_hidden_layers
        self.n_head = self.bert.config.num_attention_heads
        self.n_embd = self.bert.config.hidden_size // self.bert.config.num_attention_heads
        self.hidden_size = self.bert.config.hidden_size
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)

        if config.task_name == "RelationExtraction":
            self.extract_mode = "entity_marker"
        else:
            raise NotImplementedError

        if config.frozen:
            logger.info("freeze the parameters of the pretrained language model.")
            for param in self.bert.parameters():
                param.requires_grad = False
            self.origin_bert = copy.deepcopy(self.bert)
            for param in self.origin_bert.parameters():
                param.requires_grad = False

    def add_adapter(self, task_id):
        # Todo: support more peft types like prefix tuning, prompt tuning and so on.
        peft_config = LoraConfig(
            task_type=TaskType.FEATURE_EXTRACTION, inference_mode=False, r=8, lora_alpha=16, lora_dropout=0.1,
            target_modules=["key", "query", "value"],
        )
        adapter_name = f"task-{task_id}"
        self.peft_bert = get_peft_model(copy.deepcopy(self.bert), peft_config, adapter_name)
        self.peft_bert.print_trainable_parameters()
        logger.info(f"inject {self.peft_type} into the pretrain model, name is {adapter_name}")

    def save_and_load_all_adapters(self, task_id, save_dir, save=True):
        if save:
            self.peft_bert.save_pretrained(save_dir)
        self.peft_bert = PeftModel.from_pretrained(
            copy.deepcopy(self.bert),
            f"{save_dir}/task-0",
            adapter_name="task-0"
        )
        for i in range(1, task_id + 1):
            self.peft_bert.load_adapter(f"{save_dir}/task-{i}", adapter_name=f"task-{i}")

    def load_adapter(self, task_id):
        adapter_name = f"task-{task_id}"
        self.peft_bert.set_adapter(adapter_name)

    def get_prompts_by_indices(self, indices, attention_mask):
        batch_size, _ = attention_mask.size()

        prompt_attention_mask = torch.ones(batch_size, self.pre_seq_len, dtype=torch.long, device=self.device)
        attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)

        prompt = torch.stack([self.prompts[idx] for idx in indices])
        past_key_values = None
        
        return past_key_values, attention_mask, prompt

    def forward(
            self,
            input_ids,
            inputs_embeds=None,
            attention_mask=None,
            extract_mode=None,
            use_origin=False,
            indices=None,
            **kwargs
    ):
        batch_size, _ = input_ids.size()

        if attention_mask is None:
            attention_mask = input_ids != 0

        if use_origin:
            outputs = self.origin_bert(
                input_ids,
                attention_mask=attention_mask,
            )
        elif self.peft_type is not None and indices is not None:
            if self.peft_type == "lora":
                self.load_adapter(indices[0])
                outputs = self.peft_bert(
                    input_ids,
                    attention_mask=attention_mask,
                )
            elif self.peft_type == "prefix":
                past_key_values, attention_mask, _ = self.get_prompts_by_indices(indices, attention_mask)
                outputs = self.bert(
                    input_ids,
                    attention_mask=attention_mask,
                    past_key_values=past_key_values
                )
            elif self.peft_type == "prompt":
                _, attention_mask, prompt = self.get_prompts_by_indices(indices, attention_mask)
                prompt_len = prompt.shape[1]
                inputs_embeds = self.bert.embeddings.word_embeddings(input_ids)
                inputs_embeds = torch.cat([prompt, inputs_embeds], dim=1)  # (batch, prompt_len + sent_len, dim)
                outputs = self.bert(
                    inputs_embeds=inputs_embeds,
                    attention_mask=attention_mask,
                )
                outputs[0] = outputs[0][:, prompt_len:, :]
                attention_mask = attention_mask[:, prompt_len:]
            else:
                raise NotImplementedError
        else:
            # only for tuning
            outputs = self.bert(
                input_ids,
                attention_mask=attention_mask,
                past_key_values=kwargs["past_key_values"] if "past_key_values" in kwargs else None,
            )

        extract_mode = extract_mode if extract_mode is not None else self.extract_mode
        # different feature extraction modes
        if extract_mode == "cls":
            hidden_states = outputs[1]  # (batch, dim)
        elif extract_mode == "mean_pooling":
            # (batch, dim)
            hidden_states = torch.sum(outputs[0] * attention_mask.unsqueeze(-1), dim=1) / \
                            torch.sum(attention_mask, dim=1).unsqueeze(-1)
        elif extract_mode == "mask":
            mask_pos = kwargs["mask_pos"]
            last_hidden_states = outputs[0]
            idx = torch.arange(last_hidden_states.size(0)).to(last_hidden_states.device)
            hidden_states = last_hidden_states[idx, mask_pos]
        elif extract_mode == "entity":
            last_hidden_states = outputs[0]
            subj_st, subj_ed = kwargs["subject_st"], kwargs["subject_ed"]
            obj_st, obj_ed = kwargs["object_st"], kwargs["object_ed"]
            hidden_states = []
            for idx in range(last_hidden_states.size(0)):
                subj = last_hidden_states[idx][subj_st[idx]: subj_ed[idx] + 1]
                obj = last_hidden_states[idx][obj_st[idx]: obj_ed[idx] + 1]
                subj = subj.mean(0)
                obj = obj.mean(0)
                hidden_states.append(torch.cat([subj, obj]))
            hidden_states = torch.stack(hidden_states, dim=0)
        elif extract_mode == "entity_marker":
            subject_start_pos = kwargs["subject_marker_st"]
            object_start_pos = kwargs["object_marker_st"]
            last_hidden_states = outputs[0]
            idx = torch.arange(last_hidden_states.size(0)).to(last_hidden_states.device)
            ss_emb = last_hidden_states[idx, subject_start_pos]
            os_emb = last_hidden_states[idx, object_start_pos]
            hidden_states = torch.cat([ss_emb, os_emb], dim=-1)  # (batch, 2 * dim)
        else:
            raise NotImplementedError

        return hidden_states


In [17]:
from copy import deepcopy
from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_outputs import ModelOutput


class ExpertOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: Optional[torch.FloatTensor] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None


class ExpertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.device = config.device

        self.feature_extractor = PeftFeatureExtractor(config)
        self.hidden_size = self.feature_extractor.bert.config.hidden_size

        self.num_old_labels = 0
        self.num_labels = 0

        self.classifier_hidden_size = self.feature_extractor.bert.config.hidden_size
        if config.task_name == "RelationExtraction":
            self.classifier_hidden_size = 2 * self.feature_extractor.bert.config.hidden_size

        self.classifier = nn.Linear(self.classifier_hidden_size, self.num_labels)

    @torch.no_grad()
    def new_task(self, num_labels):
        self.num_old_labels = self.num_labels
        self.num_labels += num_labels
        w = self.classifier.weight.data.clone()
        b = self.classifier.bias.data.clone()
        self.classifier = nn.Linear(self.classifier_hidden_size, self.num_labels, device=self.device)
        self.classifier.weight.data[:self.num_old_labels] = w
        self.classifier.bias.data[:self.num_old_labels] = b

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        hidden_states = self.feature_extractor(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            loss = F.cross_entropy(logits, labels)

        return ExpertOutput(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states,
        )


In [18]:
from typing import List, Dict, Any, Optional, Union

import torch
from transformers import PreTrainedTokenizerBase
from transformers.file_utils import PaddingStrategy

class CustomCollatorWithPadding:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def pad_to_same_length(self, batch_data):
        if isinstance(batch_data[0], int):
            if self.return_tensors == "pt":
                return torch.LongTensor(batch_data)
            else:
                return batch_data
        max_length = max([len(c) for c in batch_data])
        ans = []
        for ins in batch_data:
            ins = ins + [0] * (max_length - len(ins))
            ans.append(ins)
        if self.return_tensors == "pt":
            return torch.LongTensor(ans)
        else:
            return ans

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch_keys = features[0].keys()
        batch = {k: [] for k in batch_keys}
        for ins in features:
            for k in batch_keys:
                batch[k].append(ins[k])
        for k in batch_keys:
            batch[k] = self.pad_to_same_length(batch[k])
        return batch



In [None]:
from utils import relation_data_augmentation, CustomCollatorWithPadding

class ExpertTrainer:
    def __init__(self, args, **kwargs):
        self.optimizer = None
        self.task_idx = 0
        self.args = args

    def run(self, data, model, tokenizer, label_order, seed=None):
        if seed is not None:
            set_seed(seed)
        default_data_collator = CustomCollatorWithPadding(tokenizer)

        seen_labels = []
        all_cur_acc = [0] * self.args.num_tasks
        all_total_acc = [0] * self.args.num_tasks
        all_total_hit = [0] * self.args.num_tasks
        marker_ids = tuple([tokenizer.convert_tokens_to_ids(c) for c in self.args.additional_special_tokens])
        logger.info(f"marker ids: {marker_ids}")
        for task_idx in range(self.args.num_tasks):
            self.task_idx = task_idx
            cur_labels = [data.label_list[c] for c in label_order[task_idx]]
            data.add_labels(cur_labels, task_idx)
            seen_labels += cur_labels

            logger.info(f"***** Task-{task_idx + 1} *****")
            logger.info(f"Current classes: {' '.join(cur_labels)}")

            train_data = data.filter(cur_labels, "train")
            # data augmentation
            num_train_labels = len(cur_labels)
            train_data, num_train_labels = relation_data_augmentation(
                train_data, len(seen_labels), copy.deepcopy(data.id2label), marker_ids, self.args.augment_type
            )
            train_dataset = BaseDataset(train_data)

            model.new_task(num_train_labels)

            self.train(
                model=model,
                train_dataset=train_dataset,
                data_collator=default_data_collator
            )
            cur_test_data = data.filter(cur_labels, 'test')
            cur_test_dataset = BaseDataset(cur_test_data)
            cur_result = self.eval(
                model=model,
                eval_dataset=cur_test_dataset,
                data_collator=default_data_collator,
                seen_labels=seen_labels,
            )

            os.makedirs(self.args.save_model_dir, exist_ok=True)
            save_model_name = f"{self.args.dataset_name}_{seed}_{self.args.augment_type}.pth"
            save_model_path = os.path.join(self.args.save_model_dir, save_model_name)
            logger.info(f"save expert model to {save_model_path}")
            self.save_model(model, save_model_path)

            all_cur_acc[self.task_idx] = cur_result
            all_total_acc[self.task_idx] = cur_result
            all_total_hit[self.task_idx] = 1
            # only for the first task
            if self.task_idx == 0:
                break

        return {
            "cur_acc": all_cur_acc,
            "total_acc": all_total_acc,
            "total_hit": all_total_hit,
        }

    def train(self, model, train_dataset, data_collator):
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=self.args.train_batch_size,
            shuffle=True,
            collate_fn=data_collator
        )
        len_dataloader = len(train_dataloader)
        num_examples = len(train_dataset)
        max_steps = len_dataloader * self.args.num_train_epochs

        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Num Epochs = {self.args.num_train_epochs}")
        logger.info(f"  Train batch size = {self.args.train_batch_size}")
        logger.info(f"  Total optimization steps = {max_steps}")

        no_decay = ["bias", "LayerNorm.weight"]
        parameters = [
            {'params': [p for n, p in model.named_parameters() if 'feature_extractor' in n and not any(nd in n for nd in no_decay)],
             'lr': self.args.learning_rate, 'weight_decay': 1e-2},
            {'params': [p for n, p in model.named_parameters() if 'feature_extractor' in n and any(nd in n for nd in no_decay)],
             'lr': self.args.learning_rate, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if 'feature_extractor' not in n and not any(nd in n for nd in no_decay)],
             'lr': self.args.classifier_learning_rate, 'weight_decay': 1e-2},
            {'params': [p for n, p in model.named_parameters() if 'feature_extractor' not in n and any(nd in n for nd in no_decay)],
             'lr': self.args.classifier_learning_rate, 'weight_decay': 0.0},
        ]
        self.optimizer = AdamW(parameters)

        progress_bar = tqdm(range(max_steps))

        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(name)

        for epoch in range(self.args.num_train_epochs):
            model.train()
            for step, inputs in enumerate(train_dataloader):
                self.optimizer.zero_grad()

                inputs = {k: v.to(self.args.device) for k, v in inputs.items()}
                outputs = model(**inputs)
                loss = outputs.loss
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                self.optimizer.step()

                progress_bar.update(1)
                progress_bar.set_postfix({"Loss": loss.item()})

        progress_bar.close()

    @torch.no_grad()
    def eval(self, model, eval_dataset, data_collator, seen_labels):
        eval_dataloader = DataLoader(
            eval_dataset,
            batch_size=self.args.eval_batch_size,
            shuffle=False,
            collate_fn=data_collator,
        )

        len_dataloader = len(eval_dataloader)
        num_examples = len(eval_dataset)

        logger.info("***** Running evaluating *****")
        logger.info(f"  Num examples = {num_examples}")
        logger.info(f"  Eval batch size = {self.args.eval_batch_size}")

        progress_bar = tqdm(range(len_dataloader))

        golds = []
        preds = []

        model.eval()
        for step, inputs in enumerate(eval_dataloader):
            labels = inputs.pop('labels')
            inputs = {k: v.to(self.args.device) for k, v in inputs.items()}

            outputs = model(**inputs)

            logits = outputs.logits

            predicts = logits.max(dim=-1)[1]

            predicts = predicts.cpu().tolist()
            labels = labels.cpu().tolist()
            golds.extend(predicts)
            preds.extend(labels)

            progress_bar.update(1)
        progress_bar.close()

        micro_f1 = metrics.f1_score(golds, preds, average='micro')
        logger.info("Micro F1 {}".format(micro_f1))

        return micro_f1

    def save_model(self, model, save_path):
        bert_state_dict = model.feature_extractor.bert.state_dict()
        linear_state_dict = model.classifier.state_dict()
        torch.save({
            "model": bert_state_dict,
            "linear": linear_state_dict,
        }, save_path)



In [None]:
exp_result = trainer.run(
    data=data,
    model=model,
    tokenizer=tokenizer,
    label_order=task_seq,
    seed=exp_seed
)
