In [1]:
import math
import json

import torch
import torch.nn as nn

In [2]:
class Config(dict): 
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__

    @classmethod
    def load(cls, file):
        with open(file, 'r') as f:
            config = json.loads(f.read())
            return Config(config)
    
    @classmethod
    def save(cls, config_dict, file):
        with open(file, 'w') as f:
            config = json.dump(config_dict, f, indent=4)
            print('config file is saved on {}'.format(file))

In [3]:
config_dict = {
    'vocab_size' : 30000, # tokenizer.vocab_size,
    'd_model' : 768,
    'max_position_embeddings' : 512,
    'type_vocab_size' : 2,
    'num_labels' : 2,
    'pad_token_id' : 0, # tokenizer.pad_token_id,
    'bos_token_id' : 2, # tokenizer.bos_token_id,
    'eos_token_id' : 3, # tokenizer.eos_token_id,
    'share_embedding' : True,
    'init_std' : 2e-2,
    'layer_norm_eps' : 1e-12, 
    'drop_out_raito' : 0.1,
    'num_enc_layers' : 12,
    'num_dec_layers' : 12,
    'num_att_heads' : 12,
    'feed_forward_dim' : 2048,
    'has_relative_attention_bias' : True, # Only T5
    'relative_attention_num_buckets' : 32, # Only T5
}

config = Config(config_dict)

In [40]:
input_ids = torch.tensor([2,4,5,6,7,8,9,3,0,0]).unsqueeze(0)
token_type_ids = torch.tensor([0]*10).unsqueeze(0)
attention_mask = torch.tensor([1]*8 + [0]*2).unsqueeze(0)
dec_input_ids = torch.tensor([2,4,5,6,3,0,0]).unsqueeze(0)

In [41]:
input_ids.shape, token_type_ids.shape, attention_mask.shape, dec_input_ids.shape

(torch.Size([1, 10]),
 torch.Size([1, 10]),
 torch.Size([1, 10]),
 torch.Size([1, 7]))

In [5]:
!git clone https://github.com/Taeksu-Kim/Various_Transformers.git

fatal: destination path 'Various_Transformers' already exists and is not an empty directory.


In [6]:
cd Various_Transformers/original_transformer

/content/Various_Transformers/original_transformer


In [7]:
from transformer import *

In [8]:
model = Transformer(config)

In [9]:
result = model(input_ids,
               dec_input_ids)

In [10]:
result[0].shape

torch.Size([1, 7, 30000])

In [11]:
import math
import torch
import torch.nn as nn

def get_extended_attention_mask(attention_mask, autoregressive=False):

    dtype = torch.float16

    extended_attention_mask = attention_mask[:, None, None, :]

    if autoregressive is True:

      subsequent_mask = torch.ones_like(extended_attention_mask, device=attention_mask.device).expand(-1, -1, attention_mask.size(1), -1)
      subsequent_mask = subsequent_mask.triu(diagonal=1)
      subsequent_mask = torch.lt(subsequent_mask,1)

      extended_attention_mask = torch.gt((extended_attention_mask+subsequent_mask), 1).int()

    extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
    return extended_attention_mask

class PoswiseFeedForward(nn.Module):
    def __init__(self, config):
        super(PoswiseFeedForward, self).__init__()      

        self.feed_forward = nn.Sequential(nn.Linear(config.d_model, config.feed_forward_dim),
                                          nn.GELU(),
                                          nn.Linear(config.feed_forward_dim, config.d_model),
                                          nn.Dropout(config.drop_out_raito))
    def forward(self, inputs):
        return self.feed_forward(inputs)
    
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.d_model)

        self.LayerNorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.drop_out_raito)

        self.position_ids = torch.arange(config.max_position_embeddings).expand((1, -1))

    def forward(
        self, 
        input_ids=None, 
        token_type_ids=None, 
        position_ids=None,
        ):
        
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        if token_type_ids is None:
            token_type_ids = torch.zeros([batch_size, seq_len], dtype=torch.long, device=device)

        position_ids = self.position_ids[:, :seq_len].to(device)

        word_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = word_embeds + token_type_embeddings + position_embeddings
        
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

