We're going to use the Generator and discriminators to thus train on the LJSpeech dataset. The dataset consists of around 13,000 samples reading from 7 non-fiction books. They vary from 1 - 10 seconds and total around 24h.  



In [4]:
import argparse
import jax
import os
import librosa
import optax
import wandb
from tqdm import tqdm
import equinox as eqx

from Generator import Generator
from Discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, discriminator_loss

def create_parser():
    parser = argparse.ArgumentParser(description="Arguments for training HiFiGaN")

    parser.add_argument("--dataset_path", "Path to the dataset to use (LJSpeech)")

    parser.add_argument("--learning_rate", "Learning rate during training")

    parser.add_argument("--output_path", "Path to store model weights")

    return parser

def save_model(model, path):
    eqx.tree_serialise_leaves(path, model)

def get_dataset(dataset_path):
    mel_dir = os.path.join(dataset_path, 'mel_spectrograms')
    wav_dir = os.path.join(dataset_path, 'processed_wavs')

    mels = []
    wavs = []

    for filename in os.listdir(mel_dir):
        np_data = jax.numpy.load(os.path.join(mel_dir, filename))
        # print(np_data.shape)
        mels.append(np_data[:,:-1])

    for filename in os.listdir(wav_dir):
        wav_data, _ = librosa.load(os.path.join(wav_dir, filename[:-4]+".mp3"))
        wav_data = jax.numpy.array(wav_data)
        wav_data = jax.numpy.expand_dims(wav_data, 0)
        wavs.append(wav_data)


    return jax.numpy.array(mels), jax.numpy.array(wavs)



@eqx.filter_value_and_grad
def calculate_gan_loss(gan, period, scale, x, y):

    gan_result = jax.vmap(gan)(x)
    fake_scale, _, _, _ = jax.vmap(scale)(gan_result, y)
    fake_period, _, _, _ = jax.vmap(period)(gan_result, y)

    l1_loss = jax.numpy.mean(jax.numpy.abs(gan_result - y)) # L1 loss
    G_loss = 0
    for fake in fake_period:
        G_loss += jax.numpy.mean((fake - 1) ** 2)
    for fake in fake_scale:
        G_loss += jax.numpy.mean((fake - 1) ** 2)
    # G_loss_scale = jax.numpy.mean((fake_scale - jax.numpy.ones(batch_size)) ** 2)

    return G_loss + 30 * l1_loss

@eqx.filter_value_and_grad
def calculate_disc_loss(model, fake, real):
    fake_result, real_result, _, _ = jax.vmap(model)(fake, real)
    loss = 0
    for fake_res, real_res in zip(fake_result, real_result):
        fake_loss = jax.numpy.mean((fake_res)**2)
        real_loss = jax.numpy.mean((real_res-1)**2)
        loss += (fake_loss + real_loss)

    return loss

@eqx.filter_jit
def make_step(gan, period_disc, scale_disc, x, y, gan_optim, period_optim, scale_optim, optim1, optim2, optim3):

    result = jax.vmap(gan)(x)

    trainable_scale, _ = eqx.partition(scale_disc, eqx.is_inexact_array)
    trainable_period, _ = eqx.partition(period_disc, eqx.is_inexact_array)

    loss_scale, grads_scale = calculate_disc_loss(scale_disc, result, y)
    updates, scale_optim = optim2.update(grads_scale, scale_optim)
    scale_disc = eqx.apply_updates(trainable_scale, updates)

    loss_period, grads_period = calculate_disc_loss(period_disc, result, y)
    updates, period_optim = optim3.update(grads_period, period_optim)
    period_disc = eqx.apply_updates(trainable_period, updates)

    loss_gan, grads_gan = calculate_gan_loss(gan, period_disc, scale_disc, x, y)
    updates, gan_optim = optim1.update(grads_gan, gan_optim)
    gan = eqx.apply_updates(gan, updates)
    
    return loss_gan, loss_period, loss_scale, gan, period_disc, scale_disc, gan_optim, period_optim, scale_optim

