## Requirements

python: 3.7.12

GPU: P100

pytorch: 1.13.0

transformers: 4.26.1

In [None]:
import logging
import random
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import json
import logging
from torch.profiler import profile, record_function, ProfilerActivity
import psutil
from tqdm import tqdm,trange
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import argparse
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from transformers import BertConfig,BertTokenizer,BertModel,BertPreTrainedModel,AdamW,get_linear_schedule_with_warmup
import os
import math
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
import wandb
os.environ['WANDB_API_KEY']="" #if not use "wandb", please ignore 
os.environ['WANDB_MODE']="online"
wandb.login()

## Loss Function

In [None]:
def multilabel_categorical_crossentropy(y_pred, y_true):
    """
    https://kexue.fm/archives/7359
    """
  
    y_true=F.one_hot(y_true)
    y_true = y_true.float().detach()
    y_pred = (1 - 2 * y_true) * y_pred  # -1 -> pos classes, 1 -> neg classes
    y_pred_neg = y_pred - y_true * 1e12  # mask the pred outputs of pos classes
    y_pred_pos = (y_pred - (1 - y_true) * 1e12)  # mask the pred outputs of neg classes
    zeros = torch.zeros_like(y_pred[..., :1])
    y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
    y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
    neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
    pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
    return (neg_loss + pos_loss).sum()

## Layers

In [None]:
class Biaffine(nn.Module):
    def __init__(self, n_in, n_out=1, bias_x=True, bias_y=True):
        super(Biaffine, self).__init__()
        self.n_in = n_in
        self.n_out = n_out
        self.bias_x = bias_x
        self.bias_y = bias_y
        weight = torch.zeros(( n_in + int(bias_x),n_out, n_in + int(bias_y)))
        nn.init.xavier_normal_(weight)
        self.weight = nn.Parameter(weight, requires_grad=True)

    def forward(self, x, y):
        if self.bias_x:
            x = torch.cat((x, torch.ones_like(x[..., :1])), -1)
        if self.bias_y:
            y = torch.cat((y, torch.ones_like(y[..., :1])), -1)
        # [batch_size, seq_len, seq_len, n_out]
        s = torch.einsum('bxi,ioj,byj->bxyo', x, self.weight, y)
        # remove dim 1 if n_out == 1
        return s
class MLPla(nn.Module):
    def __init__(self,num_hiddens,out_size,active_fun=None):
        super().__init__()
        self.mlplayer=nn.Sequential(
            nn.LayerNorm(num_hiddens),
#             nn.BatchNorm1d(hidden_size, affine=False, track_running_stats=False),
#             nn.Linear(num_hiddens,1024),
#             nn.Softplus(),
            nn.Linear(num_hiddens,out_size),
            nn.GELU() if active_fun=='GELU' else nn.ReLU() if active_fun=='ReLU' else nn.Softplus() if active_fun=='Softplus' else nn.Sigmoid() if active_fun=='Sigmoid' else nn.Softmax()
        )
    def forward(self,X):
        return self.mlplayer(X)
class MLP(nn.Module):
    def __init__(self,num_hiddens,out_size,active_fun=None):
        super().__init__()
        self.mlplayer=nn.Sequential(
#             nn.LayerNorm(num_hiddens),
#             nn.Linear(num_hiddens,1024),
#             nn.Softplus(),
            nn.Linear(num_hiddens,out_size),
            nn.GELU() if active_fun=='GELU' else nn.ReLU() if active_fun=='ReLU' else nn.Softplus() if active_fun=='Softplus' else nn.Sigmoid() if active_fun=='Sigmoid' else nn.Softmax()
        )
    def forward(self,X):
        return self.mlplayer(X)
class MLPnorm(nn.Module):
    def __init__(self,num_hiddens,out_size,active_fun=None):
        super().__init__()
        self.banorm= nn.BatchNorm1d(num_hiddens*127)
        self.linear=nn.Linear(num_hiddens,out_size)
        self.active= nn.GELU() if active_fun=='GELU' else nn.ReLU() if active_fun=='ReLU' else nn.Softplus() if active_fun=='Softplus' else nn.Sigmoid() if active_fun=='Sigmoid' else nn.Softmax()
    def forward(self,X):
        dim=X.dim()
        if dim==3:
            B,L,H=X.size()
        else:
            B,L,L,H=X.size()
        X=X.view(B,-1)
        X1=self.banorm(X)
        if dim==3:
            X1=X1.view(B,L,H)
        else:
            X1=X1.view(B,L,L,H)
        X2=self.linear(X1)
        
        X2=self.active(X2)
        return X2

In [None]:
class ConvolutionLayer(nn.Module):
    '''卷积层
    '''
    def __init__(self, input_size, channels=96, dilation=[1,2,3], dropout=0.0):
        super(ConvolutionLayer, self).__init__()

        self.base = nn.Sequential(
            nn.Dropout2d(dropout),
            nn.Conv2d(input_size, channels, kernel_size=1),
            nn.GELU(),
#             nn.Softplus(),
        )

        self.convs = nn.ModuleList(
            [nn.Conv2d(channels, channels, kernel_size=11, groups=channels, dilation=d, padding=5*d) for d in dilation])

    def forward(self, x):
        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.base(x)
        outputs = []
        for conv in self.convs:
            x = conv(x)
            x = F.gelu(x)
            outputs.append(x)
        outputs = torch.cat(outputs, dim=1)
        outputs = outputs.permute(0, 2, 3, 1).contiguous()
        return outputs
    
