In [None]:
import torch 
import torch.nn as nn
import torchvision
import torchvision.transforms as tf
import matplotlib.pyplot as plt
import PIL
from google.colab import drive
import scipy.io as sio
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from torchvision.utils import make_grid

# Dataset Loaders

In [None]:
drive.mount('/content/drive')
data_path = '/content/drive/My Drive/DeepLearning_2021/Final_project/proyecto/Data/'
results_path = '/content/drive/My Drive/DeepLearning_2021/Final_project/proyecto/Results'

**Spectogram data loader definition**

In [None]:
#@title SPECTOGRAM DATASET CLASS
class SPECTOGRAM_128(torch.utils.data.Dataset):
    # Initialization method for the dataset
    def __init__(self,dataDir, transform = None):
        mat_loaded = sio.loadmat(dataDir)
        self.data = mat_loaded['X']
        self.labels = mat_loaded['label']
        self.transform = transform

    # What to do to load a single item in the dataset ( read image and label)    
    def __getitem__(self, index):
        data = self.data[:,:,0,index] 
        #label = self.labels[0,index]   ####################
        #data = Image.fromarray(data,mode='L')
        # Apply a trasnformaiton to the image if it is indicated in the initalizer
        if self.transform is not None : 
            data = self.transform(data)
        
        # return the image and the label
        return data,1

    # Return the number of images
    def __len__(self):
        return self.data.shape[3]

**Visualization functions**

In [None]:
#@title Visualize image grid
def show_spectograms(image_batch):
    image_grid = make_grid(image_batch,nrow=8,padding=1)
    image_grid= torch.mean(image_grid,axis=0,keepdim=True)
    npimg = image_grid.detach().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)).squeeze(), interpolation='nearest')

In [None]:
#@title Load dataset
# Test data loader 
tr = transforms.Compose([
        transforms.ToTensor(), 
        transforms.Normalize(0,1)
        ])
disco_pop_128 =  SPECTOGRAM_128(data_path+'Disco_Pop_128.mat',tr)
train_loader = torch.utils.data.DataLoader(dataset=disco_pop_128,
                                           batch_size=64, 
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=disco_pop_128,
                                           batch_size=32, 
                                           shuffle=True)


# GAN

In [None]:
# Discriminator similar to VAE encoder
class Discriminator(nn.Module):
  def __init__(self, base_channels=16):
    super(Discriminator, self).__init__()
    # last fully connected layer acts as a a binary classifier
    self.classifier = Encoder(1,base_channels)

  # Forward pass obtaining the discriminator probability
  def forward(self,x):
    out = self.classifier(x)
    # use sigmoid to get the real/fake image probability
    return torch.sigmoid(out)

# Generator is defined as VAE decoder
class Generator(nn.Module):
  def __init__(self,in_features,base_channels=16):
    super(Generator, self).__init__()
    self.base_channels = base_channels
    self.in_features = in_features
    self.decoder = Decoder(in_features,base_channels)

  # Generate an image from vector z
  def forward(self,z):
    return torch.sigmoid(self.decoder(z))

  # Sample a set of images from random vectors z
  def sample(self,n_samples=256,device='cpu'):
    samples_unit_normal = torch.randn((n_samples,self.in_features)).to(device)
    return self.decoder(samples_unit_normal)

In [None]:
def train_GAN(gen, disc,  train_loader, optimizer_gen, optim_disc,
              num_epochs=10, model_name='gan_disco_pop.ckpt', device='cpu'):
    gen = gen.to(device)
    gen.train() # Set the generator in train mode
    disc = disc.to(device)
    disc.train() # Set the discriminator in train mode

    total_step = len(train_loader)
    losses_list = []

    # Iterate over epochs
    for epoch in range(num_epochs):
        # Iterate the dataset
        disc_loss_avg = 0
        gen_loss_avg = 0
        nBatches = 0
        update_generator = True

        for i, (real_images) in enumerate(train_loader): ################ real_images,labels
            # Get batch of samples and labels
            real_images = real_images.float().to(device) / 255
            n_images = real_images.shape[0]

            # Forward pass
            # Generate random images with the generator
            fake_images = gen.sample(n_images,device=device)
            
            # Use the discriminator to obtain the probabilties for real and generate imee
            prob_real = disc(real_images)
            prob_fake = disc(fake_images)
            
            # Generator loss
            gen_loss = -torch.log(prob_fake).mean()
            # Discriminator loss
            disc_loss = -0.5*(torch.log(prob_real) + torch.log(1-prob_fake)).mean()

            
            # We are going to update the discriminator and generator parameters alternatively at each iteration

            if (update_generator):
              # Optimize generator
              # Backward and optimize
              optimizer_gen.zero_grad()
              gen_loss.backward() # Necessary to not erase intermediate variables needed for computing disc_loss gradient
              optimizer_gen.step()
              update_generator = False
            else:           
              # Optimize discriminator
              # Backward and optimize
              optimizer_disc.zero_grad()
              disc_loss.backward()
              optimizer_disc.step()
              update_generator = True
                

            disc_loss_avg += disc_loss.cpu().item()
            gen_loss_avg += gen_loss.cpu().item()

            nBatches+=1
            if (i+1) % 200 == 0:
                print ('Epoch [{}/{}], Step [{}/{}], Gen. Loss: {:.4f}, Disc Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, gen_loss_avg / nBatches, disc_loss_avg / nBatches))
        print ('Epoch [{}/{}], Step [{}/{}], Gen. Loss: {:.4f}, Disc Loss: {:.4f}' 
                       .format(epoch+1, num_epochs, i+1, total_step, gen_loss_avg / nBatches, disc_loss_avg / nBatches))
        # Save model
        losses_list.append(disc_loss_avg / nBatches)
        torch.save(gan_gen.state_dict(), results_path+ '/' + model_name)
          
    return losses_list 

## Trainning a GAN

In [None]:
# Define Geneartor and Discriminator networks
gan_gen = Generator(32)
gan_disc = Discriminator()

#Initialize indepdent optimizer for both networks
learning_rate = .0005
optimizer_gen = torch.optim.Adam(gan_gen.parameters(),lr = learning_rate, weight_decay=1e-5)
optimizer_disc = torch.optim.Adam(gan_disc.parameters(),lr = learning_rate, weight_decay=1e-5)

# Train the GAN
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
loss_list = train_GAN(gan_gen,gan_disc, train_loader, optimizer_gen, optimizer_disc,
                      num_epochs=20, model_name='gan_disco_pop.ckpt', device=device)

## Visualize synthetic images

In [None]:
# Load generator
gan_gen = Generator(32)
gan_gen.load_state_dict(torch.load(results_path+'/gan_disco_pop.ckpt',map_location=torch.device('cpu')))
gan_gen.eval() # Put in eval model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gan_gen = gan_gen.to(device)

# Generate random images from sampled vectors z and visualize them 
x_gen = gan_gen.sample(64,device=device)
show_spectograms(x_gen.cpu())
# image_grid = make_grid(x_gen.cpu(),nrow=8,padding=1)
# plt.figure(figsize=(8,8))
# plt.imshow(image_grid.permute(1,2,0).detach().numpy())

# Save the images

In [None]:
from PIL import Image

for i in range(len(x_gen)):
  data = x_gen[i,:,:,:].cpu().detach().numpy().squeeze()
  #Rescale to 0-255 and convert to uint8
  rescaled = (255.0 / data.max() * (data - data.min())).astype(np.uint8)
  im = Image.fromarray(rescaled,mode='L')
  im.save(results_path + f'/test_{i}.png')