<a href="https://colab.research.google.com/github/Utterbackian/Neuromatch2023_Medical_Imaging/blob/main/EEG2Image_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

GAN Model from EEG2Image Paper converted to Pytorch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import torchvision.transforms.functional as TF

In [None]:
class Generator(nn.Module):
    def __init__(self, n_class=10, res=128):
        super(Generator, self).__init__()
        filters = [1024, 512, 256, 128, 64, 32]  # , 16]
        strides = [4, 2, 2, 2, 2, 2]  # , 2]

        self.cnn_depth = len(filters)

        # For discrete condition we are using Embedding
        self.cond_embedding = nn.Embedding(num_embeddings=n_class, embedding_dim=50)
        self.cond_flat = nn.Flatten()
        self.cond_dense = nn.Linear(in_features=8 * 8 * 1, out_features=64)
        self.cond_reshape = nn.Reshape((64,))

        # Hyperparameter:
        # If only conv  : mean=0.0, var=0.02
        # If using bnorm: mean=1.0, var=0.02
        self.conv = nn.ModuleList([
            spectral_norm(nn.ConvTranspose2d(
                in_channels=1, out_channels=filters[idx], kernel_size=3,
                stride=strides[idx], padding=1, bias=False))
            for idx in range(self.cnn_depth)
        ])

        self.act = nn.ModuleList([nn.LeakyReLU() for idx in range(self.cnn_depth)])
        self.bnorm = nn.ModuleList([nn.BatchNorm2d(filters[idx]) for idx in range(self.cnn_depth)])

        self.last_conv = spectral_norm(nn.Conv2d(
            in_channels=filters[-1], out_channels=3, kernel_size=3,
            stride=1, padding=1, bias=False))

    def forward(self, X):
        X = X.unsqueeze(2).unsqueeze(3)
        X = self.act[0](self.conv[0](X))

        for idx in range(1, self.cnn_depth):
            X = self.act[idx](self.bnorm[idx](self.conv[idx](X)))
        X = self.last_conv(X)
        return X

class Discriminator(nn.Module):
    def __init__(self, n_class=10, res=128):
        super(Discriminator, self).__init__()
        filters = [64, 128, 256, 512, 1024, 1]
        strides = [2, 2, 2, 2, 1, 1]
        self.cnn_depth = len(filters)

        self.cond_embedding = nn.Embedding(num_embeddings=n_class, embedding_dim=50)
        self.cond_flat = nn.Flatten()
        self.cond_dense = nn.Linear(in_features=res * res * 1, out_features=res * res * 1)
        self.cond_reshape = nn.Reshape((res, res, 1))

        self.cnn_conv = nn.ModuleList([
            spectral_norm(nn.Conv2d(
                in_channels=1, out_channels=filters[i], kernel_size=3,
                stride=strides[i], padding=1, bias=False))
            for i in range(self.cnn_depth)
        ])

        self.cnn_bnorm = nn.ModuleList([nn.BatchNorm2d(filters[i]) for i in range(self.cnn_depth)])
        self.cnn_act = nn.ModuleList([nn.LeakyReLU(negative_slope=0.2) for i in range(self.cnn_depth)])

        self.flat = nn.Flatten()
        self.disc_out = nn.Linear(in_features=res * res * 1, out_features=1)

    def forward(self, x, C):
        C = C.unsqueeze(2).unsqueeze(3)
        C = C.expand(-1, x.shape[1], x.shape[2], -1)
        x = torch.cat([x, C], dim=-1)

        for layer_no in range(self.cnn_depth):
            x = self.cnn_act[layer_no](self.cnn_bnorm[layer_no](self.cnn_conv[layer_no](x)))

        reconst_x = None
        x = self.disc_out(self.flat(x))

        return x, reconst_x

class DCGAN(nn.Module):
    def __init__(self):
        super(DCGAN, self).__init__()
        self.gen = Generator()
        self.disc = Discriminator()

    def forward(self, X, C):
        """
        X:  Real or fake images (Discriminator inputs)
        C:  Conditional vector (EEG Features 1D vector concat with noise)
        """
        return self.gen(X), self.disc(X, C)