class LayerNorm(nn.Module):
    def __init__(self, input_dim, cond_dim=0, center=True, scale=True, epsilon=None, conditional=False,
                 hidden_units=None, hidden_activation='linear', hidden_initializer='xaiver', **kwargs):
        super(LayerNorm, self).__init__()
        """
        input_dim: inputs.shape[-1]
        cond_dim: cond.shape[-1]
        """
        self.center = center
        self.scale = scale
        self.conditional = conditional
        self.hidden_units = hidden_units
        self.hidden_initializer = hidden_initializer
        self.epsilon = epsilon or 1e-12
        self.input_dim = input_dim
        self.cond_dim = cond_dim

        if self.center:
            self.beta = nn.Parameter(torch.zeros(input_dim))
        if self.scale:
            self.gamma = nn.Parameter(torch.ones(input_dim))

        if self.conditional:
            if self.hidden_units is not None:
                self.hidden_dense = nn.Linear(in_features=self.cond_dim, out_features=self.hidden_units, bias=False)
            if self.center:
                self.beta_dense = nn.Linear(in_features=self.cond_dim, out_features=input_dim, bias=False)
            if self.scale:
                self.gamma_dense = nn.Linear(in_features=self.cond_dim, out_features=input_dim, bias=False)

        self.initialize_weights()

    def initialize_weights(self):

        if self.conditional:
            if self.hidden_units is not None:
                if self.hidden_initializer == 'normal':
                    torch.nn.init.normal(self.hidden_dense.weight)
                elif self.hidden_initializer == 'xavier':  # glorot_uniform
                    torch.nn.init.xavier_uniform_(self.hidden_dense.weight)

            if self.center:
                torch.nn.init.constant_(self.beta_dense.weight, 0)
            if self.scale:
                torch.nn.init.constant_(self.gamma_dense.weight, 0)

    def forward(self, inputs, cond=None):
        if self.conditional:
            if self.hidden_units is not None:
                cond = self.hidden_dense(cond)

            for _ in range(len(inputs.shape) - len(cond.shape)):
                cond = cond.unsqueeze(1)  # cond = K.expand_dims(cond, 1)

            if self.center:
                beta = self.beta_dense(cond) + self.beta
            if self.scale:
                gamma = self.gamma_dense(cond) + self.gamma
        else:
            if self.center:
                beta = self.beta
            if self.scale:
                gamma = self.gamma

        outputs = inputs
        if self.center:
            mean = torch.mean(outputs, dim=-1).unsqueeze(-1)
            outputs = outputs - mean
        if self.scale:
            variance = torch.mean(outputs ** 2, dim=-1).unsqueeze(-1)
            std = (variance + self.epsilon) ** 0.5
            outputs = outputs / std
            outputs = outputs * gamma
        if self.center:
            outputs = outputs + beta

        return outputs

### InformationFusionLayer

In [None]:
class InformationFusionLayer(nn.Module):
    def __init__(self,type_size,num_hiddens,type_hiddens=768):
        super(InformationFusionLayer,self).__init__()
        self.type_embedding=nn.Embedding(type_size,type_hiddens)
       
        self.clnlayer=LayerNorm(num_hiddens,2*num_hiddens,conditional=True)

        self.type_embedding.weight.requires_grad = True

        self.mlp=MLP(type_hiddens,type_hiddens)
    def forward(self,X,overlap_ids=None,attention_mask=None): #overlap_id[bs,length]

        X_H=X.transpose(0,1)[1:].transpose(0,1)
        trigger_word_vecs = []
        tri_mask = overlap_ids.bool()
        for i in range(X_H.size(0)):
            # Extract the trigger word vectors for the i-th element in the batch
            tri_mask_i = tri_mask[i]
            trigger_word_vecs_i = X_H[i][tri_mask_i]

            # Reshape and take the maximum along the first dimension
            trigger_word_vecs_i = trigger_word_vecs_i.view(-1, X_H.size(-1))
            trigger_word_vecs_i,_ = torch.max(trigger_word_vecs_i,dim=0) #[emb]

            # Append the extracted trigger word vectors to the list
            trigger_word_vecs.append(trigger_word_vecs_i)

        # Concatenate the trigger word vectors for all elements in the batch
        trigger_word_vecs = torch.stack(trigger_word_vecs, dim=0) #[bs,emb]
#         query=self.linear(trigger_word_vecs)
        score = torch.einsum('be,te->bt', trigger_word_vecs, self.type_embedding.weight)#[bs,ts]
        _, max_indices = score.topk(k=2, dim=-1) #[bs,2]
        
        type_emb_list=[]
        for i in range(X_H.size(0)):
            if max(overlap_ids[i])>1:
                idx=max_indices[i][1]
            else:
                idx=max_indices[i][0]
            type_emb_list.append(self.mlp(self.type_embedding(idx)))
        type_emb=torch.stack(type_emb_list,dim=0)#[bs,emb]
        type_emb=type_emb.unsqueeze(1).expand_as(X_H)#[bs,1,emb]
        trigger_word_vecs=trigger_word_vecs.reshape(X_H.shape[0],-1,X_H.shape[-1]).expand_as(X_H)
        trigger_word_vecs=torch.cat((trigger_word_vecs,type_emb),dim=-1)
        outputs_tr_ty=self.clnlayer(X_H,trigger_word_vecs) #[bs,length,emb]
        return outputs_tr_ty,None #[bs,length-1,emd]

### JointPredictionLayer

In [None]:
def get_arg_pe(L):
   
    _arg_rpe = torch.zeros((L, L),dtype=torch.float32)
    for k in range(L):
        _arg_rpe[k, :] += k
        _arg_rpe[:, k] -= k
    P = torch.zeros((L,L))
    for i in range(L):
        for j in range(L):
            if j>i:
                v =_arg_rpe[i][j]-9
            elif j==i:
                v=_arg_rpe[i][j]-19
            else:
                v=_arg_rpe[i][j]
            w= 1 / (1000 ** (v / L))
            if v % 2 == 0:
                P[i,j] =torch.sin(w * i)
            else:
                P[i,j] = torch.cos(w * i)
    return P

