In [1]:
# !pip install transformers


In [2]:
# !pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html

In [3]:
# !pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html

In [4]:
# !pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html

In [5]:
# !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html

In [6]:
# !pip install torch_geometric

In [7]:
# from transformers import AutoTokenizer, AutoModelForTokenClassification
# from transformers import pipeline

# tokenizer = AutoTokenizer.from_pretrained("dslim/bert-base-NER")
# model = AutoModelForTokenClassification.from_pretrained("dslim/bert-base-NER")

# nlp = pipeline("ner", model=model, tokenizer=tokenizer)
# example = "My name is Wolfgang and I live in Berlin"

# ner_results = nlp(example)
# print(ner_results)

In [8]:
# !pip install wandb
import os

OUTPUT_DIR = './autodl-tmp/plus/'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [9]:
from __future__ import absolute_import, division, print_function
import wandb
import datetime
import argparse
import csv
import logging
import os
import random
import sys
import pickle
import numpy as np
import torch
import torch.nn as nn
import json
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch_geometric.nn import GCNConv,GATConv,GraphConv,GATv2Conv,RGATConv,RGCNConv
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import OneCycleLR
from tqdm import tqdm, trange
import glob
import gc

from collections import defaultdict
from torch.nn import CrossEntropyLoss, MSELoss
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score

from transformers import AutoTokenizer, AutoModel, AutoConfig,get_cosine_schedule_with_warmup,DebertaV2TokenizerFast
#from modeling import (ElectraForMultipleChoicePlus, Baseline, BertBaseline, RobertaBaseline, BertForMultipleChoicePlus, RobertaForMultipleChoicePlus)
from transformers import (get_linear_schedule_with_warmup, WEIGHTS_NAME, CONFIG_NAME)
import re
import os

logger = logging.getLogger(__name__)

# Random Seed Initialize
RANDOM_SEED = 42

def seed_everything(seed=RANDOM_SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything()

In [10]:
# Device Optimization
if torch.cuda.is_available():
    device = torch.device('cuda')
    print("GPU found (", torch.cuda.get_device_name(torch.cuda.current_device()), ")")
    print("num device avail: ", torch.cuda.device_count())
else:
    device = torch.device('cpu')
    
print(f'Using device: {device}')

GPU found ( NVIDIA A100-SXM-80GB )
num device avail:  1
Using device: cuda


In [11]:
class Config:
    wandb=False
    data_dir = "input_plus"
    max_seq_length = 512
    max_utterance_num = 30
    model_name = "microsoft/deberta-v2-xxlarge"
    epochs = 10
    lr = 2e-6
    output_dir= "output"
    batch_size = 2
    h_dim = 1536
    max_grad_norm = 100
    eps=1e-6
    betas=(0.9, 0.999)
    params = {
    'scheduler_name': 'OneCycleLR',
    'max_lr': 2e-6,                 # OneCycleLR
    'pct_start': 0.1,               # OneCycleLR
    'anneal_strategy': 'cos',       # OneCycleLR
    'div_factor': 1e2,              # OneCycleLR
    'final_div_factor': 1e2,        # OneCycleLR
}
    
# def get_logger(filename=OUTPUT_DIR+'train'):
#     from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
#     logger = getLogger(__name__)
#     logger.setLevel(INFO)
#     handler1 = StreamHandler()
#     handler1.setFormatter(Formatter("%(message)s"))
#     handler2 = FileHandler(filename=f"{filename}.log")
#     handler2.setFormatter(Formatter("%(message)s"))
#     logger.addHandler(handler1)
#     logger.addHandler(handler2)
#     return logger

# LOGGER = get_logger()

In [12]:
if Config.wandb:
    
    import wandb

    try:
        from kaggle_secrets import UserSecretsClient
        user_secrets = UserSecretsClient()
        secret_value_0 = user_secrets.get_secret("wandb_api")
        wandb.login(key='23a2d4b438741f956162624b0da65448022b5571')
        anony = None
    except:
        anony = "must"
        print('If you want to use your W&B account, go to Add-ons -> Secrets and provide your W&B access token. Use the Label name as wandb_api. \nGet your W&B access token from here: https://wandb.ai/authorize')


    def class2dict(f):
        return dict((name, getattr(f, name)) for name in dir(f) if not name.startswith('__'))

    run = wandb.init(project='EMNLP', 
                     name='test',
                     config=class2dict(Config),
                     group=Config.model_name,
                     job_type="调参",
                     anonymous=anony)

In [13]:
class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label

class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, example_id, choices_features, label):
        self.example_id = example_id
        self.choices_features = [
            {
                'input_ids': input_ids,
                'input_mask': input_mask,
                'segment_ids': segment_ids,
                'sep_pos': sep_pos,
                'turn_ids': turn_ids,
                'cls_pos':cls_pos
            }
            for input_ids, input_mask, segment_ids, sep_pos, turn_ids, cls_pos in choices_features
        ]
        self.label = label

class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