def train_hifigan(dataset_path, output_path, learning_rate=1e-4, batch_size=32, epochs=100, seed=69):
    run = wandb.init(
    # Set the project where this run will be logged
    project="HiFiGaN JAX",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "PRNG_SEED": seed,
    },
)

    key = jax.random.PRNGKey(seed)

    key1, key2, key3 = jax.random.split(key, 3)

    generator = Generator(channels_in=80, channels_out=1, key=key1)
    generator = eqx.tree_deserialise_leaves("checkpoints_trained_to_100/generator_epoch_90.eqx", like=generator) 
    scale_disc = MultiScaleDiscriminator(key=key2)
    period_disc = MultiPeriodDiscriminator(key=key3)

    mels, wavs = get_dataset(dataset_path)


    optim1 = optax.adam(2e-4, b1=0.8, b2=0.99)
    gan_optim = optim1.init(generator)
    
    optim2 = optax.adam(learning_rate)
    trainable_scale, _ = eqx.partition(scale_disc, eqx.is_inexact_array)
    scale_optim = optim2.init(trainable_scale)

    optim3= optax.adam(learning_rate)
    trainable_period, _ = eqx.partition(period_disc, eqx.is_inexact_array)
    period_optim = optim3.init(trainable_period)

    dataset_size = len(wavs)

    for epoch in tqdm(range(epochs)):
        avg_gan_loss = 0
        avg_scale_loss = 0
        avg_period_loss = 0

        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, dataset_size)
        
        for batch_start in range(0, dataset_size, batch_size):
            batch_indices = perm[batch_start: batch_start + batch_size]
            x = mels.take(batch_indices, axis=0)
            y = wavs.take(batch_indices, axis=0)
            
            # Display batch indices and data
            gan_loss, period_loss,scale_loss, generator, period_disc, scale_disc, gan_optim, period_optim, scale_optim = make_step(generator, period_disc, scale_disc, x, y, gan_optim, period_optim,scale_optim, optim1, optim2, optim3)
            # print(grads)        loss = loss.item()
            wandb.log({"GAN loss": gan_loss, "SCALE loss": scale_loss, "PERIOD loss": period_loss})

            avg_gan_loss += gan_loss
            # avg_scale_loss += scale_loss
            avg_period_loss += period_loss
            # print(batch_data.shape)

            # print(res.shape)

        if epoch % 10 == 0: 
            save_model(generator, os.path.join(output_path, f"generator_epoch_{epoch}.eqx"))
        print(f"Average gan loss: {avg_gan_loss/dataset_size}, Avg scale loss: {avg_scale_loss/dataset_size}, Avg period loss: {avg_period_loss/dataset_size}")

# if __name__ == "__main__":
#     parser = create_parser()
#     args = parser.parse_args()

#     train_hifigan(dataset_path=args.dataset_path, output_path=args.output_path, learning_rate=args.learning_rate)


In [5]:
train_hifigan(dataset_path="dataset", output_path="checkpoint", learning_rate=1e-5)

  0%|          | 0/100 [00:00<?, ?it/s]2024-10-27 15:35:45.840517: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 33.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-10-27 15:35:46.820631: W external/xla/xla/tsl/framework/bfc_allocator.cc:306] Allocator (GPU_0_bfc) ran out of memory trying to allocate 33.03GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
  1%|          | 1/100 [01:56<3:12:40, 116.77s/it]

Average gan loss: 0.19403529167175293, Avg scale loss: 0.0, Avg period loss: 0.07992879301309586


  2%|▏         | 2/100 [03:11<2:30:48, 92.33s/it] 

Average gan loss: 0.14877063035964966, Avg scale loss: 0.0, Avg period loss: 0.07096628844738007


  3%|▎         | 3/100 [04:27<2:16:39, 84.53s/it]

Average gan loss: 0.15316516160964966, Avg scale loss: 0.0, Avg period loss: 0.07150690257549286


  4%|▍         | 4/100 [05:42<2:09:27, 80.91s/it]