class JointPredictionLayer(nn.Module):
    def __init__(self,num_hiddens,span_size,relation_size,active_fun,mlp_name=None):
        super().__init__()
        self.dis_embs = nn.Embedding(20, 20)
        self.reg_embs = nn.Embedding(3, 20)
        self.reg_embs1 = nn.Embedding(4, 20)
        self.span_size=span_size
        self.relation_size=relation_size

        self.span_cln = LayerNorm(num_hiddens,num_hiddens,conditional=True)
        self.relation_cln = LayerNorm(num_hiddens,num_hiddens,conditional=True)
        
        self.sconvLayer = ConvolutionLayer(20+20+num_hiddens,96)
        self.rconvLayer = ConvolutionLayer(20+20+num_hiddens,96)
        

        self.S1_mlplayer=globals()[mlp_name](num_hiddens,4,active_fun='Sigmoid')
        self.spans_mlpayer=globals()[mlp_name](num_hiddens+2,512,active_fun)
        self.spane_mlpayer=globals()[mlp_name](num_hiddens+2,512,active_fun)

        self.R1_mlplayer=globals()[mlp_name](num_hiddens,4,active_fun='Sigmoid')
        self.relations_mlpayer=globals()[mlp_name](num_hiddens+4,512,active_fun)
        self.relatione_mlpayer=globals()[mlp_name](num_hiddens+127,512,active_fun)

        
        self.spe1_mlpayer=globals()[mlp_name](127+512,512,active_fun)
        self.spe2_mlpayer=globals()[mlp_name](127+512,512,active_fun)
        self.rpe1_mlpayer=globals()[mlp_name](127+512,512,active_fun)
        self.rpe2_mlpayer=globals()[mlp_name](127+512,512,active_fun)
        

        self.SCLN_mlplayer=globals()[mlp_name](3*96,288,active_fun)
        self.RCLN_mlplayer=globals()[mlp_name](3*96,288,active_fun)

        self.S_biaffine=Biaffine(512,span_size)
        self.R_biaffine=Biaffine(512,relation_size)
        self.W_relation = nn.Parameter(torch.randn(num_hiddens,num_hiddens),requires_grad = True)
        self.dropout=nn.Dropout(p=args.dropout_p)
        self.sigmoid=nn.Sigmoid()

        self.Wr=nn.Linear(288,relation_size,active_fun)
        self.Ws=nn.Linear(288,span_size,active_fun)
        self.P=get_arg_pe(L=127)
        
    def forward(self,X,Y,_dist_inputs,orlp_ids,span_labels): #[bs,length,emd]
#         assert X_info1.shape[1]==Y_info1.shape[1]==127,print(X_info1.shape,Y_info1.shape)
        
        span_grid_mask2d=(span_labels!=-1)

        tril_mask = torch.tril(span_grid_mask2d.clone().long())
        reg_inputs = tril_mask + span_grid_mask2d.clone().long()

        reg_emb = self.reg_embs(reg_inputs)
        dis_emb = self.dis_embs(_dist_inputs)
        
        
        span_cln=self.span_cln(X.unsqueeze(2),X)
        span_conv_inputs = torch.cat([dis_emb, reg_emb, span_cln], dim=-1)
        span_conv_inputs[span_labels==-1]=0
        span_conv_outputs=self.sconvLayer(span_conv_inputs)
        span_conv_outputs[span_labels==-1]=0
        cln_span=self.SCLN_mlplayer(self.dropout(span_conv_outputs))
        cln_span_logits=self.Ws(cln_span)
        
        orlp_inputs=orlp_ids.unsqueeze(1).expand_as(span_grid_mask2d).clone()
        
        tril_mask1 = torch.tril(span_grid_mask2d.clone().long())
        reg_inputs1 = tril_mask1 + span_grid_mask2d.clone().long()
        
        reg_inputs1[orlp_inputs>0]=3
        reg_inputs1[span_labels==-1]=0

        
        reg_emb1=self.reg_embs1(reg_inputs1)
        relation_cln=self.relation_cln(Y.unsqueeze(2),Y)

             
        relation_conv_inputs = torch.cat([dis_emb,reg_emb1,relation_cln], dim=-1)  
        relation_conv_inputs[span_labels==-1]=0
        relation_conv_outputs=self.rconvLayer(relation_conv_inputs)
        relation_conv_outputs[span_labels==-1]=0
        cln_relation=self.RCLN_mlplayer(self.dropout(relation_conv_outputs))
        cln_relation_logits=self.Wr(cln_relation)
        


        P=self.P.unsqueeze(0).expand(X.shape[0], -1, -1)
        rp_emb=P.to(X.device)
        X_args=self.S1_mlplayer(self.dropout(X))#[,4]

        X_arg_start, X_arg_end = torch.chunk(X_args, 2, dim=-1)#[,2]
        
        X_span_start=torch.cat((X,X_arg_start),dim=-1)#[,768+2]
        X_span_end=torch.cat((X,X_arg_end),dim=-1)
        
        X_info_span1=self.spans_mlpayer(self.dropout(X_span_start))#[,512]
        X_info_span2=self.spane_mlpayer(self.dropout(X_span_end))
        X_info_span1=torch.cat((X_info_span1,rp_emb),dim=-1)#[,256+512]
        X_info_span2=torch.cat((X_info_span2,rp_emb),dim=-1)
        X_info_s1=self.spe1_mlpayer(X_info_span1)
        X_info_s2=self.spe2_mlpayer(X_info_span2)
        
        Y_args=self.R1_mlplayer(self.dropout(Y))

        Y_arg_start, Y_arg_end = torch.chunk(Y_args, 2, dim=-1)

        X_info_arg=torch.cat((X,Y_arg_start,Y_arg_end),dim=-1)#[,768+2+2]


        Y_info_real1=self.relations_mlpayer(self.dropout(X_info_arg))#[,512]
        Y_info_real1=torch.cat((Y_info_real1,rp_emb),dim=-1)
        Y_info_r1=self.rpe1_mlpayer(Y_info_real1)
        
        Y_info_arg=torch.einsum('bxi,ij,byj->bxy', X, self.W_relation, X)#[,127,127]
        Y_info_r=self.sigmoid(self.dropout(Y_info_arg))
        Y_info_r=torch.cat((X,Y_info_r),dim=-1)#[,768+127]


        Y_info_real2=self.relatione_mlpayer(self.dropout(Y_info_r))#[,512]
        Y_info_real2=torch.cat((Y_info_real2,rp_emb),dim=-1)
        Y_info_r2=self.rpe2_mlpayer(Y_info_real2)
  

        #bilinar_classifer
        bilinar_span_logits=self.S_biaffine(X_info_s1,X_info_s2)
        bilinar_relation_logits=self.R_biaffine(Y_info_r1,Y_info_r2)

                      
        #span_logits:[bs,(length-1),(length-1),span_size] relation_logits:[bs,(length-1),(length-1),relation_size]
        return bilinar_span_logits+cln_span_logits,bilinar_relation_logits+cln_relation_logits