class BertModel(nn.Module):
    def __init__(self, config):
      super().__init__()
      self.embedding = BertEmbeddings(config)
      self.encoder = BertEncoder(config, self.embedding)

      self.init_weights()

    def init_weights(self):
        # Initialize weights for each layer
        self.apply(self.init_layer_weights)

    # ref huggingface
    # https://huggingface.co/transformers/v4.9.2/_modules/transformers/models/electra/modeling_electra.html#ElectraPreTrainedModel
    def init_layer_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            module.eps = self.config.layer_norm_eps
    
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                ):

      outputs, self_attn_probs, _ = self.encoder(input_ids=input_ids,
                                                 token_type_ids=token_type_ids,
                                                 attention_mask=attention_mask,
                                                 )
      
      return outputs, self_attn_probs

class BertEncoder(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        self.config = config
        self.embedding = embedding
        self.layers = nn.ModuleList(
            [BertEncoderLayer(config) for _ in range(config.num_enc_layers)]
        )

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                ):

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).int()
        
        self_attention_mask = get_extended_attention_mask(attention_mask, autoregressive=False)

        outputs = self.embedding(input_ids,
                                 token_type_ids=token_type_ids)
        
        self_attn_probs = []
        for i, layer in enumerate(self.layers):
            outputs, self_attn_prob = layer(inputs=outputs,
                                            self_attention_mask=self_attention_mask,
                                            )
            self_attn_probs.append(self_attn_prob)

        return outputs, self_attn_probs, self_attention_mask    
    
class BertEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = BertAttention(config)
        self.attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.feed_forward = PoswiseFeedForward(config)
        self.feed_forward_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

    def forward(
        self, 
        inputs, 
        self_attention_mask,
        ):

        outputs, self_attn_prob = self.self_attention(query=inputs, 
                                                      key=None, 
                                                      value=None, 
                                                      attention_mask=self_attention_mask,
                                                      )
        outputs = self.attention_norm(inputs + outputs)

        inputs = outputs
        outputs = self.feed_forward(inputs)
        outputs = self.feed_forward_norm(inputs + outputs)
        
        return outputs, self_attn_prob

class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.num_att_heads = config.num_att_heads
        assert self.d_model % self.num_att_heads == 0, "d_model({}) % num_att_heads({}) = {}. It should be 0.".format(self.d_model, self.num_att_heads, self.d_model % self.num_att_heads)
        self.d_head = int(self.d_model / self.num_att_heads)
        self.scale = self.d_head ** 0.5
        
        self.query_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)
        self.key_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)
        self.value_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)

        self.attn_dropout = nn.Dropout(config.drop_out_raito)

        self.fc = nn.Linear(self.d_head * self.num_att_heads, self.d_model)
        self.context_dropout = nn.Dropout(config.drop_out_raito)

    def forward(self,
                query,
                key=None,
                value=None,
                attention_mask=None,
                ):

        if key is None and value is None:
            key = value = query

        batch_size = query.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, query_len, d_head]
        key = self.key_proj(key).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, key_len, d_head]
        value = self.value_proj(value).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, value_len, d_head]

        scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale # [bs, num_heads, query_len, key_len]        
        scores = scores + attention_mask
        
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.attn_dropout(attn_prob)

        context = torch.matmul(attn_prob, value) # [bs, num_heads, query_len, d_head]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_att_heads * self.d_head)
        
        context = self.fc(context)
        context = self.context_dropout(context)

        return context, attn_prob

class BertDecoder(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        self.config = config
        self.embedding = embedding
        self.layers = nn.ModuleList(
            [BertDecoderLayer(config) for _ in range(config.num_enc_layers)]
        )
        self.fc = nn.Linear(config.d_model, config.vocab_size)

    def forward(self,
                input_ids,
                attention_mask=None,
                enc_outputs=None,
                enc_attention_mask=None):
      
        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).int()

        self_attention_mask = get_extended_attention_mask(attention_mask, autoregressive=True)

        outputs = self.embedding(input_ids,
                                 token_type_ids=None)

        self_attn_probs, cross_attn_probs = [], []
        for i, layer in enumerate(self.layers):
            outputs, self_attn_prob, cross_attn_prob = layer(inputs=outputs,
                                                             self_attention_mask=self_attention_mask,
                                                             enc_outputs=enc_outputs,
                                                             cross_attention_mask=enc_attention_mask,
                                                             )
            self_attn_probs.append(self_attn_prob)
            cross_attn_probs.append(cross_attn_prob)      

        outputs = self.fc(outputs)        

        return outputs, self_attn_probs, cross_attn_probs

class BertDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = BertAttention(config)
        self.self_attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        self.cross_attention = BertAttention(config)
        self.cross_attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        self.feed_forward = PoswiseFeedForward(config)
        self.feed_forward_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        
    def forward(self,
                inputs,
                self_attention_mask,
                enc_outputs,
                cross_attention_mask,
                ):

        outputs, self_attn_prob = self.self_attention(query=inputs, 
                                                      key=None, 
                                                      value=None, 
                                                      attention_mask=self_attention_mask,
                                                      )
        outputs = self.self_attention_norm(inputs + outputs)

        inputs = outputs
        outputs, cross_attn_prob = self.cross_attention(query=inputs, 
                                                        key=enc_outputs, 
                                                        value=enc_outputs, 
                                                        attention_mask=cross_attention_mask,
                                                        )
        outputs = self.cross_attention_norm(inputs + outputs)

        inputs = outputs
        outputs = self.feed_forward(inputs)
        outputs = self.feed_forward_norm(inputs + outputs)
        
        return outputs, self_attn_prob, cross_attn_prob

class Bert_Encoder_Decoder_Model(nn.Module):
    def __init__(self, config):
      super().__init__()
      self.config=config

      if config.share_embedding is True:
          self.shared_embedding = BertEmbeddings(config)
          self.encoder = BertEncoder(config, self.shared_embedding)
          self.decoder = BertDecoder(config, self.shared_embedding)
      
      else:
          self.encoder_embedding = BertEmbeddings(config)
          self.decoder_embedding = BertEmbeddings(config)
  
          self.encoder = BertEncoder(config, self.encoder_embedding)
          self.decoder = BertDecoder(config, self.decoder_embedding)

      self.init_weights()

    def init_weights(self):
        # Initialize weights for each layer
        self.apply(self.init_layer_weights)

    # ref huggingface
    # https://huggingface.co/transformers/v4.9.2/_modules/transformers/models/electra/modeling_electra.html#ElectraPreTrainedModel
    def init_layer_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            module.eps = self.config.layer_norm_eps

    def forward(self,
                enc_input_ids,
                enc_token_type_ids=None,
                enc_attention_mask=None,
                dec_input_ids=None,
                dec_attention_mask=None,
                ):

        enc_outputs, enc_self_attn_probs, enc_attention_mask = self.encoder(enc_input_ids,
                                                                            enc_token_type_ids,
                                                                            enc_attention_mask,
                                                                            )
        
        dec_outputs, dec_self_attn_probs, dec_cross_attn_probs = self.decoder(input_ids=dec_input_ids,
                                                                              attention_mask=dec_attention_mask,
                                                                              enc_outputs=enc_outputs,
                                                                              enc_attention_mask=enc_attention_mask,
                                                                              )

        return dec_outputs, enc_self_attn_probs, dec_self_attn_probs, dec_cross_attn_probs


In [12]:
cd ../bert_style

/content/Various_Transformers/bert_style


In [13]:
from bert import *

In [14]:
import math
import torch
import torch.nn as nn

def get_extended_attention_mask(attention_mask, autoregressive=False):

    dtype = torch.float16

    extended_attention_mask = attention_mask[:, None, None, :]

    if autoregressive is True:

      subsequent_mask = torch.ones_like(extended_attention_mask, device=attention_mask.device).expand(-1, -1, attention_mask.size(1), -1)
      subsequent_mask = subsequent_mask.triu(diagonal=1)
      subsequent_mask = torch.lt(subsequent_mask,1)

      extended_attention_mask = torch.gt((extended_attention_mask+subsequent_mask), 1).int()

    extended_attention_mask = extended_attention_mask.to(dtype=dtype)  # fp16 compatibility
    extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
    return extended_attention_mask

