In [1]:
import mindspore
import mindspore.ops as ops
import mindspore.nn as nn



In [2]:
class vgg(nn.Cell):
    def make_2layer_conv(self, input_size, output_size):
        return nn.SequentialCell(
            [
                nn.Conv2d(input_size, output_size, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(output_size, output_size, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ]
        )

    def make_3layer_conv(self, input_size, output_size):
        return nn.SequentialCell(
            [
                nn.Conv2d(input_size, output_size, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(output_size, output_size, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Conv2d(output_size, output_size, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=2, stride=2)
            ]
        )
    
    def make_3layer_fc(self, input_size, output_size):
        return nn.SequentialCell(
            [
                nn.Dense(input_size, output_size),
                nn.ReLU(),
                nn.Dropout(keep_prob=0.5)
            ]
        )

    def __init__(self):
        self.conv0 = self.make_2layer_conv(3, 32)
        self.conv1 = self.make_2layer_conv(32, 64)
        self.conv2 = self.make_2layer_conv(64, 128)
        self.conv3 = self.make_3layer_conv(128, 256)
        self.conv4 = self.make_3layer_conv(256, 512)
        self.conv5 = self.make_3layer_conv(512, 512)
        self.fc6 = self.make_3layer_fc(512*7*7, 4096)
        self.fc7 = self.make_3layer_fc(4096, 4096)
        self.fc8 = nn.Dense(4096, 1000)
        self.softmax = nn.Softmax()

    def construct(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        # x = self.conv5(x).view(1, -1)
        # x = self.fc6(x)
        # x = self.fc7(x)
        # x = self.fc8(x)
        # x = self.softmax(x)
        return x

class block(nn.Cell):
    def __init__(self):
        self.Wq = nn.Dense(embed_size, 512)
        self.Wi = nn.Dense(512, 512, has_bias=False)
        self.Wp = nn.Dense(512, 512)
        self.transpose = ops.Transpose()
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(0)
        self.reduce_sum = ops.ReduceSum()
    
    def construct(self, v_i, v_q):
        encoded_q = self.Wq(v_q)
        encoded_i = self.Wi(self.transpose(v_i, (2,0,1)))
        hA = self.tanh(encoded_q + encoded_i)
        pI = self.softmax(self.Wp(hA).view(-1,512))
        vI = self.reduce_sum(encoded_q.view(-1,512) * pI, 0)
        u = vI.view(1, 1, 512) + v_q
        return u

class san(nn.Cell):
    def __init__(self, vocab_size):
        self.vgg = vgg()
        self.embed = nn.Embedding(vocab_size, 512)
        self.lstm_q = nn.LSTM(512, 512)
        self.lstm_ans = nn.LSTM(512, 512)
        self.block0 = block()
        self.block1 = block()
        self.block2 = block()
        self.block3 = block()
        self.matmul = ops.MatMul()

    def construct(self, image, question, answer):
        v_i = self.vgg(image)
        _, (v_q, __) = self.lstm_q(question)
        u0 = self.block0(v_i, v_q)
        u1 = self.block1(v_i, u0)
        u2 = self.block2(v_i, u1)
        u3 = self.block3(v_i, u2)

        _, (v_ans, __) = self.lstm_ans(answer)
        ans_prob = self.matmul(v_ans.view(-1, 512), u3.view(512, 1))
        return ans_prob