## Model

In [None]:
class BERTETANET(BertPreTrainedModel):
    def __init__(self,config,num_hiddens,type_size,span_size,relation_size,loss_name,active_fun=None,mlp_name=None,weights=None):
        super().__init__(config)
        self.bert=BertModel.from_pretrained('bert-base-multilingual-cased',config=config)

        self.tri_rpe_embed=nn.Embedding(128,20)

        self.mlplayer=globals()[mlp_name](num_hiddens+20,num_hiddens,active_fun)
        

        self.infofulayer=InformationFusionLayer(type_size,num_hiddens)

        self.predlayer=JointPredictionLayer(num_hiddens,span_size,relation_size,active_fun,mlp_name)
        self.span_size=span_size
        self.relation_size=relation_size
        self.loss_name=loss_name
        self.s_weight,self.r_weight=weights

    def forward(self,input_ids=None,attention_mask=None,token_type_ids=None, position_ids=None, \
                head_mask=None, inputs_embeds=None,span_labels=None,relation_labels=None,overlap_ids=None,_dist_inputs=None,_tir_rpe=None):

        bert_embs =self.bert(                 # [bs, length, emd] 
            input_ids,                      # [batch,length]
            attention_mask=attention_mask,  # padding
            token_type_ids=token_type_ids,  # sentence segmentation
            position_ids=position_ids,      # position emebedding
            head_mask=head_mask,            # ？
            inputs_embeds=inputs_embeds,    # lookup mat
        )
        outputs = bert_embs[0]

        
        info_outs,tri_logits = self.infofulayer(outputs,overlap_ids,attention_mask) #[bs,length-1,num_hiddens]

        orlp_ids=(overlap_ids!=-1).clone().unsqueeze(-1)
        
        loss_t=None
        if tri_logits != None:
            
            orlp_ids[orlp_ids>0]=1
            tri_logits=tri_logits[orlp_ids!=-1]
            orlp_ids=orlp_ids[orlp_ids!=-1]
            loss_t=multilabel_categorical_crossentropy(tri_logits,orlp_ids)
        tri_rpe_input=self.tri_rpe_embed(_tir_rpe)#[bs,length,20]
        outputs_tri_pe=torch.cat((info_outs,tri_rpe_input),dim=-1)

        outputs_tri=outputs_tri_pe*orlp_ids

        outputs1=self.mlplayer(outputs_tri)
        
        outputs2=outputs1*orlp_ids

        span_logits,relation_logits=self.predlayer(outputs2,outputs2,_dist_inputs,overlap_ids,span_labels)
        span_label_mask = (span_labels != -1).clone() #[bs,length-1,length-1]
        relation_label_mask = (relation_labels != -1).clone()

        if span_labels !=None:
            span_outs = span_logits[span_label_mask]
            span_label = span_labels[span_label_mask]
            
            relation_outs = relation_logits[relation_label_mask]
            relation_label = relation_labels[relation_label_mask]

            loss_s=multilabel_categorical_crossentropy(span_outs,span_label)

            loss_r=multilabel_categorical_crossentropy(relation_outs,relation_label)

            if loss_t != None:
                loss = loss_t+loss_s+loss_r
            else:

                loss=self.s_weight*loss_s+self.s_weight*loss_r
            return (loss,)+(span_logits,relation_logits)
        
        return (span_logits,relation_logits)


## DataUtils

In [None]:
class InputExample(object):
    def __init__(self,example_id,content,overlap_id,span_label,relation_label,):
        self.example_id=example_id
        self.content=content
        self.overlap_id=overlap_id
        self.span_label=span_label
        self.relation_label=relation_label
        
class InputFeatures(object):
    def __init__(self,example_id,input_ids,attention_mask,overlap_id,span_label,relation_label,_dist_inputs,_tir_rpe,):
        self.example_id=example_id
        self.attention_mask=attention_mask
        self.input_ids=input_ids
        self.overlap_id=overlap_id
        self.span_label=span_label
        self.relation_label=relation_label
        self._dist_inputs=_dist_inputs
        self._tir_rpe=_tir_rpe
        