class MuTualProcessor(DataProcessor):
    """Processor for the MuTual data set."""

    def get_train_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} train".format(data_dir))
        file = os.path.join(data_dir, 'train')
        file = self._read_txt(file)
        return self._create_examples(file, 'train')

    def get_dev_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} dev".format(data_dir))
        file = os.path.join(data_dir, 'dev')
        file = self._read_txt(file)
        return self._create_examples(file, 'dev')

    def get_test_examples(self, data_dir):
        """See base class."""
        logger.info("LOOKING AT {} test".format(data_dir))
        file = os.path.join(data_dir, 'test')
        file = self._read_txt(file)
        return self._create_examples(file, 'test')

    def get_labels(self):
        """See base class."""
        return ["0", "1", "2", "3"]

    def _read_txt(self, input_dir):
        lines = []
        files = glob.glob(input_dir + "/*txt")
        for file in tqdm(files, desc="read files"):
            with open(file, 'r', encoding='utf-8') as fin:
                data_raw = json.load(fin)
                data_raw["id"] = file
                lines.append(data_raw)
        return lines
    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (_, data_raw) in enumerate(tqdm(lines,desc="create examples")):
            id = "%s-%s" % (set_type, data_raw["id"])
            article = data_raw["article"]

            article = re.split(r"(f : |m : |M: |F: )", article)  #分割完是数组
            #print(article)
            article = ["".join(i) for i in zip(article[1::2], article[2::2])]  #然后再拼接好，最终的结果是每个说话者的话是列表中的一个元素

            truth = str(ord(data_raw['answers']) - ord('A'))
            options = data_raw['options']

            examples.append(
                InputExample(
                    guid=id,
                    text_a = [options[0], options[1], options[2], options[3]],
                    text_b=article, # this is not efficient but convenient
                    label=truth))
        return examples