Average gan loss: 0.15544641017913818, Avg scale loss: 0.0, Avg period loss: 0.0700031965970993


  5%|▌         | 5/100 [06:57<2:04:57, 78.92s/it]

Average gan loss: 0.1619253307580948, Avg scale loss: 0.0, Avg period loss: 0.06883680075407028


  6%|▌         | 6/100 [08:13<2:01:40, 77.67s/it]

Average gan loss: 0.16180655360221863, Avg scale loss: 0.0, Avg period loss: 0.0684453621506691


  7%|▋         | 7/100 [09:28<1:59:18, 76.97s/it]

Average gan loss: 0.16317789256572723, Avg scale loss: 0.0, Avg period loss: 0.06919680535793304


  8%|▊         | 8/100 [10:44<1:57:16, 76.49s/it]

Average gan loss: 0.16158124804496765, Avg scale loss: 0.0, Avg period loss: 0.06738818436861038


  9%|▉         | 9/100 [11:59<1:55:31, 76.17s/it]

Average gan loss: 0.1624322384595871, Avg scale loss: 0.0, Avg period loss: 0.06672374159097672


 10%|█         | 10/100 [13:14<1:53:48, 75.87s/it]

Average gan loss: 0.16497038304805756, Avg scale loss: 0.0, Avg period loss: 0.06740179657936096


 11%|█         | 11/100 [14:30<1:52:23, 75.77s/it]

Average gan loss: 0.16380207240581512, Avg scale loss: 0.0, Avg period loss: 0.06744048744440079


 12%|█▏        | 12/100 [15:45<1:50:58, 75.67s/it]

Average gan loss: 0.1627483367919922, Avg scale loss: 0.0, Avg period loss: 0.06731497496366501


 13%|█▎        | 13/100 [17:01<1:49:37, 75.60s/it]

Average gan loss: 0.16508182883262634, Avg scale loss: 0.0, Avg period loss: 0.06547105312347412


 14%|█▍        | 14/100 [18:16<1:48:17, 75.55s/it]

Average gan loss: 0.16632242500782013, Avg scale loss: 0.0, Avg period loss: 0.06656733900308609


 15%|█▌        | 15/100 [19:31<1:46:52, 75.45s/it]

Average gan loss: 0.1636766940355301, Avg scale loss: 0.0, Avg period loss: 0.0670621320605278


 16%|█▌        | 16/100 [20:47<1:45:37, 75.45s/it]

Average gan loss: 0.16513197124004364, Avg scale loss: 0.0, Avg period loss: 0.06730874627828598


 17%|█▋        | 17/100 [22:02<1:44:20, 75.42s/it]

Average gan loss: 0.1639147847890854, Avg scale loss: 0.0, Avg period loss: 0.06508345901966095


 18%|█▊        | 18/100 [23:18<1:43:05, 75.43s/it]

Average gan loss: 0.16690604388713837, Avg scale loss: 0.0, Avg period loss: 0.06614424288272858


 19%|█▉        | 19/100 [24:33<1:41:48, 75.41s/it]

Average gan loss: 0.16603098809719086, Avg scale loss: 0.0, Avg period loss: 0.06604374200105667


 20%|██        | 20/100 [25:48<1:40:26, 75.33s/it]

Average gan loss: 0.16451285779476166, Avg scale loss: 0.0, Avg period loss: 0.06482987105846405


 21%|██        | 21/100 [27:04<1:39:18, 75.43s/it]

Average gan loss: 0.16306641697883606, Avg scale loss: 0.0, Avg period loss: 0.06889234483242035


 22%|██▏       | 22/100 [28:19<1:38:04, 75.44s/it]

Average gan loss: 0.16450852155685425, Avg scale loss: 0.0, Avg period loss: 0.0668235719203949


 23%|██▎       | 23/100 [29:35<1:36:48, 75.44s/it]

Average gan loss: 0.16832588613033295, Avg scale loss: 0.0, Avg period loss: 0.06314294040203094


 24%|██▍       | 24/100 [30:50<1:35:29, 75.39s/it]