In [None]:
def dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
    diff_augment_policies = "color,translation"
    noise_vector = torch.rand(batch_size, latent_dim, device=X.device) * 2 - 1
    noise_vector_2 = torch.rand(batch_size, latent_dim, device=X.device) * 2 - 1
    noise_vector = torch.cat([noise_vector, C], dim=-1)
    noise_vector_2 = torch.cat([noise_vector_2, C], dim=-1)

    def disc_hinge(D_real, D_fake):
        return (torch.mean(nn.ReLU()(1 - D_real)) + torch.mean(nn.ReLU()(1 + D_fake))) / 2

    def gen_hinge(D_fake):
        return -torch.mean(D_fake)

    def train_step_disc(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
        model_copt.zero_grad()

        fake_img = model.gen(noise_vector)

        X_aug = TF.affine(TF.to_pil_image(X), *TF._get_inverse_affine_args(*TF._random_affine(TF.to_tensor(X), **TF.random_affine_params(0, 0.3, 0.1, 0.1, 5, False)), TF.to_pil_image(X).size, interpolation=TF.InterpolationMode.BILINEAR))
        X_aug = TF.to_tensor(X_aug).to(X.device)
        fake_img = TF.affine(TF.to_pil_image(fake_img), *TF._get_inverse_affine_args(*TF._random_affine(TF.to_tensor(fake_img), **TF.random_affine_params(0, 0.3, 0.1, 0.1, 5, False)), TF.to_pil_image(fake_img).size, interpolation=TF.InterpolationMode.BILINEAR))
        fake_img = TF.to_tensor(fake_img).to(X.device)

        D_real, X_recon = model.disc(X_aug, C, training=True)
        D_fake, _ = model.disc(fake_img, C, training=True)

        c_loss = disc_hinge(D_real, D_fake)

        c_loss.backward()
        model_copt.step()
        return c_loss.item()

    def train_step_gen(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
        model_gopt.zero_grad()

        fake_img_o = model.gen(noise_vector)
        fake_img_2_o = model.gen(noise_vector_2)
        fake_img = TF.affine(TF.to_pil_image(fake_img_o), *TF._get_inverse_affine_args(*TF._random_affine(TF.to_tensor(fake_img_o), **TF.random_affine_params(0, 0.3, 0.1, 0.1, 5, False)), TF.to_pil_image(fake_img_o).size, interpolation=TF.InterpolationMode.BILINEAR))
        fake_img = TF.to_tensor(fake_img).to(X.device)
        fake_img_2 = TF.affine(TF.to_pil_image(fake_img_2_o), *TF._get_inverse_affine_args(*TF._random_affine(TF.to_tensor(fake_img_2_o), **TF.random_affine_params(0, 0.3, 0.1, 0.1, 5, False)), TF.to_pil_image(fake_img_2_o).size, interpolation=TF.InterpolationMode.BILINEAR))
        fake_img_2 = TF.to_tensor(fake_img_2).to(X.device)

        D_fake, _ = model.disc(fake_img, C, training=False)
        D_fake_2, _ = model.disc(fake_img_2, C, training=False)
        g_loss = gen_hinge(D_fake) + gen_hinge(D_fake_2)
        mode_loss = torch.mean(torch.abs(fake_img_2_o - fake_img_o)) / torch.mean(torch.abs(noise_vector_2 - noise_vector))
        mode_loss = 1.0 / (mode_loss + 1e-5)
        g_loss = g_loss + 1.0 * mode_loss

        g_loss.backward()
        model_gopt.step()
        return g_loss.item()

    model.train()
    model_gopt = optim.Adam(model.gen.parameters())
    model_copt = optim.Adam(model.disc.parameters())
    per_replica_loss_disc = mirrored_strategy.run(train_step_disc, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,))
    per_replica_loss_gen = mirrored_strategy.run(train_step