def convert_examples_to_features(examples, label_list, max_seq_length, max_utterance_num,
                                 tokenizer, output_mode=None):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label : i for i, label in enumerate(label_list)}

    features = []
    
    for (ex_index, example) in enumerate(tqdm(examples,desc="create features")):
        #if ex_index % 10000 == 0:
        #    logger.info("Writing example %d of %d" % (ex_index, len(examples)))

        choices_features = []
        all_tokens = []
        text_a = example.text_a
        text_b = example.text_b

        #for ending_idx, (text_a, text_b) in enumerate(zip(example.text_a, example.text_b)): #zip() 函数用于将可迭代的对象作为参数，将对象中对应的元素打包成一个个元组，然后返回由这些元组组成的列表。
        text_a[0] = re.sub("f : |m : |M: |F: ","",text_a[0])
        text_a[1] = re.sub("f : |m : |M: |F: ","",text_a[1])
        text_a[2] = re.sub("f : |m : |M: |F: ","",text_a[2])
        text_a[3] = re.sub("f : |m : |M: |F: ","",text_a[3])
        

        tokens_a = tokenizer.tokenize(text_a[0])
        tokens_a = ["[CLS]"] + tokens_a + ["[SEP]"]
        tokens_b = tokenizer.tokenize(text_a[1])
        tokens_b = ["[CLS]"] + tokens_b + ["[SEP]"]
        tokens_c = tokenizer.tokenize(text_a[2])
        tokens_c = ["[CLS]"] + tokens_c + ["[SEP]"]
        tokens_d = tokenizer.tokenize(text_a[3])
        tokens_d = ["[CLS]"] + tokens_d
        tokens_options = tokens_a + tokens_b + tokens_c + tokens_d

        tokens_article = []

        for idx, text in enumerate(text_b):  
            if len(text.strip()) > 0:  #strip() 方法用于移除字符串头尾指定的字符（默认为空格或换行符）或字符序列。
                text = re.sub("f : |m : |M: |F: ","",text)
                tokens_article.extend(["[CLS]"]+tokenizer.tokenize(text) + ["[SEP]"])  #extend() 函数用于在列表末尾一次性追加另一个序列中的多个值（用新列表扩展原来的列表）。
                                                #而append是加入一个对象，如果这里用append的话那么可能会变成将列表加入到原来的列表当中
        tokens_article.pop(0)
        _truncate_seq_pair(tokens_options, tokens_article, max_seq_length-2)

        tokens = ["[CLS]"]
        turn_ids = [0]

        context_len = []
        sep_pos = []
        cls_pos = [0]

            
        tokens_article_raw = " ".join(tokens_article)
        tokens_article = []
        current_pos = 0
        for toks in tokens_article_raw.split("[SEP]")[-max_utterance_num - 1:-1]:
            context_len.append(len(toks.split()) + 1)
            tokens_article.extend(toks.split())
            tokens_article.extend(["[SEP]"])
            current_pos += context_len[-1]
            turn_ids += [len(sep_pos)] * context_len[-1]
            sep_pos.append(current_pos)
            cls_pos.append(current_pos+1)
        cls_pos.pop()
                
        tokens += tokens_article

        segment_ids = [0] * (len(tokens))

        tokens_options += ["[SEP]"]
        

        for index,toks in enumerate(tokens_options):
          if toks == "[CLS]":
            cls_pos.append(len(tokens)+index)

        tokens += tokens_options

        segment_ids += [1] * (len(tokens_options))
            
        turn_ids += [len(sep_pos)] * len(tokens_options) 
        sep_pos.append(len(tokens) - 1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        input_mask = [1] * len(input_ids)

        padding = [0] * (max_seq_length - len(input_ids))
        input_ids += padding
        input_mask += padding
        segment_ids += padding
        turn_ids += padding

        context_len += [-1] * (max_utterance_num - len(context_len))
        sep_pos += [0] * (max_utterance_num + 1 - len(sep_pos))
        num_nodes = len(cls_pos)
        cls_pos += [0] * (max_utterance_num + 1 - len(cls_pos))
        cls_pos[-1] = num_nodes

        assert len(sep_pos) == max_utterance_num + 1
        assert len(cls_pos) == max_utterance_num + 1
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(context_len) == max_utterance_num 
        assert len(turn_ids) == max_seq_length 

        choices_features.append((input_ids, input_mask, segment_ids, sep_pos, turn_ids, cls_pos))   #turn_ids代表每个说话者句子的序号，
            #最前面的是历史轮对话，最后面的是响应答案和填充；sep_pos代表语句的结束位置；segment_ids代表历史论对话和响应答案，0，1，
            #input_mask代表有效的和填充的，1，0
        all_tokens.append(tokens)


        label_id = label_map[example.label] #数字型
        

        if ex_index < 10:  #打印一下输出
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            for choice_idx, (input_ids, input_mask, segment_ids, sep_pos, turn_ids, cls_pos) in enumerate(choices_features):
                logger.info("choice: {}".format(choice_idx))
                logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
                logger.info("tokens: %s" % " ".join([str(x) for x in all_tokens[choice_idx]]))
                logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
                logger.info("sep_pos: %s" % " ".join([str(x) for x in sep_pos]))
                logger.info("turn_ids: %s" % " ".join([str(x) for x in turn_ids]))
                logger.info("label: %s (id = %d)" % (example.label, label_id))

        features.append(
            InputFeatures(
                example_id = example.guid, 
                choices_features = choices_features,
                label = label_id
                )
        )

    return features

def _truncate_seq_pair(tokens_a, tokens_b, max_length):
    """Truncates a sequence pair in place to the maximum length."""

    # This is a simple heuristic which will always truncate the longer sequence
    # one token at a time. This makes more sense than truncating an equal percent
    # of tokens from each, since if one sequence is very short then each token
    # that's truncated likely contains more information than a longer sequence.
    while True:
        total_length = len(tokens_a) + len(tokens_b)
        if total_length <= max_length:
            break
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop(0)


In [14]:
# tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
tokenizer = DebertaV2TokenizerFast.from_pretrained(Config.model_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [15]:
processor = MuTualProcessor()
label_list = processor.get_labels()
num_labels = len(label_list)
train_examples = processor.get_train_examples(Config.data_dir)


read files: 100%|██████████| 7088/7088 [00:00<00:00, 34287.87it/s]
create examples: 100%|██████████| 7088/7088 [00:00<00:00, 96672.24it/s]


In [16]:
val_examples = processor.get_dev_examples(Config.data_dir)


read files: 100%|██████████| 886/886 [00:00<00:00, 32082.27it/s]
create examples: 100%|██████████| 886/886 [00:00<00:00, 100084.93it/s]


In [17]:
train_features = convert_examples_to_features(
                train_examples, label_list, Config.max_seq_length, Config.max_utterance_num, tokenizer)

create features: 100%|██████████| 7088/7088 [00:07<00:00, 967.43it/s] 


In [18]:
val_features = convert_examples_to_features(
                val_examples, label_list, Config.max_seq_length, Config.max_utterance_num, tokenizer)

create features: 100%|██████████| 886/886 [00:00<00:00, 1001.26it/s]


In [19]:
class MuTualDataset(Dataset):
  def __init__(self, features):
    self.features = features
    self.length = len(features)

  def __len__(self): 
      return self.length
  
  def __getitem__(self, idx):
    input_ids = self.features[idx].choices_features[0]['input_ids']
    input_mask = self.features[idx].choices_features[0]['input_mask']
    segment_ids = self.features[idx].choices_features[0]['segment_ids']
    cls_pos = self.features[idx].choices_features[0]['cls_pos']
    label = self.features[idx].label
    

    label_id = np.zeros((4,1),dtype=int)

    label_id[label,0] = 1
    #length = len(cls_pos)
    
    #edge_index = _get_edge_index(length, edge_index)
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    input_mask = torch.tensor(input_mask, dtype=torch.long)
    segment_ids =torch.tensor(segment_ids, dtype=torch.long)
    cls_pos = torch.tensor(cls_pos, dtype=torch.long)
    label_id = torch.tensor(label_id, dtype=torch.float)

    return {"input_ids":input_ids,"input_mask": input_mask,"segment_ids": segment_ids,"cls_pos":cls_pos,"label_id":label_id}
        

In [20]:
def get_edge_index(length):
    all_edges = set()
    all_options_edges = set()
    num_edges = (length-4)*(length-5) + (length-4)*4
    #edge_index=numpy.zeros((2,num_labels,dtype=int)

    for i in range(length-4,length):
      for j in range(length-4,length):
          all_options_edges.add((i,j))
  
    #cnt=0
    for i in range(length-4):
      for j in range(length-4):
        if i != j:
          all_edges.add((i,j))
    for i in range(length-4):
      for j in range(length-4,length):
          all_edges.add((i,j))
    
    assert num_edges == len(all_edges)

    return list(all_edges),len(all_edges),list(all_options_edges)


def batch_graphify(batch_output,batch_cls_pos):
  #edge_length_sum = 0
  batch_size = len(batch_output)
  #print("batch_output:",batch_output.shape)
  #print("batch_cls_pos:",batch_cls_pos.shape)

  node_length_sum = 0
  edge_index_batch = []
  options_edge_index_batch = []
  nodes_feature_batch_first = batch_output[:,0,:] #4,768
  #print("nodes_feature_batch_first:",nodes_feature_batch_first)
  nodes_feature_list = []
  options_cls_batch = []
  for i in range(batch_size):
    
      
    edges,edges_length,options_edges = get_edge_index(batch_cls_pos[i,-1])
    #print("edges:",edges)
    edges_s = [(item[0]+node_length_sum, item[1]+node_length_sum) for item in edges]
    #print("edges_s:",edges_s)
    for item in edges_s:
      edge_index_batch.append(torch.tensor([item[0], item[1]]))
    options_edges_s = [(item[0]+node_length_sum, item[1]+node_length_sum) for item in options_edges]
    #print("edges_s:",edges_s)
    for item in options_edges_s:
      options_edge_index_batch.append(torch.tensor([item[0], item[1]]))


    #edge_length_sum+=edges_length
    node_length_sum+=batch_cls_pos[i,-1]
    #print("node_length_sum:",node_length_sum)
    
    nodes_feature = nodes_feature_batch_first[i].unsqueeze(0) #1,768
    #print("nodes_feature:",nodes_feature.shape)
    for j in range(len(batch_cls_pos[i])-1):
      if batch_cls_pos[i,j] != 0:
        nodes_feature = torch.cat((nodes_feature,batch_output[i,batch_cls_pos[i,j]].unsqueeze(0)),0)
        #print(nodes_feature)
    nodes_feature_list.append(nodes_feature)
    for k in range(-4,0):
      options_cls_batch.append(node_length_sum+k)
    

  
  nodes_feature_batch = nodes_feature_list[0]
  for i in range(len(nodes_feature_list)):
    if i !=0:
      nodes_feature_batch = torch.cat((nodes_feature_batch,nodes_feature_list[i]),0)

  nodes_feature_batch = nodes_feature_batch.to(device)
  edge_index_batch = torch.stack(edge_index_batch).transpose(0, 1).to(device)
  options_edge_index_batch = torch.stack(options_edge_index_batch).transpose(0, 1).to(device)

  options_cls_batch = torch.tensor(options_cls_batch).to(device)
  #print("nodes_feature_batch:",nodes_feature_batch.shape)
  #print("edge_index_batch:",edge_index_batch)
  #print("options_cls_batch:",options_cls_batch)



  return nodes_feature_batch,edge_index_batch,options_cls_batch,options_edge_index_batch




  
  









In [21]:
class MetricMonitor:
    def __init__(self, float_precision=4):
        self.float_precision = float_precision
        self.reset()

    def reset(self):
        self.metrics = defaultdict(lambda: {"val": 0, "count": 0, "avg": 0})#defaultdict的作用是在于，当字典里的key不存在但被查找时，返回的不是keyError而是一个默认值

    def update(self, metric_name, val):
        metric = self.metrics[metric_name]

        metric["val"] += val
        metric["count"] += 1
        metric["avg"] = metric["val"] / metric["count"]

    def __str__(self):
        return " | ".join(
            [
                "{metric_name}: {avg:.{float_precision}f}".format(
                    metric_name=metric_name, avg=metric["avg"],
                    float_precision=self.float_precision
                )
                for (metric_name, metric) in self.metrics.items()
            ]
        )

In [22]:
class Co_attention_head(nn.Module):
    def __init__(self):
        super(Co_attention_head, self).__init__()
        self.q1 = nn.Linear(Config.h_dim,384)
        self.q2 = nn.Linear(Config.h_dim,384)
        self.q3 = nn.Linear(Config.h_dim,384)
        self.q4 = nn.Linear(Config.h_dim,384)
    
        self.k1 = nn.Linear(Config.h_dim,384)
        self.k2 = nn.Linear(Config.h_dim,384)
        self.k3 = nn.Linear(Config.h_dim,384)
        self.k4 = nn.Linear(Config.h_dim,384)
        
        self.v1 = nn.Linear(Config.h_dim,384)
        self.v2 = nn.Linear(Config.h_dim,384)
        self.v3 = nn.Linear(Config.h_dim,384)
        self.v4 = nn.Linear(Config.h_dim,384)
        
        self.layer_norm = nn.LayerNorm(Config.h_dim)
    def forward(self, node1,node2):
        q1 = self.q1(node1)
        q2 = self.q2(node1)
        q3 = self.q3(node1)
        q4 = self.q4(node1)
        
        k1 = self.k1(node2)
        k2 = self.k2(node2)
        k3 = self.k3(node2)
        k4 = self.k4(node2)
        
        v1 = self.v1(node2)
        v2 = self.v2(node2)
        v3 = self.v3(node2)
        v4 = self.v4(node2)
        
        r1 = (torch.matmul(q1,k1.permute(1,0))/((Config.h_dim/4)**0.5))*v1
        r2 = (torch.matmul(q2,k2.permute(1,0))/((Config.h_dim/4)**0.5))*v2
        r3 = (torch.matmul(q3,k3.permute(1,0))/((Config.h_dim/4)**0.5))*v3
        r4 = (torch.matmul(q4,k4.permute(1,0))/((Config.h_dim/4)**0.5))*v4
        
        return self.layer_norm(torch.cat((r1,r2,r3,r4),-1)) 

class MuTualModel(nn.Module):
  def __init__(self,model_name=Config.model_name):
    super(MuTualModel, self).__init__()
    self.model_name = model_name
    self.config = AutoConfig.from_pretrained(model_name)
    self.model = AutoModel.from_pretrained(model_name,config=self.config)
    #self.conv1 = GATConv(768, 768,dropout=0.2)
    #self.conv2 = GCNConv(768, 768)
    # self.conv2 = GATConv(768, 768,dropout=0.2)
    #self.conv2 = GCNConv(768, 768)
    self.conv3 = GraphConv(768,64)
    self.conv1 = RGCNConv(Config.h_dim,768,8)
    self.encoder_layer1 = nn.TransformerEncoderLayer(d_model=Config.h_dim, nhead=2,batch_first=True)
    self.encoder_layer2 = nn.TransformerEncoderLayer(d_model=64, nhead=2,batch_first=True)
    self.linear1 = nn.Linear(64, 64)
    self.linear2 = nn.Linear(64, 1)
    self.tanh = nn.Tanh()
    self.softmax = nn.Softmax(dim=1)

    self.coattention =Co_attention_head()
    #self.linear3 = nn.Linear(768*2,768)
    self.linear4 = nn.Linear(Config.h_dim,Config.h_dim)
    self.linear5 = nn.Linear(Config.h_dim,64)
    self.linear6 = nn.Linear(64,8)

    self.relu = nn.ReLU()
    self.softmax1 = nn.Softmax(dim=0)


  
  def forward(self, input_ids, input_mask,segment_ids, cls_pos):
    outputs = self.model(input_ids,attention_mask=input_mask,token_type_ids=segment_ids)
    sequence_output = outputs[0]  #拿出最后一层
    #print("sequence_output:",sequence_output.shape)

    
    nodes_feature_batch,edge_index_batch,options_cls_batch,options_edge_index_batch = batch_graphify(sequence_output,cls_pos)
    batch_size = len(sequence_output)
    options_batch_raw = torch.zeros((batch_size*4,Config.h_dim),dtype = torch.float32).to(device)
    for index,i in enumerate(options_cls_batch):
      options_batch_raw[index] = nodes_feature_batch[i]
    options_batch_raw = options_batch_raw.view(batch_size,4,Config.h_dim)
    #print(nodes_feature_batch)
    #print(options_batch_raw)
    options_batch_mutual = self.encoder_layer1(options_batch_raw)
   
    options_batch_mutual = options_batch_mutual.view(batch_size*4,Config.h_dim)

    for index,i in enumerate(options_cls_batch):
      nodes_feature_batch[i] = options_batch_mutual[index]



    #print(nodes_feature_batch)
    #print(nodes_feature_batch.shape)
    #print(edge_index_batch.shape)
    #print(options_cls_batch)
    
    
    
    
    #print(sequence_output.shape)
    #print(cls_pos.shape)

    #nodes_feature = sequence_output[:,0]
    # #print(nodes_feature.shape)

    # for i in range(len(cls_pos[0,:])-1):
    #   if cls_pos[0,i] != 0:
    #     nodes_feature = torch.cat((nodes_feature,sequence_output[:,cls_pos[0,i]]),0)
    
    #print(nodes_feature.shape)
    #print(nodes_feature)
    #print(cls_pos)
    
    # length = cls_pos[:,-1].item()
    # num_edges = (length-4)*(length-5) + (length-4)*4
    
    # edge_index=torch.zeros((2,num_edges),dtype=torch.long).to(device)
    # cnt=0
    # for i in range(length-4):
    #   for j in range(length-4):
    #     if i != j:
    #       edge_index[0,cnt] = i
    #       edge_index[1,cnt] = j
    #       cnt= cnt+1
    # for i in range(length-4):
    #   for j in range(length-4,length):
    #       edge_index[0,cnt] = i
    #       edge_index[1,cnt] = j
    #       cnt= cnt+1

    #nodes_feature = nodes_feature.unsqueeze(0)
    #print(nodes_feature.shape)
    #print(nodes_feature_batch[-4:,:])
    #output = self.conv1(nodes_feature_batch,options_edge_index_batch)
    #print(output[-4:,:])
    #print(output) 
    #print(output.shape)
    edge_type = torch.zeros(edge_index_batch.size(1),dtype=torch.long).to(device)

    for i in range(edge_index_batch.size(1)):
      start = edge_index_batch[0,i]
      end = edge_index_batch[1,i]
      temp = self.relu(self.coattention(nodes_feature_batch[start].unsqueeze(0),nodes_feature_batch[end].unsqueeze(0)))
      #temp = torch.cat((word_nodes_feature_batch_a[start],word_nodes_feature_batch_a[end]),-1)
      #temp = self.tanh(self.linear3(temp))
      temp = self.relu(self.linear4(temp))
      temp = self.relu(self.linear5(temp))
      result = self.softmax1(self.linear6(temp))
      edge_type[i]=torch.argmax(result,dim=1)
    
    # edge_index_batch = edge_index_batch[:,edge_type != 8]
    # edge_type = edge_type[edge_type != 8]

    output = self.conv1(nodes_feature_batch,edge_index_batch,edge_type)
    output = self.conv3(output,edge_index_batch)
    #print(output[-4:].shape)
    #print(output)
    #print(output.shape)

    options_batch = torch.zeros((batch_size*4,64),dtype = torch.float32).to(device)
    cnt = 0
    for i in options_cls_batch:
      options_batch[cnt] = output[i]
      cnt+=1
    #print(options_batch)
    options_batch = options_batch.view(batch_size,4,64)
    #print(output)
    #print(options_batch)
    options_batch = self.encoder_layer2(options_batch)
    options_batch = self.tanh(self.linear1(options_batch))
    #print(options_batch.shape)
    options_batch = self.linear2(options_batch)
    #print(options_batch)




    # logits = self.softmax(options_batch)
    #print(logits.shape)
    #print(logits)

    return options_batch




In [23]:
def criterion2(preds, labels):
  p1=0
  p2 = 0
  mrr = 0
  for i in range(len(preds)):
    j = sorted(list(preds[i]), reverse = True)
    for index,label in enumerate(labels[i]):
      if label == 1:
        if preds[i,index] == j[0]:
          p1+=1
          p2+=1
          mrr += 1 
          break
        elif preds[i,index] == j[1]:
          p2+=1
          mrr += 1/2
          break
        elif preds[i,index] == j[2]:
          mrr += 1/3
        elif preds[i,index] == j[3]:
          mrr += 1/4
    
  return p1 / len(preds), p2 / len(preds), mrr / len(preds)


# def criterion1(preds, targets):
#   return nn.BCELoss()(preds,targets)
def criterion1(preds, targets):
  return nn.CrossEntropyLoss(label_smoothing=0.05)(preds,targets)


In [24]:
def train_fn(train_loader, model, criterion1, optimizer ,scheduler=None):
    metric_monitor = MetricMonitor()
    model.train()
    stream = tqdm(train_loader)
    scaler = torch.cuda.amp.GradScaler(enabled=True)
    
    for step,batch in enumerate(stream):
        
        input_ids = batch["input_ids"].to(device)
        input_mask = batch["input_mask"].to(device)
        segment_ids = batch["segment_ids"].to(device)
        cls_pos = batch["cls_pos"].to(device)
        label = batch["label_id"].to(device)
        #edge_index = batch["edge_index"].to(device)
        with torch.cuda.amp.autocast(enabled=True):
            preds = model(input_ids,input_mask,segment_ids,cls_pos)

        #print(preds)
        #print(targets)
        #print(preds.shape)
        #print(label)
        loss = criterion1(preds, label)
        #print(loss.item())
        metric_monitor.update('Loss', loss.item())

        # loss.backward()
        scaler.scale(loss).backward()
        # grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), Config.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        
        if scheduler is not None:
            scheduler.step()
        
        optimizer.zero_grad()
        stream.set_description(f"Epoch: {epoch:02}. Train. {metric_monitor}")
        # wandb.log({f"loss": loss.item()})
        

In [25]:
def validate_fn(val_loader, model, criterion2):
    metric_monitor1 = MetricMonitor()
    metric_monitor2 = MetricMonitor()
    metric_monitor3 = MetricMonitor()
    #results =[]
    model.eval()
    stream = tqdm(val_loader)
    all_r4_1 = []
    all_r4_2=[]
    all_mrr=[]
    with torch.no_grad():
        for i, batch in enumerate(stream ):
          input_ids = batch["input_ids"].to(device)
          input_mask = batch["input_mask"].to(device)
          segment_ids = batch["segment_ids"].to(device)
          cls_pos = batch["cls_pos"].to(device)
          label = batch["label_id"].to(device)
          

          preds = model(input_ids,input_mask,segment_ids,cls_pos)
          #for i in range(len(preds)):
          #  results.append(np.argmax(preds.cpu().numpy(),axis=1))
          #print("label",label)
          #print("preds:",preds)




          
          r4_1,r4_2,mrr = criterion2(preds,label)

          #print("r4_1:",r4_1)
          #print("r4_2:",r4_2)
          #print("mrr:",mrr)
          
          all_r4_1.append(r4_1)
          all_r4_2.append(r4_2)
          all_mrr.append(mrr)
          metric_monitor1.update('R4_1', r4_1)
          metric_monitor2.update('R4_2', r4_2)
          metric_monitor3.update('MRR', mrr)
          stream.set_description(f"Epoch: {epoch:02}. Valid. {metric_monitor1} {metric_monitor2} {metric_monitor3}")
          # wandb.log({"R4_1": r4_1,
          #           "R4_2": r4_2,
          #           "MRR": mrr})
          
    #print(results)
    # LOGGER.info(f'R4_1:{(np.mean(all_r4_1))},R4_2:{(np.mean(all_r4_2))},MRR:{(np.mean(all_mrr))}')
    return np.mean(all_r4_1),np.mean(all_r4_2),np.mean(all_mrr)

In [26]:
def get_scheduler(optimizer, scheduler_params=Config.params):
    if scheduler_params['scheduler_name'] == 'CosineAnnealingWarmRestarts':
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=scheduler_params['T_0'],
            eta_min=scheduler_params['min_lr'],
            last_epoch=-1
        )
    elif scheduler_params['scheduler_name'] == 'OneCycleLR':
        scheduler = OneCycleLR(
            optimizer,
            max_lr=scheduler_params['max_lr'],
            steps_per_epoch=int(7088 / Config.batch_size) + 1,
            epochs=Config.epochs,
            pct_start=scheduler_params['pct_start'],
            anneal_strategy=scheduler_params['anneal_strategy'],
            div_factor=scheduler_params['div_factor'],
            final_div_factor=scheduler_params['final_div_factor'],
        )
    return scheduler

In [27]:
gc.collect()
train_dataset = MuTualDataset(train_features)
train_dataloader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, pin_memory=True)