Average gan loss: 0.16721899807453156, Avg scale loss: 0.0, Avg period loss: 0.06367617845535278


 25%|██▌       | 25/100 [32:06<1:34:15, 75.41s/it]

Average gan loss: 0.16748204827308655, Avg scale loss: 0.0, Avg period loss: 0.06430180370807648


 26%|██▌       | 26/100 [33:21<1:32:59, 75.40s/it]

Average gan loss: 0.16479118168354034, Avg scale loss: 0.0, Avg period loss: 0.06552526354789734


 27%|██▋       | 27/100 [34:36<1:31:44, 75.40s/it]

Average gan loss: 0.16624528169631958, Avg scale loss: 0.0, Avg period loss: 0.06386009603738785


 28%|██▊       | 28/100 [35:52<1:30:29, 75.42s/it]

Average gan loss: 0.17001213133335114, Avg scale loss: 0.0, Avg period loss: 0.06353356689214706


 29%|██▉       | 29/100 [37:07<1:29:09, 75.35s/it]

Average gan loss: 0.1672668606042862, Avg scale loss: 0.0, Avg period loss: 0.06476273387670517


 30%|███       | 30/100 [38:22<1:27:56, 75.37s/it]

Average gan loss: 0.16839757561683655, Avg scale loss: 0.0, Avg period loss: 0.062308914959430695


 31%|███       | 31/100 [39:38<1:26:43, 75.42s/it]

Average gan loss: 0.16926661133766174, Avg scale loss: 0.0, Avg period loss: 0.0638468936085701


 32%|███▏      | 32/100 [40:53<1:25:28, 75.41s/it]

Average gan loss: 0.16814735531806946, Avg scale loss: 0.0, Avg period loss: 0.06409817934036255


 33%|███▎      | 33/100 [42:09<1:24:11, 75.40s/it]

Average gan loss: 0.17003175616264343, Avg scale loss: 0.0, Avg period loss: 0.06287665665149689


 34%|███▍      | 34/100 [43:24<1:22:51, 75.33s/it]

Average gan loss: 0.1684846431016922, Avg scale loss: 0.0, Avg period loss: 0.06422221660614014


 35%|███▌      | 35/100 [44:39<1:21:36, 75.33s/it]

Average gan loss: 0.169595867395401, Avg scale loss: 0.0, Avg period loss: 0.06382512301206589


 36%|███▌      | 36/100 [45:55<1:20:21, 75.34s/it]

Average gan loss: 0.16801653802394867, Avg scale loss: 0.0, Avg period loss: 0.061973486095666885


 37%|███▋      | 37/100 [47:10<1:19:06, 75.35s/it]

Average gan loss: 0.16999436914920807, Avg scale loss: 0.0, Avg period loss: 0.06375285238027573


 38%|███▊      | 38/100 [48:25<1:17:46, 75.27s/it]

Average gan loss: 0.1697048544883728, Avg scale loss: 0.0, Avg period loss: 0.061803098767995834


 39%|███▉      | 39/100 [49:40<1:16:33, 75.30s/it]

Average gan loss: 0.17295080423355103, Avg scale loss: 0.0, Avg period loss: 0.06193727254867554


 40%|████      | 40/100 [50:56<1:15:20, 75.34s/it]

Average gan loss: 0.1724739521741867, Avg scale loss: 0.0, Avg period loss: 0.061514124274253845


 41%|████      | 41/100 [52:11<1:14:07, 75.39s/it]

Average gan loss: 0.17288519442081451, Avg scale loss: 0.0, Avg period loss: 0.05999084562063217


 42%|████▏     | 42/100 [53:27<1:12:52, 75.38s/it]

Average gan loss: 0.17212103307247162, Avg scale loss: 0.0, Avg period loss: 0.06266217678785324


 43%|████▎     | 43/100 [54:42<1:11:32, 75.30s/it]