class DataProcessor(object):
    """Base class for data converters for multiple choice data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the test set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    def get_MI_examples(self,data_dir):
        raise NotImplementedError()

    def _create_examples(self):
        raise NotImplementedError()
        
class FewFCProcessor(DataProcessor):
    def get_train_examples(self,data_dir):
        datas=[]
        file1=open(os.path.join(data_dir,'train_labels.json'),"r")

        for line in file1:
            datas.append(json.loads(line))

        return self._create_examples(datas,'train')
    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        datas=[]
        file1=open(os.path.join(data_dir,'dev_labels.json'),"r")

        for line in file1:
            datas.append(json.loads(line))

        return self._create_examples(datas,'dev')
    def get_test_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the test set."""
        raise None

    def get_labels(self):
        """Gets the list of labels for this data set."""
        return ['None','S','T'],['None','R']
    def _create_examples(self,lines,set_type):
        examples=[]
        for data in lines:
            e_id = "%s-%s" % (set_type, str(data['id']))
            indices,values,size=data['span_label'][0],data['span_label'][1],data['span_label'][2]
            span_label=torch.sparse_coo_tensor(indices,values,size).to_dense()#转成稠密矩阵
            
            indices,values,size=data['relation_label'][0],data['relation_label'][1],data['relation_label'][2]
            relation_label=torch.sparse_coo_tensor(indices,values,size).to_dense()#转成稠密矩阵
            
            examples.append(InputExample(
                example_id=e_id,
                content=data['content'],
                overlap_id=data['overlap_ids'],
                span_label=span_label,
                relation_label=relation_label,
                
            ))
        return examples

In [None]:
# 相对距离设置
dis2idx = np.zeros((1000), dtype='int64')
dis2idx[1] = 1
dis2idx[2:] = 2
dis2idx[4:] = 3
dis2idx[8:] = 4
dis2idx[16:] = 5
dis2idx[32:] = 6
dis2idx[64:] = 7
dis2idx[128:] = 8
dis2idx[256:] = 9
def convert_examples_to_features(examples,label_list,max_length,tokenizer):
    features = []
   
    for example in tqdm(examples,desc="convert examples to features"):
        inputs=tokenizer.encode_plus(example.content,add_special_tokens=True,max_length=max_length,padding='max_length',truncation=True)
        inputs_ids,attention_mask=inputs['input_ids'],inputs['attention_mask'] #padding inputs

        overlap_id=torch.tensor(example.overlap_id)
        assert (overlap_id>0).sum()>0,print((overlap_id>0).sum(),example.example_id)
        length=len(overlap_id)
        overlap_id=F.pad(overlap_id,pad=(0,max_length-len(example.overlap_id)-1),value=-1)#padding overlap_id但去掉[cls]
        #trigger的相对位置编码
        _tri_index=torch.argwhere(overlap_id>0)
        _tri_start,_tri_end=min(_tri_index),max(_tri_index)
        _tri_start,_tri_end=_tri_start.item(),_tri_end.item()
        pos=list(range(-_tri_start, 0)) + [0] * (_tri_end - _tri_start + 1) + list(range(1, length - _tri_end))
        pos=[abs(x) for x in pos]
        _tir_rpe=torch.tensor(pos)
        _tir_rpe=F.pad(_tir_rpe,pad=(0,max_length-len(example.overlap_id)-1),value=max_length-1)
             
        #词对的相对位置编码
        _dist_inputs = np.zeros((length, length),dtype=np.int64)
        for k in range(length):
            _dist_inputs[k, :] += k
            _dist_inputs[:, k] -= k
        for i in range(length):
            for j in range(length):
                if _dist_inputs[i, j] < 0:
                    _dist_inputs[i, j] = dis2idx[-_dist_inputs[i, j]] + 9
                else:
                    _dist_inputs[i, j] = dis2idx[_dist_inputs[i, j]]
            _dist_inputs[_dist_inputs == 0] = 19
        
        assert (_dist_inputs>19).sum()==0,print(_dist_inputs)
        _dist_inputs=torch.as_tensor(_dist_inputs)   
    
        _dist_inputs=F.pad(_dist_inputs,pad=(0,max_length-len(example.overlap_id)-1,0,max_length-len(example.overlap_id)-1),value=0)
#         assert overlap_id!=torch.zeros(len(overlap_id)),ValueError("overlap_id is zero tensor")
        #label padding
        
        extend_label_s=F.pad(example.span_label,pad=(0,max_length-len(example.span_label)-1,0,max_length-\
                                                     len(example.span_label)-1),value=-1)
        relation_label=example.relation_label.transpose(1,0)
        
        extend_label_r=F.pad(relation_label,pad=(0,max_length-len(example.relation_label)-1,0,max_length-\
                                                         len(example.relation_label)-1),value=-1)

        features.append(InputFeatures(example.example_id,inputs_ids,attention_mask,overlap_id,\
                                      extend_label_s,extend_label_r,_dist_inputs,_tir_rpe,))
    return features
        

In [None]:
def load_and_cache_examples(args, task, tokenizer, evaluate=False, test=False):
    processor = processors[task]()
    label_list = processor.get_labels()
    # Load data features from cache or dataset file
    if evaluate:
        cached_mode = "dev"
    elif test:
        cached_mode = "test"
    else:
        cached_mode = "train"
        
    #存储特征
    cached_features_file = os.path.join(
        features_dir,
        "cached_{}_{}_{}".format(
            cached_mode,
            list(filter(None, args.model_name_or_path.split("/"))).pop(),
            str(task),
        ),
    )
    
