## 

In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
torch.__version__

'1.4.0'

In [3]:
class Encoder(nn.Module):
    def __init__(self, input_dim, 
                       embed_dim, 
                       enc_hidden_dim, 
                       dec_hidden_dim, 
                       dropout):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim,
                           enc_hidden_dim,
                           bidirectional=True)
        self.fc = nn.Linear(enc_hidden_dim * 2, dec_hidden_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, source, source_len):
        '''
        source = [batch size, source length]
        '''
        embedded = self.dropout(self.embedding(source))
        # embedded  = [batch size, source length, embed dim]

        packed_embedded = pack_padded_sequence(embedded, source_len)
        packed_output, (hidden, _) = self.lstm(packed_embedded)      
        output, _ = pad_packed_sequence(packed_output)
        # output    = [batch size, source length, enc hidden dim * 2]
        # hidden    = [batch size, num layers * 2, enc hidden dim]
        # [forward_1, backward_1, forward_2, backward_2, ..., forward_n, backward_n]
        
        hidden = torch.tanh(self.fc(torch.cat((hidden[:, -2, :], hidden[:, -1, :]), dim=1)))
        return output, hidden


class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, 
                       dec_hidden_dim):
        super(Attention, self).__init__()
        self.attention = nn.Linear(dec_hidden_dim + (enc_hidden_dim * 2), dec_hidden_dim)
        self.weight = self.attention.weight
        self.bias = self.attention.bias
        self.v = nn.Parameter(torch.rand(dec_hidden_dim))

    def forward(self, hidden, encoder_output, mask):
        '''
        디코더의 이전 'hidden state'와 인코더의 output을 입력 값으로 받아
        hidden state가 인코더의 어느 부분을 참조할지를 계산
        
        hidden         = [batch size, dec hidden dim]
        encoder_output = [batch size, source length, enc hidden dim * 2]
        mask           = [batch size, source length]
        '''
        hidden = hidden.unsqueeze(1)  
        batch_size, src_len = encoder_output.size()
        # hidden         = [batch size, 1, dec hidden dim]

        # 인코더의 output과 길이를 맞추기 위해 hidden을 src_len 만큼 반복
        hidden = hidden.repeat(1, src_len, 1)
        # hidden         = [batch size, source length, dec hidden dim]

        # 디코더의 hidden과 인코더의 output 내 토큰들이 얼마나 매치되는지 계산
        energy = torch.tanh(self.attention(torch.cat((hidden, encoder_output), dim=2)))
        # energy         = [batch size, source length, dec hidden dim]

        energy = energy.permute(0, 2, 1)
        # energy         = [batch size, dec hidden dim, source length]

        # 이제 energy를 사용해 어텐션 맵을 만들기 위해 [batch size, source length]로 줄여주어야 함
        # 이는 [batch size, 1, dec hidden dim] 사이즈의 텐서 "v"를 활용해 달성 가능
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        # v = [batch size, 1, dec hidden dim]

        # 이제 dec hidden dim 차원 내 원소들에 가중합을 적용
        # cf. 이때, "v"는 학습 가능한 가중치이기 때문에 우리는 이를 'Parameterized Attention' 이라 부름

        # bmm is a batch matrix-matrix product: [batch size, a, b] * [batch size, b, c] = [batch size, a, c]
        attention = torch.bmm(v, energy).squeeze(1)
        # attention = [batch size, source length]

        # 마스크를 이용해 패드 토큰에 어텐션을 주지 않도록 설정
        attention = attention.masked_fill(mask == 1, -1e10)

        return F.softmax(attention, dim=1)


