## Colab 환경 구축


### 활용 라이브러리 (고정)

*   [torch==1.9.0](https://pytorch.org/)
*   [transformers](https://pypi.org/project/transformers/)

In [1]:
!pip3 install torch==1.9.0 torchvision torchaudio
!pip3 install transformers

Collecting torchaudio
  Downloading torchaudio-0.10.0-cp37-cp37m-manylinux1_x86_64.whl (2.9 MB)
[K     |████████████████████████████████| 2.9 MB 5.1 MB/s 
  Downloading torchaudio-0.9.1-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 41.1 MB/s 
[?25h  Downloading torchaudio-0.9.0-cp37-cp37m-manylinux1_x86_64.whl (1.9 MB)
[K     |████████████████████████████████| 1.9 MB 38.5 MB/s 
[?25hInstalling collected packages: torchaudio
Successfully installed torchaudio-0.9.0
Collecting transformers
  Downloading transformers-4.12.3-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 MB 5.1 MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
  Downloading tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 42.4 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.1.1-py3-none-any.whl (59 k

### 필요 라이브러리 import

In [2]:
import math
import torch
from torch import nn
from transformers import (BertTokenizer, BertConfig, 
                          apply_chunking_to_forward, set_seed
                          )
from transformers.modeling_utils import find_pruneable_heads_and_indices, prune_linear_layer
from transformers.activations import gelu
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions


1. Reproducibility- 재구현을 위한 random_seed 42로 설정

```
set_seed(42)
```


2. BERT [tokenizer](https://huggingface.co/transformers/model_doc/bert.html#berttokenizer), [configuration(config)](https://huggingface.co/transformers/model_doc/bert.html#bertconfig) 설정

```
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_configuraiton = BertConfig.from_pretrained('bert-base-cased')
```


3. input sequence 토크나이징

```
input_texts = ['I love cats!', 'He hates pineapple pizza.']

input_sequences = tokenizer(text=input_texts, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')
```



In [3]:
# Set seed for reproducibility
set_seed(42)

# GELU Activation function.
ACT2FN = {"gelu": gelu}

# Define BertLayerNorm.
BertLayerNorm = nn.LayerNorm

# Create BertTokenizer, Configuration
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_configuraiton = BertConfig.from_pretrained('bert-base-cased')

# Create input sequence using tokenizer
input_texts = ['I love cats!', 'He hates pineapple pizza.']
labels = [1, 0]
input_sequences = tokenizer(text=input_texts, add_special_tokens=True, padding=True, truncation=True, return_tensors='pt')

# Since input sequences is a dictionary we can also add to labels to it
# want to make sure all values at tensors
input_sequences.update({'labels':torch.tensor(labels)})
# print(input_sequences)

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

### Input Sequence에 대해 Embedding 진행 (Huggingface 코드)
- 변경 필요 없음, 그대로 사용

1. BertEmbedding에 대한 configuration 생성
2. bert_embedding block에 대해 forward 수행
3. Embedding Output Shape : [batch_size, seq_len, hidden_size(768)]

```
# Create Bert embedding layer
bert_embeddings_block = BertEmbeddings(bert_configuraiton)

# Perform a forward pass
embedding_output = bert_embeddings_block.forward(input_ids=input_sequences['input_ids'], token_type_ids=input_sequences['token_type_ids'])
```



In [4]:
class BertEmbeddings(nn.Module):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")

    def forward(
        self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
    ):
        # print('============== BertEmbeddings ==============')
        if input_ids is not None:
            input_shape = input_ids.size()
        else:
            input_shape = inputs_embeds.size()[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]       
        #print('Created Tokens Positions IDs: ', position_ids) # ADDED
        

        if token_type_ids is None:
            token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)

        if inputs_embeds is None:
            inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)

        # ADDED
        # print('Tokens IDs: ', input_ids.shape)
        # print('Tokens Type IDs: ', token_type_ids.shape)
        # print('Word Embeddings: ', inputs_embeds.shape)

        embeddings = inputs_embeds + token_type_embeddings
        if self.position_embedding_type == "absolute":
            position_embeddings = self.position_embeddings(position_ids)
            # print('Position Embeddings: ', position_embeddings.shape) # ADDED

            embeddings += position_embeddings

        # ADDED
        # print('Token Types Embeddings: ', token_type_embeddings.shape)
        # print('Sum Up All Embeddings: ', embeddings.shape)

        embeddings = self.LayerNorm(embeddings)
        # print('Embeddings Layer Nromalization: ', embeddings.shape) # ADDED

        embeddings = self.dropout(embeddings)
        # print('Embeddings Dropout Layer: ', embeddings.shape) # ADDED
        
        return embeddings

In [5]:
def get_extended_attention_mask(attention_mask, input_shape, device):
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.
        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """
        # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
        # ourselves in which case we just need to make it broadcastable to all heads.
        if attention_mask.dim() == 3:
            extended_attention_mask = attention_mask[:, None, :, :]
        elif attention_mask.dim() == 2:
            # Provided a padding mask of dimensions [batch_size, seq_length]
            # - if the model is a decoder, apply a causal mask in addition to the padding mask
            # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
            
            extended_attention_mask = attention_mask[:, None, None, :]
        else:
            raise ValueError(
                f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
            )

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        return extended_attention_mask.to(device)

### BertSelfAttention (Huggingface 코드)

- Self-Attention 수행 

- 결과로 나오는 output shape : [batch_size, seq_len, 768] 

In [6]:
class BertSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # ADDED
        # print('============== BertSelfAttention ==============')
        # print('Attention Head Size: ', self.attention_head_size)
        # print('Combined Attentions Head Size: ', self.all_head_size)

        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            self.max_position_embeddings = config.max_position_embeddings
            self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)

        self.is_decoder = config.is_decoder

    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        
        # print('Hidden States: ', hidden_states.shape) # ADDED

        mixed_query_layer = self.query(hidden_states)

        # If this is instantiated as a cross-attention module, the keys
        # and values come from an encoder; the attention mask needs to be
        # such that the encoder's padding tokens are not attended to.
        is_cross_attention = encoder_hidden_states is not None

        if is_cross_attention and past_key_value is not None:
            # ADDED
            # print('Query Linear Layer: ', mixed_query_layer.shape)
            # print('Key Linear Layer: ', past_key_value[0].shape)
            # print('Value Linear Layer: ', past_key_value[1].shape)

            # reuse k,v, cross_attentions
            key_layer = past_key_value[0]
            value_layer = past_key_value[1]
            attention_mask = encoder_attention_mask
        
        elif is_cross_attention:
            # ADDED
            # print('Query Linear Layer: ', mixed_query_layer.shape)
            # print('Key Linear Layer: ', self.key(encoder_hidden_states).shape)
            # print('Value Linear Layer: ', self.value(encoder_hidden_states).shape)

            key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
            value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
            attention_mask = encoder_attention_mask
        
        elif past_key_value is not None:
            # ADDED
            # print('Query Linear Layer: ', mixed_query_layer.shape)
            # print('Key Linear Layer: ', self.key(hidden_states).shape)
            # print('Value Linear Layer: ', self.value(hidden_states).shape)

            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))
            key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
            value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
        
        else:
            # ADDED
            # print('Query Linear Layer: ', mixed_query_layer.shape)
            # print('Key Linear Layer: ', self.key(hidden_states).shape)
            # print('Value Linear Layer: ', self.value(hidden_states).shape)

            key_layer = self.transpose_for_scores(self.key(hidden_states))
            value_layer = self.transpose_for_scores(self.value(hidden_states))

        query_layer = self.transpose_for_scores(mixed_query_layer)

        # ADDED
        # print('Query: ', query_layer.shape)
        # print('Key: ', key_layer.shape)
        # print('Value: ', value_layer.shape)

        if self.is_decoder:
            # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
            # Further calls to cross_attention layer can then reuse all cross-attention
            # key/value_states (first "if" case)
            # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
            # all previous decoder key/value_states. Further calls to uni-directional self-attention
            # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
            # if encoder bi-directional self-attention `past_key_value` is always `None`
            past_key_value = (key_layer, value_layer)

        # ADDED
        # print('Key Transposed: ', key_layer.transpose(-1, -2).shape)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # ADDED
        # print('Attention Scores: ', attention_scores.shape)

        if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
            seq_length = hidden_states.size()[1]
            position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
            position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
            distance = position_ids_l - position_ids_r
            positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
            positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

            if self.position_embedding_type == "relative_key":
                relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
                relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
                attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

        attention_scores = attention_scores / math.sqrt(self.attention_head_size) # root 계산
        # print('Attention Scores Divided by Scalar: ', attention_scores.shape) # ADDED

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = nn.Softmax(dim=-1)(attention_scores)
        # print('Attention Probabilities Softmax Layer: ', attention_probs.shape) # ADDED

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs)
        # print('Attention Probabilities Dropout Layer: ', attention_probs.shape) # ADDED

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = torch.matmul(attention_probs, value_layer)
        # print('Context: ', context_layer.shape) # ADDED

        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # print('Context Permute: ', context_layer.shape) # ADDED

        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        # print('Context Reshaped: ', context_layer.shape) # ADDED
        
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        if self.is_decoder:
            outputs = outputs + (past_key_value,)
        return outputs

### BertSelfOutput (Huggingface 코드)

- Add&Norm 수행 

- 결과로 나오는 output shape : [batch_size, seq_len, 768] 

In [7]:
class BertSelfOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        # print('Hidden States: ', hidden_states.shape)

        hidden_states = self.dense(hidden_states)
        # print('Hidden States Linear Layer: ', hidden_states.shape)

        hidden_states = self.dropout(hidden_states)
        # print('Hidden States Dropout Layer: ', hidden_states.shape)

        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # print('Hidden States Normalization Layer: ', hidden_states.shape)

        return hidden_states

In [8]:
class BertAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self = BertSelfAttention(config)
        self.output = BertSelfOutput(config)
        self.pruned_heads = set()

    def prune_heads(self, heads):
        if len(heads) == 0:
            return
        heads, index = find_pruneable_heads_and_indices(
            heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
        )

        # Prune linear layers
        self.self.query = prune_linear_layer(self.self.query, index)
        self.self.key = prune_linear_layer(self.self.key, index)
        self.self.value = prune_linear_layer(self.self.value, index)
        self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

        # Update hyper params and store pruned heads
        self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
        self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
        self.pruned_heads = self.pruned_heads.union(heads)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        self_outputs = self.self(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            past_key_value,
            output_attentions,
        )
        attention_output = self.output(self_outputs[0], hidden_states)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them
        return outputs

### BertIntermediate (Huggingface 코드)

- Position-wise Feed-Forward Networks

- 결과로 나오는 output shape : [batch_size, seq_len, 4*768] 

In [9]:
class BertIntermediate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn = config.hidden_act

    def forward(self, hidden_states):
        # print('Hidden States: ', hidden_states.shape)

        hidden_states = self.dense(hidden_states)
        # print('Hidden States Linear Layer: ', hidden_states.shape)

        hidden_states = self.intermediate_act_fn(hidden_states)
        # print('Hidden States Gelu Activation Function: ', hidden_states.shape)

        return hidden_states

### BertOutput (Huggingface 코드)

- Add&Norm 수행 

- 결과로 나오는 output shape : [batch_size, seq_len, 768] 

In [10]:
class BertOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
        self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, hidden_states, input_tensor):
        # print('Hidden States: ', hidden_states.shape)

        hidden_states = self.dense(hidden_states)
        # print('Hidden States Linear Layer: ', hidden_states.shape)

        hidden_states = self.dropout(hidden_states)
        # print('Hidden States Dropout Layer: ', hidden_states.shape)

        hidden_states = self.LayerNorm(hidden_states + input_tensor)
        # print('Hidden States Layer Normalization: ', hidden_states.shape)

        return hidden_states

In [11]:
class BertLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.chunk_size_feed_forward = config.chunk_size_feed_forward
        self.seq_len_dim = 1
        self.attention = BertAttention(config)
        self.is_decoder = config.is_decoder
        self.add_cross_attention = config.add_cross_attention
        if self.add_cross_attention:
            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
            self.crossattention = BertAttention(config)
        self.intermediate = BertIntermediate(config)
        self.output = BertOutput(config)

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_value=None,
        output_attentions=False,
    ):
        # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
        self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
        self_attention_outputs = self.attention(
            hidden_states,
            attention_mask,
            head_mask,
            output_attentions=output_attentions,
            past_key_value=self_attn_past_key_value,
        )
        attention_output = self_attention_outputs[0]

        # if decoder, the last output is tuple of self-attn cache
        if self.is_decoder:
            outputs = self_attention_outputs[1:-1]
            present_key_value = self_attention_outputs[-1]
        else:
            outputs = self_attention_outputs[1:]  # add self attentions if we output attention weights

        cross_attn_present_key_value = None
        if self.is_decoder and encoder_hidden_states is not None:
            assert hasattr(
                self, "crossattention"
            ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"

            # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
            cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
            cross_attention_outputs = self.crossattention(
                attention_output,
                attention_mask,
                head_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                cross_attn_past_key_value,
                output_attentions,
            )
            attention_output = cross_attention_outputs[0]
            outputs = outputs + cross_attention_outputs[1:-1]  # add cross attentions if we output attention weights

            # add cross-attn cache to positions 3,4 of present_key_value tuple
            cross_attn_present_key_value = cross_attention_outputs[-1]
            present_key_value = present_key_value + cross_attn_present_key_value

        layer_output = apply_chunking_to_forward(
            self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        )
        outputs = (layer_output,) + outputs

        # if decoder, return the attn key/values as the last output
        if self.is_decoder:
            outputs = outputs + (present_key_value,)

        return outputs

    def feed_forward_chunk(self, attention_output):
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

### BertEncoder (Huggingface 코드)

- num_hidden_layers 만큼 BertLayer 반복 (BERT의 경우 12개 layer)

In [12]:
class BertEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) # config.num_hidden_layers 만큼 반복

    def forward(
        self,
        hidden_states,
        attention_mask=None,
        head_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

        next_decoder_cache = () if use_cache else None
        for i, layer_module in enumerate(self.layer):

            # ADDED
            print('----------------- BERT LAYER %d -----------------'%(i+1))

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_head_mask = head_mask[i] if head_mask is not None else None
            past_key_value = past_key_values[i] if past_key_values is not None else None
            if getattr(self.config, "gradient_checkpointing", False):

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs, past_key_value, output_attentions)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(layer_module),
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                layer_outputs = layer_module(
                    hidden_states,
                    attention_mask,
                    layer_head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    past_key_value,
                    output_attentions,
                )

            hidden_states = layer_outputs[0]
            if use_cache:
                next_decoder_cache += (layer_outputs[-1],)
            if output_attentions:
                all_self_attentions = all_self_attentions + (layer_outputs[1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    next_decoder_cache,
                    all_hidden_states,
                    all_self_attentions,
                    all_cross_attentions,
                ]
                if v is not None
            )
        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=next_decoder_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [13]:
# -- 1. Bert Embedding (huggingface) -- #
bert_embeddings_block = BertEmbeddings(bert_configuraiton)

# -- 2. Bert Encoder (huggingface) -- #
bert_encoder_block = BertEncoder(bert_configuraiton)
bert_encoder_block.eval()

BertEncoder(
  (layer): ModuleList(
    (0): BertLayer(
      (attention): BertAttention(
        (self): BertSelfAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=True)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (output): BertSelfOutput(
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (intermediate): BertIntermediate(
        (dense): Linear(in_features=768, out_features=3072, bias=True)
      )
      (output): BertOutput(
        (dense): Linear(in_features=3072, out_features=768, bias=True)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      

### Custom Transformer Encoder

- Encoder
  - Encoder Layer * 12
    - MultiHeadAttention : Self-Attention
    - FeedForward : Huggingface의 Bert Intermediate와 같은 기능, feedforward & GELU 수행

In [17]:
import torch
from torch import nn
from transformers.activations import gelu
import math

""" 
BERT base config
hidden_size = 768
num_attention_heads = 12
num_hidden_layers = 12
hidden_dropout_prob = 0.1
hidden_act = gelu
"""

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()      
        assert config.hidden_size % config.num_attention_heads == 0
        
        self.hidden_size = config.hidden_size # 768
        self.num_attention_heads = config.num_attention_heads # 12
        self.head_dim = config.hidden_size // config.num_attention_heads 
        
        self.query = nn.Linear(self.hidden_size, self.hidden_size)
        self.key = nn.Linear(self.hidden_size, self.hidden_size)
        self.value = nn.Linear(self.hidden_size, self.hidden_size)

        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
        
    
    def forward(self, hidden_states, attention_mask=None):
      
        print('Hidden States: ', hidden_states.shape) # ADDED

        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        print('Q.size',Q.size())
        print('K.size',K.size())
        print('V.size',V.size())
       
        batch_size = hidden_states.shape[0]
               
        Q = Q.view(batch_size, -1, self.num_attention_heads, self.head_dim).permute(0,2,1,3)
        K = K.view(batch_size, -1, self.num_attention_heads, self.head_dim).permute(0,2,1,3)
        V = V.view(batch_size, -1, self.num_attention_heads, self.head_dim).permute(0,2,1,3)
        print('Q.size',Q.size())
        print('K.size',K.size())
        print('V.size',V.size())
        
        d_k = self.head_dim # d_k
        print('dk',d_k)
        print('transpose k', K.transpose(-2,-1).size())
        attention_score = torch.matmul(Q, K.transpose(-1,-2)) # Q x K^T
        attention_score = attention_score / math.sqrt(d_k) 
        print('attention score: ', attention_score.size())
        
        if attention_mask is not None:
          attention_score = attention_score + attention_mask
        
        attention = nn.functional.softmax(attention_score, dim=-1) 
        print('softmax attention score: ', attention.size())
        
        attention = self.dropout(attention)
        
        output = torch.matmul(attention,V) 
        print('score*v',output.size())

        output = output.permute(0, 2, 1, 3) 
        print('permute output',output.size())

        output = output.reshape(2,9,768)
        print('reshape output: ', output.size())

        return output
        

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        
        self.linear = nn.Linear(config.hidden_size, config.intermediate_size) # 768 -> 3072

        self.conv1 = nn.Conv1d(config.hidden_size,config.intermediate_size,1)
        self.conv2 = nn.Conv2d(config.intermediate_size,config.hidden_size,1)
        self.activate = nn.functional.gelu
        self.dropout = nn.Dropout(0.1)
                
    def forward(self, hidden_states):
        # hidden_states = [batch size, seq len, hid dim]
        
        # linear -> gelu
        ### Custom Code 작성 ###      
        print('hidden_states(input)',hidden_states.size())
        #output = self.conv1(hidden_states.transpose(1,2))
        output = self.linear(hidden_states)
        print('linear output',output.size())
        output = self.activate(output)
        print('activate output',output.size())
        #output = self.conv2(output).transpse(1,2)
        output = self.dropout(output)
        print('dropout output',output.size())
        
        return output

class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.self_attention = MultiHeadAttention(config)
        self.feedforward = FeedForward(config)  

        self.linear_1 = nn.Linear(config.hidden_size, config.hidden_size)
        self.linear_2= nn.Linear(config.intermediate_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        
    def forward(self, hidden_states, attention_mask=None):
        # hidden_states = [batch_size, seq_len, hidden_size]
                
        # 1. multi-head attention
        attention_output = self.self_attention(hidden_states, attention_mask)

        # 2. add & norm : linear -> dropout -> residual connection and layer norm
        ### Custom Code 작성 ### 
        attention_output = self.linear_1(attention_output)
        attention_output = self.dropout(attention_output)
        attention_output = self.layer_norm(hidden_states + attention_output)
        print('attention_output',attention_output.size())

        # 3. feedforward
        feedforward_output = self.feedforward(attention_output)

        # 4. add & norm
        ### Custom Code 작성 ###
        feedforward_output = self.linear_2(feedforward_output)
        feedforward_output = self.dropout(feedforward_output)
        feedforward_output = self.layer_norm(feedforward_output + attention_output) 
        print('feedforward_output',feedforward_output.size())

        return feedforward_output

class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.config = config
        self.layer = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)]) # layer 만큼 반복하여 생성

    def forward(self, hidden_states, attention_mask=None):
        
        for layer in self.layer:
            hidden_states = layer(hidden_states, attention_mask)
            
        return hidden_states

- Encoder 선언
- Huggingface BertEncoder 파라미터를 custom encoder로 가져옴 (parameter copy)
- Randomness 최소화
  - eval() 모드
  - with torch.no_grad() 

- 두 개의 Encoder로부터 Forward 결과로 나오는 output 값 비교

In [18]:
# -- custom transformer encoder -- #
custom_encoder = Encoder(bert_configuraiton)

# -- Parameter Copying -- #
# -- huggingface transformers의 parameter를 custom encoder로 복사 -- #
for layer_num, enc_layer in enumerate(bert_encoder_block.layer):
    # <<< to MultiHeadAttention (wq, wk, wv) >>>
    custom_encoder.layer[layer_num].self_attention.query.load_state_dict(enc_layer.attention.self.query.state_dict()) # wq
    custom_encoder.layer[layer_num].self_attention.key.load_state_dict(enc_layer.attention.self.key.state_dict()) # wk
    custom_encoder.layer[layer_num].self_attention.value.load_state_dict(enc_layer.attention.self.value.state_dict()) # wv

    # <<< to position-wise feedforward (feed_forward_1, feed_forward_2) >>>
    custom_encoder.layer[layer_num].linear_1.load_state_dict(enc_layer.attention.output.dense.state_dict()) # feed_forward_1
    custom_encoder.layer[layer_num].feedforward.linear.load_state_dict(enc_layer.intermediate.dense.state_dict()) # feed_forward_2
    custom_encoder.layer[layer_num].linear_2.load_state_dict(enc_layer.output.dense.state_dict()) # feed_forward_3

# eval mode 설정
custom_encoder.eval()

Encoder(
  (layer): ModuleList(
    (0): EncoderLayer(
      (self_attention): MultiHeadAttention(
        (query): Linear(in_features=768, out_features=768, bias=True)
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (feedforward): FeedForward(
        (linear): Linear(in_features=768, out_features=3072, bias=True)
        (conv1): Conv1d(768, 3072, kernel_size=(1,), stride=(1,))
        (conv2): Conv2d(3072, 768, kernel_size=(1, 1), stride=(1, 1))
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (linear_1): Linear(in_features=768, out_features=768, bias=True)
      (linear_2): Linear(in_features=3072, out_features=768, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    )
    (1): EncoderLayer(
      (self_attention): MultiHeadAttention(


In [19]:
with torch.no_grad():
    input_shape= input_sequences.input_ids.shape
    attention_mask = get_extended_attention_mask(input_sequences.attention_mask, input_shape, input_sequences.input_ids.device)

    # [Huggingface] Perform a forward pass
    embedding_output = bert_embeddings_block.forward(input_ids=input_sequences['input_ids'], token_type_ids=input_sequences['token_type_ids'])
    encoder_embedding = bert_encoder_block.forward(hidden_states=embedding_output, attention_mask=attention_mask)

    # [Custom] Perform a forward pass
    layer_output_custom = custom_encoder.forward(hidden_states=embedding_output, attention_mask=attention_mask)

print('[Huggingface] Output : ', encoder_embedding)
print('[Custom] Output : ', layer_output_custom)

----------------- BERT LAYER 1 -----------------
----------------- BERT LAYER 2 -----------------
----------------- BERT LAYER 3 -----------------
----------------- BERT LAYER 4 -----------------
----------------- BERT LAYER 5 -----------------
----------------- BERT LAYER 6 -----------------
----------------- BERT LAYER 7 -----------------
----------------- BERT LAYER 8 -----------------
----------------- BERT LAYER 9 -----------------
----------------- BERT LAYER 10 -----------------
----------------- BERT LAYER 11 -----------------
----------------- BERT LAYER 12 -----------------
Hidden States:  torch.Size([2, 9, 768])
Q.size torch.Size([2, 9, 768])
K.size torch.Size([2, 9, 768])
V.size torch.Size([2, 9, 768])
Q.size torch.Size([2, 12, 9, 64])
K.size torch.Size([2, 12, 9, 64])
V.size torch.Size([2, 12, 9, 64])
dk 64
transpose k torch.Size([2, 12, 64, 9])
attention score:  torch.Size([2, 12, 9, 9])
softmax attention score:  torch.Size([2, 12, 9, 9])
score*v torch.Size([2, 12, 9, 64]

In [20]:
print(torch.eq(encoder_embedding[0], layer_output_custom))

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])