Average gan loss: 0.1727130264043808, Avg scale loss: 0.0, Avg period loss: 0.05979642644524574


 44%|████▍     | 44/100 [55:57<1:10:17, 75.32s/it]

Average gan loss: 0.17452049255371094, Avg scale loss: 0.0, Avg period loss: 0.05976242199540138


 45%|████▌     | 45/100 [57:13<1:09:04, 75.35s/it]

Average gan loss: 0.17524348199367523, Avg scale loss: 0.0, Avg period loss: 0.058840468525886536


 46%|████▌     | 46/100 [58:28<1:07:50, 75.37s/it]

Average gan loss: 0.17342333495616913, Avg scale loss: 0.0, Avg period loss: 0.06035872921347618


 47%|████▋     | 47/100 [59:43<1:06:35, 75.39s/it]

Average gan loss: 0.1704658716917038, Avg scale loss: 0.0, Avg period loss: 0.0630633682012558


 48%|████▊     | 48/100 [1:00:59<1:05:18, 75.35s/it]

Average gan loss: 0.1669086515903473, Avg scale loss: 0.0, Avg period loss: 0.06640580296516418


 49%|████▉     | 49/100 [1:02:14<1:04:03, 75.35s/it]

Average gan loss: 0.17163367569446564, Avg scale loss: 0.0, Avg period loss: 0.06259941309690475


 50%|█████     | 50/100 [1:03:29<1:02:48, 75.37s/it]

Average gan loss: 0.17190413177013397, Avg scale loss: 0.0, Avg period loss: 0.06338431686162949


 51%|█████     | 51/100 [1:04:45<1:01:35, 75.42s/it]

Average gan loss: 0.16767708957195282, Avg scale loss: 0.0, Avg period loss: 0.06644237786531448


 52%|█████▏    | 52/100 [1:06:00<1:00:17, 75.36s/it]

Average gan loss: 0.16909605264663696, Avg scale loss: 0.0, Avg period loss: 0.06340790539979935


 53%|█████▎    | 53/100 [1:07:16<59:02, 75.37s/it]  

Average gan loss: 0.16844171285629272, Avg scale loss: 0.0, Avg period loss: 0.06547370553016663


 54%|█████▍    | 54/100 [1:08:31<57:47, 75.38s/it]

Average gan loss: 0.16919799149036407, Avg scale loss: 0.0, Avg period loss: 0.06474242359399796


 55%|█████▌    | 55/100 [1:09:46<56:32, 75.38s/it]

Average gan loss: 0.16738589107990265, Avg scale loss: 0.0, Avg period loss: 0.06647446751594543


 56%|█████▌    | 56/100 [1:11:02<55:16, 75.38s/it]

Average gan loss: 0.16833144426345825, Avg scale loss: 0.0, Avg period loss: 0.06511909514665604


 57%|█████▋    | 57/100 [1:12:17<53:58, 75.32s/it]

Average gan loss: 0.16804754734039307, Avg scale loss: 0.0, Avg period loss: 0.06494808197021484


 58%|█████▊    | 58/100 [1:13:32<52:44, 75.34s/it]

Average gan loss: 0.16710612177848816, Avg scale loss: 0.0, Avg period loss: 0.06544008105993271


 59%|█████▉    | 59/100 [1:14:48<51:28, 75.34s/it]

Average gan loss: 0.1681547909975052, Avg scale loss: 0.0, Avg period loss: 0.06605041027069092


 60%|██████    | 60/100 [1:16:03<50:13, 75.33s/it]

Average gan loss: 0.16602292656898499, Avg scale loss: 0.0, Avg period loss: 0.06495950371026993


 61%|██████    | 61/100 [1:17:19<49:01, 75.41s/it]

Average gan loss: 0.16894780099391937, Avg scale loss: 0.0, Avg period loss: 0.06435863673686981


 62%|██████▏   | 62/100 [1:18:34<47:42, 75.33s/it]

Average gan loss: 0.16857613623142242, Avg scale loss: 0.0, Avg period loss: 0.06282250583171844


 63%|██████▎   | 63/100 [1:19:49<46:27, 75.33s/it]