class Decoder(nn.Module):
    def __init__(self, input_dim, 
                       embed_dim, 
                       enc_hidden_dim, 
                       dec_hidden_dim,
                       output_dim,
                       dropout,
                       attention
                ):
        super(Decoder, self).__init__()
        self.attention = attention
        self.embedding = nn.Embedding(output_dim, embed_dim)
        self.lstm = nn.LSTM(embed_dim + (enc_hidden_dim * 2), dec_hidden_dim)
        self.fc = nn.Linear(dec_hidden_dim + embed_dim + (enc_hidden_dim * 2), output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, target, hidden, encoder_output, mask):
        '''
        target         = [batch size]
        hidden         = [batch size, dec hidden dim]
        encoder_output = [batch size, batch size, enc hidden dim * 2]
        mask           = [batch size, source length]
        '''
        target = target.unsqueeze(0)  
        # target = [batch size, 1]
        
        embedded = self.dropout(self.embedding(target))
        # embedded = [batch size, 1, embed dim]

        # 이전 hidden state와 인코더 output을 활용해 어텐션 벡터 계산
        attention = self.attention(hidden, encoder_output, mask).unsqueeze(1)
        # attention = [batch size, 1, source length]

        # 어텐션과 인코더 output을 이용해 가중합의 컨텍스트 벡터 생성
        weighted = torch.bmm(attention, encoder_output)
        # weighted = [batch size, 1, enc hidden dim * 2]

        dec_input = torch.cat((embedded, weighted), dim=2)
        # dec_input = [batch size, 1, embed dim + (enc hidden dim * 2)]

        output, (hidden, _) = self.lstm(dec_input, hidden.unsqueeze(0))
        # output = [batch size, 1, dec hidden dim]
        # hidden = [batch size, 1, dec hidden dim]

        assert (output == hidden).all()

        output = output.squeeze(1)      # [batch size, dec hidden dim]    : 디코더의 현재 히든 스테이트
        embedded = embedded.squeeze(1)  # [batch size, embed dim]         : 디코더의 현재 타겟 토큰
        weighted = weighted.squeeze(1)  # [batch size, enc hidden dim * 2]: 어텐션을 이용한 가중합

        pred = self.fc(torch.cat((output, weighted, embedded), dim=1))
        # pred = [batch size, output dim]

        # return a prediction, a new hidden state and attention tensor
        return prediction, hidden.squeeze(1), attention.squeeze(1)


class Seq2SeqAttention(nn.Module):
    def __init__(self, input_dim, 
                       embed_dim, 
                       enc_hidden_dim, 
                       dec_hidden_dim, 
                       dropout,
                       output_dim, 
                       pad_idx):
        super(Seq2SeqAttention, self).__init__()
        self.encoder = Encoder(
            input_dim,
            embed_dim, 
            enc_hidden_dim, 
            dec_hidden_dim, 
            dropout
        )
        self.attention = Attention(
            enc_hidden_dim, 
            dec_hidden_dim
        )
        self.decoder = Decoder(
            input_dim, 
            embed_dim, 
            enc_hidden_dim, 
            dec_hidden_dim,
            output_dim,
            dropout,
            self.attention
        )
        self.pad_idx = pad_idx
        self.output_dim = output_dim

    def create_mask(self, source):
        mask = (source == self.pad_idx)
        return mask

    def forward(self, source, source_length, target, teacher_forcing=0.5):
        '''
        source = [batch size, source length]
        target = [batch size, target length]
        '''

        target_max_len, batch_size = target.size()
        outputs = torch.zeros(batch_size, target_max_len, self.output_dim).to(self.device)
        # outputs = [batch size, target length, output dim]

        # 어텐션을 저장할 텐서 정의
        attentions = torch.zeros(batch_size, target_max_len, source.shape[0]).to(self.device)
        # attentions = [batch size, target length, source length]

        encoder_output, hidden = self.encoder(source, source_length)
        # encoder_output    = [batch size, source length, enc hidden dim * 2]
        # hidden            = [batch size, dec hidden dim]

        # 디코더의 초기 입력 값은 <SOS> 토큰
        dec_input = target[0, :]
        # input = [batch size]

        mask = self.create_mask(source)
        # mask = [batch size, source length]

        for t in range(1, target_max_len):
            output, hidden, attention = self.decoder(dec_input, hidden, encoder_output, mask)
            # output    = [batch size, output dim]
            # hidden    = [batch size, dec hidden dim]
            # attention = [batch size, source length]

            # 앞서 정의한 outputs와 attnetions를 해당 타입 스텝에 맞게 채워나감
            outputs[t] = output
            attentions[t] = attention

            # Teacher forcing 확률 계산
            teacher_force = random.random() < teacher_forcing

            # 가장 높은 확률로 다음에 올 토큰 예측 
            # output.max(1)는 (해당 토큰의 확률, 해당 토큰의 인덱스) 튜플을 반환
            top1 = output.max(1)[1]

            # Teacher forcing을 사용하면 Ground-truth 토큰을
            # Teacher forcing을 사용하지 않는다면, 디코더가 예측한 토큰을 다음 입력 값으로 사용
            dec_input = (target[t] if teacher_force else top1)

        return outputs, attentions

