In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from torch import nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, input_size: int = 128, output_size: int = 512):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.ReLU(),
            nn.Linear(256, output_size),
            nn.Tanh()
        )

    def forward(self, input: torch.Tensor):
        return self.seq(input)


class Discriminator(nn.Module):
    def __init__(self, input_size: int = 512):
        super().__init__()

        self.seq = nn.Sequential(
            nn.Linear(input_size, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 128),
            nn.LeakyReLU(),
            nn.Linear(128, 32),
            nn.LeakyReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

    def forward(self, input: torch.Tensor):
        return self.seq(input)

In [51]:
class BiLSTMDifferentInputs(nn.Module):
    def __init__(self, input_size_1: int = 512, input_size_2: int = 1):
        super().__init__()
        self.lstm1 = nn.LSTM(input_size_1, 256)
        self.lstm2 = nn.LSTM(input_size_2, 32)


    def forward(self, input_1: torch.Tensor, input_2: torch.Tensor):
        out1, (hn1, cn1) = self.lstm1(input_1)
        out2, (hn2, cn2) = self.lstm2(input_2)
        out = []
        for i in range(len(out1)):
            t = torch.cat((out1[i], out2[len(out2)-i-1]), 1)
            out.append(t)
        return out

In [52]:
from torch.utils.data import Dataset

class RandomDataGen(Dataset):
    def __init__(self, input_size: int = 128, num_tensors = 512):
        inputs = []
        for i in range(num_tensors):
            ten = torch.normal(0, torch.ones(1, input_size))
            inputs.append(ten)

        self.len = num_tensors

        self.data = torch.stack(inputs, dim=0)

    def __len__(self):
        return self.len
    
    def __getitem__(self, idx: int):
        return self.data[idx]

In [53]:
from torch.utils.data import DataLoader

def get_gan_outputs():
    G = Generator()
    D = Discriminator()
    LSTM = BiLSTMDifferentInputs()

    dataset = RandomDataGen()
    dl = DataLoader(dataset, batch_size=32, shuffle=True)

    for step, x_list in enumerate(dl):
        g_outputs = []
        d_outputs = []

        for x in x_list:
            out = G(x)
            g_outputs.append(out)
            out = D(out)
            d_outputs.append(out)
            # print(out.squeeze().unsqueeze())
        
        d_outputs.reverse()

        g_outputs = torch.stack(g_outputs)
        d_outputs = torch.stack(d_outputs)

        out = LSTM(g_outputs, d_outputs)

        return out


In [54]:
get_gan_outputs()

1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1


[tensor([[-4.3287e-02,  7.6542e-02, -6.6061e-02, -4.2252e-02, -7.3386e-02,
          -2.5715e-02,  5.3953e-02,  1.0311e-01,  1.7680e-02, -5.5460e-02,
           4.2085e-02,  3.4359e-02,  7.6865e-03, -2.0326e-02, -2.1861e-02,
          -7.0709e-02,  6.2905e-02,  3.9618e-02, -4.0952e-02,  2.3800e-02,
          -7.4954e-03,  2.1727e-02,  1.3824e-02,  3.1360e-02, -1.0935e-02,
           3.1468e-02, -6.1533e-03,  7.4034e-03, -1.7850e-02, -7.2767e-02,
           4.0234e-02, -1.9082e-02, -5.2544e-02,  6.1610e-02,  6.2377e-02,
          -3.3574e-03, -2.7670e-02,  2.3304e-02,  6.4804e-02,  3.2205e-02,
           5.2742e-02, -3.0560e-04,  4.3050e-02, -9.1649e-02,  2.5642e-02,
          -3.1054e-03,  4.8858e-02,  2.4999e-02, -2.4158e-02, -1.7436e-02,
          -1.4974e-02, -1.4247e-02,  6.9693e-02,  2.5855e-02, -2.4566e-02,
          -8.6725e-03, -3.8083e-02, -9.7424e-02,  6.9184e-02, -7.5501e-02,
          -3.8128e-02,  1.2891e-02, -1.4231e-02, -5.2258e-03, -8.8086e-02,
          -1.1854e-01,  8