val_dataset = MuTualDataset(val_features)
val_dataloader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=True, pin_memory=True)

model = MuTualModel()

#model.load_state_dict(torch.load('autodl-tmp/plus/8352model.bin', map_location=torch.device(device)))
model = model.to(device)
# param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
# optimizer_grouped_parameters = [
#     {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#     {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# ]
def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
    named_parameters = list(model.named_parameters())    
    parameters = []

    # increase lr every second layer
    increase_lr_every_k_layer = 1
    lrs = np.linspace(1, 2, 48 // increase_lr_every_k_layer)
    num = 0
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    for layer_num, (name, params) in enumerate(named_parameters):
        weight_decay = 0.0 if any(nd in name for nd in no_decay) else 0.01
        splitted_name = name.split('.')
        lr = encoder_lr
        if len(splitted_name) >= 4 and str.isdigit(splitted_name[3]):
            layer_num = int(splitted_name[3])
            lr = lrs[layer_num // increase_lr_every_k_layer] * encoder_lr
            # num+=1
            print(name,lr)
        if 'model' not in splitted_name:
            lr = lrs[-1]*encoder_lr
            print(name,lr)
#         if splitted_name[0] in ['fc']:
#             lr = 10*encoder_lr
#             print(name,lr)

#         if splitted_name[0] in ['head']:
#             lr = 10*encoder_lr
#             print(name,lr)
        # print(num)
        parameters.append({"params": params,
                           "weight_decay": weight_decay,
                           "lr": lr})
    return parameters
# def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
#     param_optimizer = list(model.named_parameters())
#     no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
#     optimizer_parameters = [
#         {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
#          'lr': encoder_lr, 'weight_decay': weight_decay, 'initial_lr':encoder_lr},
#         {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
#          'lr': encoder_lr, 'weight_decay': 0.0, 'initial_lr':encoder_lr},
#         {'params': [p for n, p in model.named_parameters() if "model" not in n],
#          'lr': decoder_lr, 'weight_decay': 0.0, 'initial_lr':decoder_lr}
#     ]
#     return optimizer_parameters
optimizer_parameters = get_optimizer_params(model,
                                                encoder_lr=Config.lr, 
                                                decoder_lr=Config.lr,
                                                weight_decay=0.01)
optimizer = AdamW(optimizer_parameters)
# optimizer = optim.AdamW(optimizer_grouped_parameters, lr=Config.lr)
scheduler = get_scheduler(optimizer)
best=0.0
for epoch in range(1, Config.epochs + 1):#Config.epochs
  print(f'******************** Training Epoch: {epoch} ********************')
  train_fn(train_dataloader, model, criterion1, optimizer,scheduler)
  r4_1,r4_2,mrr = validate_fn(val_dataloader, model, criterion2)
  if r4_1 >= best:
        best = r4_1
        print(f'{r4_1} model saved')
        torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"model.bin"))

