In [1]:
from typing import OrderedDict

import torch as th
import torch.nn as nn

from music_gan.networks import Generator, Discriminator

In [2]:
nb_steps = 16

In [3]:
tmp_gen = Generator(16, 7)

In [4]:
state_dict = tmp_gen.state_dict()

In [5]:
class RecurrentGenerator(nn.Module):
    def __init__(
            self,
            input_size:int,
            conv_rand_channels: int,
            cnn_state_dict: OrderedDict[str, th.Tensor]
    ):
        super(RecurrentGenerator, self).__init__()

        gen = Generator(
            conv_rand_channels, end_layer=7
        )

        gen.load_state_dict(cnn_state_dict)

        self.__conv_blocks = gen.conv_blocks
        self.__end_block = gen.end_block

        self.__input_size = input_size
        self.__hidden_size = conv_rand_channels * 2

        self.__rnn = nn.RNN(
            self.__input_size,
            self.__hidden_size,
            batch_first=True
        )

    def forward(self, z_rec: th.Tensor) -> th.Tensor:

        out_rec, _ = self.__rnn(z_rec)

        out = th.stack(out_rec.split(16, dim=-1), dim=1).permute(0, 3, 1, 2)

        for layer in self.__conv_blocks:
            out = layer(out)

        out = self.__end_block(out)

        return out


In [6]:
rec_gen = RecurrentGenerator(32, 16, state_dict)

In [7]:
x = th.randn(5, nb_steps, 32)

In [8]:
o = rec_gen(x)

In [9]:
o.size()

torch.Size([5, 2, 512, 4096])

In [10]:
disc_state_dict = Discriminator(start_layer=0).state_dict()

In [22]:
class RecurrentDiscriminator(nn.Module):
    def __init__(self, cnn_state_dict: OrderedDict[str, th.Tensor]):
        super(RecurrentDiscriminator, self).__init__()

        conv_disc = Discriminator(start_layer=0)
        conv_disc.load_state_dict(cnn_state_dict)

        self.__start_block = conv_disc.start_block
        self.__conv_blocks = conv_disc.conv_blocks[:-1]

        rnn_out_size = 64

        self.__rnn = nn.RNN(
            conv_disc.end_layer_channels * 2,
            rnn_out_size,
            batch_first=True,
            nonlinearity="relu"
        )

        self.__clf = nn.Sequential(
            nn.Linear(
                rnn_out_size,
                1
            ),
            nn.Sigmoid()
        )

    def forward(self, data: th.Tensor) -> th.Tensor:
        out = self.__start_block(data)
        for layer in self.__conv_blocks:
            out = layer(out)

        out = th.flatten(out, 1, 2).permute(0, 2, 1)

        out, _ = self.__rnn(out)
        
        # sigmoid + log proba sum seems not correct with WGAN...
        out = self.__clf(out)
        out = out.log().sum(dim=1)

        return out

In [23]:
rec_disc = RecurrentDiscriminator(disc_state_dict)

In [24]:
out_disc = rec_disc(o)

In [25]:
out_disc.size()

torch.Size([5, 1])

In [26]:
out_disc

tensor([[-11.5789],
        [-11.5789],
        [-11.5789],
        [-11.5789],
        [-11.5789]], grad_fn=<SumBackward1>)

In [28]:
a=th.ones(5, 16, 1) - 0.1
b=th.zeros(5, 16, 1) + 0.1

In [29]:
res = -(a.log().sum(dim=1).mean() - b.log().sum(dim=1).mean())

In [30]:
res

tensor(-35.1556)

In [31]:
res_2 = -(b.log().sum(dim=1).mean() - a.log().sum(dim=1).mean())

In [32]:
res_2

tensor(35.1556)