#     if evaluate:
#         cached_features_file="/kaggle/input/fewdataset2/cached_dev_bert-base-multilingual-cased_fewfc"
#     elif test:
#         pass
#     else:
#         cached_features_file='/kaggle/input/fewdataset2/cached_train_bert-base-multilingual-cased_fewfc'
    if os.path.exists(cached_features_file) and not args.overwrite_cache:
        logger.info("Loading features from cached file %s", cached_features_file)
        features = torch.load(cached_features_file)
    else:
        print("Creating features from dataset file at %s", args.data_dir)
        logger.info("Creating features from dataset file at %s", args.data_dir)
        label_list = processor.get_labels()
        if evaluate:
            examples = processor.get_dev_examples(args.data_dir)
        elif test:
            examples = processor.get_test_examples(args.data_dir)
        else:
            examples = processor.get_train_examples(args.data_dir)
            
        logger.info("Training number: %s", str(len(examples)))
            #检测数据转成features所占内存
        features=convert_examples_to_features(examples,label_list,args.max_seq_length,tokenizer)   
        if args.local_rank in [-1, 0]:
                logger.info("Saving features into cached file %s", cached_features_file)
                torch.save(features, cached_features_file)
    # Convert to Tensors and build dataset

    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)

    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_overlap_ids = torch.tensor([f.overlap_id.tolist() for f in features], dtype=torch.long)
    all_span_labels = torch.tensor([f.span_label.tolist() for f in features], dtype=torch.long)
    all_dist_inputs = torch.tensor([f._dist_inputs.tolist() for f in features], dtype=torch.long)
    all_tir_rpe = torch.tensor([f._tir_rpe.tolist() for f in features], dtype=torch.long)
    all_relation_labels=torch.tensor([f.relation_label.tolist() for f in features], dtype=torch.long)
    
    dataset = TensorDataset(all_input_ids, all_attention_mask, all_overlap_ids,\
                            all_span_labels,all_relation_labels,all_dist_inputs,all_tir_rpe,)
    if evaluate or test:
        return dataset, [f.example_id for f in features]
    return dataset

## Train

In [None]:
def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
        
def train(config,args, train_dataset, tokenizer):
    model = BERTETANET(config,args.hidden_size,args.type_num,3,2,args.loss_name,args.active_fun,args.mlp_name,(args.loss_s_weight,args.loss_r_weight))
    model.to(args.device)
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) #8*2=16   
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
      
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    
    # Prepare optimizer and schedule (linear warmup and decay) #####################################################
    
    #继续恢复训练##############################################################################
#     output_train_checkpoint_dir= os.path.join(args.output_dir,'train_checkpoint')
#     max_epoch=-1
#     if os.path.exists(output_train_checkpoint_dir):
#         checkpoint_path=os.listdir(output_train_checkpoint_dir)
# #         print('checkpoint_path:',checkpoint_path)
#         for checkpoint in checkpoint_path:
#             s=int(checkpoint.split('-')[-1].split('.')[0])
#             if s >max_epoch:
#                 max_epoch=s
#     else:
#         os.makedirs(output_train_checkpoint_dir)
        
#     model_path = os.path.join(output_train_checkpoint_dir,"checkpoint_epoch-{}.pt".format(max_epoch))
#     if os.path.exists(model_path):
#         checkpoint = torch.load(model_path)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#         start_epoch = checkpoint['epoch']
#         print("Successfully loaded checkpoint from epoch {}".format(epoch))
#     else:
#     bert_params = set(model.bert.parameters())
#     other_params = list(set(model.parameters()) - bert_params)
#     no_decay = ['bias', 'LayerNorm.weight']
#     optimizer_grouped_parameters = [
#         {'params': [p for n, p in model.bert.named_parameters() if not any(nd in n for nd in no_decay)],
#          'lr': args.bert_lr,
#          'weight_decay': args.weight_decay},
#         {'params': [p for n, p in model.bert.named_parameters() if any(nd in n for nd in no_decay)],
#          'lr': args.bert_lr,
#          'weight_decay': 0.0},
#         {'params': other_params,
#          'lr': args.learning_rate,
#          'weight_decay': args.weight_decay},
#     ]   
    no_decay = ["bias", "LayerNorm.weight"]   
    optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                'lr': args.learning_rate,
                "weight_decay": args.weight_decay,
            },
            {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],'lr': args.learning_rate,"weight_decay": 0.0},
        ] #采用weight_decay策略或不采用

    optimizer = torch.optim.__dict__[args.optim_type](optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        
    scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=args.warmup_ratio*t_total, num_training_steps=t_total,
        )
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
        
     # multi-gpu training (should be after apex fp16 initialization)################################
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    ########################################################################################### Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    wandb.init(project='ETANET',config=args.__dict__)
    model.run_id=wandb.run.id
    wandb.watch(model,log='all')
#     params_to_watch = list(model.infofulayer.parameters()) + list(model.predlayer.parameters())
#     print(params_to_watch)
#     wandb.watch(params_to_watch)
#     wandb_config=wandb.config
    #################################################################################################
    global_step = 0
    tr_loss, logging_loss,epoch_loss = 0.0, 0.0,0.0
    best_dev_tif1 = 0.0
    best_dev_aif1 = 0.0
    best_steps = 0
    best_dev_span_preds = []
    best_dev_relation_preds = []
    best_test_preds = []
    
    
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
    set_seed(args)  # Added here for reproductibility

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        model.train()
        temp_step=0
        for step, batch in enumerate(epoch_iterator):
            temp_step=temp_step+1
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "overlap_ids": batch[2],
                "span_labels": batch[3],
                "relation_labels":batch[4],
                "_dist_inputs": batch[5],
                "_tir_rpe": batch[6],
            }
            outputs = model(**inputs)
            loss = outputs[0]
            
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
                
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:

                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += loss.item() #累计误差
            #累计计算梯度
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
                print('loss:',loss,optimizer.state_dict()['param_groups'][0]['lr'])
                #进行评估
                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    if (args.local_rank == -1 and args.evaluate_during_training):
                        
#                         results, dev_pred_span_results = evaluate(args, model, tokenizer)
                        results = evaluate(args, model, tokenizer)

#                         if results["eval_ti_f1"] > best_dev_tif1:
#                             best_dev_tif1=results["eval_ti_f1"]
#                             best_steps = global_step
# #                             best_dev_span_preds = dev_pred_span_results
#                             logger.info(
#                                 "eval_ti_f1: %s,loss: %s,glob steps: %s",
#                                 str(results["eval_ti_f1"]),
#                                 str(results["eval_loss"]),
#                                 str(global_step),
#                             )
                        if results["eval_ai_f1"] > best_dev_aif1:
                            best_dev_aif1=results["eval_ai_f1"]
                            best_steps = global_step
