In [1]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer.general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime
import argparse
import os

BATCH_SIZE = 16
IMAGE_SIZE = 256

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

text_encoder = defaultTextEncodeModel()
autoencoder = StableDiffusionVAE(**{"modelname": "pcuenq/sd-vae-ft-mse-flax"})

# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield {"text": tokens}

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)


2025-04-12 13:11:02.488048: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744463463.166512   23190 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744463463.483736   23190 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744463466.930492   23190 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744463466.930522   23190 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744463466.930524   23190 computation_placer.cc:177] computation placer alr

Scaling factor: 0.18215
Calculating downscale factor...
Downscale factor: 8
Latent channels: 4


In [3]:
from flax import linen as nn
from diffusers import FlaxUNet2DConditionModel
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_key='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ],
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet_model = FlaxUNet2DConditionModel(
    sample_size=input_shapes["x"][1],  # the target image resolution
    in_channels=input_shapes["x"][2],  # the number of input channels, 3 for RGB images
    out_channels=input_shapes["x"][2],  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 256, 512),  # the number of output channels for each UNet block
    cross_attention_dim=512,  # the size of the cross-attention layers
    dtype=jnp.bfloat16,
    use_memory_efficient_attention=True,
)
        
class BCHWModelWrapper(nn.Module):
    model: nn.Module

    @nn.compact
    def __call__(self, x, temb, textcontext):
        # Reshape the input to BCHW format from BHWC
        x = jnp.transpose(x, (0, 3, 1, 2))
        # Pass the input through the UNet model
        out = self.model(
            sample=x,
            timesteps=temb,
            encoder_hidden_states=textcontext,
        )
        # Reshape the output back to BHWC format
        out = jnp.transpose(out.sample, (0, 2, 3, 1))
        return out
    
unet = BCHWModelWrapper(unet_model)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


In [3]:
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig

input_config = DiffusionInputConfig(
    sample_data_key='image',
    sample_data_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
    conditions=[
        ConditionalInputConfig(
            encoder=text_encoder,
            conditioning_data_key='text',
            pretokenized=True,
            unconditional_input="",
            model_key_override="textcontext",
        )
    ]
)

input_shapes = input_config.get_input_shapes(
    autoencoder=autoencoder,
)

unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float32, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
            ],
            num_res_blocks=2,
            num_middle_res_blocks=1,
            dtype=jnp.bfloat16,
            output_channels=input_shapes["x"][2],
)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


In [4]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model

# Define optimizer
solver = optax.adam(2e-4)

# Create the GeneralDiffusionTrainer
experiment_name = f"General_Diffusion_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"

trainer = GeneralDiffusionTrainer(
    unet,
    optimizer=solver,
    noise_schedule=edm_schedule,
    autoencoder=autoencoder,
    input_config=input_config,
    rngs=jax.random.PRNGKey(42),
    name=experiment_name,
    model_output_transform=KarrasPredictionTransform(
        sigma_data=edm_schedule.sigma_data),
    # data_key='image',  # Specify the key for image data in batches
    distributed_training=True,
    wandb_config={
        "project": 'mlops-msml605-project',
        "entity": 'umd-projects',
        "name": experiment_name,
        "config": {
            "batch_size": BATCH_SIZE,
            "image_size": IMAGE_SIZE,
            "architecture": "unet",
        }
    },
    native_resolution=IMAGE_SIZE
)

Calculated input shapes: {'x': (32, 32, 4), 'temb': (), 'textcontext': (77, 768)}


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [5]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

		Epoch 1475: 600step [00:26, 22.34step/s, loss=0.5232]                                             