model = Seq2SeqAttention(
    input_dim=2000, 
    embed_dim=200, 
    enc_hidden_dim=200, 
    dec_hidden_dim=200, 
    dropout=0.2,
    output_dim=3000, 
    pad_idx=0
)
print(model)

Seq2SeqAttention(
  (encoder): Encoder(
    (embedding): Embedding(2000, 200)
    (lstm): LSTM(200, 200, bidirectional=True)
    (fc): Linear(in_features=400, out_features=200, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (attention): Attention(
    (attention): Linear(in_features=600, out_features=200, bias=True)
  )
  (decoder): Decoder(
    (attention): Attention(
      (attention): Linear(in_features=600, out_features=200, bias=True)
    )
    (embedding): Embedding(3000, 200)
    (lstm): LSTM(600, 200)
    (fc): Linear(in_features=800, out_features=3000, bias=True)
    (dropout): Dropout(p=0.2, inplace=False)
  )
)


### 모델 조사
이제 Pruning이 적용되지 않은 `attnetion` 레이어를 살펴봅시다. 해당 레이어는 현재 `weight`와 `bias` 그리고 `v`를 지니고 있습니다.

In [4]:
module = model.attention
list(module.named_parameters())

[('weight', Parameter containing:
  tensor([[ 0.0042,  0.0097,  0.0081,  ..., -0.0141,  0.0324,  0.0173],
          [-0.0177, -0.0256, -0.0201,  ..., -0.0254, -0.0268, -0.0381],
          [-0.0349,  0.0042, -0.0222,  ...,  0.0366, -0.0176, -0.0316],
          ...,
          [-0.0004, -0.0259, -0.0041,  ...,  0.0257,  0.0156, -0.0021],
          [ 0.0181,  0.0133, -0.0258,  ...,  0.0261,  0.0254, -0.0005],
          [-0.0307, -0.0260,  0.0096,  ...,  0.0003, -0.0045,  0.0006]],
         requires_grad=True)), ('bias', Parameter containing:
  tensor([-0.0395,  0.0304, -0.0336, -0.0144,  0.0356, -0.0219,  0.0299,  0.0010,
           0.0209, -0.0233,  0.0165,  0.0103,  0.0139, -0.0338,  0.0012,  0.0333,
          -0.0162, -0.0030,  0.0391,  0.0257,  0.0054, -0.0264, -0.0029,  0.0186,
           0.0234,  0.0347, -0.0405,  0.0044,  0.0310,  0.0293,  0.0307,  0.0133,
          -0.0025,  0.0290, -0.0300,  0.0012,  0.0350, -0.0126,  0.0102, -0.0160,
          -0.0335,  0.0241, -0.0059,  0.0091, 

In [5]:
list(module.named_buffers())

[]

### 모델 Pruning


In [6]:
prune.random_unstructured(module, 'v', amount=0.3)

Attention(
  (attention): Linear(in_features=600, out_features=200, bias=True)
)

In [7]:
list(module.named_parameters())

[('weight', Parameter containing:
  tensor([[ 0.0042,  0.0097,  0.0081,  ..., -0.0141,  0.0324,  0.0173],
          [-0.0177, -0.0256, -0.0201,  ..., -0.0254, -0.0268, -0.0381],
          [-0.0349,  0.0042, -0.0222,  ...,  0.0366, -0.0176, -0.0316],
          ...,
          [-0.0004, -0.0259, -0.0041,  ...,  0.0257,  0.0156, -0.0021],
          [ 0.0181,  0.0133, -0.0258,  ...,  0.0261,  0.0254, -0.0005],
          [-0.0307, -0.0260,  0.0096,  ...,  0.0003, -0.0045,  0.0006]],
         requires_grad=True)), ('bias', Parameter containing:
  tensor([-0.0395,  0.0304, -0.0336, -0.0144,  0.0356, -0.0219,  0.0299,  0.0010,
           0.0209, -0.0233,  0.0165,  0.0103,  0.0139, -0.0338,  0.0012,  0.0333,
          -0.0162, -0.0030,  0.0391,  0.0257,  0.0054, -0.0264, -0.0029,  0.0186,
           0.0234,  0.0347, -0.0405,  0.0044,  0.0310,  0.0293,  0.0307,  0.0133,
          -0.0025,  0.0290, -0.0300,  0.0012,  0.0350, -0.0126,  0.0102, -0.0160,
          -0.0335,  0.0241, -0.0059,  0.0091, 

In [8]:
list(module.named_buffers())

[('v_mask',
  tensor([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1.,
          1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0.,
          1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
          0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1.,
          1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1.,
          0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0.,
          1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1.,
          1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1.,
          1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
          0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
          0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
          1., 1.]))]

In [9]:
module.v.size()

torch.Size([200])

In [11]:
module.v

tensor([0.0000, 0.0000, 0.0000, 0.1632, 0.7836, 0.2296, 0.0000, 0.9788, 0.8731,
        0.0000, 0.0000, 0.1541, 0.9458, 0.0991, 0.0810, 0.0000, 0.1863, 0.8496,
        0.8108, 0.5251, 0.0000, 0.0000, 0.5540, 0.5456, 0.2588, 0.0536, 0.3693,
        0.8613, 0.0000, 0.5353, 0.5890, 0.3590, 0.9152, 0.4470, 0.7180, 0.0000,
        0.5198, 0.0000, 0.6132, 0.9551, 0.0000, 0.1698, 0.0000, 0.2517, 0.4096,
        0.2482, 0.0135, 0.3361, 0.0000, 0.0275, 0.0000, 0.0000, 0.6862, 0.6623,
        0.0000, 0.4214, 0.1704, 0.6022, 0.5915, 0.0000, 0.2501, 0.2400, 0.3429,
        0.5581, 0.2312, 0.0000, 0.6589, 0.0000, 0.0000, 0.0000, 0.3515, 0.6454,
        0.8333, 0.5264, 0.0000, 0.0000, 0.6894, 0.0000, 0.3962, 0.0000, 0.0000,
        0.0000, 0.0000, 0.8744, 0.2893, 0.1269, 0.0499, 0.7993, 0.0000, 0.6862,
        0.0000, 0.7840, 0.0000, 0.8050, 0.0980, 0.0000, 0.0000, 0.0000, 0.5931,
        0.0000, 0.7929, 0.4560, 0.8388, 0.7698, 0.7667, 0.6927, 0.0050, 0.0000,
        0.7969, 0.6506, 0.7902, 0.2139, 

In [10]:
(module.v == 0).sum()

tensor(60)

In [12]:
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x1fa2674b710>)])

In [13]:
prune.l1_unstructured(module, name='bias', amount=3)

Attention(
  (attention): Linear(in_features=600, out_features=200, bias=True)
)

In [14]:
list(module.named_parameters())

[('weight', Parameter containing:
  tensor([[ 0.0042,  0.0097,  0.0081,  ..., -0.0141,  0.0324,  0.0173],
          [-0.0177, -0.0256, -0.0201,  ..., -0.0254, -0.0268, -0.0381],
          [-0.0349,  0.0042, -0.0222,  ...,  0.0366, -0.0176, -0.0316],
          ...,
          [-0.0004, -0.0259, -0.0041,  ...,  0.0257,  0.0156, -0.0021],
          [ 0.0181,  0.0133, -0.0258,  ...,  0.0261,  0.0254, -0.0005],
          [-0.0307, -0.0260,  0.0096,  ...,  0.0003, -0.0045,  0.0006]],
         requires_grad=True)), ('v_orig', Parameter containing:
  tensor([0.0826, 0.1509, 0.5415, 0.1632, 0.7836, 0.2296, 0.8762, 0.9788, 0.8731,
          0.0524, 0.8753, 0.1541, 0.9458, 0.0991, 0.0810, 0.5743, 0.1863, 0.8496,
          0.8108, 0.5251, 0.3228, 0.9408, 0.5540, 0.5456, 0.2588, 0.0536, 0.3693,
          0.8613, 0.8083, 0.5353, 0.5890, 0.3590, 0.9152, 0.4470, 0.7180, 0.9864,
          0.5198, 0.1658, 0.6132, 0.9551, 0.5013, 0.1698, 0.3747, 0.2517, 0.4096,
          0.2482, 0.0135, 0.3361, 0.0129, 0.

In [30]:
module.bias

tensor([-0.0395,  0.0304, -0.0336, -0.0144,  0.0356, -0.0219,  0.0299,  0.0010,
         0.0209, -0.0233,  0.0165,  0.0103,  0.0139, -0.0338,  0.0012,  0.0333,
        -0.0162, -0.0030,  0.0391,  0.0257,  0.0054, -0.0264, -0.0029,  0.0186,
         0.0234,  0.0347, -0.0405,  0.0044,  0.0310,  0.0293,  0.0307,  0.0133,
        -0.0025,  0.0290, -0.0300,  0.0012,  0.0350, -0.0126,  0.0102, -0.0160,
        -0.0335,  0.0241, -0.0059,  0.0091,  0.0217,  0.0389, -0.0000,  0.0115,
        -0.0384,  0.0156, -0.0329, -0.0095,  0.0238, -0.0066,  0.0300, -0.0069,
         0.0075,  0.0086,  0.0130,  0.0365,  0.0107,  0.0020, -0.0300, -0.0015,
         0.0209,  0.0279,  0.0372, -0.0039, -0.0000, -0.0171,  0.0068, -0.0359,
         0.0376,  0.0319,  0.0139, -0.0092, -0.0270, -0.0033,  0.0162,  0.0120,
        -0.0330, -0.0362, -0.0104, -0.0193,  0.0196,  0.0170, -0.0056, -0.0020,
        -0.0400, -0.0007, -0.0402, -0.0094, -0.0315,  0.0070,  0.0347, -0.0346,
        -0.0260,  0.0197, -0.0056,  0.02

In [15]:
list(module.named_buffers())

[('v_mask',
  tensor([0., 0., 0., 1., 1., 1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1.,
          1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0.,
          1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1.,
          0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1.,
          1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1.,
          0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0.,
          1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1.,
          1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 0., 1., 1.,
          1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
          0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
          0., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1.,
          1., 1.])),
 ('bias_mask',
  tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 

In [22]:
pruned = (list(module.named_buffers())[-1][-1] == 0).nonzero()

In [23]:
pruned[0]

tensor([46])

In [28]:
module.bias[46]

tensor(-0., grad_fn=<SelectBackward>)

In [29]:
list(module.named_parameters())[-1][-1][46]

tensor(-0.0003, grad_fn=<SelectBackward>)

In [31]:
module._forward_pre_hooks

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured at 0x1fa2674b710>),
             (1, <torch.nn.utils.prune.L1Unstructured at 0x1fa2679c320>)])

### 반복적인 Pruning