In [1]:
# 2022-09-02 16:23 Seoul

# --- import dataset ---
from utils.dataloader import mel_dataset
from utils.losses import *
from torch.utils.data import DataLoader, random_split

# --- import model ---
from model.supervised_model import *

# --- import framework ---
import flax 
import flax.linen as nn
from flax.training import train_state
import jax
import numpy as np
import jax.numpy as jnp
import optax

import cloudpickle
import argparse
from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt


# --- collate batch for dataloader ---
def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)




# Init_state

In [2]:
# --- define init state ---
def init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape), key)
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)

# Train step

In [8]:
# --- define train_step ---
@jax.jit
def train_step(state, x, z_rng):    
    
    def loss_fn(params):
        recon_x, mean, logvar = model.apply(params, x, z_rng)
        kld_loss = kl_divergence(mean, logvar).mean()
        mse_loss = ((recon_x - x)**2).mean()
        loss = mse_loss + kld_loss
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    return state.apply_gradients(grads=grads), loss


# Eval step

In [4]:
# --- define eval step ---
@jax.jit
def eval_step(state, x, z_rng):
    
    recon_x, mean, logvar = model.apply(state.params, x, z_rng)
    kld_loss = kl_divergence(mean, logvar).mean()
    mse_loss = ((recon_x - x)**2).mean()
    loss = mse_loss + kld_loss
    
    return recon_x, loss, mse_loss, kld_loss

@jax.jit
def linear_eval_step(state, x, y):
    
    logits = linear_evaluation().apply(state.params, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    
    return loss, accuracy



In [3]:
from model.Conv2d_model import Conv2d_VAE, Encoder
model = Conv2d_VAE(dilation=True)


batch_size = 16
lr = 0.0001

rng = jax.random.PRNGKey(303)


# ---Load dataset---
dataset_dir = os.path.join(os.path.expanduser('~'),'dataset')            

print("Loading dataset...")    
data = mel_dataset(dataset_dir)
print(f'Loaded data : {len(data)}\n')

dataset_size = len(data)
train_size = int(dataset_size * 0.8)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(data, [train_size, test_size])


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=int(batch_size/4), shuffle=True, num_workers=0, collate_fn=collate_batch)

Loading dataset...
Load song_meta.json...


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 707989/707989 [00:00<00:00, 840478.87it/s]


Load complete!

Load file list...


111it [00:09, 11.21it/s]


Loaded data : 93389



In [5]:

print('Data load complete!\n')
print(nn.tabulate(model, rngs={'params': rng})(next(iter(train_dataloader))[0], rng))



# ---initializing model---
print("Initializing model....")
state = init_state(model, 
                   next(iter(train_dataloader))[0].shape, 
                   rng, 
                   lr)

print("Initialize complete!!\n")



Data load complete!






Initializing model....
Initialize complete!!



In [9]:
# ---train model---


train_data = iter(train_dataloader)
test_data = iter(test_dataloader)

train_loss_mean = 0
test_loss_mean = 0



rng, key = jax.random.split(rng)
x, y = next(train_data)
test_x, test_y = next(test_data)

x = (x + 100) 
test_x = (test_x + 100)

state, train_loss = train_step(state, x, rng)           
recon_x, test_loss, mse_loss, kld_loss = eval_step(state, test_x, rng)



print('Pre train complete!\n\n\n')




Pre train complete!





In [11]:
recon_x.shape

(4, 48, 1876)

In [14]:
def linear_init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape))
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)


In [15]:
linear_state = linear_init_state(linear_evaluation(), (16, 512), rng, lr)

In [13]:
enc_state = state.params['params']['encoder']
enc_batch = state.params['batch_stats']['encoder']

In [19]:
latent, mean, logvar = Encoder().apply({'params':enc_state, 'batch_stats':enc_batch}, x, rng)

In [20]:
latent.shape

(16, 512)

In [21]:
logits = linear_evaluation().apply(linear_state.params, latent)

In [24]:
y.shape

(16, 30)

In [23]:
logits.shape

(16, 30)

In [25]:
loss = jnp.mean(optax.softmax_cross_entropy(logits, y))

In [26]:
loss

DeviceArray(3.823896, dtype=float32)

In [33]:
@jax.jit
def linear_train_step(enc_state, 
                      enc_batch, 
                      linear_state, 
                      x, y, rng):    
    
    latent, mean, logvar = Encoder().apply({'params':enc_state, 'batch_stats':enc_batch}, x, rng)
    
    def loss_fn(params):
        logits = linear_evaluation().apply(params, latent)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(linear_state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(y, -1))
    
    return linear_state.apply_gradients(grads=grads), loss


In [34]:
linear_state, loss = linear_train_step(enc_state, enc_batch, linear_state, x, y, rng)

In [37]:
loss

DeviceArray(3.8246925, dtype=float32)