In [None]:
# ResNet3Small adaptado a discriminador

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

class Discriminator(nn.Module):
    def __init__(self, width, height, input_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            self._add_res_net_block(input_channels, 16, 3, False),
            nn.MaxPool2d(2, 2),

            self._add_res_net_block(16, 32, 3, True),
            self._add_res_net_block(32, 32, 3, False),
            nn.MaxPool2d(2, 2),

            self._add_res_net_block(32, 64, 3, True),
            self._add_res_net_block(64, 64, 3, False),
            self._add_res_net_block(64, 64, 3, False),
            nn.MaxPool2d(2, 2),

            self._add_res_net_block(64, 128, 3, True),
            self._add_res_net_block(128, 128, 3, False),
            self._add_res_net_block(128, 128, 3, False),
            nn.MaxPool2d(2, 2),

            self._add_res_net_block(128, 256, 3, True),
            self._add_res_net_block(256, 256, 3, False),
            self._add_res_net_block(256, 256, 3, False),
            nn.AvgPool2d(2, 2)
        )
        
        # Agregamos una capa completamente conectada para la clasificación final.
        # La entrada de esta capa sera el numero de elementos en el feature map despues de la ultima capa de agrupación.
        self.feature_vector_length = 256 * (width // 16) * (height // 16)
        self.fc = nn.Linear(self.feature_vector_length, 1)

    def _add_convolution(self, in_channels, out_channels, kernel_size):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def _add_res_net_block(self, in_channels, out_channels, kernel_size, downsample):
        if downsample or in_channels != out_channels:
            shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2)
        else:
            shortcut = nn.Identity()
        
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride=2 if downsample else 1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            shortcut,
            nn.ReLU()
        )

    def forward(self, img):
        # Flatten the output de la red ResNet para la capa completamente conectada.
        out = self.model(img)
        out = out.view(out.size(0), -1)
        validity = self.fc(out)
        return torch.sigmoid(validity)


#input_channels = 3
width, height = 96, 96
#discriminator = Discriminator(width=width, height=height, input_channels=input_channels)


In [None]:
# Generador
import torch
from torch import nn
import numpy as np

class Generator(nn.Module):
    def __init__(self, img_shape, latent_dim, n_classes):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        self.label_emb = nn.Embedding(n_classes, n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim + n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and noise to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *self.img_shape)
        return img




In [None]:
img_shape = (1, 96, 96)  # Suponiendo imágenes en escala de grises de 96x96
latent_dim = 100  # Tamaño del vector de ruido
n_classes = 80 # Número de símbolos musicales diferentes

generator = Generator(img_shape=img_shape, latent_dim=latent_dim, n_classes=n_classes)

# Supongamos que tenemos un lote de ruido y etiquetas de ejemplo para generar imágenes
batch_size = 16
noise = torch.randn(batch_size, latent_dim)
labels = torch.randint(0, n_classes, (batch_size,))

# Generar imágenes
generated_images = generator(noise, labels)
print(generated_images.shape)  # Debería ser (16, 1, 96, 96)