Some weights of the model checkpoint at microsoft/deberta-v2-xxlarge were not used when initializing DebertaV2Model: ['lm_predictions.lm_head.LayerNorm.weight', 'lm_predictions.lm_head.LayerNorm.bias', 'lm_predictions.lm_head.bias', 'lm_predictions.lm_head.dense.bias', 'lm_predictions.lm_head.dense.weight']
- This IS expected if you are initializing DebertaV2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DebertaV2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


model.encoder.layer.0.attention.self.query_proj.weight 2e-06
model.encoder.layer.0.attention.self.query_proj.bias 2e-06
model.encoder.layer.0.attention.self.key_proj.weight 2e-06
model.encoder.layer.0.attention.self.key_proj.bias 2e-06
model.encoder.layer.0.attention.self.value_proj.weight 2e-06
model.encoder.layer.0.attention.self.value_proj.bias 2e-06
model.encoder.layer.0.attention.output.dense.weight 2e-06
model.encoder.layer.0.attention.output.dense.bias 2e-06
model.encoder.layer.0.attention.output.LayerNorm.weight 2e-06
model.encoder.layer.0.attention.output.LayerNorm.bias 2e-06
model.encoder.layer.0.intermediate.dense.weight 2e-06
model.encoder.layer.0.intermediate.dense.bias 2e-06
model.encoder.layer.0.output.dense.weight 2e-06
model.encoder.layer.0.output.dense.bias 2e-06
model.encoder.layer.0.output.LayerNorm.weight 2e-06
model.encoder.layer.0.output.LayerNorm.bias 2e-06
model.encoder.layer.1.attention.self.query_proj.weight 2.0425531914893613e-06
model.encoder.layer.1.attent