#                             best_dev_relarion_preds = dev_pred_relation_results
#                             logger.info(
#                                 "eval_ai_f1: %s,loss: %s,glob steps: %s",
#                                 str(results["eval_ai_f1"]),
#                                 str(results["eval_loss"]),
#                                 str(global_step),
#                             )
                        wandb.log({'eval_ai_f1':results["eval_ai_f1"],'best_eval_ai_f1':best_dev_aif1,'best_steps':best_steps})
                        wandb.log({'eval_loss':results["eval_loss"]})
                        wandb.log({'average_train_loss':(tr_loss - logging_loss) / args.logging_steps,'lr':optimizer.state_dict()['param_groups'][0]['lr']})
#                 print('Average loss:',(tr_loss - logging_loss) / args.logging_steps,'at global step:',global_step)
#                 logger.info(
#                         "Average loss: %s at global step: %s",
#                         str((tr_loss - logging_loss) / args.logging_steps),
#                         str(global_step),
#                     )
                logging_loss = tr_loss
            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        wandb.log({'train_loss':tr_loss-epoch_loss / temp_step})
        epoch_loss=tr_loss
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
#         if epoch % 30 == 0:
#             output_train_checkpoint_dir= os.path.join(args.output_dir,'train_checkpoint')
#             if not os.path.exists(output_train_checkpoint_dir):
#                 os.makedirs(output_train_checkpoint_dir)
#             checkpoint_path = os.path.join(output_train_checkpoint_dir,"checkpoint_epoch-{}.pt".format(epoch))
#             torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict()
#         }, checkpoint_path)
#     batch_size = args.per_gpu_train_batch_size
#     seq_length = 128
#     input_ids = torch.zeros(batch_size, seq_length, dtype=torch.long)
#     onnx_model=torch.onnx.export(model,input_ids,"model.onnx")
#     wandb.save(onnx_model)
    wandb.finish()        


## Evaluate

In [None]:
def evaluate(args, model, tokenizer, prefix="", test=False):
    eval_task_names = (args.task_name,)
    eval_outputs_dirs = (args.output_dir,)

    results = {}
    for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
        eval_dataset, eval_ids = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True, test=test)

        if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
            os.makedirs(eval_output_dir)

        args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
        # Note that DistributedSampler samples randomly
        eval_sampler = RandomSampler(eval_dataset)
        eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
        
        # Eval!
        logger.info("***** Running evaluation {} *****".format(prefix))
        logger.info("  Num examples = %d", len(eval_dataset))
        logger.info("  Batch size = %d", args.eval_batch_size)
        
        eval_loss = 0.0
        nb_eval_steps = 0
        
        span_preds_result=None
        span_preds = None
        span_preds_result=None
        relation_preds=None
        model.eval()
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(args.device) for t in batch)

            with torch.no_grad():
                inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "overlap_ids": batch[2],
                "span_labels": batch[3],#[bs,len-1,len-1]
                "relation_labels":batch[4],
                "_dist_inputs": batch[5],
                "_tir_rpe": batch[6],
                
            }
                temp_eval_loss,span_logits,relation_logits = model(**inputs)
#                 temp_eval_loss,span_logits, = model(**inputs)
    
                span_label_mask=(inputs['span_labels']!= -1)
                relation_label_mask=(inputs['relation_labels']!= -1)
                eval_loss += temp_eval_loss.item()

                temp_span_pred=None
                temp_relation_pred=None
                if span_preds is None:
                    span_preds = span_logits[span_label_mask] #2D [-1,span_size]
                    span_labels = inputs['span_labels'][span_label_mask] #1D
                    relation_preds = relation_logits[relation_label_mask] #2D [-1,span_size]
                    relation_labels = inputs['relation_labels'][relation_label_mask] #1D

                    span_preds_result=torch.argmax(span_logits, dim=-1) #[bs,lne-1.len-1]
                    span_preds_result[span_label_mask==False]=-1
                    relation_preds_result=torch.argmax(relation_logits, dim=-1) #[bs,lne-1.len-1]
                    relation_preds_result[relation_label_mask==False]=-1
                else:
                    #将其展开用于计算precison,reall,f1
                    span_preds = torch.cat((span_preds, span_logits[span_label_mask]), axis=0)
                    span_labels = torch.cat((span_labels, inputs['span_labels'][span_label_mask]), axis=0)
                    relation_preds = torch.cat((relation_preds, relation_logits[relation_label_mask]), axis=0)
                    relation_labels = torch.cat((relation_labels, inputs['relation_labels'][relation_label_mask]), axis=0)
                    #未展开的用于进行预测和decoding
#                     temp_span_pred=torch.argmax(span_logits, dim=-1)
#                     temp_span_pred[span_label_mask==False]=-1 #填充部分的预测标记为-1
#                     span_preds_result=torch.cat((span_preds_result,temp_span_pred), axis=0)
                    
              
            nb_eval_steps += 1     
        span_preds=torch.argmax(span_preds, dim=-1) #1D
        relation_preds=torch.argmax(relation_preds, dim=-1) #1D
        eval_loss = eval_loss / nb_eval_steps

        ti_p,ti_r,ti_f1,ai_p,ai_r,ai_f1=calculate_scores(span_preds.detach().cpu().numpy(),span_labels.detach().cpu().numpy())

        result = {"eval_ti_precision": ti_p, "eval_ti_recall": ti_r, "eval_ti_f1": ti_f1,"eval_ai_precision": ai_p,\
                  "eval_ai_recall": ai_r, "eval_ai_f1": ai_f1, "eval_loss": eval_loss}
     
        results.update(result)
    return results
#     return results, dict(list(zip(eval_ids,span_preds_result)))



