In [None]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer import DiffusionTrainer
from flaxdiff.data.datasets import get_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime

BATCH_SIZE = 16
IMAGE_SIZE = 128

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Define noise scheduler
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
# Define model
unet = Unet(emb_features=256, 
            feature_depths=[64, 64, 128, 256, 512],
            attention_configs=[
                None,
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True}, 
                {"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
                ],
            num_res_blocks=2,
            num_middle_res_blocks=1
)

In [3]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

input_shapes = {
    "x": (IMAGE_SIZE, IMAGE_SIZE, 3),
    "temb": (),
    "textcontext": (77, 768)
}

In [4]:
text_encoder = defaultTextEncodeModel()

2025-04-08 05:32:54.024023: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744090374.049239  527485 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744090374.056681  527485 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744090374.075269  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075312  527485 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744090374.075314  527485 computation_placer.cc:177] computation placer alr

In [5]:
# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield tokens

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)

In [6]:
# Define optimizer
solver = optax.adam(2e-4)

# Create trainer
trainer = DiffusionTrainer(
    unet, optimizer=solver, 
    input_shapes=input_shapes,
    noise_schedule=edm_schedule,
    rngs=jax.random.PRNGKey(4), 
    name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
    model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
    encoder=text_encoder,
    distributed_training=True,
    wandb_config = {
        "project": 'mlops-msml605-project',
        "name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
})


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mashishkumar4[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




Generating states for DiffusionTrainer


In [None]:
# Train the model
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)

		Epoch 96: 600step [00:30, 19.58step/s, loss=0.0896]                                               

