In [None]:
from transformers import AutoModel, BertConfig
from transformers import BertModel, BertPreTrainedModel
import torch
from torch import nn
import numpy as np
from torch.nn import CrossEntropyLoss
import pytorch_pretrained_bert

In [None]:
%run ./loss.ipynb

In [None]:
%run ./layer.ipynb

In [None]:
class BERT_CRF(BertPreTrainedModel):
  def __init__(self, config):
    super(BERT_CRF, self).__init__(config)
    
    self.bert = AutoModel.from_pretrained(args.pretrained_model)
    self.dropout = nn.Dropout(args.drop)
    self.classifier = nn.Linear(768, len(args.tag2idx))
    self.crf = CRF(num_tags = len(args.tag2idx), batch_first=True)
    self.init_weights()
  
  def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
    #x_mask = self.make_bert_mask(x, 0)
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
    #print(output.shape)
    #emissions = self.position_wise_ff(output)
    sequence_output = outputs[0]
    sequence_output = self.dropout(sequence_output)
    logits = self.classifier(sequence_output)
    #outputs = (logits,)
    #log_likelihood = self.crf(emissions, tags, mask=x_mask.bool(), reduction='mean')
    #sequence_of_tags =  self.crf.decode(emissions, mask=x_mask.bool())
    loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
    #outputs =(-1*loss,)+outputs
    return (-1*loss, logits)
    
  #def init_model(self):
  #  init_list = [self.classifier, self.crf]
  #  for module in init_list:
  #    for param in module.parameters():
  #      if param.dim() > 1:
  #        nn.init.xavier_uniform_(param)
    #return log_likelihood, sequence_of_tags

In [None]:
import pytorch_pretrained_bert

In [None]:
pytorch_pretrained_bert.__version__

In [None]:
class BERT_CRF_pre(nn.Module):
  def __init__(self, config):
    super(BERT_CRF_pre, self).__init__()
    
    self.bert = pytorch_pretrained_bert.BertModel.from_pretrained(config.pretrain_path)
    self.dropout = nn.Dropout(args.drop)
    self.classifier = nn.Linear(768, len(args.tag2idx))
    self.crf = CRF(num_tags = len(args.tag2idx), batch_first=True)
    #self.init_weights()
  
  def forward(self,input_ids, attention_mask, labels):
    #print('yes')
    outputs = self.bert(input_ids, attention_mask=attention_mask, output_all_encoded_layers=False)
    #print(output.shape)
    #emissions = self.position_wise_ff(output)
    sequence_output = outputs[0]
    sequence_output = self.dropout(sequence_output)
    logits = self.classifier(sequence_output)
    #outputs = (logits,)
    #log_likelihood = self.crf(emissions, tags, mask=x_mask.bool(), reduction='mean')
    #sequence_of_tags =  self.crf.decode(emissions, mask=x_mask.bool())
    loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
    #outputs =(-1*loss,)+outputs
    return (-1*loss, logits)


In [None]:
class BERT_Linear_pre(nn.Module):
  def __init__(self, config):
    super(BERT_Linear_pre, self).__init__()
    
    self.num_labels = config.num_labels
    self.bert = pytorch_pretrained_bert.BertModel.from_pretrained(config.pretrain_path)
    
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.loss_type = args.loss_type
    
    #self.init_weights()
    #self.apply(self.init_bert_weights)
  
  def forward(self,input_ids, attention_mask, labels):
        #return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.bert(input_ids, attention_mask=attention_mask, output_all_encoded_layers=False)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        #outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            assert self.loss_type in ['lsr', 'fl', 'ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'fl':
                #loss_fct = FocalLoss()
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                #print(active_logits)
                #print(active_labels)
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            #outputs = (loss,) + outputs
        return (loss, logits)  # (loss), scores, (hidden_states), (attentions)



In [None]:
class BERT_Linear(BertPreTrainedModel):
  def __init__(self, config):
    super(BERT_Linear, self).__init__(config)
    
    self.num_labels = config.num_labels
    self.bert = BertModel(config, add_pooling_layer=False)
    
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.loss_type = args.loss_type
    
    self.init_weights()
    #self.apply(self.init_bert_weights)
  
  def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
    
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
      
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)
        #outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            assert self.loss_type in ['lsr', 'fl', 'ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'fl':
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            #outputs = (loss,) + outputs
        return (loss, logits)  # (loss), scores, (hidden_states), (attentions)



In [None]:
class multihead_attention(nn.Module):

    def __init__(self, num_units, num_heads=1, dropout_rate=0, gpu=True, causality=False):
        '''Applies multihead attention.
        Args:
            num_units: A scalar. Attention size.
            dropout_rate: A floating point number.
            causality: Boolean. If true, units that reference the future are masked.
            num_heads: An int. Number of heads.
        '''
        super(multihead_attention, self).__init__()
        self.gpu = gpu
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.causality = causality
        self.Q_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.K_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.V_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        if self.gpu:
            self.Q_proj = self.Q_proj.cuda()
            self.K_proj = self.K_proj.cuda()
            self.V_proj = self.V_proj.cuda()


        self.output_dropout = nn.Dropout(p=self.dropout_rate)

    def forward(self, queries, keys, values,last_layer = False):
        # keys, values: same shape of [N, T_k, C_k]
        # queries: A 3d Variable with shape of [N, T_q, C_q]
        # Linear projections
        Q = self.Q_proj(queries)  # (N, T_q, C)
        K = self.K_proj(keys)  # (N, T_q, C)
        V = self.V_proj(values)  # (N, T_q, C)
        # Split and concat
        Q_ = torch.cat(torch.chunk(Q, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        K_ = torch.cat(torch.chunk(K, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        V_ = torch.cat(torch.chunk(V, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        # Multiplication
        outputs = torch.bmm(Q_, K_.permute(0, 2, 1))  # (h*N, T_q, T_k)
        # Scale
        outputs = outputs / (K_.size()[-1] ** 0.5)

        # Activation
        if last_layer == False:
            outputs = F.softmax(outputs, dim=-1)  # (h*N, T_q, T_k)
        # Query Masking
        query_masks = torch.sign(torch.abs(torch.sum(queries, dim=-1)))  # (N, T_q)
        query_masks = query_masks.repeat(self.num_heads, 1)  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, 2).repeat(1, 1, keys.size()[1])  # (h*N, T_q, T_k)
        outputs = outputs * query_masks
        # Dropouts
        outputs = self.output_dropout(outputs)  # (h*N, T_q, T_k)
        if last_layer == True:
            return outputs
        # Weighted sum
        outputs = torch.bmm(outputs, V_)  # (h*N, T_q, C/h)
        # Restore shape
        outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2)  # (N, T_q, C)
        # Residual connection
        outputs += queries

        return outputs

In [None]:
class BERT_LAN(BertPreTrainedModel):
  def __init__(self, config):
    super(BERT_LAN, self).__init__(config)
    
    self.num_labels = config.num_labels
    self.bert = BertModel(config, add_pooling_layer=False)
    
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = multihead_attention(768, num_heads=config.head_num, dropout_rate=config.drop_rate)
    self.loss_type = args.loss_type
    self.label_dim = 200
    self.label_embedding = nn.Embedding(self.num_labels, self.label_dim)
    
    self.init_weights()
    #self.apply(self.init_bert_weights)
  
  def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
    
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        #print(outputs)
        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output, sequence_output, sequence_output)
        #outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            assert self.loss_type in ['lsr', 'fl', 'ce']
            if self.loss_type == 'lsr':
                loss_fct = LabelSmoothingCrossEntropy()
            elif self.loss_type == 'fl':
                loss_fct = FocalLoss()
            else:
                loss_fct = CrossEntropyLoss()
            # Only keep active parts of the loss
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                active_logits = logits.view(-1, self.num_labels)
                active_labels = torch.where(
                    active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
                )
                loss = loss_fct(active_logits, active_labels)
            else:
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            #outputs = (loss,) + outputs
        return (loss, logits)  # (loss), scores, (hidden_states), (attentions)



In [None]:
class multihead_attention(nn.Module):

    def __init__(self, num_units, num_heads=1, dropout_rate=0, gpu=True, causality=False):
        '''Applies multihead attention.
        Args:
            num_units: A scalar. Attention size.
            dropout_rate: A floating point number.
            causality: Boolean. If true, units that reference the future are masked.
            num_heads: An int. Number of heads.
        '''
        super(multihead_attention, self).__init__()
        self.gpu = gpu
        self.num_units = num_units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.causality = causality
        self.Q_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.K_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        self.V_proj = nn.Sequential(nn.Linear(self.num_units, self.num_units), nn.ReLU())
        if self.gpu:
            self.Q_proj = self.Q_proj.cuda()
            self.K_proj = self.K_proj.cuda()
            self.V_proj = self.V_proj.cuda()


        self.output_dropout = nn.Dropout(p=self.dropout_rate)

    def forward(self, queries, keys, values,last_layer = False):
        # keys, values: same shape of [N, T_k, C_k]
        # queries: A 3d Variable with shape of [N, T_q, C_q]
        # Linear projections
        Q = self.Q_proj(queries)  # (N, T_q, C)
        K = self.K_proj(keys)  # (N, T_q, C)
        V = self.V_proj(values)  # (N, T_q, C)
        # Split and concat
        Q_ = torch.cat(torch.chunk(Q, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        K_ = torch.cat(torch.chunk(K, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        V_ = torch.cat(torch.chunk(V, self.num_heads, dim=2), dim=0)  # (h*N, T_q, C/h)
        # Multiplication
        outputs = torch.bmm(Q_, K_.permute(0, 2, 1))  # (h*N, T_q, T_k)
        # Scale
        outputs = outputs / (K_.size()[-1] ** 0.5)

        # Activation
        if last_layer == False:
            outputs = F.softmax(outputs, dim=-1)  # (h*N, T_q, T_k)
        # Query Masking
        query_masks = torch.sign(torch.abs(torch.sum(queries, dim=-1)))  # (N, T_q)
        query_masks = query_masks.repeat(self.num_heads, 1)  # (h*N, T_q)
        query_masks = torch.unsqueeze(query_masks, 2).repeat(1, 1, keys.size()[1])  # (h*N, T_q, T_k)
        outputs = outputs * query_masks
        # Dropouts
        outputs = self.output_dropout(outputs)  # (h*N, T_q, T_k)
        if last_layer == True:
            return outputs
        # Weighted sum
        outputs = torch.bmm(outputs, V_)  # (h*N, T_q, C/h)
        # Restore shape
        outputs = torch.cat(torch.chunk(outputs, self.num_heads, dim=0), dim=2)  # (N, T_q, C)
        # Residual connection
        outputs += queries

        return outputs