[32mEpoch done on index 0 => 1475 Loss: 0.4699305593967438[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1475 completed. Avg Loss: 0.4699305593967438, Time: 26.86s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.51it/s]


[32mValidation done on process index 0[0m

Epoch 1476/2000


		Epoch 1476:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4465]

First batch loaded at step 754236
Training started for process index 0 at step 754236


		Epoch 1476: 600step [00:27, 22.20step/s, loss=0.4653]                                             

[32mEpoch done on index 0 => 1476 Loss: 0.46798717975616455[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1476 completed. Avg Loss: 0.46798717975616455, Time: 27.03s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.08it/s]


[32mValidation done on process index 0[0m

Epoch 1477/2000


		Epoch 1477:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5279]

First batch loaded at step 754747
Training started for process index 0 at step 754747


		Epoch 1477: 600step [00:26, 22.60step/s, loss=0.5680]                                             

[32mEpoch done on index 0 => 1477 Loss: 0.46927326917648315[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1477 completed. Avg Loss: 0.46927326917648315, Time: 26.55s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.56it/s]


[32mValidation done on process index 0[0m

Epoch 1478/2000


		Epoch 1478:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4609]

First batch loaded at step 755258
Training started for process index 0 at step 755258


		Epoch 1478: 600step [00:26, 22.45step/s, loss=0.4745]                                             

[32mEpoch done on index 0 => 1478 Loss: 0.46889442205429077[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1478 completed. Avg Loss: 0.46889442205429077, Time: 26.73s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 111.09it/s]


[32mValidation done on process index 0[0m

Epoch 1479/2000


		Epoch 1479:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4646]

First batch loaded at step 755769
Training started for process index 0 at step 755769


		Epoch 1479: 600step [00:27, 22.19step/s, loss=0.4725]                                             

[32mEpoch done on index 0 => 1479 Loss: 0.4689628481864929[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1479 completed. Avg Loss: 0.4689628481864929, Time: 27.04s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.36it/s]


[32mValidation done on process index 0[0m

Epoch 1480/2000


		Epoch 1480:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5225]

First batch loaded at step 756280
Training started for process index 0 at step 756280


		Epoch 1480: 600step [00:26, 22.30step/s, loss=0.4822]                                             

[32mEpoch done on index 0 => 1480 Loss: 0.46772485971450806[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1480 completed. Avg Loss: 0.46772485971450806, Time: 26.91s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.74it/s]


[32mValidation done on process index 0[0m

Epoch 1481/2000


		Epoch 1481:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4535]

First batch loaded at step 756791
Training started for process index 0 at step 756791


		Epoch 1481: 600step [00:26, 22.42step/s, loss=0.4467]                                             

[32mEpoch done on index 0 => 1481 Loss: 0.4691937565803528[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1481 completed. Avg Loss: 0.4691937565803528, Time: 26.76s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.81it/s]


[32mValidation done on process index 0[0m

Epoch 1482/2000


		Epoch 1482:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4897]

First batch loaded at step 757302
Training started for process index 0 at step 757302


		Epoch 1482: 600step [00:27, 21.87step/s, loss=0.5098]                                             

[32mEpoch done on index 0 => 1482 Loss: 0.468335896730423[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1482 completed. Avg Loss: 0.468335896730423, Time: 27.43s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.60it/s]


[32mValidation done on process index 0[0m

Epoch 1483/2000


		Epoch 1483:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4462]

First batch loaded at step 757813
Training started for process index 0 at step 757813


		Epoch 1483: 600step [00:26, 23.01step/s, loss=0.4715]                                             

[32mEpoch done on index 0 => 1483 Loss: 0.47192198038101196[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1483 completed. Avg Loss: 0.47192198038101196, Time: 26.08s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 115.83it/s]


[32mValidation done on process index 0[0m

Epoch 1484/2000


		Epoch 1484:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5057]

First batch loaded at step 758324
Training started for process index 0 at step 758324


		Epoch 1484: 600step [00:27, 21.79step/s, loss=0.4352]                                             

[32mEpoch done on index 0 => 1484 Loss: 0.4703725278377533[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1484 completed. Avg Loss: 0.4703725278377533, Time: 27.53s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.86it/s]


[32mValidation done on process index 0[0m

Epoch 1485/2000


		Epoch 1485:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4507]

First batch loaded at step 758835
Training started for process index 0 at step 758835


		Epoch 1485: 600step [00:26, 22.49step/s, loss=0.4180]                                             

[32mEpoch done on index 0 => 1485 Loss: 0.4677030146121979[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1485 completed. Avg Loss: 0.4677030146121979, Time: 26.68s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.61it/s]


[32mValidation done on process index 0[0m

Epoch 1486/2000


		Epoch 1486:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4836]

First batch loaded at step 759346
Training started for process index 0 at step 759346


		Epoch 1486: 600step [00:26, 22.65step/s, loss=0.4840]                                             

[32mEpoch done on index 0 => 1486 Loss: 0.46962007880210876[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1486 completed. Avg Loss: 0.46962007880210876, Time: 26.49s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.26it/s]


[32mValidation done on process index 0[0m

Epoch 1487/2000


		Epoch 1487:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4580]

First batch loaded at step 759857
Training started for process index 0 at step 759857


		Epoch 1487: 600step [00:26, 22.42step/s, loss=0.4657]                                             

[32mEpoch done on index 0 => 1487 Loss: 0.46994540095329285[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1487 completed. Avg Loss: 0.46994540095329285, Time: 26.76s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.70it/s]


[32mValidation done on process index 0[0m

Epoch 1488/2000


		Epoch 1488:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5313]

First batch loaded at step 760368
Training started for process index 0 at step 760368


		Epoch 1488: 600step [00:26, 22.55step/s, loss=0.4569]                                             

[32mEpoch done on index 0 => 1488 Loss: 0.46668046712875366[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1488 completed. Avg Loss: 0.46668046712875366, Time: 26.61s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.08it/s]


[32mValidation done on process index 0[0m

Epoch 1489/2000


		Epoch 1489:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4368]

First batch loaded at step 760879
Training started for process index 0 at step 760879


		Epoch 1489: 600step [00:26, 22.28step/s, loss=0.4446]                                             

[32mEpoch done on index 0 => 1489 Loss: 0.46838825941085815[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1489 completed. Avg Loss: 0.46838825941085815, Time: 27.00s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.10it/s]


[32mValidation done on process index 0[0m

Epoch 1490/2000


		Epoch 1490:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4817]

First batch loaded at step 761390
Training started for process index 0 at step 761390


		Epoch 1490: 600step [00:27, 21.86step/s, loss=0.4779]                                             

[32mEpoch done on index 0 => 1490 Loss: 0.4682980477809906[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1490 completed. Avg Loss: 0.4682980477809906, Time: 27.45s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.80it/s]


[32mValidation done on process index 0[0m

Epoch 1491/2000


		Epoch 1491:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4931]

First batch loaded at step 761901
Training started for process index 0 at step 761901


		Epoch 1491: 600step [00:27, 21.89step/s, loss=0.5158]                                             

[32mEpoch done on index 0 => 1491 Loss: 0.4662906229496002[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1491 completed. Avg Loss: 0.4662906229496002, Time: 27.41s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.74it/s]


[32mValidation done on process index 0[0m

Epoch 1492/2000


		Epoch 1492:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4321]

First batch loaded at step 762412
Training started for process index 0 at step 762412


		Epoch 1492: 600step [00:27, 21.98step/s, loss=0.4510]                                             

[32mEpoch done on index 0 => 1492 Loss: 0.46941182017326355[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1492 completed. Avg Loss: 0.46941182017326355, Time: 27.30s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.50it/s]


[32mValidation done on process index 0[0m

Epoch 1493/2000


		Epoch 1493:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5106]

First batch loaded at step 762923
Training started for process index 0 at step 762923


		Epoch 1493: 600step [00:27, 22.03step/s, loss=0.4292]                                             

[32mEpoch done on index 0 => 1493 Loss: 0.47032368183135986[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1493 completed. Avg Loss: 0.47032368183135986, Time: 27.24s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.59it/s]


[32mValidation done on process index 0[0m

Epoch 1494/2000


		Epoch 1494:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4439]

First batch loaded at step 763434
Training started for process index 0 at step 763434


		Epoch 1494: 600step [00:26, 22.26step/s, loss=0.4882]                                             

[32mEpoch done on index 0 => 1494 Loss: 0.4663403630256653[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1494 completed. Avg Loss: 0.4663403630256653, Time: 26.96s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 109.25it/s]


[32mValidation done on process index 0[0m

Epoch 1495/2000


		Epoch 1495:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4712]

First batch loaded at step 763945
Training started for process index 0 at step 763945


		Epoch 1495: 600step [00:28, 21.42step/s, loss=0.4837]                                             

[32mEpoch done on index 0 => 1495 Loss: 0.46778395771980286[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1495 completed. Avg Loss: 0.46778395771980286, Time: 28.02s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 109.02it/s]


[32mValidation done on process index 0[0m

Epoch 1496/2000


		Epoch 1496:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4423]

First batch loaded at step 764456
Training started for process index 0 at step 764456


		Epoch 1496: 600step [00:27, 22.05step/s, loss=0.4289]                                             

[32mEpoch done on index 0 => 1496 Loss: 0.46825721859931946[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1496 completed. Avg Loss: 0.46825721859931946, Time: 27.21s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.77it/s]


[32mValidation done on process index 0[0m

Epoch 1497/2000


		Epoch 1497:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4667]

First batch loaded at step 764967
Training started for process index 0 at step 764967


		Epoch 1497: 600step [00:27, 22.04step/s, loss=0.4367]                                             

[32mEpoch done on index 0 => 1497 Loss: 0.46788260340690613[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1497 completed. Avg Loss: 0.46788260340690613, Time: 27.23s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 115.35it/s]


[32mValidation done on process index 0[0m

Epoch 1498/2000


		Epoch 1498:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4852]

First batch loaded at step 765478
Training started for process index 0 at step 765478


		Epoch 1498: 600step [00:27, 22.00step/s, loss=0.4712]                                             

[32mEpoch done on index 0 => 1498 Loss: 0.46709874272346497[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1498 completed. Avg Loss: 0.46709874272346497, Time: 27.27s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.23it/s]


[32mValidation done on process index 0[0m

Epoch 1499/2000


		Epoch 1499:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4530]

First batch loaded at step 765989
Training started for process index 0 at step 765989


		Epoch 1499: 600step [00:27, 21.66step/s, loss=0.4348]                                             

[32mEpoch done on index 0 => 1499 Loss: 0.46812382340431213[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1499 completed. Avg Loss: 0.46812382340431213, Time: 27.71s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 115.78it/s]


[32mValidation done on process index 0[0m

Epoch 1500/2000


		Epoch 1500:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5141]

First batch loaded at step 766500
Training started for process index 0 at step 766500


		Epoch 1500: 600step [00:27, 21.54step/s, loss=0.4639]                                             

[32mEpoch done on index 0 => 1500 Loss: 0.47118183970451355[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1500 completed. Avg Loss: 0.47118183970451355, Time: 27.86s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.83it/s]


[32mValidation done on process index 0[0m

Epoch 1501/2000


		Epoch 1501:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5069]

First batch loaded at step 767011
Training started for process index 0 at step 767011


		Epoch 1501: 600step [00:27, 22.03step/s, loss=0.4503]                                             

[32mEpoch done on index 0 => 1501 Loss: 0.4683796167373657[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1501 completed. Avg Loss: 0.4683796167373657, Time: 27.24s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.95it/s]


[32mValidation done on process index 0[0m

Epoch 1502/2000


		Epoch 1502:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4610]

First batch loaded at step 767522
Training started for process index 0 at step 767522


		Epoch 1502: 600step [00:27, 21.98step/s, loss=0.4934]                                             

[32mEpoch done on index 0 => 1502 Loss: 0.4709969758987427[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1502 completed. Avg Loss: 0.4709969758987427, Time: 27.30s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.30it/s]


[32mValidation done on process index 0[0m

Epoch 1503/2000


		Epoch 1503:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4881]

First batch loaded at step 768033
Training started for process index 0 at step 768033


		Epoch 1503: 600step [00:26, 22.23step/s, loss=0.4751]                                             

[32mEpoch done on index 0 => 1503 Loss: 0.4683407247066498[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1503 completed. Avg Loss: 0.4683407247066498, Time: 27.00s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.58it/s]


[32mValidation done on process index 0[0m

Epoch 1504/2000


		Epoch 1504:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5252]

First batch loaded at step 768544
Training started for process index 0 at step 768544


		Epoch 1504: 600step [00:26, 22.50step/s, loss=0.4556]                                             

[32mEpoch done on index 0 => 1504 Loss: 0.46880778670310974[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1504 completed. Avg Loss: 0.46880778670310974, Time: 26.67s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.42it/s]


[32mValidation done on process index 0[0m

Epoch 1505/2000


		Epoch 1505:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4798]

First batch loaded at step 769055
Training started for process index 0 at step 769055


		Epoch 1505: 600step [00:28, 21.32step/s, loss=0.4614]                                             

[32mEpoch done on index 0 => 1505 Loss: 0.46843093633651733[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1505 completed. Avg Loss: 0.46843093633651733, Time: 28.15s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.24it/s]


[32mValidation done on process index 0[0m

Epoch 1506/2000


		Epoch 1506:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4150]

First batch loaded at step 769566
Training started for process index 0 at step 769566


		Epoch 1506: 600step [00:27, 22.05step/s, loss=0.4257]                                             

[32mEpoch done on index 0 => 1506 Loss: 0.47002676129341125[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1506 completed. Avg Loss: 0.47002676129341125, Time: 27.21s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.16it/s]


[32mValidation done on process index 0[0m

Epoch 1507/2000


		Epoch 1507:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4706]

First batch loaded at step 770077
Training started for process index 0 at step 770077


		Epoch 1507: 600step [00:27, 22.09step/s, loss=0.4259]                                             

[32mEpoch done on index 0 => 1507 Loss: 0.46826231479644775[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1507 completed. Avg Loss: 0.46826231479644775, Time: 27.16s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.22it/s]


[32mValidation done on process index 0[0m

Epoch 1508/2000


		Epoch 1508:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4079]

First batch loaded at step 770588
Training started for process index 0 at step 770588


		Epoch 1508: 600step [00:27, 22.19step/s, loss=0.3936]                                             

[32mEpoch done on index 0 => 1508 Loss: 0.4718356132507324[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1508 completed. Avg Loss: 0.4718356132507324, Time: 27.05s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.39it/s]


[32mValidation done on process index 0[0m

Epoch 1509/2000


		Epoch 1509:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4378]

First batch loaded at step 771099
Training started for process index 0 at step 771099


		Epoch 1509: 600step [00:26, 22.32step/s, loss=0.4786]                                             

[32mEpoch done on index 0 => 1509 Loss: 0.46852320432662964[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1509 completed. Avg Loss: 0.46852320432662964, Time: 26.88s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.55it/s]


[32mValidation done on process index 0[0m

Epoch 1510/2000


		Epoch 1510:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4906]

First batch loaded at step 771610
Training started for process index 0 at step 771610


		Epoch 1510: 600step [00:27, 21.91step/s, loss=0.4556]                                             

[32mEpoch done on index 0 => 1510 Loss: 0.4670238196849823[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1510 completed. Avg Loss: 0.4670238196849823, Time: 27.39s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.59it/s]


[32mValidation done on process index 0[0m

Epoch 1511/2000


		Epoch 1511:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5447]

First batch loaded at step 772121
Training started for process index 0 at step 772121


		Epoch 1511: 600step [00:27, 22.15step/s, loss=0.5134]                                             

[32mEpoch done on index 0 => 1511 Loss: 0.46990641951560974[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1511 completed. Avg Loss: 0.46990641951560974, Time: 27.09s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.80it/s]


[32mValidation done on process index 0[0m

Epoch 1512/2000


		Epoch 1512:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4616]

First batch loaded at step 772632
Training started for process index 0 at step 772632


		Epoch 1512: 600step [00:27, 22.02step/s, loss=0.4448]                                             

[32mEpoch done on index 0 => 1512 Loss: 0.4684750437736511[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1512 completed. Avg Loss: 0.4684750437736511, Time: 27.25s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.44it/s]


[32mValidation done on process index 0[0m

Epoch 1513/2000


		Epoch 1513:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4918]

First batch loaded at step 773143
Training started for process index 0 at step 773143


		Epoch 1513: 600step [00:27, 22.01step/s, loss=0.4863]                                             

[32mEpoch done on index 0 => 1513 Loss: 0.46790367364883423[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1513 completed. Avg Loss: 0.46790367364883423, Time: 27.27s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.56it/s]


[32mValidation done on process index 0[0m

Epoch 1514/2000


		Epoch 1514:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4455]

First batch loaded at step 773654
Training started for process index 0 at step 773654


		Epoch 1514: 600step [00:26, 22.49step/s, loss=0.4350]                                             

[32mEpoch done on index 0 => 1514 Loss: 0.4691774249076843[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1514 completed. Avg Loss: 0.4691774249076843, Time: 26.69s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 114.09it/s]


[32mValidation done on process index 0[0m

Epoch 1515/2000


		Epoch 1515:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4406]

First batch loaded at step 774165
Training started for process index 0 at step 774165


		Epoch 1515: 600step [00:26, 22.78step/s, loss=0.4731]                                             

[32mEpoch done on index 0 => 1515 Loss: 0.4703384041786194[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1515 completed. Avg Loss: 0.4703384041786194, Time: 26.34s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 121.04it/s]


[32mValidation done on process index 0[0m

Epoch 1516/2000


		Epoch 1516:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5213]

First batch loaded at step 774676
Training started for process index 0 at step 774676


		Epoch 1516: 600step [00:27, 21.78step/s, loss=0.4731]                                             

[32mEpoch done on index 0 => 1516 Loss: 0.46927598118782043[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1516 completed. Avg Loss: 0.46927598118782043, Time: 27.56s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 112.93it/s]


[32mValidation done on process index 0[0m

Epoch 1517/2000


		Epoch 1517:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4650]

First batch loaded at step 775187
Training started for process index 0 at step 775187


		Epoch 1517: 600step [00:26, 22.39step/s, loss=0.5116]                                             

[32mEpoch done on index 0 => 1517 Loss: 0.46821358799934387[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1517 completed. Avg Loss: 0.46821358799934387, Time: 26.80s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.68it/s]


[32mValidation done on process index 0[0m

Epoch 1518/2000


		Epoch 1518:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4835]

First batch loaded at step 775698
Training started for process index 0 at step 775698


		Epoch 1518: 600step [00:26, 22.36step/s, loss=0.4948]                                             

[32mEpoch done on index 0 => 1518 Loss: 0.4663558900356293[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1518 completed. Avg Loss: 0.4663558900356293, Time: 26.84s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.17it/s]


[32mValidation done on process index 0[0m

Epoch 1519/2000


		Epoch 1519:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4903]

First batch loaded at step 776209
Training started for process index 0 at step 776209


		Epoch 1519: 600step [00:26, 22.31step/s, loss=0.4663]                                             

[32mEpoch done on index 0 => 1519 Loss: 0.4685025215148926[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1519 completed. Avg Loss: 0.4685025215148926, Time: 26.89s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.68it/s]


[32mValidation done on process index 0[0m

Epoch 1520/2000


		Epoch 1520:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4092]

First batch loaded at step 776720
Training started for process index 0 at step 776720


		Epoch 1520: 600step [00:27, 22.17step/s, loss=0.5366]                                             

[32mEpoch done on index 0 => 1520 Loss: 0.46999451518058777[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1520 completed. Avg Loss: 0.46999451518058777, Time: 27.07s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.80it/s]


[32mValidation done on process index 0[0m

Epoch 1521/2000


		Epoch 1521:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4949]

First batch loaded at step 777231
Training started for process index 0 at step 777231


		Epoch 1521: 600step [00:26, 22.24step/s, loss=0.4603]                                             

[32mEpoch done on index 0 => 1521 Loss: 0.4683949053287506[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1521 completed. Avg Loss: 0.4683949053287506, Time: 26.98s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 113.30it/s]


[32mValidation done on process index 0[0m

Epoch 1522/2000


		Epoch 1522:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4949]

First batch loaded at step 777742
Training started for process index 0 at step 777742


		Epoch 1522: 600step [00:26, 22.24step/s, loss=0.5045]                                             

[32mEpoch done on index 0 => 1522 Loss: 0.4679831564426422[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1522 completed. Avg Loss: 0.4679831564426422, Time: 26.98s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.04it/s]


[32mValidation done on process index 0[0m

Epoch 1523/2000


		Epoch 1523:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4812]

First batch loaded at step 778253
Training started for process index 0 at step 778253


		Epoch 1523: 600step [00:27, 22.01step/s, loss=0.4796]                                             

[32mEpoch done on index 0 => 1523 Loss: 0.4688240587711334[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1523 completed. Avg Loss: 0.4688240587711334, Time: 27.27s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.91it/s]


[32mValidation done on process index 0[0m

Epoch 1524/2000


		Epoch 1524:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4704]

First batch loaded at step 778764
Training started for process index 0 at step 778764


		Epoch 1524: 600step [00:26, 22.26step/s, loss=0.4859]                                             

[32mEpoch done on index 0 => 1524 Loss: 0.4659760296344757[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1524 completed. Avg Loss: 0.4659760296344757, Time: 26.96s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.38it/s]


[32mValidation done on process index 0[0m

Epoch 1525/2000


		Epoch 1525:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4356]

First batch loaded at step 779275
Training started for process index 0 at step 779275


		Epoch 1525: 600step [00:27, 22.18step/s, loss=0.5079]                                             

[32mEpoch done on index 0 => 1525 Loss: 0.46703651547431946[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1525 completed. Avg Loss: 0.46703651547431946, Time: 27.05s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.48it/s]


[32mValidation done on process index 0[0m

Epoch 1526/2000


		Epoch 1526:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4901]

First batch loaded at step 779786
Training started for process index 0 at step 779786


		Epoch 1526: 600step [00:28, 21.39step/s, loss=0.4779]                                             

[32mEpoch done on index 0 => 1526 Loss: 0.46822288632392883[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1526 completed. Avg Loss: 0.46822288632392883, Time: 28.06s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.06it/s]


[32mValidation done on process index 0[0m

Epoch 1527/2000


		Epoch 1527:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4482]

First batch loaded at step 780297
Training started for process index 0 at step 780297


		Epoch 1527: 600step [00:26, 22.31step/s, loss=0.5028]                                             

[32mEpoch done on index 0 => 1527 Loss: 0.46881717443466187[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1527 completed. Avg Loss: 0.46881717443466187, Time: 26.89s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.06it/s]


[32mValidation done on process index 0[0m

Epoch 1528/2000


		Epoch 1528:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4751]

First batch loaded at step 780808
Training started for process index 0 at step 780808


		Epoch 1528: 600step [00:26, 22.32step/s, loss=0.5028]                                             

[32mEpoch done on index 0 => 1528 Loss: 0.4676021933555603[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1528 completed. Avg Loss: 0.4676021933555603, Time: 26.88s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.41it/s]


[32mValidation done on process index 0[0m

Epoch 1529/2000


		Epoch 1529:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4529]

First batch loaded at step 781319
Training started for process index 0 at step 781319


		Epoch 1529: 600step [00:26, 22.32step/s, loss=0.4536]                                             

[32mEpoch done on index 0 => 1529 Loss: 0.4688307046890259[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1529 completed. Avg Loss: 0.4688307046890259, Time: 26.89s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.28it/s]


[32mValidation done on process index 0[0m

Epoch 1530/2000


		Epoch 1530:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4856]

First batch loaded at step 781830
Training started for process index 0 at step 781830


		Epoch 1530: 600step [00:27, 22.17step/s, loss=0.5035]                                             

[32mEpoch done on index 0 => 1530 Loss: 0.47000652551651[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1530 completed. Avg Loss: 0.47000652551651, Time: 27.07s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.60it/s]


[32mValidation done on process index 0[0m

Epoch 1531/2000


		Epoch 1531:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4235]

First batch loaded at step 782341
Training started for process index 0 at step 782341


		Epoch 1531: 600step [00:26, 22.34step/s, loss=0.4712]                                             

[32mEpoch done on index 0 => 1531 Loss: 0.4685263931751251[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1531 completed. Avg Loss: 0.4685263931751251, Time: 26.86s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 116.57it/s]


[32mValidation done on process index 0[0m

Epoch 1532/2000


		Epoch 1532:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4291]

First batch loaded at step 782852
Training started for process index 0 at step 782852


		Epoch 1532: 600step [00:27, 22.08step/s, loss=0.4345]                                             

[32mEpoch done on index 0 => 1532 Loss: 0.4677729308605194[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1532 completed. Avg Loss: 0.4677729308605194, Time: 27.18s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 120.36it/s]


[32mValidation done on process index 0[0m

Epoch 1533/2000


		Epoch 1533:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4835]

First batch loaded at step 783363
Training started for process index 0 at step 783363


		Epoch 1533: 600step [00:26, 22.28step/s, loss=0.4484]                                             

[32mEpoch done on index 0 => 1533 Loss: 0.4668722450733185[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1533 completed. Avg Loss: 0.4668722450733185, Time: 26.93s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.34it/s]


[32mValidation done on process index 0[0m

Epoch 1534/2000


		Epoch 1534:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4369]

First batch loaded at step 783874
Training started for process index 0 at step 783874


		Epoch 1534: 600step [00:27, 22.08step/s, loss=0.4632]                                             

[32mEpoch done on index 0 => 1534 Loss: 0.4679631292819977[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1534 completed. Avg Loss: 0.4679631292819977, Time: 27.18s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.47it/s]


[32mValidation done on process index 0[0m

Epoch 1535/2000


		Epoch 1535:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4497]

First batch loaded at step 784385
Training started for process index 0 at step 784385


		Epoch 1535: 600step [00:26, 22.46step/s, loss=0.4681]                                             

[32mEpoch done on index 0 => 1535 Loss: 0.46574872732162476[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1535 completed. Avg Loss: 0.46574872732162476, Time: 26.72s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.66it/s]


[32mValidation done on process index 0[0m

Epoch 1536/2000


		Epoch 1536:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4392]

First batch loaded at step 784896
Training started for process index 0 at step 784896


		Epoch 1536: 600step [00:27, 22.13step/s, loss=0.4408]                                             

[32mEpoch done on index 0 => 1536 Loss: 0.4671768546104431[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1536 completed. Avg Loss: 0.4671768546104431, Time: 27.11s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.63it/s]


[32mValidation done on process index 0[0m

Epoch 1537/2000


		Epoch 1537:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4653]

First batch loaded at step 785407
Training started for process index 0 at step 785407


		Epoch 1537: 600step [00:27, 22.11step/s, loss=0.4575]                                             

[32mEpoch done on index 0 => 1537 Loss: 0.4697463810443878[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1537 completed. Avg Loss: 0.4697463810443878, Time: 27.13s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.77it/s]


[32mValidation done on process index 0[0m

Epoch 1538/2000


		Epoch 1538:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4564]

First batch loaded at step 785918
Training started for process index 0 at step 785918


		Epoch 1538: 600step [00:28, 21.28step/s, loss=0.4908]                                             

[32mEpoch done on index 0 => 1538 Loss: 0.46610766649246216[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1538 completed. Avg Loss: 0.46610766649246216, Time: 28.20s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.89it/s]


[32mValidation done on process index 0[0m

Epoch 1539/2000


		Epoch 1539:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4456]

First batch loaded at step 786429
Training started for process index 0 at step 786429


		Epoch 1539: 600step [00:25, 23.16step/s, loss=0.5181]                                             

[32mEpoch done on index 0 => 1539 Loss: 0.4708362817764282[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1539 completed. Avg Loss: 0.4708362817764282, Time: 25.91s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.70it/s]


[32mValidation done on process index 0[0m

Epoch 1540/2000


		Epoch 1540:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4338]

First batch loaded at step 786940
Training started for process index 0 at step 786940


		Epoch 1540: 600step [00:27, 22.18step/s, loss=0.4380]                                             

[32mEpoch done on index 0 => 1540 Loss: 0.4671047627925873[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1540 completed. Avg Loss: 0.4671047627925873, Time: 27.05s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 117.90it/s]


[32mValidation done on process index 0[0m

Epoch 1541/2000


		Epoch 1541:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4824]

First batch loaded at step 787451
Training started for process index 0 at step 787451


		Epoch 1541: 600step [00:26, 22.26step/s, loss=0.4518]                                             

[32mEpoch done on index 0 => 1541 Loss: 0.467400461435318[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1541 completed. Avg Loss: 0.467400461435318, Time: 26.96s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.31it/s]


[32mValidation done on process index 0[0m

Epoch 1542/2000


		Epoch 1542:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4646]

First batch loaded at step 787962
Training started for process index 0 at step 787962


		Epoch 1542: 600step [00:27, 22.04step/s, loss=0.5106]                                             

[32mEpoch done on index 0 => 1542 Loss: 0.4687298834323883[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1542 completed. Avg Loss: 0.4687298834323883, Time: 27.23s, Best Loss: 0.4655294418334961[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.26it/s]


[32mValidation done on process index 0[0m

Epoch 1543/2000


		Epoch 1543:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.5002]

First batch loaded at step 788473
Training started for process index 0 at step 788473


		Epoch 1543: 600step [00:27, 21.82step/s, loss=0.4802]                                             

[32mEpoch done on index 0 => 1543 Loss: 0.46508604288101196[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 1543 step 788984





[32m
	Epoch 1543 completed. Avg Loss: 0.46508604288101196, Time: 27.50s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0


100%|██████████| 200/200 [00:01<00:00, 114.16it/s]


[32mValidation done on process index 0[0m

Epoch 1544/2000


		Epoch 1544:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4768]

First batch loaded at step 788984
Training started for process index 0 at step 788984


		Epoch 1544: 600step [00:27, 22.22step/s, loss=0.4652]                                             

[32mEpoch done on index 0 => 1544 Loss: 0.46932077407836914[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1544 completed. Avg Loss: 0.46932077407836914, Time: 27.01s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 115.99it/s]


[32mValidation done on process index 0[0m

Epoch 1545/2000


		Epoch 1545:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4557]

First batch loaded at step 789495
Training started for process index 0 at step 789495


		Epoch 1545: 600step [00:25, 23.30step/s, loss=0.5102]                                             

[32mEpoch done on index 0 => 1545 Loss: 0.4684477746486664[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1545 completed. Avg Loss: 0.4684477746486664, Time: 25.76s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 121.79it/s]


[32mValidation done on process index 0[0m

Epoch 1546/2000


		Epoch 1546:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4899]

First batch loaded at step 790006
Training started for process index 0 at step 790006


		Epoch 1546: 600step [00:26, 22.70step/s, loss=0.4653]                                             

[32mEpoch done on index 0 => 1546 Loss: 0.46814167499542236[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1546 completed. Avg Loss: 0.46814167499542236, Time: 26.43s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 121.79it/s]


[32mValidation done on process index 0[0m

Epoch 1547/2000


		Epoch 1547:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4551]

First batch loaded at step 790517
Training started for process index 0 at step 790517


		Epoch 1547: 600step [00:27, 21.85step/s, loss=0.4790]                                             

[32mEpoch done on index 0 => 1547 Loss: 0.4689295291900635[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1547 completed. Avg Loss: 0.4689295291900635, Time: 27.47s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.49it/s]


[32mValidation done on process index 0[0m

Epoch 1548/2000


		Epoch 1548:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4479]

First batch loaded at step 791028
Training started for process index 0 at step 791028


		Epoch 1548: 600step [00:26, 22.26step/s, loss=0.5068]                                             

[32mEpoch done on index 0 => 1548 Loss: 0.46703776717185974[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1548 completed. Avg Loss: 0.46703776717185974, Time: 26.96s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 120.78it/s]


[32mValidation done on process index 0[0m

Epoch 1549/2000


		Epoch 1549:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4266]

First batch loaded at step 791539
Training started for process index 0 at step 791539


		Epoch 1549: 600step [00:26, 22.45step/s, loss=0.5074]                                             

[32mEpoch done on index 0 => 1549 Loss: 0.46821272373199463[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1549 completed. Avg Loss: 0.46821272373199463, Time: 26.73s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 118.85it/s]


[32mValidation done on process index 0[0m

Epoch 1550/2000


		Epoch 1550:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4748]

First batch loaded at step 792050
Training started for process index 0 at step 792050


		Epoch 1550: 600step [00:26, 22.29step/s, loss=0.4429]                                             

[32mEpoch done on index 0 => 1550 Loss: 0.46687814593315125[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 1550 completed. Avg Loss: 0.46687814593315125, Time: 26.92s, Best Loss: 0.46508604288101196[0m
Validation started for process index 0



100%|██████████| 200/200 [00:01<00:00, 119.06it/s]


[32mValidation done on process index 0[0m

Epoch 1551/2000


		Epoch 1551:   0%|                                          | 0/511 [00:00<?, ?step/s, loss=0.4530]

First batch loaded at step 792561
Training started for process index 0 at step 792561


		Epoch 1551:  98%|███████████████████████████████▎| 500/511 [00:21<00:00, 21.40step/s, loss=0.4651]Process SpawnProcess-21:
Process SpawnProcess-25:
Process SpawnProcess-8:
Process SpawnProcess-20:
Process SpawnProcess-11:
Process SpawnProcess-29:
Process SpawnProcess-16:
Process SpawnProcess-18:
Process SpawnProcess-31:
Process SpawnProcess-17:
Process SpawnProcess-26:
Process SpawnProcess-23:
Process SpawnProcess-7:
Process SpawnProcess-27:
Process SpawnProcess-12:
Process SpawnProcess-5:
Process SpawnProcess-14:
Process SpawnProcess-19:
Process SpawnProcess-30:
Process SpawnProcess-10:
Process SpawnProcess-28:
Process SpawnProcess-2:
Process SpawnProcess-6:
Process SpawnProcess-15:
Process SpawnProcess-13:
Process SpawnProcess-32:
Process SpawnProcess-3:
Process SpawnProcess-24:
Process SpawnProcess-1:
Process SpawnProcess-4:
Process SpawnProcess-9:
Process SpawnProcess-22:
Traceback (most recent call last):
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/multip

KeyboardInterrupt: 

"/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/multiprocessing/queues.py", line 89, in put
    if not self._sem.acquire(block, timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/multiprocessing/queues.py", line 89, in put
    if not self._sem.acquire(block, timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/grain/_src/python/grain_pool.py", line 236, in _worker_loop
    if not multiprocessing_common.add_element_to_queue(  # pytype: disable=wrong-arg-types
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/mrwhite0racle/miniconda3/envs/flaxdiff/lib/python3.11/site-packages/grain/_src/python/grain_pool.py", line 236, in _worker_loop
    if not multiprocessing_common.add_element_to_queue(  # pytype: disable=wrong-arg-types
           ^^^^^^^^^^^^^^^

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f2b1c15dc90>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f4d2413fbd0, execution_count=5 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 7f4563d634d0, raw_cell="# Train the model
final_state = trainer.fit(data, .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Btpu-v4-8/home/mrwhite0racle/persist/FlaxDiff/prototype_general_pipeline.ipynb#W5sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


MailboxClosedError: 