In [2]:
from dataloader import mel_dataset
from torch.utils.data import DataLoader, random_split
from Conv2d_model import Conv2d_VAE
from linear_evaluation import linear_evaluation

import flax 
import flax.linen as nn
from flax.training import train_state

import jax
import numpy as np
import jax.numpy as jnp
import optax
from tqdm import tqdm
import os
import wandb
import matplotlib.pyplot as plt


def collate_batch(batch):
    x_train = [x for x, _ in batch]
    y_train = [y for _, y in batch]                  
        
    return np.array(x_train), np.array(y_train)



if __name__ == "__main__":
    batch_size = 16
    lr = 0.0001
    rng = jax.random.PRNGKey(303)
    
    print('\n')
    # ---Load dataset---
    print("Loading dataset...")
    dataset_dir = os.path.join(os.path.expanduser('~'),'dataset')
    data = mel_dataset(dataset_dir)
    print(f'Loaded data : {len(data)}\n')
    target = data[0][0] # (48, 1876)
    target = jnp.expand_dims(target, axis = 0)
    
    dataset_size = len(data)
    train_size = int(dataset_size * 0.8)
    test_size = dataset_size - train_size
    
    train_dataset, test_dataset = random_split(data, [train_size, test_size])

    
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=int(batch_size/4), shuffle=True, num_workers=0, collate_fn=collate_batch)
    
    print(f'batch_size = {batch_size}')
    print(f'learning rate = {lr}')
    print(f'train_size = {train_size}')
    print(f'test_size = {test_size}')
    
   
    
    print("Initialize complete!!\n")
    # ---train model---
    epoch = 1
    # checkpoint_dir = str(input('checkpoint dir : '))
    
    
    for i in range(epoch):
        train_data = iter(train_dataloader)
        test_data = iter(test_dataloader)
        
        train_loss_mean = 0
        test_loss_mean = 0
        
        
        print(f'\nEpoch {i+1}')
        
        for j in range(len(train_dataloader)):
            rng, key = jax.random.split(rng)
            x, y = next(train_data)
            test_x, test_y = next(test_data)
            
            # x = x + 100
            # test_x = test_x + 100
            
            state, train_loss = train_step(state, x, rng)           
            _, test_loss, mse_loss, kld_loss = eval_step(state, test_x, rng)
            
            recon_x, _, _, _ = eval_step(state, target, rng)
            train_loss_mean += train_loss
            test_loss_mean += test_loss
            
                        
        
                
                
            print(f'step : {j}/{len(train_dataloader)}, train_loss : {round(train_loss, 3)}, test_loss : {round(test_loss, 3)}', end='\r')

        print(f'epoch {i+1} - average loss - train : {round(train_loss_mean/len(train_dataloader), 3)}, test : {round(test_loss_mean/len(test_dataloader), 3)}')
        
    # ---Linear evaluation---
    
           






Loading dataset...
Loaded data : 93389

batch_size = 16
learning rate = 0.0001
train_size = 74711
test_size = 18678
Data load complete!

Initializing model....
Initialize complete!!


Epoch 1


ScopeParamShapeError: Inconsistent shapes between value and initializer for parameter "kernel" in "/encoder/Conv_0": (3, 3, 1876, 512), (3, 3, 1, 512). (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [4]:
# ---initializing model---
model = Conv2d_VAE()
print("Initializing model....")
state = init_state(model, 
                   np.expand_dims(next(iter(train_dataloader))[0], axis=-1).shape,
                   rng, 
                   lr)

Initializing model....


tcmalloc: large alloc 1475346432 bytes == 0x10fc30000 @  0x7f1251b11680 0x7f1251b32824 0x7f1251b32b8a 0x7f10b0ad6a8c 0x7f10ac252b35 0x7f10ac256acd 0x7f10ac1902a9 0x7f10abf597fd 0x7f10abf42781 0x5f6929 0x5f74f6 0x50c383 0x570b26 0x569dba 0x5f6eb3 0x5f8892 0x66931d 0x5f627e 0x56d2d5 0x569dba 0x5f6eb3 0x5f6082 0x56d2d5 0x569dba 0x5f6eb3 0x5f8892 0x66931d 0x5f627e 0x56d2d5 0x5f6cd6 0x56bbfa
tcmalloc: large alloc 1475354624 bytes == 0x167b30000 @  0x7f1251b11680 0x7f1251b32824 0x7f10ac9cb045 0x7f10aca6cf46 0x7f10aca6d718 0x7f1014774096 0x7f10ac256acd 0x7f10ac1902a9 0x7f10abf597fd 0x7f10abf42781 0x5f6929 0x5f74f6 0x50c383 0x570b26 0x569dba 0x5f6eb3 0x5f8892 0x66931d 0x5f627e 0x56d2d5 0x569dba 0x5f6eb3 0x5f6082 0x56d2d5 0x569dba 0x5f6eb3 0x5f8892 0x66931d 0x5f627e 0x56d2d5 0x5f6cd6
tcmalloc: large alloc 1475346432 bytes == 0x167b30000 @  0x7f1251b11680 0x7f1251b32824 0x7f1251b32b8a 0x7f10b0ad6a8c 0x7f10ac252b35 0x7f10ac256acd 0x7f10ac1902a9 0x7f10abf597fd 0x7f10abf42781 0x5f6929 0x5f74f6 0x50

In [None]:
enc_params = state.params['params']['encoder']
enc_batch_stats = state.params['batch_stats']['encoder']

@jax.jit
def latent_vector(data):
    whole_latent = []
    for linear_data in data:
        linear_data = jnp.expand_dims(linear_data, axis=0)
        latent_vector = Encoder().apply({'params':enc_params,'batch_stats':enc_batch_stats}, linear_data, )