class PoswiseFeedForward(nn.Module):
    def __init__(self, config):
        super(PoswiseFeedForward, self).__init__()      

        self.feed_forward = nn.Sequential(nn.Linear(config.d_model, config.feed_forward_dim),
                                          nn.GELU(),
                                          nn.Linear(config.feed_forward_dim, config.d_model),
                                          nn.Dropout(config.drop_out_raito))
    def forward(self, inputs):
        return self.feed_forward(inputs)
    
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.d_model)

        self.LayerNorm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.drop_out_raito)

        self.position_ids = torch.arange(config.max_position_embeddings).expand((1, -1))

    def forward(
        self, 
        input_ids=None, 
        token_type_ids=None, 
        position_ids=None,
        ):
        
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        if token_type_ids is None:
            token_type_ids = torch.zeros([batch_size, seq_len], dtype=torch.long, device=device)

        position_ids = self.position_ids[:, :seq_len].to(device)

        word_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = word_embeds + token_type_embeddings + position_embeddings
        
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        
        return embeddings

class BertModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.embedding = BertEmbeddings(config)
        self.encoder = BertEncoder(config, self.embedding)

        self.init_weights()

    def init_weights(self):
        # Initialize weights for each layer
        self.apply(self.init_layer_weights)

    # ref huggingface
    # https://huggingface.co/transformers/v4.9.2/_modules/transformers/models/electra/modeling_electra.html#ElectraPreTrainedModel
    def init_layer_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            module.eps = self.config.layer_norm_eps
    
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                ):

      outputs, self_attn_probs, _ = self.encoder(input_ids=input_ids,
                                                 token_type_ids=token_type_ids,
                                                 attention_mask=attention_mask,
                                                 )
      
      return outputs, self_attn_probs

class BertEncoder(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        self.config = config
        self.embedding = embedding
        self.layers = nn.ModuleList(
            [BertEncoderLayer(config) for _ in range(config.num_enc_layers)]
        )

    def forward(self,
                input_ids,
                token_type_ids,
                attention_mask,
                ):

        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).int()
        
        self_attention_mask = get_extended_attention_mask(attention_mask, autoregressive=False)

        outputs = self.embedding(input_ids,
                                 token_type_ids=token_type_ids)
        
        self_attn_probs = []
        for i, layer in enumerate(self.layers):
            outputs, self_attn_prob = layer(inputs=outputs,
                                            self_attention_mask=self_attention_mask,
                                            )
            self_attn_probs.append(self_attn_prob)

        return outputs, self_attn_probs, self_attention_mask    
    
class BertEncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = BertAttention(config)
        self.attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        self.feed_forward = PoswiseFeedForward(config)
        self.feed_forward_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

    def forward(
        self, 
        inputs, 
        self_attention_mask,
        ):

        outputs, self_attn_prob = self.self_attention(query=inputs, 
                                                      key=None, 
                                                      value=None, 
                                                      attention_mask=self_attention_mask,
                                                      )
        outputs = self.attention_norm(inputs + outputs)

        inputs = outputs
        outputs = self.feed_forward(inputs)
        outputs = self.feed_forward_norm(inputs + outputs)
        
        return outputs, self_attn_prob