Epoch: 01. Train. Loss: 1.3363: 100%|██████████| 3544/3544 [37:24<00:00,  1.58it/s]
Epoch: 01. Valid. R4_1: 0.5305 R4_2: 0.7833 MRR: 0.7212: 100%|██████████| 443/443 [01:44<00:00,  4.26it/s]


0.5304740406320542 model saved
******************** Training Epoch: 2 ********************


Epoch: 02. Train. Loss: 0.7064: 100%|██████████| 3544/3544 [37:16<00:00,  1.58it/s]
Epoch: 02. Valid. R4_1: 0.8160 R4_2: 0.9470 MRR: 0.8979: 100%|██████████| 443/443 [01:44<00:00,  4.26it/s]


0.8160270880361173 model saved
******************** Training Epoch: 3 ********************


Epoch: 03. Train. Loss: 0.4768: 100%|██████████| 3544/3544 [37:16<00:00,  1.58it/s]
Epoch: 03. Valid. R4_1: 0.8307 R4_2: 0.9537 MRR: 0.9065: 100%|██████████| 443/443 [01:44<00:00,  4.24it/s]


0.8306997742663657 model saved
******************** Training Epoch: 4 ********************


Epoch: 04. Train. Loss: 0.3685: 100%|██████████| 3544/3544 [37:19<00:00,  1.58it/s]
Epoch: 04. Valid. R4_1: 0.8330 R4_2: 0.9526 MRR: 0.9078: 100%|██████████| 443/443 [01:44<00:00,  4.24it/s]


