In [1]:
import torch.nn as nn
import torch

import torch
from torch import nn

if torch.cuda.is_available():
    device = torch.device("cuda:0")

else:
    device = "cpu"

In [2]:
class GeneratorDCGAN_v1_Text(nn.Module):
    def __init__(self, noise_dim, word_embedding_dim):
        super(GeneratorDCGAN_v1_Text, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.noise_dim = noise_dim
        self.word_embedding_dim = word_embedding_dim

        self.fc_noise = nn.Linear(noise_dim, 105 * 8)
        self.fc_word_embedding = nn.Linear(word_embedding_dim, 105 * 8)
        self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)


    def forward(self, noise, word_embedding):
        # Process noise
        noise = self.fc_noise(noise)
        noise = noise.view(noise.size(0), 1, 105, 8)

        # Process word embedding
        word_embedding = self.fc_word_embedding(word_embedding.to(self.device))
        word_embedding = word_embedding.view(word_embedding.size(0), 1, 105, 8)

        # Concatenate noise and word embedding
        combined_input = torch.cat([noise, word_embedding], dim=1)

        # Upsample and generate the output
        z = self.conv1(combined_input)
        z = self.bn1(z)
        z = self.relu(z)
        z = self.conv2(z)

        return z

class DiscriminatorDCGAN_v1_Text(nn.Module):
    def __init__(self, n_filters, word_embedding_dim):
        super(DiscriminatorDCGAN_v1_Text, self).__init__()

        self.word_embedding_dim = word_embedding_dim
        self.fc_word_embedding = nn.Linear(word_embedding_dim, 105 * 8)

        self.network = nn.Sequential(
            nn.Conv2d(2, n_filters, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters, n_filters*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters * 2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(n_filters*2, n_filters*4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.InstanceNorm2d(n_filters*4),
            nn.LeakyReLU(0.2),

            nn.Flatten(),  # Flatten spatial dimensions

            # Fully connected layer to reduce to a single value per sample
            nn.Linear(n_filters*4 * (105 // 8) * (8 // 8), 1),
            nn.Sigmoid()
        )

    def forward(self, input, word_embedding):
        word_embedding = self.fc_word_embedding(word_embedding.to(self.device))
        word_embedding = word_embedding.view(word_embedding.size(0), 1, 105, 8)

        combined_input = torch.cat([input, word_embedding], dim=1)

        output = self.network(combined_input)
        return output