[32mEpoch done on index 0 => 96 Loss: 0.07253348082304001[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 96 completed. Avg Loss: 0.07253348082304001, Time: 30.64s, Best Loss: 0.07236480712890625[0m
Validation started for process index 0



100%|██████████| 200/200 [00:25<00:00,  7.85it/s]


[32mValidation done on process index 0[0m

Epoch 97/2000


		Epoch 97:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0821]

First batch loaded at step 49567
Training started for process index 0 at step 49567


		Epoch 97: 600step [00:32, 18.65step/s, loss=0.0585]                                               

[32mEpoch done on index 0 => 97 Loss: 0.07285355776548386[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 97 completed. Avg Loss: 0.07285355776548386, Time: 32.17s, Best Loss: 0.07236480712890625[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.46it/s]


[32mValidation done on process index 0[0m

Epoch 98/2000


		Epoch 98:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0638]

First batch loaded at step 50078
Training started for process index 0 at step 50078


		Epoch 98: 600step [00:33, 18.15step/s, loss=0.0703]                                               

[32mEpoch done on index 0 => 98 Loss: 0.07217347621917725[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 98 step 50589





[32m
	Epoch 98 completed. Avg Loss: 0.07217347621917725, Time: 33.07s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.47it/s]


[32mValidation done on process index 0[0m

Epoch 99/2000


		Epoch 99:   0%|                                            | 0/511 [00:00<?, ?step/s, loss=0.0725]

First batch loaded at step 50589
Training started for process index 0 at step 50589


		Epoch 99: 600step [00:31, 19.33step/s, loss=0.0665]                                               

[32mEpoch done on index 0 => 99 Loss: 0.0727015808224678[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 99 completed. Avg Loss: 0.0727015808224678, Time: 31.04s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:23<00:00,  8.56it/s]


[32mValidation done on process index 0[0m

Epoch 100/2000


		Epoch 100:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0771]

First batch loaded at step 51100
Training started for process index 0 at step 51100


		Epoch 100: 600step [00:30, 19.93step/s, loss=0.0671]                                              

[32mEpoch done on index 0 => 100 Loss: 0.07300964742898941[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 100 completed. Avg Loss: 0.07300964742898941, Time: 30.10s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.14it/s]


[32mValidation done on process index 0[0m

Epoch 101/2000


		Epoch 101:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0830]

First batch loaded at step 51611
Training started for process index 0 at step 51611


		Epoch 101: 600step [00:29, 20.05step/s, loss=0.1006]                                              

[32mEpoch done on index 0 => 101 Loss: 0.07275137305259705[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 101 completed. Avg Loss: 0.07275137305259705, Time: 29.94s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  8.93it/s]


[32mValidation done on process index 0[0m

Epoch 102/2000


		Epoch 102:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0616]

First batch loaded at step 52122
Training started for process index 0 at step 52122


		Epoch 102: 600step [00:30, 19.72step/s, loss=0.0661]                                              

[32mEpoch done on index 0 => 102 Loss: 0.07347641885280609[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 102 completed. Avg Loss: 0.07347641885280609, Time: 30.42s, Best Loss: 0.07217347621917725[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.19it/s]


[32mValidation done on process index 0[0m

Epoch 103/2000


		Epoch 103:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0899]

First batch loaded at step 52633
Training started for process index 0 at step 52633


		Epoch 103: 600step [00:31, 19.12step/s, loss=0.0633]                                              

[32mEpoch done on index 0 => 103 Loss: 0.07193342596292496[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 103 step 53144





[32m
	Epoch 103 completed. Avg Loss: 0.07193342596292496, Time: 31.39s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.35it/s]


[32mValidation done on process index 0[0m

Epoch 104/2000


		Epoch 104:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0651]

First batch loaded at step 53144
Training started for process index 0 at step 53144


		Epoch 104: 600step [00:31, 19.13step/s, loss=0.0980]                                              

[32mEpoch done on index 0 => 104 Loss: 0.07361845672130585[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 104 completed. Avg Loss: 0.07361845672130585, Time: 31.37s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.25it/s]


[32mValidation done on process index 0[0m

Epoch 105/2000


		Epoch 105:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0601]

First batch loaded at step 53655
Training started for process index 0 at step 53655


		Epoch 105: 600step [00:29, 20.58step/s, loss=0.0576]                                              

[32mEpoch done on index 0 => 105 Loss: 0.07254830002784729[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 105 completed. Avg Loss: 0.07254830002784729, Time: 29.16s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.20it/s]


[32mValidation done on process index 0[0m

Epoch 106/2000


		Epoch 106:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0554]

First batch loaded at step 54166
Training started for process index 0 at step 54166


		Epoch 106: 600step [00:30, 19.77step/s, loss=0.0821]                                              

[32mEpoch done on index 0 => 106 Loss: 0.07343030720949173[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 106 completed. Avg Loss: 0.07343030720949173, Time: 30.35s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.36it/s]


[32mValidation done on process index 0[0m

Epoch 107/2000


		Epoch 107:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0768]

First batch loaded at step 54677
Training started for process index 0 at step 54677


		Epoch 107: 600step [00:30, 19.70step/s, loss=0.0593]                                              

[32mEpoch done on index 0 => 107 Loss: 0.07227365672588348[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 107 completed. Avg Loss: 0.07227365672588348, Time: 30.46s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.25it/s]


[32mValidation done on process index 0[0m

Epoch 108/2000


		Epoch 108:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0691]

First batch loaded at step 55188
Training started for process index 0 at step 55188


		Epoch 108: 600step [00:30, 19.76step/s, loss=0.0777]                                              

[32mEpoch done on index 0 => 108 Loss: 0.07215934246778488[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 108 completed. Avg Loss: 0.07215934246778488, Time: 30.37s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.26it/s]


[32mValidation done on process index 0[0m

Epoch 109/2000


		Epoch 109:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0721]

First batch loaded at step 55699
Training started for process index 0 at step 55699


		Epoch 109: 600step [00:30, 19.47step/s, loss=0.0820]                                              

[32mEpoch done on index 0 => 109 Loss: 0.07310600578784943[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 109 completed. Avg Loss: 0.07310600578784943, Time: 30.83s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.29it/s]


[32mValidation done on process index 0[0m

Epoch 110/2000


		Epoch 110:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0875]

First batch loaded at step 56210
Training started for process index 0 at step 56210


		Epoch 110: 600step [00:30, 19.83step/s, loss=0.0729]                                              

[32mEpoch done on index 0 => 110 Loss: 0.07238931953907013[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 110 completed. Avg Loss: 0.07238931953907013, Time: 30.26s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.28it/s]


[32mValidation done on process index 0[0m

Epoch 111/2000


		Epoch 111:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0738]

First batch loaded at step 56721
Training started for process index 0 at step 56721


		Epoch 111: 600step [00:30, 19.89step/s, loss=0.0919]                                              

[32mEpoch done on index 0 => 111 Loss: 0.07308177649974823[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 111 completed. Avg Loss: 0.07308177649974823, Time: 30.18s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  8.77it/s]


[32mValidation done on process index 0[0m

Epoch 112/2000


		Epoch 112:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0695]

First batch loaded at step 57232
Training started for process index 0 at step 57232


		Epoch 112: 600step [00:32, 18.32step/s, loss=0.0778]                                              

[32mEpoch done on index 0 => 112 Loss: 0.07218995690345764[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 112 completed. Avg Loss: 0.07218995690345764, Time: 32.75s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  8.99it/s]


[32mValidation done on process index 0[0m

Epoch 113/2000


		Epoch 113:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0754]

First batch loaded at step 57743
Training started for process index 0 at step 57743


		Epoch 113: 600step [00:31, 19.28step/s, loss=0.0662]                                              

[32mEpoch done on index 0 => 113 Loss: 0.07230154424905777[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 113 completed. Avg Loss: 0.07230154424905777, Time: 31.13s, Best Loss: 0.07193342596292496[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.25it/s]


[32mValidation done on process index 0[0m

Epoch 114/2000


		Epoch 114:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0592]

First batch loaded at step 58254
Training started for process index 0 at step 58254


		Epoch 114: 600step [00:31, 19.08step/s, loss=0.0677]                                              

[32mEpoch done on index 0 => 114 Loss: 0.07148570567369461[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 114 step 58765





[32m
	Epoch 114 completed. Avg Loss: 0.07148570567369461, Time: 31.45s, Best Loss: 0.07148570567369461[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.14it/s]


[32mValidation done on process index 0[0m

Epoch 115/2000


		Epoch 115:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0771]

First batch loaded at step 58765
Training started for process index 0 at step 58765


		Epoch 115: 600step [00:31, 19.12step/s, loss=0.0706]                                              

[32mEpoch done on index 0 => 115 Loss: 0.07351625710725784[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 115 completed. Avg Loss: 0.07351625710725784, Time: 31.38s, Best Loss: 0.07148570567369461[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  8.95it/s]


[32mValidation done on process index 0[0m

Epoch 116/2000


		Epoch 116:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0655]

First batch loaded at step 59276
Training started for process index 0 at step 59276


		Epoch 116: 600step [00:29, 20.21step/s, loss=0.0746]                                              

[32mEpoch done on index 0 => 116 Loss: 0.07141997665166855[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 116 step 59787





[32m
	Epoch 116 completed. Avg Loss: 0.07141997665166855, Time: 29.69s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.37it/s]


[32mValidation done on process index 0[0m

Epoch 117/2000


		Epoch 117:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0779]

First batch loaded at step 59787
Training started for process index 0 at step 59787


		Epoch 117: 600step [00:30, 19.49step/s, loss=0.0711]                                              

[32mEpoch done on index 0 => 117 Loss: 0.07231315225362778[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 117 completed. Avg Loss: 0.07231315225362778, Time: 30.79s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:23<00:00,  8.47it/s]


[32mValidation done on process index 0[0m

Epoch 118/2000


		Epoch 118:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0730]

First batch loaded at step 60298
Training started for process index 0 at step 60298


		Epoch 118: 600step [00:30, 19.50step/s, loss=0.0864]                                              

[32mEpoch done on index 0 => 118 Loss: 0.07244674116373062[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 118 completed. Avg Loss: 0.07244674116373062, Time: 30.78s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.24it/s]


[32mValidation done on process index 0[0m

Epoch 119/2000


		Epoch 119:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0884]

First batch loaded at step 60809
Training started for process index 0 at step 60809


		Epoch 119: 600step [00:30, 19.44step/s, loss=0.0938]                                              

[32mEpoch done on index 0 => 119 Loss: 0.07229142636060715[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 119 completed. Avg Loss: 0.07229142636060715, Time: 30.87s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  9.08it/s]


[32mValidation done on process index 0[0m

Epoch 120/2000


		Epoch 120:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0765]

First batch loaded at step 61320
Training started for process index 0 at step 61320


		Epoch 120: 600step [00:32, 18.74step/s, loss=0.0713]                                              

[32mEpoch done on index 0 => 120 Loss: 0.07193206995725632[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 120 completed. Avg Loss: 0.07193206995725632, Time: 32.02s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.26it/s]


[32mValidation done on process index 0[0m

Epoch 121/2000


		Epoch 121:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0858]

First batch loaded at step 61831
Training started for process index 0 at step 61831


		Epoch 121: 600step [00:32, 18.47step/s, loss=0.0739]                                              

[32mEpoch done on index 0 => 121 Loss: 0.07297328859567642[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 121 completed. Avg Loss: 0.07297328859567642, Time: 32.49s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.13it/s]


[32mValidation done on process index 0[0m

Epoch 122/2000


		Epoch 122:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0694]

First batch loaded at step 62342
Training started for process index 0 at step 62342


		Epoch 122: 600step [00:31, 19.23step/s, loss=0.0571]                                              

[32mEpoch done on index 0 => 122 Loss: 0.07336069643497467[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 122 completed. Avg Loss: 0.07336069643497467, Time: 31.20s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.25it/s]


[32mValidation done on process index 0[0m

Epoch 123/2000


		Epoch 123:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0778]

First batch loaded at step 62853
Training started for process index 0 at step 62853


		Epoch 123: 600step [00:33, 17.80step/s, loss=0.0709]                                              

[32mEpoch done on index 0 => 123 Loss: 0.07186485081911087[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 123 completed. Avg Loss: 0.07186485081911087, Time: 33.72s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.41it/s]


[32mValidation done on process index 0[0m

Epoch 124/2000


		Epoch 124:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0494]

First batch loaded at step 63364
Training started for process index 0 at step 63364


		Epoch 124: 600step [00:30, 19.44step/s, loss=0.0813]                                              

[32mEpoch done on index 0 => 124 Loss: 0.07174386829137802[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 124 completed. Avg Loss: 0.07174386829137802, Time: 30.87s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:23<00:00,  8.67it/s]


[32mValidation done on process index 0[0m

Epoch 125/2000


		Epoch 125:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0514]

First batch loaded at step 63875
Training started for process index 0 at step 63875


		Epoch 125: 600step [00:32, 18.27step/s, loss=0.0701]                                              

[32mEpoch done on index 0 => 125 Loss: 0.0722045749425888[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 125 completed. Avg Loss: 0.0722045749425888, Time: 32.84s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.39it/s]


[32mValidation done on process index 0[0m

Epoch 126/2000


		Epoch 126:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0680]

First batch loaded at step 64386
Training started for process index 0 at step 64386


		Epoch 126: 600step [00:31, 19.19step/s, loss=0.0832]                                              

[32mEpoch done on index 0 => 126 Loss: 0.07238595932722092[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 126 completed. Avg Loss: 0.07238595932722092, Time: 31.28s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.40it/s]


[32mValidation done on process index 0[0m

Epoch 127/2000


		Epoch 127:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0565]

First batch loaded at step 64897
Training started for process index 0 at step 64897


		Epoch 127: 600step [00:31, 19.04step/s, loss=0.0461]                                              

[32mEpoch done on index 0 => 127 Loss: 0.07233043015003204[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 127 completed. Avg Loss: 0.07233043015003204, Time: 31.51s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.13it/s]


[32mValidation done on process index 0[0m

Epoch 128/2000


		Epoch 128:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0834]

First batch loaded at step 65408
Training started for process index 0 at step 65408


		Epoch 128: 600step [00:30, 19.82step/s, loss=0.0681]                                              

[32mEpoch done on index 0 => 128 Loss: 0.07242085039615631[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 128 completed. Avg Loss: 0.07242085039615631, Time: 30.28s, Best Loss: 0.07141997665166855[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.29it/s]


[32mValidation done on process index 0[0m

Epoch 129/2000


		Epoch 129:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0679]

First batch loaded at step 65919
Training started for process index 0 at step 65919


		Epoch 129: 600step [00:31, 19.31step/s, loss=0.0738]                                              

[32mEpoch done on index 0 => 129 Loss: 0.07110010832548141[0m
[32mEpoch done on process index 0[0m
Saving model at epoch 129 step 66430





[32m
	Epoch 129 completed. Avg Loss: 0.07110010832548141, Time: 31.07s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0


100%|██████████| 200/200 [00:21<00:00,  9.39it/s]


[32mValidation done on process index 0[0m

Epoch 130/2000


		Epoch 130:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0819]

First batch loaded at step 66430
Training started for process index 0 at step 66430


		Epoch 130: 600step [00:32, 18.64step/s, loss=0.0956]                                              

[32mEpoch done on index 0 => 130 Loss: 0.0730186402797699[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 130 completed. Avg Loss: 0.0730186402797699, Time: 32.19s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.22it/s]


[32mValidation done on process index 0[0m

Epoch 131/2000


		Epoch 131:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0796]

First batch loaded at step 66941
Training started for process index 0 at step 66941


		Epoch 131: 600step [00:30, 19.87step/s, loss=0.0655]                                              

[32mEpoch done on index 0 => 131 Loss: 0.07194973528385162[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 131 completed. Avg Loss: 0.07194973528385162, Time: 30.20s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:20<00:00,  9.55it/s]


[32mValidation done on process index 0[0m

Epoch 132/2000


		Epoch 132:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0786]

First batch loaded at step 67452
Training started for process index 0 at step 67452


		Epoch 132: 600step [00:31, 19.23step/s, loss=0.0799]                                              

[32mEpoch done on index 0 => 132 Loss: 0.07239669561386108[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 132 completed. Avg Loss: 0.07239669561386108, Time: 31.20s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  9.02it/s]


[32mValidation done on process index 0[0m

Epoch 133/2000


		Epoch 133:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0783]

First batch loaded at step 67963
Training started for process index 0 at step 67963


		Epoch 133: 600step [00:30, 19.72step/s, loss=0.0749]                                              

[32mEpoch done on index 0 => 133 Loss: 0.07230673730373383[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 133 completed. Avg Loss: 0.07230673730373383, Time: 30.42s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.49it/s]


[32mValidation done on process index 0[0m

Epoch 134/2000


		Epoch 134:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0597]

First batch loaded at step 68474
Training started for process index 0 at step 68474


		Epoch 134: 600step [00:30, 19.48step/s, loss=0.0861]                                              

[32mEpoch done on index 0 => 134 Loss: 0.07289659231901169[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 134 completed. Avg Loss: 0.07289659231901169, Time: 30.81s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:22<00:00,  9.08it/s]


[32mValidation done on process index 0[0m

Epoch 135/2000


		Epoch 135:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0860]

First batch loaded at step 68985
Training started for process index 0 at step 68985


		Epoch 135: 600step [00:32, 18.64step/s, loss=0.0854]                                              

[32mEpoch done on index 0 => 135 Loss: 0.07352910935878754[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 135 completed. Avg Loss: 0.07352910935878754, Time: 32.19s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.27it/s]


[32mValidation done on process index 0[0m

Epoch 136/2000


		Epoch 136:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0811]

First batch loaded at step 69496
Training started for process index 0 at step 69496


		Epoch 136: 600step [00:30, 19.80step/s, loss=0.1005]                                              

[32mEpoch done on index 0 => 136 Loss: 0.07220148295164108[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 136 completed. Avg Loss: 0.07220148295164108, Time: 30.31s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



100%|██████████| 200/200 [00:21<00:00,  9.46it/s]


[32mValidation done on process index 0[0m

Epoch 137/2000


		Epoch 137:   0%|                                           | 0/511 [00:00<?, ?step/s, loss=0.0691]

First batch loaded at step 70007
Training started for process index 0 at step 70007


		Epoch 137: 600step [00:31, 19.33step/s, loss=0.0813]                                              

[32mEpoch done on index 0 => 137 Loss: 0.07264856994152069[0m
[32mEpoch done on process index 0[0m
[32m
	Epoch 137 completed. Avg Loss: 0.07264856994152069, Time: 31.04s, Best Loss: 0.07110010832548141[0m
Validation started for process index 0



  0%|          | 0/200 [00:00<?, ?it/s]