Average gan loss: 0.17041970789432526, Avg scale loss: 0.0, Avg period loss: 0.06221742182970047


 64%|██████▍   | 64/100 [1:21:05<45:13, 75.37s/it]

Average gan loss: 0.17289337515830994, Avg scale loss: 0.0, Avg period loss: 0.059907786548137665


 65%|██████▌   | 65/100 [1:22:20<43:59, 75.41s/it]

Average gan loss: 0.1761462241411209, Avg scale loss: 0.0, Avg period loss: 0.057457294315099716


 66%|██████▌   | 66/100 [1:23:35<42:41, 75.33s/it]

Average gan loss: 0.17980948090553284, Avg scale loss: 0.0, Avg period loss: 0.055510446429252625


 67%|██████▋   | 67/100 [1:24:51<41:26, 75.35s/it]

Average gan loss: 0.18335282802581787, Avg scale loss: 0.0, Avg period loss: 0.05222637951374054


 68%|██████▊   | 68/100 [1:26:06<40:12, 75.38s/it]

Average gan loss: 0.18615971505641937, Avg scale loss: 0.0, Avg period loss: 0.05067170783877373


 69%|██████▉   | 69/100 [1:27:21<38:57, 75.39s/it]

Average gan loss: 0.18861985206604004, Avg scale loss: 0.0, Avg period loss: 0.04873603209853172


 70%|███████   | 70/100 [1:28:37<37:42, 75.41s/it]

Average gan loss: 0.19135895371437073, Avg scale loss: 0.0, Avg period loss: 0.04698837548494339


 71%|███████   | 71/100 [1:29:52<36:26, 75.39s/it]

Average gan loss: 0.19197575747966766, Avg scale loss: 0.0, Avg period loss: 0.04687337204813957


 72%|███████▏  | 72/100 [1:31:08<35:10, 75.39s/it]

Average gan loss: 0.19205766916275024, Avg scale loss: 0.0, Avg period loss: 0.047354139387607574


 73%|███████▎  | 73/100 [1:32:23<33:55, 75.40s/it]

Average gan loss: 0.19269372522830963, Avg scale loss: 0.0, Avg period loss: 0.04720839858055115


 74%|███████▍  | 74/100 [1:33:38<32:40, 75.41s/it]

Average gan loss: 0.1943623423576355, Avg scale loss: 0.0, Avg period loss: 0.04571434110403061


 75%|███████▌  | 75/100 [1:34:54<31:25, 75.44s/it]

Average gan loss: 0.1951626092195511, Avg scale loss: 0.0, Avg period loss: 0.04578141123056412


 76%|███████▌  | 76/100 [1:36:09<30:08, 75.37s/it]

Average gan loss: 0.19714704155921936, Avg scale loss: 0.0, Avg period loss: 0.0440371073782444


 77%|███████▋  | 77/100 [1:37:25<28:53, 75.38s/it]

Average gan loss: 0.19869700074195862, Avg scale loss: 0.0, Avg period loss: 0.04345095157623291


 78%|███████▊  | 78/100 [1:38:40<27:38, 75.40s/it]

Average gan loss: 0.19997115433216095, Avg scale loss: 0.0, Avg period loss: 0.04291647672653198


 79%|███████▉  | 79/100 [1:39:55<26:23, 75.41s/it]

Average gan loss: 0.2023390829563141, Avg scale loss: 0.0, Avg period loss: 0.042054519057273865


 80%|████████  | 80/100 [1:41:11<25:06, 75.34s/it]

Average gan loss: 0.2042093724012375, Avg scale loss: 0.0, Avg period loss: 0.04078497737646103


 81%|████████  | 81/100 [1:42:26<23:52, 75.39s/it]

Average gan loss: 0.20575007796287537, Avg scale loss: 0.0, Avg period loss: 0.04013194516301155


 82%|████████▏ | 82/100 [1:43:41<22:36, 75.37s/it]

