<a href="https://colab.research.google.com/github/Ahtesham519/Genrative_Deep_learning_v2_2023/blob/main/Musegan_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt

plt.style.use("seaborn-v0_8-colorblind")

import os
import tensorflow as tf
from tensorflow.keras import (
    layers,
    models,
    optimizers,
    callbacks,
    initializers,
    metrics,
)


#0.Parameters

In [None]:
BATCH_SIZE = 64

N_BARS = 2
N_STEPS_PER_BAR = 16
MAX_PITCH = 83
N_PITCHES = MAX_PITCH + 1
Z_DIM = 32

CRITIC_STEPS = 5
GP_WEIGHT = 10
CRITIC_LEARNING_RATE = 0.001
GENERATOR_LEARNING_RATE = 0.001
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
EPOCHS = 6000
LOAD_MODEL = False


#1. Prepare the Data

In [None]:
#Load the data
file = os.path.join("/app/data/bach-chorales/Jsb16thSeparated.npz")
with np.load(file, encoding = "bytes" , allow_pickle = True) as f:
  data = f["train"]

In [None]:
N_SONGS = len(data)
print(f"{N_SONGS} chorales in the dataset")
chorale = data[0]
N_BEATS , N_TRACKS = chorale.shape
print(f"{N_BEATS, N_TRACKS} shape of chorale 0")
print("\nChorale 0")
print(chorale[:8])

In [None]:
two_bars = np.array([x[:(N_STEPS_PER_BAR * N_BARS)] for x in data])
two_bars = np.array(np.nan_to_num(two_bars, nan = MAX_PITCH) , dtype = int)
two_bars = two_bars.reshape([N_SONGS , N_BARS, N_STEPS_PER_BAR, N_TRACKS])
print(f"Two bars shape {two_bars.shape}")

In [None]:
data_binary = np.eye(N_PITCHES)[two_bars]
data_binary[data_binary == 0] = -1
data_binary = data_binary.transpose([0,1,2,4,3])
print(f"Data binary shape{data_binary.shape}")

#2. Build the GAN

In [None]:
#Some helper function

initializer = initializers.RandomNormal(mean = 0.0 , stddev = 0.02)

def conv(x, f , k , s, p):
  x = layers.Conv3D(
      filters = f,
      kernel_size = k,
      padding = p ,
      strides = s,
      kernel_initializer = initializer,
  )(x)
  x = layers.LeakyReLU()(x)
  return x

def conv_t(x, f, k, s, a, p , bn):
  x = layers.Conv2DTranspose(
      filters = f,
      kernel_size = k,
      padding = p,
      strides = s,
      kernel_initializer = initializer,
  )(x)
  if bn:
    x = layers.BatchNormalization(momentum = 0.9)(x)

  x = layers.Activation(a)(x)
  return x



In [None]:
def TemporalNetwork():
  input_layer = layers.Input(shape = (Z_DIM,) , name = "temporal_input")
  x = layers.Reshape([1,1,Z_DIM])(input_layer)
  x = conv_t(x, f = 1024 , k=(2,1), s = (1,1) , a = "relu" , p = "valid" , bn = True)
  x = conv_t(
      x ,f = Z_DIM, k = (N_BARS - 1, 1), s = (1,1) , a = "relu" , p = "valid" , bn = True
  )
  output_layer = layers.Reshape([N_BARS, Z_DIM])(x)
  return models.Model(input_layer, output_layer)

TemporalNetwork().summary()

In [None]:
def BarGenerator():
  input_layer = layers.Input(shape=(Z_DIM * 4, ), name = "bar_generator_input")

  x = layers.Dense(1024)(input_layer)
  x = layers.BatchNormalization(momentum = 0.9)(x)
  x = layers.Activation("relu")(x)
  x = layers.Reshape([2, 1 , 512])(x)

  x = conv_t(x , f = 512, k=(2,1) , s = (2,1) , a = "relu" , p = "same" , bn = True)
  x = conv_t(x , f = 256 , k=(2,1), s = (2,1), a = "relu", p = "same", bn = True)
  x = conv_t(x, f = 256 , k = (2,1) , s = (2,1) , a = "relu" , p = "same" , bn = True)
  x = conv_t(x , f = 256 , k = (1,7), s= (1,7), a = "relu" , p ="same" , bn = True)
  x = conv_t(x , f= 1 , k=(1,12) , s = (1,12) , a = "tanh" , p = "same" , bn = False)

  output_layer = layers.Reshape([1, N_STEPS_PER_BAR, N_PITCHES, 1])(x)

  return models.Model(input_layer, output_layer)

BarGenerator().summary()

In [None]:
def Generator():
  chords_input = layers.Input(shape = (Z_DIM , ), name="chords_input")
  style_input = layers.Input(shape = (Z_DIM , ), name = "style_input")
  melody_input = layers.Input(shape = (N_TRACKS , Z_DIM ), name = "melody_input" )
  groove_input = layers.Input(shape = (N_TRACKS, Z_DIM ), name = "groove_input")

  #CHORDS -> TEMPORAL NETWORK
  chords_tempNetwork = TemporalNetwork()
  chords_over_time = chords_tempNetwork(chords_input) #

  #MELODY -> TEMPORAL NETWORK
  melody_over_time = [
      None
  ] * N_TRACKS
  melody_tempNetwork = [None] * N_TRACKS
  for track in range(N_TRACKS):
    melody_tempNetwork[track] = TemporalNetwork()
    melody_track = layers.Lambda(lambda x, track = track : x[:, track, :])(
        melody_input
    )
    melody_over_time[track] = melody_tempNetwork[track](melody_track)