# Tutorial
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

https://github.com/nashory/gans-awesome-applications

https://github.com/soumith/ganhacks


`ffmpeg -i notebooks/out.mp4 -i data/audio/sail.mp3 -c copy -map 0:v:0 -map 1:a:0 output.mp4`

`youtube-dl -f bestaudio --extract-audio --audio-format mp3 --audio-quality 0 --verbose tgIqecROs5M`

# Imports

In [None]:
import random
import numpy as np
import cv2
import subprocess

from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data

from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose
import torchvision.transforms as transforms
import torchvision.utils as vutils

import torchaudio
from torchaudio.transforms import MelSpectrogram, AmplitudeToDB

import matplotlib.pyplot as plt
import matplotlib.animation as animation

%matplotlib notebook

# device

In [None]:
# Decide which device we want to run on
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Seed

In [None]:
# manualSeed = 999
# random.seed(manualSeed)
# torch.manual_seed(manualSeed);

# Settings

In [None]:
# Root directory for dataset
dataroot = "../data"

# Number of workers for dataloader
num_workers = 4

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Size of z latent vector (i.e. size of generator input)
z_size = 128

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Data Loader

In [None]:
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = ImageFolder(root=dataroot,transform=transform)

# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,shuffle=True, num_workers=num_workers)

In [None]:
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)));

# Shape Printing Helper Module

In [None]:
class PrintShape(nn.Module):
    def __init__(self,enabled = False):
        super().__init__()
        self.enabled = enabled
        
    def forward(self,x):
        if self.enabled:
            print(x.shape)
        return x

# Weight Initialization Function

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# Generator

In [None]:
class Generator(nn.Module):
    def __init__(self,print_enabled=False):
        super().__init__()
        
        self.main = nn.Sequential(
            PrintShape(print_enabled),
            # input is Z, going into a convolution
            #torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, padding_mode='zeros')
            nn.ConvTranspose2d( z_size, ngf * 8, 4, 1, 0, bias=False),
            PrintShape(print_enabled),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            PrintShape(print_enabled),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            PrintShape(print_enabled),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            PrintShape(print_enabled),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, 3, 4, 2, 1, bias=False),
            PrintShape(print_enabled),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

In [None]:
g = Generator(print_enabled=True)
g(torch.rand(1,z_size,1,1));

# Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, print_enabled=False):
        super().__init__()

        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            PrintShape(print_enabled),
            nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            PrintShape(print_enabled),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            PrintShape(print_enabled),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            PrintShape(print_enabled),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            PrintShape(print_enabled),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
            PrintShape(print_enabled),
        )

    def forward(self, input):
        return self.main(input)

In [None]:
d = Discriminator(print_enabled=True)
d(torch.rand(1,3,64,64));

# Random Sample

In [None]:
def rand(n,z_size):
    noise = torch.randn(n, z_size, 1, 1, device=device)
    noise /= torch.norm(noise, p=2,dim=1, keepdim=True)
    return noise
    

# Training Setup

In [None]:

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
epoch=0


netG = Generator().to(device)
netG.apply(weights_init);

netD = Discriminator().to(device)
netD.apply(weights_init);


# Initialize BCELoss function
criterion = nn.BCELoss() # -( y*log(x) + (1-y)*log(1-x) )

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = rand(64, z_size)

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training Loop

In [None]:

num_epochs = 25
iters = 0

print("Starting Training Loop...")
# For each epoch
while epoch < num_epochs:
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        #get the batch of real images
        real_image_batch = data[0].to(device)
        
        #get the shape of the image batch
        b,c,h,w = real_image_batch.shape
        
        #create the tensors of labels for the batches
        real_label_batch = torch.full((b,),real_label, device=device)
        fake_label_batch = torch.full((b,),fake_label, device=device)
        
        
        ### GENERATE FAKE IMAGE BATCH###
        #create noise vector
        noise = rand(b, z_size)
        
        #generate fake image
        fake_image_batch = netG(noise)
        
        
        ### TRAIN DISCRIMINATOR ###
        #Try to discriminate between real or fake
        real_output_batch = netD( real_image_batch          ).view(-1)
        fake_output_batch = netD( fake_image_batch.detach() ).view(-1)
        
        D_x    = real_output_batch.mean().item()
        D_G_z1 = fake_output_batch.mean().item()
        
        #calculate the loss for classifying correctly
        lossD_real = criterion(real_output_batch, real_label_batch)
        lossD_fake = criterion(fake_output_batch, fake_label_batch)
        lossD = lossD_real + lossD_fake
        
        #Accumulate the gradients for the discriminator
        netD.zero_grad()
        lossD_real.backward()
        lossD_fake.backward()
        
        #step the optimizer for the discriminator
        optimizerD.step()
        
        
        ### TRAIN GENERATOR ###
        #discriminate the fake images again but don't detach the generator
        fake_output_batch = netD( fake_image_batch ).view(-1)
        D_G_z2 = fake_output_batch.mean().item()
        
        #calculate the loss for classifying these as real images
        lossG = criterion(fake_output_batch, real_label_batch)
        
        #Accumulate gradients for the generator
        netG.zero_grad()
        lossG.backward()
        
        #step the optimizer for the generator
        optimizerG.step()
        
        
        ### STATS ###
         # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     lossD.item(), lossG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(lossG.item())
        D_losses.append(lossD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
    epoch += 1
        
        
        
        

# Results

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Imports

# Real Images vs. Fake Images

In [None]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,7))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

# Load Audio

In [None]:
class AudioLoader():
    def __call__(self,data_dict):
        #make sure path is a pathlib object
        data_dict["audio_path"] = Path(data_dict["audio_path"])
        
        #load the audio into a tensor
        audio, sample_rate = torchaudio.load(data_dict["audio_path"])
        
        
        data_dict["audio"] = audio
        data_dict["sample_rate"] = sample_rate
        data_dict["run_time"] = audio.shape[1] / sample_rate
        print(f"Path: {data_dict['audio_path']},run_time: {data_dict['run_time']}")
        return data_dict




# Spectrogram

In [None]:
class AudioToSpectrogram():
    def __init__(self,n_fft,n_mels):
        self.n_fft = n_fft
        self.n_mels = n_mels
        
    def __call__(self,data_dict):
        with torch.no_grad():
            #convert the audio into a spectrogram
            spectrogram = MelSpectrogram(sample_rate = data_dict["sample_rate"], n_fft = self.n_fft, n_mels = self.n_mels)(data_dict["audio"])

            data_dict["spectrogram"] = spectrogram
            print(data_dict["spectrogram"].shape)
        
        return data_dict

    
class SpectrogramToZ():

    def __call__(self,data_dict):
        with torch.no_grad():
                
            data_dict["num_frames"] = round(data_dict["run_time"]*data_dict["fps"])
            
            spectrogram = torch.nn.AdaptiveAvgPool2d((None,data_dict["num_frames"]))(data_dict["spectrogram"])
            
            #convert the amplitudes into decibel scale
            spectrogram = torchaudio.transforms.AmplitudeToDB()(spectrogram)

            #subtract the mean of each frequency so that it is zero centered
            spectrogram -= spectrogram.mean(dim=2,keepdim=True)

            #normalize the spectrum for each time slice so that the vector has unit length
            spectrogram /= torch.norm(spectrogram,p=2,dim=1,keepdim=True)
            
            data_dict["z"] = spectrogram[0]
            
        return data_dict
    
#     plt.figure(figsize=(10,10))
#     plt.imshow(spectrogram[1,:,0:1000], aspect='auto')
    
#     spectrogram_length = spectrogram.shape[2]
#     print(spectrogram.shape)

# Generate Video

In [None]:
class GenerateVideo():
    def __init__(self,generator):
        self.generator = generator

    def __call__(self,data_dict):
        data_dict["temp_video_path"] = Path("temp.mp4")
        try:
            fourcc = cv2.VideoWriter_fourcc(*"X264")
            writer = cv2.VideoWriter(str(data_dict["temp_video_path"]),fourcc,data_dict["fps"],(image_size,image_size))

            for i in range(data_dict["num_frames"]):
                
                z = data_dict["z"][:,i].reshape(1,-1,1,1).to(device)
                
                fake = self.generator(z)[0].detach().cpu().numpy()
                fake = (fake*128+128).astype("uint8")
                fake = fake.transpose(1,2,0)
                fake = cv2.cvtColor(fake,cv2.COLOR_RGB2BGR)
                writer.write(fake)
        except:
            pass
        finally:
            writer.release()
            cv2.destroyAllWindows()
            
        return data_dict


# Compose Video

In [None]:
class ComposeVideo():
    def __init__(self,output_dir):
        self.output_dir = Path(output_dir)
        
    def __call__(self,data_dict):
        audio_path = str(data_dict["audio_path"])
        video_path = str(data_dict["temp_video_path"])
        data_dict["output_path"] = (self.output_dir / data_dict["audio_path"].stem).with_suffix(".mp4")
        output_path = data_dict["output_path"]
        
        cmd = f"ffmpeg -y -i {video_path} -i {audio_path} -c copy -map 0:v:0 -map 1:a:0 {output_path}"
        print(cmd)
        subprocess.call(cmd, shell=True)
        
        return data_dict


In [None]:
pipeline = Compose([
    AudioLoader(),
    AudioToSpectrogram(1024,z_size),
    SpectrogramToZ(),
    GenerateVideo(netG),
    ComposeVideo("../output_videos"),
])

data_dict = {"fps":30}
for audio_path in Path("../data/audio").glob("*.mp3"):
    data_dict["audio_path"] = audio_path
    data_dict = pipeline(data_dict)