Average gan loss: 0.20800824463367462, Avg scale loss: 0.0, Avg period loss: 0.039551034569740295


 83%|████████▎ | 83/100 [1:44:57<21:21, 75.40s/it]

Average gan loss: 0.20902149379253387, Avg scale loss: 0.0, Avg period loss: 0.039127398282289505


 84%|████████▍ | 84/100 [1:46:12<20:06, 75.39s/it]

Average gan loss: 0.21072696149349213, Avg scale loss: 0.0, Avg period loss: 0.03755015507340431


 85%|████████▌ | 85/100 [1:47:27<18:49, 75.32s/it]

Average gan loss: 0.21174654364585876, Avg scale loss: 0.0, Avg period loss: 0.03795434534549713


 86%|████████▌ | 86/100 [1:48:43<17:34, 75.33s/it]

Average gan loss: 0.21375565230846405, Avg scale loss: 0.0, Avg period loss: 0.037155136466026306


 87%|████████▋ | 87/100 [1:49:58<16:19, 75.37s/it]

Average gan loss: 0.21531535685062408, Avg scale loss: 0.0, Avg period loss: 0.03611285984516144


 88%|████████▊ | 88/100 [1:51:14<15:05, 75.42s/it]

Average gan loss: 0.21518827974796295, Avg scale loss: 0.0, Avg period loss: 0.036118630319833755


 89%|████████▉ | 89/100 [1:52:29<13:49, 75.38s/it]

Average gan loss: 0.21571943163871765, Avg scale loss: 0.0, Avg period loss: 0.03585614636540413


 90%|█████████ | 90/100 [1:53:45<12:34, 75.40s/it]

Average gan loss: 0.21661078929901123, Avg scale loss: 0.0, Avg period loss: 0.03496289998292923


 91%|█████████ | 91/100 [1:55:00<11:19, 75.48s/it]

Average gan loss: 0.21816670894622803, Avg scale loss: 0.0, Avg period loss: 0.034609295427799225


 92%|█████████▏| 92/100 [1:56:16<10:03, 75.47s/it]

Average gan loss: 0.2181045114994049, Avg scale loss: 0.0, Avg period loss: 0.03434328734874725


 93%|█████████▎| 93/100 [1:57:31<08:48, 75.49s/it]

Average gan loss: 0.21859630942344666, Avg scale loss: 0.0, Avg period loss: 0.03402583673596382


 94%|█████████▍| 94/100 [1:58:46<07:32, 75.42s/it]

Average gan loss: 0.2190687358379364, Avg scale loss: 0.0, Avg period loss: 0.03359229490160942


 95%|█████████▌| 95/100 [2:00:02<06:17, 75.45s/it]

Average gan loss: 0.22093850374221802, Avg scale loss: 0.0, Avg period loss: 0.03325840085744858


 96%|█████████▌| 96/100 [2:01:17<05:01, 75.45s/it]

Average gan loss: 0.2216131091117859, Avg scale loss: 0.0, Avg period loss: 0.032506875693798065


 97%|█████████▋| 97/100 [2:02:33<03:46, 75.45s/it]

Average gan loss: 0.22250863909721375, Avg scale loss: 0.0, Avg period loss: 0.032657869160175323


 98%|█████████▊| 98/100 [2:03:48<02:30, 75.45s/it]

Average gan loss: 0.2242295742034912, Avg scale loss: 0.0, Avg period loss: 0.03231711685657501


 99%|█████████▉| 99/100 [2:05:04<01:15, 75.39s/it]

Average gan loss: 0.22289864718914032, Avg scale loss: 0.0, Avg period loss: 0.032436124980449677


100%|██████████| 100/100 [2:06:19<00:00, 75.79s/it]

Average gan loss: 0.2256462723016739, Avg scale loss: 0.0, Avg period loss: 0.03154291212558746





In [None]:
import jax
from Generator import Generator 

key1 = jax.random.PRNGKey(69)
generator = Generator(channels_in=80, channels_out=1, key=key1)
print(generator)
