In [1]:
from dataloader import mel_dataset
from torch.utils.data import DataLoader, random_split
from Conv2d_model import Conv2d_VAE
from linear_evaluation import linear_evaluation

import flax 
import flax.linen as nn
from flax.training import train_state

import jax
import numpy as np
import jax.numpy as jnp
import optax
from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt


def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)



if __name__ == "__main__":
    batch_size = 16
    lr = 0.0001
    rng = jax.random.PRNGKey(303)
    
    print('\n')
    # ---Load dataset---
    print("Loading dataset...")
    dataset_dir = os.path.join(os.path.expanduser('~'),'dataset')
    data = mel_dataset(dataset_dir)
    print(f'Loaded data : {len(data)}\n')
    target = data[0][0] # (48, 1876)
    target = jnp.expand_dims(target, axis = 0)
    
    dataset_size = len(data)
    train_size = int(dataset_size * 0.8)
    test_size = dataset_size - train_size
    
    train_dataset, test_dataset = random_split(data, [train_size, test_size])

    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=int(batch_size/4), shuffle=True, num_workers=0, collate_fn=collate_batch)
    
    print(f'batch_size = {batch_size}')
    print(f'learning rate = {lr}')
    print(f'train_size = {train_size}')
    print(f'test_size = {test_size}')
    
   
    
    print("Initialize complete!!\n")
    # ---train model---
   





Loading dataset...


KeyboardInterrupt: 

In [4]:
def init_state(model, x_shape, key, lr) -> train_state.TrainState:
    params = model.init({'params': key}, jnp.ones(x_shape), key)
    # Create the optimizer
    optimizer = optax.adam(learning_rate=lr)
    # Create a State
    return train_state.TrainState.create(
        apply_fn=model.apply,
        tx=optimizer,
        params=params)

In [11]:
@jax.vmap
def kl_divergence(mean, logvar):
    return -0.5 * jnp.sum(1 + logvar - jnp.square(mean) - jnp.exp(logvar))

@jax.vmap
def binary_cross_entropy_with_logits(logits, labels):
    logits = nn.log_sigmoid(logits)
    return -jnp.sum(labels * logits + (1. - labels) * jnp.log(-jnp.expm1(logits)))


@jax.jit
def train_step(state, x, z_rng):
    
    x = jnp.expand_dims(x, axis=-1)
    
    def loss_fn(params):
        recon_x, mean, logvar = Conv2d_VAE().apply(params, x, z_rng)

        mse_loss = ((recon_x - x)**2).mean()
        kld_loss = kl_divergence(mean, logvar).mean()
        loss = mse_loss + kld_loss
        return loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    
    return state.apply_gradients(grads=grads), loss

@jax.jit
def eval_step(state, x, z_rng):
    x = jnp.expand_dims(x, axis=-1)
    recon_x, mean, logvar = Conv2d_VAE().apply(state.params, x, z_rng)
    mse_loss = ((recon_x -x)**2).mean()
    kld_loss = kl_divergence(mean, logvar).mean()
    loss = mse_loss + kld_loss
    
    return recon_x, loss, mse_loss, kld_loss

In [5]:
# ---initializing model---
model = Conv2d_VAE()
print("Initializing model....")
state = init_state(model, 
                   (1, 48, 1876, 1),
                   rng, 
                   lr)

Initializing model....


In [14]:
epoch = 1
# checkpoint_dir = str(input('checkpoint dir : '))


train_loss_mean = 0
test_loss_mean = 0


rng, key = jax.random.split(rng)

# x = x + 100
# test_x = test_x + 100

state, train_loss = train_step(state, np.random.randn(1, 48, 1876), rng)           
_, test_loss, mse_loss, kld_loss = eval_step(state, np.random.randn(1, 48, 1876), rng)

recon_x, _, _, _ = eval_step(state, np.random.randn(1, 48, 1876), rng)
train_loss_mean += train_loss
test_loss_mean += test_loss





print(f' {round(train_loss, 3)}, test_loss : {round(test_loss, 3)}', end='\r')


# ---Linear evaluation--

 1.0460000038146973, test_loss : 1.0040000677108765

In [20]:


@jax.jit
def encoder_apply(data, rng):
    latent_vector = Encoder(train=False).apply({'params':enc_params,'batch_stats':enc_batch_stats}, data, rng)
    return latent_vector




In [23]:
from Conv2d_model import Encoder

In [26]:
latent_vector = encoder_apply(np.expand_dims(data[0][0], axis=(0,-1)), rng)

In [34]:
latent_vector[0].shape

(1, 512)

In [39]:
from tqdm import tqdm

In [44]:
x_array = [x for x, y in data]

In [45]:
x_array = np.array(x_array)

tcmalloc: large alloc 33637974016 bytes == 0x7fd60a000 @  0x7faf7ca3a680 0x7faf7ca5b824 0x7faf20d14064 0x7faf20d147ff 0x7faf20d72fc5 0x7faf20d761ea 0x7faf20d766e7 0x7faf20e13925 0x5c6617 0x570b26 0x569dba 0x6902a7 0x6023c4 0x5c6730 0x56bacd 0x501488 0x56d4d6 0x501488 0x56d4d6 0x501488 0x505166 0x56bbfa 0x5f6cd6 0x56bacd 0x5f6cd6 0x56bbfa 0x569dba 0x5f6eb3 0x50bc2c 0x5f6082 0x56d2d5


KeyboardInterrupt: 

In [40]:
whole_latent = []
for sing in tqdm(data):
    rng, key = jax.random.split(rng)
    latent_vector = encoder_apply(np.expand_dims(sing[0], axis=(0,-1)), rng)
    whole_latent.append(latent_vector)

  3%|████▌                                                                                                                                                                 | 2552/93389 [03:04<1:49:19, 13.85it/s]


KeyboardInterrupt: 

In [None]:
state = linear_init_state(linear_evaluation(), )


In [None]:
enc_params = state.params['params']['encoder']
enc_batch_stats = state.params['batch_stats']['encoder']