0.8329571106094809 model saved
******************** Training Epoch: 5 ********************


Epoch: 05. Train. Loss: 0.3004: 100%|██████████| 3544/3544 [37:23<00:00,  1.58it/s]
Epoch: 06. Train. Loss: 0.2740: 100%|██████████| 3544/3544 [37:08<00:00,  1.59it/s]01:21<00:21,  4.10it/s]
Epoch: 06. Valid. R4_1: 0.8488 R4_2: 0.9526 MRR: 0.9154: 100%|██████████| 443/443 [01:43<00:00,  4.30it/s]


0.8487584650112867 model saved
******************** Training Epoch: 7 ********************


Epoch: 07. Train. Loss: 0.2496: 100%|██████████| 3544/3544 [37:15<00:00,  1.59it/s]
Epoch: 07. Valid. R4_1: 0.8499 R4_2: 0.9515 MRR: 0.9160: 100%|██████████| 443/443 [01:44<00:00,  4.24it/s]


0.8498871331828443 model saved
******************** Training Epoch: 8 ********************


Epoch: 08. Train. Loss: 0.2358: 100%|██████████| 3544/3544 [37:22<00:00,  1.58it/s]
Epoch: 08. Valid. R4_1: 0.8634 R4_2: 0.9582 MRR: 0.9240: 100%|██████████| 443/443 [01:44<00:00,  4.24it/s]


