In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from transformers import BertTokenizer, BertModel, BertConfig
from torch.nn.utils.rnn import pack_padded_sequence
from torchvision.models import resnet101

In [3]:
# 특성 추출기 layer

# vgg_config = [64, 'N', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

def get_vgg_layer(config, batch_norm):
    layers = []
    in_channels = 3

    for c in config:
        assert c == 'M' or isinstance(c, int)
        
        if c == 'M': # c가 M이면 MaxPooling
            layers += [nn.MaxPool2d(kernel_size= 2)]

        else: # c가 int면 Convolution
            conv2d = nn.Conv2d(in_channels, c, kernel_size= 3, padding= 1)

            # batch normalization 적용
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(c), nn.ReLU(inplace= True)]
            else:
                layers += [conv2d, nn.ReLU(inplace= True)]

            in_channels = c # 다음 layer의 in_channels로 사용

    return nn.Sequential(*layers)

In [2]:
# Encoder
class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()
        # hidden size, layer 수, pixel 크기
        self.enc_hidden_size = config.enc_hidden_size
        self.dec_hidden_size = config.dec_hidden_size
        self.dec_num_layers = config.dec_num_layers
        self.pixel_size = self.enc_hidden_size * self.enc_hidden_size

        base_model = resnet101(pretrained= True, progress= False)
        # 추론기 부분은 제거
        base_model = list(base_model.children())[:-2]
        self.resnet = nn.Sequential(*base_model)
        self.pooling = nn.AdaptiveAvgPool2d((self.enc_hidden_size, self.enc_hidden_size))

        self.relu = nn.ReLU()
        self.hidden_dim_changer = nn.Sequential(
            nn.Linear(self.pixel_size, self.dec_hidden_size),
            nn.ReLU()
        )
        self.h_mlp = nn.Linear(2048, self.dec_hidden_size)
        self.c_mlp = nn.Linear(2048, self.dec_hidden_size)

        self.fine_tune(True)

    def fine_tune(self, fine_tune= True):
        for p in self.resnet.parameters():
            p.requires_grad = False

        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

    def forward(self, x):
        batch_size = x.size(0)

        x = self.resnet(x)
        x = self.pooling(x)
        x = x.view(batch_size, 2048, -1)

        if self.dec_num_layers != 1:
            tmp = self.hidden_dim_changer(self.relu(x))
        else:
            tmp = torch.mean(x, dim= 2, keepdim= True)
        tmp = torch.permute(tmp, (2, 0, 1))
        h0 = self.h_mlp(tmp)
        c0 = self.c_mlp(tmp)

        return x, (h0, c0)

In [4]:
# Attention
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()

In [None]:
# Decoder
class Decoder(nn.Module):
    def __init__(self, config, tokenizer):
        super(Decoder, self).__init__()
        self.pixel_size = config.enc_hidden_size * config.enc_hidden_size
        self.dec_hidden_size = config.dec_hidden_size
        self.dec_num_layers = config.dec_num_layers

        self.dropout = config.dropout
        self.is_attn = config.is_attn # Attention layer인지 여부

        self.pad_token_id = tokenizer.pad_token_id
        self.vocab_size = tokenizer.vocab_size

        # attention layer
        if self.is_attn:
            self.attention = Attention(self.dec_hidden_size)        
        # input_size는 attention layer인지 decoder인지에 따라 변경
        self.input_size = self.dec_hidden_size + 2048 if self.is_attn else self.dec_hidden_size

        self.embedding = nn.Embedding(self.vocab_size, self.dec_hidden_size, padding_idx= self.pad_token_id)
        self.lstm = nn.LSTM(input_size=self.input_size,
                            hidden_size= self.dec_hidden_size,
                            num_layers= self.dec_num_layers,
                            batch_first= True)
        
        self.dropout_layer = nn.Dropout(self.dropout)
        self.relu = nn.ReLU()
        self.beta = nn.Sequential(
            nn.ReLU(),
            nn.Linear(self.dec_hidden_size, 2048),
            nn.Sigmoid()
        )

        self.fc = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.dec_hidden_size, self.vocab_size)
        )

        self.embedding.apply(self.init_weights)
        self.fc.apply(self.init_weights)

    def init_weights(self, m):
        if isinstance(m, nn.Linear):
            m.bias.data.fill_(0)
            m.weight.data.uniform_(-0.1, 0.1)
        if isinstance(m, nn.Embedding):
            m.weight.data.uniform_(-0.1, 0.1)