class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.num_att_heads = config.num_att_heads
        assert self.d_model % self.num_att_heads == 0, "d_model({}) % num_att_heads({}) = {}. It should be 0.".format(self.d_model, self.num_att_heads, self.d_model % self.num_att_heads)
        self.d_head = int(self.d_model / self.num_att_heads)
        self.scale = self.d_head ** 0.5
        
        self.query_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)
        self.key_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)
        self.value_proj = nn.Linear(self.d_model, self.num_att_heads * self.d_head)

        self.attn_dropout = nn.Dropout(config.drop_out_raito)

        self.fc = nn.Linear(self.d_head * self.num_att_heads, self.d_model)
        self.context_dropout = nn.Dropout(config.drop_out_raito)

    def forward(self,
                query,
                key=None,
                value=None,
                attention_mask=None,
                ):

        if key is None and value is None:
            key = value = query

        batch_size = query.size(0)

        query = self.query_proj(query).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, query_len, d_head]
        key = self.key_proj(key).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, key_len, d_head]
        value = self.value_proj(value).view(batch_size, -1, self.num_att_heads, self.d_head).transpose(1,2) # [bs, num_heads, value_len, d_head]

        scores = torch.matmul(query, key.transpose(-2, -1)) / self.scale # [bs, num_heads, query_len, key_len]        
        scores = scores + attention_mask
        
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.attn_dropout(attn_prob)

        context = torch.matmul(attn_prob, value) # [bs, num_heads, query_len, d_head]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.num_att_heads * self.d_head)
        
        context = self.fc(context)
        context = self.context_dropout(context)

        return context, attn_prob

class BertDecoder(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        self.config = config
        self.embedding = embedding
        self.layers = nn.ModuleList(
            [BertDecoderLayer(config) for _ in range(config.num_enc_layers)]
        )
        self.fc = nn.Linear(config.d_model, config.vocab_size)

    def forward(self,
                input_ids,
                attention_mask=None,
                enc_outputs=None,
                enc_attention_mask=None):
      
        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).int()

        self_attention_mask = get_extended_attention_mask(attention_mask, autoregressive=True)

        outputs = self.embedding(input_ids,
                                 token_type_ids=None)

        self_attn_probs, cross_attn_probs = [], []
        for i, layer in enumerate(self.layers):
            outputs, self_attn_prob, cross_attn_prob = layer(inputs=outputs,
                                                             self_attention_mask=self_attention_mask,
                                                             enc_outputs=enc_outputs,
                                                             cross_attention_mask=enc_attention_mask,
                                                             )
            self_attn_probs.append(self_attn_prob)
            cross_attn_probs.append(cross_attn_prob)      

        outputs = self.fc(outputs)        

        return outputs, self_attn_probs, cross_attn_probs

class BertDecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = BertAttention(config)
        self.self_attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        self.cross_attention = BertAttention(config)
        self.cross_attention_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)

        self.feed_forward = PoswiseFeedForward(config)
        self.feed_forward_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
        
    def forward(self,
                inputs,
                self_attention_mask,
                enc_outputs,
                cross_attention_mask,
                ):

        outputs, self_attn_prob = self.self_attention(query=inputs, 
                                                      key=None, 
                                                      value=None, 
                                                      attention_mask=self_attention_mask,
                                                      )
        outputs = self.self_attention_norm(inputs + outputs)

        inputs = outputs
        outputs, cross_attn_prob = self.cross_attention(query=inputs, 
                                                        key=enc_outputs, 
                                                        value=enc_outputs, 
                                                        attention_mask=cross_attention_mask,
                                                        )
        outputs = self.cross_attention_norm(inputs + outputs)

        inputs = outputs
        outputs = self.feed_forward(inputs)
        outputs = self.feed_forward_norm(inputs + outputs)
        
        return outputs, self_attn_prob, cross_attn_prob

class Bert_Encoder_Decoder_Model(nn.Module):
    def __init__(self, config):
      super().__init__()
      self.config=config

      if config.share_embedding is True:
          self.shared_embedding = BertEmbeddings(config)
          self.encoder = BertEncoder(config, self.shared_embedding)
          self.decoder = BertDecoder(config, self.shared_embedding)
      
      else:
          self.encoder_embedding = BertEmbeddings(config)
          self.decoder_embedding = BertEmbeddings(config)
  
          self.encoder = BertEncoder(config, self.encoder_embedding)
          self.decoder = BertDecoder(config, self.decoder_embedding)

      self.init_weights()

    def init_weights(self):
        # Initialize weights for each layer
        self.apply(self.init_layer_weights)

    # ref huggingface
    # https://huggingface.co/transformers/v4.9.2/_modules/transformers/models/electra/modeling_electra.html#ElectraPreTrainedModel
    def init_layer_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            module.eps = self.config.layer_norm_eps

    def forward(self,
                enc_input_ids,
                enc_token_type_ids=None,
                enc_attention_mask=None,
                dec_input_ids=None,
                dec_attention_mask=None,
                ):

        enc_outputs, enc_self_attn_probs, enc_attention_mask = self.encoder(enc_input_ids,
                                                                            enc_token_type_ids,
                                                                            enc_attention_mask,
                                                                            )
        
        dec_outputs, dec_self_attn_probs, dec_cross_attn_probs = self.decoder(input_ids=dec_input_ids,
                                                                              attention_mask=dec_attention_mask,
                                                                              enc_outputs=enc_outputs,
                                                                              enc_attention_mask=enc_attention_mask,
                                                                              )

        return dec_outputs, enc_self_attn_probs, dec_self_attn_probs, dec_cross_attn_probs