def calculate_scores(pred_span_labels,span_labels,pred_relation_labels=None,relation_labels=None):
    # TI: Trigger Identification
    ti_correct = ((pred_span_labels == 2) & (span_labels == pred_span_labels)).sum()
    ti_total_pred = (pred_span_labels == 2).sum()
    ti_total_gold = (span_labels == 2).sum()

    ti_precision = ti_correct / ti_total_pred if ti_total_pred else 0
    ti_recall = ti_correct / ti_total_gold if ti_total_gold else 0
    ti_f1 = 2 * ti_precision * ti_recall / (ti_precision + ti_recall) if ti_precision + ti_recall else 0

    # AI: Argument Identification
    ai_correct = ((pred_span_labels == 1) & (span_labels==pred_span_labels)).sum()
    ai_total_pred = (pred_span_labels == 1).sum()
    ai_total_gold = (span_labels == 1).sum()

    ai_precision = ai_correct / ai_total_pred if ai_total_pred else 0
    ai_recall = ai_correct / ai_total_gold if ai_total_gold else 0
    ai_f1 = 2 * ai_precision * ai_recall / (ai_precision + ai_recall) if ai_precision + ai_recall else 0

    return ti_precision, ti_recall, ti_f1, ai_precision, ai_recall, ai_f1


## Main

In [None]:

features_dir='/kaggle/working/'
processors={"fewfc": FewFCProcessor}
TYPE_SIZE=30
NUM_HIDDENS=768
MODEL_CLASSES = {'bert':(BertConfig, BertModel, BertTokenizer)}
logger = logging.getLogger(__name__)
sweep_config={'method':"random"}
metric={'name':'eval_ai_f1','goal':'maximize'}
sweep_config['metric']=metric
sweep_config['parameters']={}
sweep_config['parameters'].update({
    'project_name':{'value':'ETANET'},
#     'num_train_epochs':{"value":6},
    'optim_type':{'values':['Adamw','AdaDelta','AdaDeltaW','Adamax']},
    'mlp_name':{'values':['MLP','MLPnorm','MLPla']},
#     'loss_name':{'values':['ce loss','dice loss']},
    'active_fun':{'values':['ReLU','GELU','Softplus']},
    'type_num':{
        'values':[10,20,30,70,100,200,300]
    },
    'learning_rate':{
        'distribution':'log_uniform_values',
        'min':5e-5,
        'max':0.0001
    },
    
    'weight_decay':{
        'distribution':'log_uniform_values',
        'min':1e-5,
        'max':0.0001
    },
    'loss_s_weight':{
        'distribution':'q_uniform',
        'q':1,
        'min':1,
        'max':3
    },
    'loss_r_weight':{
        'distribution':'q_uniform',
        'q':1,
        'min':1,
        'max':3
    },
#     'per_gpu_train_batch_size':{
#         'distribution':'q_uniform',
#         'q':4,
#         'min':8,
#         'max':16
#     },
#     'per_gpu_eval_batch_size':{
#     'distribution':'q_uniform',
#     'q':4,
#     'min':8,
#     'max':16
#     },
    'dropout_p':{
        'distribution':'q_uniform',
        'q':0.1,
        'min':0.1,
        'max':0.3
    },
#     'gradient_accumulation_steps':{
#     'distribution':'q_uniform',
#     'q':1,
#     'min':1,
#     'max':4
#     },
#     'max_grad_norm':{
#     'distribution':'q_uniform',
#     'q':1,
#     'min':4,
#     'max':5
#     }
})

def main(args=None):
    # Required parameters
    if args == None:
        args = Config_parser()
    
    # Setup CUDA, GPU & distributed training ########################################################
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda:0" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = 1
    args.device = device
    # Prepare task####################################################################################
    args.task_name = args.task_name.lower()
    processor = processors[args.task_name]()
    label_list = processor.get_labels()
    num_labels1,num_labels2 = len(label_list[0]),len(label_list[1])
    ###################################################################################################    
    
    # Load pretrained model and tokenizer
    args.model_type = args.model_type.lower()
    config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
    config = config_class.from_pretrained(
        args.model_name_or_path,
        finetuning_task=args.task_name,
        cache_dir=args.cache_dir if args.cache_dir else None
    )
    config.output_hidden_states=True
    tokenizer = tokenizer_class.from_pretrained(
        args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )
    
    sweep_id=wandb.sweep(sweep_config,project='ETANET')
    
    #####################################################################################################

    # Trainingr
    if args.do_train:
        print("--do train--")
        train_dataset= load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
#         assert 1==0
        new_train_func = partial(train,config=config,args=args,train_dataset=train_dataset,tokenizer=tokenizer)
        wandb.agent(sweep_id,new_train_func,count=10)
#         global_step, tr_loss, best_steps = train(args, train_dataset, model, tokenizer)

args = argparse.Namespace(data_dir="/kaggle/input/cassdataset",
                          model_type="bert", model_name_or_path="bert-base-multilingual-cased", 
                          task_name="fewfc", output_dir="/kaggle/working/",max_seq_length=128,max_steps=-1,loss_s_weight=1,loss_r_weight=1,
                          no_cuda=False,per_gpu_train_batch_size=16,per_gpu_eval_batch_size=16,seed=100,dropout_p=0.1,
                          fp16=False,gradient_accumulation_steps=1,bert_lr=2e-5,learning_rate=5e-5,weight_decay=0.001,
                          adam_epsilon=1e-7,max_grad_norm=5.0,num_train_epochs=6,warmup_ratio=0.1,loss_name='ce loss',
                          logging_steps=100,save_steps=100,eval_all_checkpoints=False,do_train=True,optim_type='AdamW',mlp_name='MLP',
                          do_pretrain=False,do_eval=False,do_test=False,evaluate_during_training=True,span_size=3,relation_size=2,
                          local_rank=-1,cache_dir='',do_lower_case=False,overwrite_cache=True,type_num=30,hidden_size=768,active_fun='ReLU')
if __name__ == "__main__":
    main(args)