0.863431151241535 model saved
******************** Training Epoch: 9 ********************


Epoch: 09. Train. Loss: 0.2312: 100%|██████████| 3544/3544 [37:21<00:00,  1.58it/s]
Epoch: 09. Valid. R4_1: 0.8612 R4_2: 0.9582 MRR: 0.9232: 100%|██████████| 443/443 [01:44<00:00,  4.26it/s]


******************** Training Epoch: 10 ********************


Epoch: 10. Train. Loss: 0.2301: 100%|██████████| 3544/3544 [37:18<00:00,  1.58it/s]
Epoch: 10. Valid. R4_1: 0.8646 R4_2: 0.9571 MRR: 0.9247: 100%|██████████| 443/443 [01:43<00:00,  4.29it/s]


0.8645598194130926 model saved


In [28]:
# gc.collect()
# train_dataset = MuTualDataset(train_features)
# train_dataloader = DataLoader(train_dataset, batch_size=Config.batch_size, shuffle=True, pin_memory=True)

# val_dataset = MuTualDataset(val_features)
# val_dataloader = DataLoader(val_dataset, batch_size=Config.batch_size, shuffle=True, pin_memory=True)

# model = MuTualModel()
# model = model.to(device)
# #model.load_state_dict(torch.load('autodl-tmp/train/9020model.bin', map_location=torch.device(device)))
# # param_optimizer = list(model.named_parameters())
# # no_decay = ['bias', 'LayerNorm.weight', 'LayerNorm.bias']
# # optimizer_grouped_parameters = [
# #     {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
# #     {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
# # ]
# # def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
# #     named_parameters = list(model.named_parameters())    
# #     parameters = []

# #     # increase lr every second layer
# #     increase_lr_every_k_layer = 1
# #     lrs = np.linspace(1, 5, 48 // increase_lr_every_k_layer)
# #     num = 0
# #     no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
# #     for layer_num, (name, params) in enumerate(named_parameters):
# #         weight_decay = 0.0 if any(nd in name for nd in no_decay) else 0.01
# #         splitted_name = name.split('.')
# #         lr = encoder_lr
# #         if len(splitted_name) >= 4 and str.isdigit(splitted_name[3]):
# #             layer_num = int(splitted_name[3])
# #             lr = lrs[layer_num // increase_lr_every_k_layer] * encoder_lr
# #             # num+=1
# #             print(name,lr)
# #         if 'model' not in splitted_name:
# #             lr = lrs[-1]*encoder_lr
# #             print(name,lr)
# # #         if splitted_name[0] in ['fc']:
# # #             lr = 10*encoder_lr
# # #             print(name,lr)

# # #         if splitted_name[0] in ['head']:
# # #             lr = 10*encoder_lr
# # #             print(name,lr)
# #         # print(num)
# #         parameters.append({"params": params,
# #                            "weight_decay": weight_decay,
# #                            "lr": lr})
# #     return parameters
# def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
#     param_optimizer = list(model.named_parameters())
#     no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
#     optimizer_parameters = [
#         {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)],
#          'lr': encoder_lr, 'weight_decay': weight_decay, 'initial_lr':encoder_lr},
#         {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)],
#          'lr': encoder_lr, 'weight_decay': 0.0, 'initial_lr':encoder_lr},
#         {'params': [p for n, p in model.named_parameters() if "model" not in n],
#          'lr': decoder_lr, 'weight_decay': 0.0, 'initial_lr':decoder_lr}
#     ]
#     return optimizer_parameters
# optimizer_parameters = get_optimizer_params(model,
#                                                 encoder_lr=Config.lr, 
#                                                 decoder_lr=Config.lr,
#                                                 weight_decay=0.01)
# optimizer = AdamW(optimizer_parameters)
# # optimizer = optim.AdamW(optimizer_grouped_parameters, lr=Config.lr)
# scheduler = get_scheduler(optimizer)
# best=0.0
# for epoch in range(1, Config.epochs + 1):
#   print(f'******************** Training Epoch: {epoch} ********************')
#   train_fn(train_dataloader, model, criterion1, optimizer,scheduler)
#   r4_1,r4_2,mrr = validate_fn(val_dataloader, model, criterion2)
#   if r4_1 >= best:
#         best = r4_1
#         print(f'{r4_1} model saved')
#         torch.save(model.state_dict(), os.path.join(OUTPUT_DIR, f"model.bin"))

In [29]:
#model = MuTualModel()

In [30]:
#model.load_state_dict(torch.load("/content/drive/MyDrive/Colab/MuTual/outputbert-base-gat_50_epoch.pth"))

In [31]:
#model = model.to(device)

In [32]:
#validate_fn(val_dataloader, model, criterion2)