In [1]:
import torch
import torch.nn as nn

In [2]:
def conv(input_channels, output_channels, kernel_size, strides, padding):
    return nn.Sequential(
        nn.Conv2d(input_channels, output_channels, kernel_size, strides, padding),
        nn.ReLU(inplace = True),
        nn.BatchNorm2d(output_channels)
    )

In [5]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()

        self.convs = nn.Sequential(
            conv(input_channels = 3,
                 output_channels = 128,
                 kernel_size = 4,
                 strides = 2,
                 padding = 1), #112x112
            conv(input_channels = 128,
                 output_channels = 128,
                 kernel_size = 4,
                 strides = 2,
                 padding = 1), #56x56
            conv(input_channels = 128,
                 output_channels = 128,
                 kernel_size = 4,
                 strides = 2,
                 padding = 1), #28x28
            conv(input_channels = 128,
                 output_channels = 128,
                 kernel_size = 4,
                 strides = 2,
                 padding = 1), #14x14
        )

    def forward(self, x):
        return self.convs(x)


In [10]:
class FilmBlock(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, beta, gamma ):

        beta = beta.view(x.size(0), x.size(1), 1, 1)
        gamma = gamma.view(x.size(0), x.size(1), 1, 1)

        return x * gamma + beta #hadamard product

In [9]:
class ResBlock(nn.Module):
    def __init__(self, input_channels, output_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels = self.input_channels,
                               out_channels = self.output_channels,
                               kernel_size = 1,
                               stride = 1,
                               padding = 0)
        self.relu1 = nn.ReLU(inplace = True)
        self.conv2 = nn.Conv2d(in_channels = output_channels,
                               out_channels = output_channels,
                               kernel_size = 3,
                               stride = 1,
                               padding = 1)
        self.batchnorm = nn.BatchNom2d(output_channels)

        self.film = self.FilmBlock()
        self.relu2 = nn.ReLU(inplace = True)

    def forward(self, x, beta, gamma):
        x = self.conv1(x)
        x = self.relu1(x)

        identity = x

        y = self.conv2(x)
        y = self.batchnorm(y)
        y = self.film(y, beta, gamma)
        y = self.relu2(y)

        y = y + identity

        return y


In [14]:
class Classifier(nn.Module):
    def __init__(self, input_channels, class_num):
        super().__init__()

        self.conv = nn.conv(in_channels = input_channels,
                            out_channels = 512,
                            kernel_size = 1,
                            stride = 1,
                            padding = 0)
        self.maxpool = nn.AdaptiveMaxPool2d((1, 1))
        self.fc = nn.Sequential(
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024,1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, class_num)
        )

    def forward(self, x):
        x = self.conv(x)

        x = self.maxpool(x)
        x = x.view(x.size(0), x.size(1))
        x = self.fc(x)

        return x

In [15]:
class FiLM(nn.Module):
    def __init__(self, n_vocab, embed_hidden, gru_hidden, res_blk_num, class_num, img_channels):
        super().__init__()

        dim_question = 11

        self.embed = nn.Embedding(n_vocab, embed_hidden)
        self.gru = nn.GRU(embed_hidden, gru_hidden, batch_first = True)
        self.film_generator = nn.Linear(gru_hidden, 2 * res_blk_num * img_channels)

        self.featureextractor = FeatureExtractor()
        self.res_blks = nn.ModuleList() #동적으로 모듈을 추가하기 위해. res_blk의 갯수가 고정적이지 않기 때문에 사용해야함.

        for _ in range(res_blk_num):
            self.res_blks.append(ResBlock(img_channels + 2, img_channels)) #여기에서 왜 +2를 하는지?-> coordinate x, coordinate y도 주기 때문.

        self.classifier = Classifier(img_channels, class_num)

        self.res_blk_num = res_blk_num
        self.img_channels = img_channels

    def forward(self, x, question, question_len):
        #x is image.

        batch_size = x.size(0)

        #1) Image Feature Extract
        x = self.featureextractor(x)

        #2) GRU의 output을 Linear 통과시켜 beta와 gamma를 구한다.
        embed = self.embed(question)
        embed = nn.utils.rnn.pack_padded_sequence(embed, question_len, batch_first=True)
        _, h = self.gru(embed)
        film_vector = self.film_generator(h.squeeze()).view(
            batch_size, self.res_blk_num, 2, self.img_channels #여기에서 2는 gamma와 beta
        )

        d = x.size(2) #이미지 사이즈(width)


        #3) 논문에서 spatial reasoning을 위해 각 resblock의 input으로
        #   -1~1로 scale된 x와y spatial position을 image feature과 concat한다고 나와있음.
        coordinate = torch.arange(-1, 1 + 0.00001, 2/(d-1)).cuda()
        coordinate_x = coordinate.expand(batch_size, 1, d, d)
        coordinate_y = coordinate.view(d, 1).expand(batch_size, 1, d, d)

        #4) 각각의 ResBlock에 concat된 x와 beta, gamma를 입력한다.
        for i, res_block in enumerate(self.res_blks):
            beta = film_vector[:, i, 0, :] #batch_size, res_blk_num, beta, channel
            gamma = film_vector[:, i, 1, :] #batch_size, res_blk_num, gamma, channel

            x = torch.cat([x, coordinate_x, coordinate_y], dim=1)
            x = res_block(x, beta, gamma)

        #5) classifier
        x = self.classifier(x)

        return x