In [15]:
model = BertModel(config)

In [16]:
result = model(input_ids,
               token_type_ids,
               attention_mask,
               )

In [17]:
result[0].shape

torch.Size([1, 10, 768])

In [18]:
model = Bert_Encoder_Decoder_Model(config)

In [19]:
result = model(input_ids,
               token_type_ids,
               attention_mask,
               dec_input_ids,
               )

In [20]:
result[0].shape

torch.Size([1, 7, 30000])

In [21]:
cd ../gpt1_style

/content/Various_Transformers/gpt1_style


In [22]:
from gpt1 import *

In [23]:
class GPT1Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embedding = GPT1Embeddings(config)
        self.decoder = GPT1Decoder(config, self.embedding)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                ):
      
        outputs, self_attn_probs = self.decoder(input_ids,
                                                token_type_ids,
                                                attention_mask,
                                                )

        return outputs, self_attn_probs

In [24]:
class GPT1Decoder(nn.Module):
    def __init__(self, config, embedding):
        super().__init__()
        self.config = config
        self.embedding = embedding

        self.layers = nn.ModuleList(
            [GPT1DecoderLayer(config) for i in range(config.num_dec_layers)]
        )


        self.fc = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                ):
      
        if attention_mask is None:
            attention_mask = input_ids.ne(self.config.pad_token_id).int()

        self_attention_mask = get_extended_attention_mask(attention_mask, autoregressive=True)

        outputs = self.embedding(input_ids,
                                 token_type_ids=token_type_ids)

        self_attn_probs = []
        for i, layer in enumerate(self.layers):
            outputs, self_attn_prob = layer(inputs=outputs,
                                            self_attention_mask=self_attention_mask, 
                                            )
            self_attn_probs.append(self_attn_prob)       

        outputs = self.fc(outputs)

        return outputs, self_attn_probs

In [25]:
class GPT1Embeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.d_model, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.d_model)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.d_model)

        self.dropout = nn.Dropout(config.drop_out_raito)

        self.position_ids = torch.arange(config.max_position_embeddings).expand((1, -1))

    def forward(
        self, 
        input_ids=None, 
        token_type_ids=None, 
        position_ids=None,
        ):
        
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        if token_type_ids is None:
            token_type_ids = torch.zeros([batch_size, seq_len], dtype=torch.long, device=device)

        position_ids = self.position_ids[:, :seq_len].to(device)

        inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)
        
        embeddings = inputs_embeds + token_type_embeddings + position_embeddings
        
        embeddings = self.dropout(embeddings)
        
        return embeddings

In [26]:
model = GPT1Model(config)

In [27]:
result = model(dec_input_ids)

In [28]:
result[0].shape

torch.Size([1, 7, 30000])

In [29]:
cd ../gpt2_style

/content/Various_Transformers/gpt2_style


In [30]:
from gpt2 import *

In [31]:
model = GPT2Model(config)

In [32]:
result = model(dec_input_ids)

In [33]:
result[0].shape

torch.Size([1, 7, 30000])

In [34]:
cd ../t5_style

/content/Various_Transformers/t5_style


In [35]:
from t5 import *

In [36]:
model = T5_Model(config)

In [37]:
result = model(input_ids,
               attention_mask,
               dec_input_ids,
               )

In [38]:
result[0].shape

torch.Size([1, 7, 30000])