In [22]:
# 2022-09-02 16:23 Seoul

# --- import dataset ---
from utils.dataloader import mel_dataset
from utils.losses import *
from torch.utils.data import DataLoader, random_split

# --- import model ---
from model.supervised_model import *
from model.Conv1d_model import Conv1d_VAE
# from model.Conv1d_model import Encoder as Encoder1d

from model.Conv2d_model import Conv2d_VAE
from model.Conv2d_model import Encoder

# --- import framework ---
import flax 
import flax.linen as nn
from flax.training import train_state
import jax
import numpy as np
import jax.numpy as jnp
import optax

import cloudpickle
import argparse
from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt

In [2]:
# --- collate batch for dataloader ---
def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)



# --- define init state ---
def init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape), key)
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)




# --- define train_step ---
@jax.jit
def train_step(state, x, z_rng):    
    
    def loss_fn(params):
        recon_x, mean, logvar = model.apply(params, x, z_rng)
        kld_loss = kl_divergence(mean, logvar).mean()
        mse_loss = ((recon_x - x)**2).mean()
        loss = mse_loss + kld_loss
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    return state.apply_gradients(grads=grads), loss

@jax.jit
def linear_train_step(encoder, 
                      enc_state, 
                      enc_batch, 
                      linear_state, 
                      x, y):    
    
    latent = encoder.apply({'params':enc_state, 'batch_stats':enc_batch}, x)
    
    def loss_fn(params):
        logits = linear_evaluation().apply(params, latent)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(linear_state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    
    return state.apply_gradients(grads=grads), loss



# --- define eval step ---
@jax.jit
def eval_step(state, x, z_rng):
    
    recon_x, mean, logvar = model.apply(state.params, x, z_rng)
    kld_loss = kl_divergence(mean, logvar).mean()
    mse_loss = ((recon_x - x)**2).mean()
    loss = mse_loss + kld_loss
    
    return recon_x, loss, mse_loss, kld_loss

@jax.jit
def linear_eval_step(state, x, y):
    
    logits = linear_evaluation().apply(state.params, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    
    return loss, accuracy



In [None]:
batch_size = 16
lr = 0.0001
dilation = True


model = Conv2d_VAE(dilation=dilation)
rng = jax.random.PRNGKey(303)


# ---Load dataset---
dataset_dir = os.path.join(os.path.expanduser('~'),'dataset')            

print("Loading dataset...")    
data = mel_dataset(dataset_dir)
print(f'Loaded data : {len(data)}\n')

dataset_size = len(data)
train_size = int(dataset_size * 0.8)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(data, [train_size, test_size])


train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=int(batch_size/4), shuffle=True, num_workers=0, collate_fn=collate_batch)


print('Data load complete!\n')
print(nn.tabulate(model, rngs={'params': rng})(next(iter(train_dataloader))[0], rng))



# ---initializing model---
print("Initializing model....")
state = init_state(model, 
                   next(iter(train_dataloader))[0].shape, 
                   rng, 
                   lr)

print("Initialize complete!!\n")



Loading dataset...
Load song_meta.json...


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 707989/707989 [00:00<00:00, 822937.41it/s]


Load complete!

Load file list...


111it [00:07, 14.11it/s]


Loaded data : 93389

Data load complete!






Initializing model....
Initialize complete!!



In [6]:
def linear_init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape))
    optimizer = optax.adam(learning_rate=lr)
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)


In [None]:
train_data = iter(train_dataloader)
test_data = iter(test_dataloader)

train_loss_mean = 0
test_loss_mean = 0

print(f'\nEpoch {i+1}')

for j in range(len(train_dataloader)):
    rng, key = jax.random.split(rng)
    x, y = next(train_data)
    test_x, test_y = next(test_data)

    x = (x + 100) 
    test_x = (test_x + 100)

    state, train_loss = train_step(state, x, rng)           
    recon_x, test_loss, mse_loss, kld_loss = eval_step(state, test_x, rng)
    train_loss_mean += train_loss
    test_loss_mean += test_loss

    print(f'step : {j}/{len(train_dataloader)}, train_loss : {round(train_loss, 3)}, test_loss : {round(test_loss, 3)}', end='\r')

print(f'epoch {i+1} - average loss - train : {round(train_loss_mean/len(train_dataloader), 3)}, test : {round(test_loss_mean/len(test_dataloader), 3)}')


print('Pre train complete!\n\n\n')

In [26]:
@jax.jit
def linear_train_step(enc_state, 
                      enc_batch, 
                      linear_state, 
                      x, y):    
    
    latent, _, _ = Encoder2d().apply({'params':enc_state, 'batch_stats':enc_batch}, x, rng)
    
    def loss_fn(params):
        logits = linear_evaluation().apply(params, latent)
        loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(linear_state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
    
    return state.apply_gradients(grads=grads), loss


In [7]:
linear_state = linear_init_state(linear_evaluation(), (batch_size, 30), rng, lr)

In [9]:
encoder = Encoder2d()

In [13]:
print('Linear evalutaion step.')

enc_state = state.params['params']['encoder']
enc_batch = state.params['batch_stats']['encoder']


Linear evalutaion step.


In [16]:
train_data = iter(train_dataloader)
test_data = iter(test_dataloader)

train_loss_mean = 0
test_loss_mean = 0

In [17]:
x, y = next(train_data)

In [18]:
x = x + 100

In [30]:
latent = Encoder2d().apply({'params':enc_state, 'batch_stats':enc_batch}, x, rng)

In [32]:
latent

AttributeError: 'tuple' object has no attribute 'shape'

In [None]:
latent = Encoder2d().apply({'params':enc_state, 'batch_stats':enc_batch}, x, rng)

def loss_fn(params):
    logits = linear_evaluation().apply(params, latent)
    loss = jnp.mean(optax.softmax_cross_entropy(logits, y))
    return loss, logits

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(linear_state.params)
accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))

return state.apply_gradients(grads=grads), loss

In [29]:
linear_state, loss = linear_train_step(
                  enc_state=enc_state, 
                  enc_batch=enc_batch, 
                  linear_state=linear_state, 
                 x=x, y=y)

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[4,512])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function linear_train_step at /tmp/ipykernel_179226/3286191795.py:1 for jit. This concrete value was not available in Python because it depends on the values of the arguments 'enc_state', 'enc_batch', and 'x'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError