In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler
import math
import glob
import random
import json

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import os
import json
from tqdm import tqdm
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline
from transformers import CLIPTokenizer, CLIPTextModel

In [6]:
import torch
import random
import numpy as np

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
# Sinusoidal timestep embedding for diffusion steps
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.exp(
        torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
        -(torch.log(torch.tensor(10000.0)) / half_dim)
    )
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # Handle odd embedding dimensions
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb

# Residual block with time and context embeddings
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, context=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        t_proj = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_proj

        # Add context embedding if available
        if self.context_proj is not None and context is not None:
            context_pooled = context.mean(dim=1)  # [batch, context_dim]
            context_proj = self.context_proj(context_pooled)[:, :, None, None]
            h = h + context_proj

        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.shortcut(x)

# Cross-attention to integrate text embeddings
class CrossAttention(nn.Module):
    def __init__(self, channels, context_dim):
        super().__init__()
        self.channels = channels
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(context_dim, channels)
        self.value = nn.Linear(context_dim, channels)
        self.out = nn.Linear(channels, channels)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x, context):
        if context is None:
            return x

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_norm = self.norm(x_flat)

        q = self.query(x_norm)  # [B, H*W, C]
        k = self.key(context)   # [B, seq_len, C]
        v = self.value(context) # [B, seq_len, C]

        scale = (C ** -0.5)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_out = torch.bmm(attn_weights, v)
        attn_out = self.out(attn_out)

        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Self-attention block for image features
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, channels), channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        scale = (C ** -0.5)
        attn = torch.bmm(q.transpose(1, 2), k) * scale
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        return self.proj(out) + x

# U-Net model updated for 256x256 latents
class UNetConditional(nn.Module):
    def __init__(self, in_channels=4, base_channels=128, context_dim=768):
        super().__init__()
        self.time_emb_dim = base_channels * 4
        from types import SimpleNamespace
        self.config = SimpleNamespace()
        self.config._diffusers_version = "0.34.0"
        self.config.in_channels = in_channels
        self.config.out_channels = in_channels
        self.config.sample_size = 256  # Updated for 256x256 latents
        self.config.layers_per_block = 2
        self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
        self.config.attention_head_dim = 8
        self.config.cross_attention_dim = context_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(base_channels, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        )

        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
        self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
        self.cross1 = CrossAttention(base_channels * 2, context_dim)

        self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
        self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
        self.cross2 = CrossAttention(base_channels * 4, context_dim)

        self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
        self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
        self.cross3 = CrossAttention(base_channels * 8, context_dim)

        # Middle
        self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
        self.middle_attn = AttentionBlock(base_channels * 8)
        self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)

        # Decoder
        self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
        self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
        self.cross_up3 = CrossAttention(base_channels * 4, context_dim)

        self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
        self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
        self.cross_up2 = CrossAttention(base_channels * 2, context_dim)

        self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
        self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)

        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(min(32, base_channels), base_channels),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, context, cfg_scale=1.0):
        t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
        t_emb = self.time_mlp(t_emb)

        def denoise(x, t_emb, context):
            h = self.input_conv(x)

            # Encoder
            h1 = self.down1(h, t_emb, context)
            h1_cross = self.cross1(h1, context)
            h1_down = self.downsample1(h1_cross)

            h2 = self.down2(h1_down, t_emb, context)
            h2_cross = self.cross2(h2, context)
            h2_down = self.downsample2(h2_cross)

            h3 = self.down3(h2_down, t_emb, context)
            h3_cross = self.cross3(h3, context)
            h3_down = self.downsample3(h3_cross)

            # Middle
            h_mid = self.middle1(h3_down, t_emb, context)
            h_mid = self.middle_attn(h_mid)
            h_mid = self.middle2(h_mid, t_emb, context)

            # Decoder
            h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
            h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
            h = self.upsample3(h)
            h = self.cross_up3(h, context)

            h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
            h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
            h = self.upsample2(h)
            h = self.cross_up2(h, context)

            h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
            h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
            h = self.upsample1(h)

            return self.output_conv(h)

        if cfg_scale == 1.0 or context is None:
            return denoise(x, t_emb, context)

        uncond = denoise(x, t_emb, context=None)
        cond = denoise(x, t_emb, context)
        return uncond + cfg_scale * (cond - uncond)

# Dataset for pixel art images
class PixelArtDataset(Dataset):
    def __init__(self, images_folder, captions, image_size=256):
        self.images_folder = images_folder
        # Adjusting to handle the provided JSON format
        self.captions = [(k, v) for k, v in captions.items() if isinstance(v, str)]
        self.image_size = image_size
        self.transform = T.Compose([
            T.Resize((image_size, image_size), interpolation=T.InterpolationMode.LANCZOS),
            T.ToTensor(),
            T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
        print(f"Initialized dataset with {len(self.captions)} valid image-caption pairs")


    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        max_retries = 5
        for retry in range(max_retries):
            try:
                current_idx = (idx + retry) % len(self.captions)
                image_name, prompt = self.captions[current_idx]
                image_path = os.path.join(self.images_folder, image_name)

                if not os.path.exists(image_path):
                    print(f"Warning: Image not found: {image_path}")
                    continue

                image = Image.open(image_path).convert("RGB")
                image_tensor = self.transform(image)

                return {
                    "image": image_tensor,
                    "prompt": prompt
                }
            except Exception as e:
                print(f"Error loading item {current_idx}: {e}")
                if retry == max_retries - 1:
                    return {
                        "image": torch.zeros(3, self.image_size, self.image_size),
                        "prompt": "a pixel art image"
                    }
                continue

# Collate function for batching
def collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    prompts = [item['prompt'] for item in batch]
    return {'image': images, 'prompt': prompts}

# Generate test images with the trained model
def generate_test_images(vae, text_encoder, tokenizer, unet, scheduler, device="cuda", num_inference_steps=20, guidance_scale=7.5):
    print(f"Using device: {device}")

    test_prompts = [
        "a pixel art cat",
        "a pixel art house",
        "a pixel art tree",
        "james graham ballard, highrise, sustainability, octane render, highly detailed",
        "pixel art landscape"
    ]

    print("🎨 Generating test images directly...")

    for i, prompt in enumerate(test_prompts):
        print(f"Generating: {prompt}")

        try:
            with torch.no_grad():
                # Encode prompt
                inputs = tokenizer(
                    prompt,
                    padding="max_length",
                    truncation=True,
                    max_length=tokenizer.model_max_length,
                    return_tensors="pt"
                )
                text_embeddings = text_encoder(inputs.input_ids.to(device))[0]

                # Generate unconditional embeddings for classifier-free guidance
                if guidance_scale > 1.0:
                    uncond_inputs = tokenizer(
                        [""], # Use an empty string for unconditional generation
                        padding="max_length",
                        truncation=True,
                        max_length=tokenizer.model_max_length,
                        return_tensors="pt"
                    )
                    uncond_embeddings = text_encoder(uncond_inputs.input_ids.to(device))[0]
                    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

                # Initial random latents (batch size 1, 4 channels, 32x32 for 256x256 output)
                latents = torch.randn(
                    1,
                    unet.config.in_channels,
                    unet.config.sample_size // 8, # VAE scaling
                    unet.config.sample_size // 8,
                    generator=torch.Generator(device=device).manual_seed(42 + i) # Set seed for reproducibility
                ).to(device)

                # Set timesteps
                scheduler.set_timesteps(num_inference_steps, device=device)
                timesteps = scheduler.timesteps

                # Denoising loop
                for t in tqdm(timesteps, desc=f"Denoising {prompt}"):
                    # Expand latents for classifier-free guidance
                    latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
                    latent_model_input = scheduler.scale_model_input(latent_model_input, t)

                    # Predict noise
                    # Ensure timestep is on the same device as latents and unet
                    t_tensor = torch.tensor([t], device=latents.device)
                    if guidance_scale > 1.0:
                         # Pass the concatenated text embeddings
                        noise_pred = unet(latent_model_input, t_tensor.repeat(2), context=text_embeddings)
                        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
                    else:
                         noise_pred = unet(latent_model_input, t_tensor, context=text_embeddings)


                    # Compute the previous noisy sample x_t -> x_t-1
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                # Decode latents to image
                latents = 1 / vae.config.scaling_factor * latents # Rescale latents
                image = vae.decode(latents).sample

                # Post-processing
                image = (image / 2 + 0.5).clamp(0, 1) # Denormalize and clamp
                image = image.cpu().permute(0, 2, 3, 1).numpy()[0] # Move to CPU, permute, convert to numpy

                # Convert to PIL Image and save
                image = Image.fromarray((image * 255).astype(np.uint8))
                safe_filename = f"/content/generated_256_{i+1}_{prompt[:30].replace(' ', '_').replace(',', '').replace('.', '')}.png"
                image.save(safe_filename)
                print(f"✅ Saved: {safe_filename}")

        except Exception as e:
            print(f"❌ Error generating '{prompt}': {e}")
            print(f"Error type: {type(e).__name__}")
            # Continue with next prompt
            continue

# Main training function
def train_diffusion_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load dataset
    images_folder = 'kkm'
    captions_file = 'kahab.json'
    with open(captions_file, 'r') as f:
        captions = json.load(f)
    print(f"Loaded {len(captions)} image-caption pairs")

    dataset = PixelArtDataset(images_folder, captions, image_size=256)
    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
        collate_fn=collate_fn,
        drop_last=True
    )

    # Initialize models
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
    vae = vae.eval().requires_grad_(False).to(device)

    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = text_encoder.eval().requires_grad_(False).to(device)

    unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768).to(device)

    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    optimizer = optim.AdamW(unet.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-6)

    num_epochs = 600
    save_every = 100
    print("Starting training...")

    for epoch in range(num_epochs):
        epoch_loss = 0
        successful_batches = 0
        unet.train()
        progress_bar = tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch_idx, batch in enumerate(progress_bar):
            try:
                images = batch['image'].to(device, non_blocking=True)
                prompts = batch['prompt']

                if images.shape[0] == 0:
                    print(f"Empty batch {batch_idx}, skipping...")
                    continue

                inputs = tokenizer(
                    prompts,
                    padding="max_length",
                    truncation=True,
                    max_length=77,
                    return_tensors="pt"
                )
                inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}

                with torch.no_grad():
                    text_embeddings = text_encoder(**inputs).last_hidden_state
                    latents = vae.encode(images).latent_dist.sample() * 0.18215

                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps,
                    (latents.size(0),), device=device
                ).long()

                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                pred_noise = unet(noisy_latents, timesteps, text_embeddings)

                if pred_noise.shape != noise.shape:
                    print(f"Shape mismatch: pred_noise {pred_noise.shape} vs noise {noise.shape}")
                    continue

                loss = F.mse_loss(pred_noise, noise)
                if torch.isnan(loss):
                    print(f"NaN loss detected in batch {batch_idx}, skipping...")
                    continue

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
                optimizer.step()

                epoch_loss += loss.item()
                successful_batches += 1
                progress_bar.set_postfix({
                    "Loss": f"{loss.item():.4f}",
                    "Success": f"{successful_batches}/{batch_idx+1}"
                })

            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue

        if successful_batches > 0:
            scheduler.step()
            avg_loss = epoch_loss / successful_batches
            print(f"Epoch {epoch+1} completed. Average Loss: {avg_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")
        else:
            print(f"Epoch {epoch+1} failed - no successful batches")

        if (epoch + 1) % save_every == 0:
            checkpoint = {
                'epoch': epoch,
                'model_state_dict': unet.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': avg_loss if successful_batches > 0 else float('inf'),
            }
            torch.save(checkpoint, f'output/checkpoint_epoch_{epoch+1}.pt')
            print(f"Checkpoint saved at epoch {epoch+1}")

    print("Training completed!")
    return vae, text_encoder, tokenizer, unet, noise_scheduler

if __name__ == "__main__":
    try:
        vae, text_encoder, tokenizer, trained_unet, scheduler = train_diffusion_model()
        print("Training completed successfully!")

        generate_test_images(vae, text_encoder, tokenizer, trained_unet, scheduler)
    
        torch.save(trained_unet.state_dict(), "output/unet_weights.pt")
        print(" UNet weights saved to output/unet_weights.pt")

    except Exception as e:
        print(f" Training failed: {e}")

Using device: cuda
Loaded 37 image-caption pairs
Initialized dataset with 37 valid image-caption pairs
Starting training...


Epoch 1/600: 100%|█████████████████████████████████████████| 37/37 [00:12<00:00,  3.06it/s, Loss=0.8852, Success=37/37]


Epoch 1 completed. Average Loss: 0.9080, LR: 1.00e-04


Epoch 2/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.4182, Success=37/37]


Epoch 2 completed. Average Loss: 0.6114, LR: 1.00e-04


Epoch 3/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.1970, Success=37/37]


Epoch 3 completed. Average Loss: 0.4840, LR: 1.00e-04


Epoch 4/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.5956, Success=37/37]


Epoch 4 completed. Average Loss: 0.3293, LR: 1.00e-04


Epoch 5/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.2913, Success=37/37]


Epoch 5 completed. Average Loss: 0.3789, LR: 1.00e-04


Epoch 6/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.4226, Success=37/37]


Epoch 6 completed. Average Loss: 0.2568, LR: 1.00e-04


Epoch 7/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.2903, Success=37/37]


Epoch 7 completed. Average Loss: 0.2839, LR: 1.00e-04


Epoch 8/600: 100%|█████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.3623, Success=37/37]


Epoch 8 completed. Average Loss: 0.2489, LR: 1.00e-04


Epoch 9/600: 100%|█████████████████████████████████████████| 37/37 [00:11<00:00,  3.28it/s, Loss=0.4184, Success=37/37]


Epoch 9 completed. Average Loss: 0.2537, LR: 1.00e-04


Epoch 10/600: 100%|████████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.8030, Success=37/37]


Epoch 10 completed. Average Loss: 0.2156, LR: 1.00e-04


Epoch 11/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.36it/s, Loss=0.1605, Success=37/37]


Epoch 11 completed. Average Loss: 0.2156, LR: 1.00e-04


Epoch 12/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.7339, Success=37/37]


Epoch 12 completed. Average Loss: 0.2466, LR: 1.00e-04


Epoch 13/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.4112, Success=37/37]


Epoch 13 completed. Average Loss: 0.1876, LR: 1.00e-04


Epoch 14/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0315, Success=37/37]


Epoch 14 completed. Average Loss: 0.2538, LR: 1.00e-04


Epoch 15/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.1166, Success=37/37]


Epoch 15 completed. Average Loss: 0.1702, LR: 9.99e-05


Epoch 16/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0877, Success=37/37]


Epoch 16 completed. Average Loss: 0.1751, LR: 9.99e-05


Epoch 17/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.2893, Success=37/37]


Epoch 17 completed. Average Loss: 0.1823, LR: 9.99e-05


Epoch 18/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0258, Success=37/37]


Epoch 18 completed. Average Loss: 0.1586, LR: 9.99e-05


Epoch 19/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.36it/s, Loss=0.0430, Success=37/37]


Epoch 19 completed. Average Loss: 0.2326, LR: 9.99e-05


Epoch 20/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0305, Success=37/37]


Epoch 20 completed. Average Loss: 0.2568, LR: 9.99e-05


Epoch 21/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.3895, Success=37/37]


Epoch 21 completed. Average Loss: 0.2709, LR: 9.99e-05


Epoch 22/600: 100%|████████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0276, Success=37/37]


Epoch 22 completed. Average Loss: 0.1959, LR: 9.99e-05


Epoch 23/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.5903, Success=37/37]


Epoch 23 completed. Average Loss: 0.1908, LR: 9.99e-05


Epoch 24/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0215, Success=37/37]


Epoch 24 completed. Average Loss: 0.2056, LR: 9.99e-05


Epoch 25/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.4962, Success=37/37]


Epoch 25 completed. Average Loss: 0.1582, LR: 9.98e-05


Epoch 26/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.3102, Success=37/37]


Epoch 26 completed. Average Loss: 0.2375, LR: 9.98e-05


Epoch 27/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.44it/s, Loss=0.2632, Success=37/37]


Epoch 27 completed. Average Loss: 0.1783, LR: 9.98e-05


Epoch 28/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0313, Success=37/37]


Epoch 28 completed. Average Loss: 0.2329, LR: 9.98e-05


Epoch 29/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0216, Success=37/37]


Epoch 29 completed. Average Loss: 0.2032, LR: 9.98e-05


Epoch 30/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.8924, Success=37/37]


Epoch 30 completed. Average Loss: 0.2258, LR: 9.98e-05


Epoch 31/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0195, Success=37/37]


Epoch 31 completed. Average Loss: 0.1902, LR: 9.98e-05


Epoch 32/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0685, Success=37/37]


Epoch 32 completed. Average Loss: 0.1856, LR: 9.98e-05


Epoch 33/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.6296, Success=37/37]


Epoch 33 completed. Average Loss: 0.1870, LR: 9.97e-05


Epoch 34/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.1229, Success=37/37]


Epoch 34 completed. Average Loss: 0.1411, LR: 9.97e-05


Epoch 35/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0245, Success=37/37]


Epoch 35 completed. Average Loss: 0.2871, LR: 9.97e-05


Epoch 36/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0178, Success=37/37]


Epoch 36 completed. Average Loss: 0.2403, LR: 9.97e-05


Epoch 37/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.2088, Success=37/37]


Epoch 37 completed. Average Loss: 0.1925, LR: 9.97e-05


Epoch 38/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.3533, Success=37/37]


Epoch 38 completed. Average Loss: 0.3055, LR: 9.96e-05


Epoch 39/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0270, Success=37/37]


Epoch 39 completed. Average Loss: 0.1694, LR: 9.96e-05


Epoch 40/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1548, Success=37/37]


Epoch 40 completed. Average Loss: 0.1561, LR: 9.96e-05


Epoch 41/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0175, Success=37/37]


Epoch 41 completed. Average Loss: 0.2526, LR: 9.96e-05


Epoch 42/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.4942, Success=37/37]


Epoch 42 completed. Average Loss: 0.2582, LR: 9.96e-05


Epoch 43/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0198, Success=37/37]


Epoch 43 completed. Average Loss: 0.2010, LR: 9.95e-05


Epoch 44/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.2760, Success=37/37]


Epoch 44 completed. Average Loss: 0.1489, LR: 9.95e-05


Epoch 45/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.4502, Success=37/37]


Epoch 45 completed. Average Loss: 0.1618, LR: 9.95e-05


Epoch 46/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0525, Success=37/37]


Epoch 46 completed. Average Loss: 0.2645, LR: 9.95e-05


Epoch 47/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1603, Success=37/37]


Epoch 47 completed. Average Loss: 0.1487, LR: 9.95e-05


Epoch 48/600: 100%|████████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0184, Success=37/37]


Epoch 48 completed. Average Loss: 0.1031, LR: 9.94e-05


Epoch 49/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0310, Success=37/37]


Epoch 49 completed. Average Loss: 0.2587, LR: 9.94e-05


Epoch 50/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.2229, Success=37/37]


Epoch 50 completed. Average Loss: 0.1470, LR: 9.94e-05


Epoch 51/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0556, Success=37/37]


Epoch 51 completed. Average Loss: 0.1940, LR: 9.94e-05


Epoch 52/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0293, Success=37/37]


Epoch 52 completed. Average Loss: 0.2020, LR: 9.93e-05


Epoch 53/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0264, Success=37/37]


Epoch 53 completed. Average Loss: 0.1331, LR: 9.93e-05


Epoch 54/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.1333, Success=37/37]


Epoch 54 completed. Average Loss: 0.1306, LR: 9.93e-05


Epoch 55/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1958, Success=37/37]


Epoch 55 completed. Average Loss: 0.1931, LR: 9.93e-05


Epoch 56/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0690, Success=37/37]


Epoch 56 completed. Average Loss: 0.1491, LR: 9.92e-05


Epoch 57/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0180, Success=37/37]


Epoch 57 completed. Average Loss: 0.1787, LR: 9.92e-05


Epoch 58/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=1.0095, Success=37/37]


Epoch 58 completed. Average Loss: 0.2032, LR: 9.92e-05


Epoch 59/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0268, Success=37/37]


Epoch 59 completed. Average Loss: 0.1596, LR: 9.92e-05


Epoch 60/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0162, Success=37/37]


Epoch 60 completed. Average Loss: 0.1936, LR: 9.91e-05


Epoch 61/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0902, Success=37/37]


Epoch 61 completed. Average Loss: 0.1154, LR: 9.91e-05


Epoch 62/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.9524, Success=37/37]


Epoch 62 completed. Average Loss: 0.1500, LR: 9.91e-05


Epoch 63/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0188, Success=37/37]


Epoch 63 completed. Average Loss: 0.2192, LR: 9.90e-05


Epoch 64/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0147, Success=37/37]


Epoch 64 completed. Average Loss: 0.1655, LR: 9.90e-05


Epoch 65/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0227, Success=37/37]


Epoch 65 completed. Average Loss: 0.1443, LR: 9.90e-05


Epoch 66/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.2711, Success=37/37]


Epoch 66 completed. Average Loss: 0.1311, LR: 9.89e-05


Epoch 67/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0632, Success=37/37]


Epoch 67 completed. Average Loss: 0.1546, LR: 9.89e-05


Epoch 68/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0365, Success=37/37]


Epoch 68 completed. Average Loss: 0.1507, LR: 9.89e-05


Epoch 69/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0137, Success=37/37]


Epoch 69 completed. Average Loss: 0.1228, LR: 9.88e-05


Epoch 70/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0216, Success=37/37]


Epoch 70 completed. Average Loss: 0.1090, LR: 9.88e-05


Epoch 71/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.5672, Success=37/37]


Epoch 71 completed. Average Loss: 0.1919, LR: 9.88e-05


Epoch 72/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.6859, Success=37/37]


Epoch 72 completed. Average Loss: 0.1259, LR: 9.87e-05


Epoch 73/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0117, Success=37/37]


Epoch 73 completed. Average Loss: 0.1228, LR: 9.87e-05


Epoch 74/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0442, Success=37/37]


Epoch 74 completed. Average Loss: 0.2421, LR: 9.87e-05


Epoch 75/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0138, Success=37/37]


Epoch 75 completed. Average Loss: 0.2013, LR: 9.86e-05


Epoch 76/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0132, Success=37/37]


Epoch 76 completed. Average Loss: 0.1835, LR: 9.86e-05


Epoch 77/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0134, Success=37/37]


Epoch 77 completed. Average Loss: 0.1470, LR: 9.86e-05


Epoch 78/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0320, Success=37/37]


Epoch 78 completed. Average Loss: 0.1350, LR: 9.85e-05


Epoch 79/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0299, Success=37/37]


Epoch 79 completed. Average Loss: 0.1541, LR: 9.85e-05


Epoch 80/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0394, Success=37/37]


Epoch 80 completed. Average Loss: 0.0971, LR: 9.84e-05


Epoch 81/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.3152, Success=37/37]


Epoch 81 completed. Average Loss: 0.2189, LR: 9.84e-05


Epoch 82/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0108, Success=37/37]


Epoch 82 completed. Average Loss: 0.2666, LR: 9.84e-05


Epoch 83/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0106, Success=37/37]


Epoch 83 completed. Average Loss: 0.1683, LR: 9.83e-05


Epoch 84/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.44it/s, Loss=0.0997, Success=37/37]


Epoch 84 completed. Average Loss: 0.1834, LR: 9.83e-05


Epoch 85/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.2304, Success=37/37]


Epoch 85 completed. Average Loss: 0.1076, LR: 9.82e-05


Epoch 86/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0115, Success=37/37]


Epoch 86 completed. Average Loss: 0.1232, LR: 9.82e-05


Epoch 87/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.7482, Success=37/37]


Epoch 87 completed. Average Loss: 0.1312, LR: 9.82e-05


Epoch 88/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0140, Success=37/37]


Epoch 88 completed. Average Loss: 0.1397, LR: 9.81e-05


Epoch 89/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0138, Success=37/37]


Epoch 89 completed. Average Loss: 0.1587, LR: 9.81e-05


Epoch 90/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0092, Success=37/37]


Epoch 90 completed. Average Loss: 0.0979, LR: 9.80e-05


Epoch 91/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0126, Success=37/37]


Epoch 91 completed. Average Loss: 0.1243, LR: 9.80e-05


Epoch 92/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0114, Success=37/37]


Epoch 92 completed. Average Loss: 0.2098, LR: 9.79e-05


Epoch 93/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0115, Success=37/37]


Epoch 93 completed. Average Loss: 0.1237, LR: 9.79e-05


Epoch 94/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0157, Success=37/37]


Epoch 94 completed. Average Loss: 0.1457, LR: 9.79e-05


Epoch 95/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0114, Success=37/37]


Epoch 95 completed. Average Loss: 0.1426, LR: 9.78e-05


Epoch 96/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0222, Success=37/37]


Epoch 96 completed. Average Loss: 0.1096, LR: 9.78e-05


Epoch 97/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0114, Success=37/37]


Epoch 97 completed. Average Loss: 0.1694, LR: 9.77e-05


Epoch 98/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0129, Success=37/37]


Epoch 98 completed. Average Loss: 0.1517, LR: 9.77e-05


Epoch 99/600: 100%|████████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0156, Success=37/37]


Epoch 99 completed. Average Loss: 0.1447, LR: 9.76e-05


Epoch 100/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0281, Success=37/37]


Epoch 100 completed. Average Loss: 0.1288, LR: 9.76e-05
Checkpoint saved at epoch 100


Epoch 101/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.1226, Success=37/37]


Epoch 101 completed. Average Loss: 0.1160, LR: 9.75e-05


Epoch 102/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0126, Success=37/37]


Epoch 102 completed. Average Loss: 0.1093, LR: 9.75e-05


Epoch 103/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0100, Success=37/37]


Epoch 103 completed. Average Loss: 0.1420, LR: 9.74e-05


Epoch 104/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0102, Success=37/37]


Epoch 104 completed. Average Loss: 0.0778, LR: 9.74e-05


Epoch 105/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.1018, Success=37/37]


Epoch 105 completed. Average Loss: 0.0865, LR: 9.73e-05


Epoch 106/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.6552, Success=37/37]


Epoch 106 completed. Average Loss: 0.1516, LR: 9.73e-05


Epoch 107/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0133, Success=37/37]


Epoch 107 completed. Average Loss: 0.0954, LR: 9.72e-05


Epoch 108/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0113, Success=37/37]


Epoch 108 completed. Average Loss: 0.1042, LR: 9.72e-05


Epoch 109/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.4785, Success=37/37]


Epoch 109 completed. Average Loss: 0.1524, LR: 9.71e-05


Epoch 110/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.9280, Success=37/37]


Epoch 110 completed. Average Loss: 0.1523, LR: 9.71e-05


Epoch 111/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0130, Success=37/37]


Epoch 111 completed. Average Loss: 0.1441, LR: 9.70e-05


Epoch 112/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0104, Success=37/37]


Epoch 112 completed. Average Loss: 0.1129, LR: 9.70e-05


Epoch 113/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0313, Success=37/37]


Epoch 113 completed. Average Loss: 0.1058, LR: 9.69e-05


Epoch 114/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0493, Success=37/37]


Epoch 114 completed. Average Loss: 0.0714, LR: 9.69e-05


Epoch 115/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.8640, Success=37/37]


Epoch 115 completed. Average Loss: 0.1137, LR: 9.68e-05


Epoch 116/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0101, Success=37/37]


Epoch 116 completed. Average Loss: 0.0866, LR: 9.67e-05


Epoch 117/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.3204, Success=37/37]


Epoch 117 completed. Average Loss: 0.0557, LR: 9.67e-05


Epoch 118/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0170, Success=37/37]


Epoch 118 completed. Average Loss: 0.1502, LR: 9.66e-05


Epoch 119/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0134, Success=37/37]


Epoch 119 completed. Average Loss: 0.0925, LR: 9.66e-05


Epoch 120/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0076, Success=37/37]


Epoch 120 completed. Average Loss: 0.1035, LR: 9.65e-05


Epoch 121/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0231, Success=37/37]


Epoch 121 completed. Average Loss: 0.1253, LR: 9.65e-05


Epoch 122/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0233, Success=37/37]


Epoch 122 completed. Average Loss: 0.1026, LR: 9.64e-05


Epoch 123/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0087, Success=37/37]


Epoch 123 completed. Average Loss: 0.0486, LR: 9.64e-05


Epoch 124/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0118, Success=37/37]


Epoch 124 completed. Average Loss: 0.0758, LR: 9.63e-05


Epoch 125/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1343, Success=37/37]


Epoch 125 completed. Average Loss: 0.0973, LR: 9.62e-05


Epoch 126/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0073, Success=37/37]


Epoch 126 completed. Average Loss: 0.0759, LR: 9.62e-05


Epoch 127/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0100, Success=37/37]


Epoch 127 completed. Average Loss: 0.1636, LR: 9.61e-05


Epoch 128/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0117, Success=37/37]


Epoch 128 completed. Average Loss: 0.1605, LR: 9.61e-05


Epoch 129/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0201, Success=37/37]


Epoch 129 completed. Average Loss: 0.0672, LR: 9.60e-05


Epoch 130/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0557, Success=37/37]


Epoch 130 completed. Average Loss: 0.1537, LR: 9.59e-05


Epoch 131/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0123, Success=37/37]


Epoch 131 completed. Average Loss: 0.1611, LR: 9.59e-05


Epoch 132/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0107, Success=37/37]


Epoch 132 completed. Average Loss: 0.1150, LR: 9.58e-05


Epoch 133/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0242, Success=37/37]


Epoch 133 completed. Average Loss: 0.0302, LR: 9.57e-05


Epoch 134/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0090, Success=37/37]


Epoch 134 completed. Average Loss: 0.1103, LR: 9.57e-05


Epoch 135/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0107, Success=37/37]


Epoch 135 completed. Average Loss: 0.1225, LR: 9.56e-05


Epoch 136/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0165, Success=37/37]


Epoch 136 completed. Average Loss: 0.1179, LR: 9.56e-05


Epoch 137/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0542, Success=37/37]


Epoch 137 completed. Average Loss: 0.0791, LR: 9.55e-05


Epoch 138/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0748, Success=37/37]


Epoch 138 completed. Average Loss: 0.0900, LR: 9.54e-05


Epoch 139/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0250, Success=37/37]


Epoch 139 completed. Average Loss: 0.0627, LR: 9.54e-05


Epoch 140/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0070, Success=37/37]


Epoch 140 completed. Average Loss: 0.0775, LR: 9.53e-05


Epoch 141/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0553, Success=37/37]


Epoch 141 completed. Average Loss: 0.0943, LR: 9.52e-05


Epoch 142/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0227, Success=37/37]


Epoch 142 completed. Average Loss: 0.0938, LR: 9.52e-05


Epoch 143/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0477, Success=37/37]


Epoch 143 completed. Average Loss: 0.0840, LR: 9.51e-05


Epoch 144/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0343, Success=37/37]


Epoch 144 completed. Average Loss: 0.1184, LR: 9.50e-05


Epoch 145/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0071, Success=37/37]


Epoch 145 completed. Average Loss: 0.1183, LR: 9.50e-05


Epoch 146/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0214, Success=37/37]


Epoch 146 completed. Average Loss: 0.1140, LR: 9.49e-05


Epoch 147/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0101, Success=37/37]


Epoch 147 completed. Average Loss: 0.1376, LR: 9.48e-05


Epoch 148/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0702, Success=37/37]


Epoch 148 completed. Average Loss: 0.1165, LR: 9.47e-05


Epoch 149/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0074, Success=37/37]


Epoch 149 completed. Average Loss: 0.0613, LR: 9.47e-05


Epoch 150/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.8594, Success=37/37]


Epoch 150 completed. Average Loss: 0.1582, LR: 9.46e-05


Epoch 151/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.5476, Success=37/37]


Epoch 151 completed. Average Loss: 0.1128, LR: 9.45e-05


Epoch 152/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.7066, Success=37/37]


Epoch 152 completed. Average Loss: 0.1067, LR: 9.45e-05


Epoch 153/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0175, Success=37/37]


Epoch 153 completed. Average Loss: 0.0727, LR: 9.44e-05


Epoch 154/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0153, Success=37/37]


Epoch 154 completed. Average Loss: 0.0599, LR: 9.43e-05


Epoch 155/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0235, Success=37/37]


Epoch 155 completed. Average Loss: 0.1088, LR: 9.42e-05


Epoch 156/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0107, Success=37/37]


Epoch 156 completed. Average Loss: 0.0771, LR: 9.42e-05


Epoch 157/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0077, Success=37/37]


Epoch 157 completed. Average Loss: 0.0864, LR: 9.41e-05


Epoch 158/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.8917, Success=37/37]


Epoch 158 completed. Average Loss: 0.1032, LR: 9.40e-05


Epoch 159/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0196, Success=37/37]


Epoch 159 completed. Average Loss: 0.1097, LR: 9.40e-05


Epoch 160/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0203, Success=37/37]


Epoch 160 completed. Average Loss: 0.0947, LR: 9.39e-05


Epoch 161/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0094, Success=37/37]


Epoch 161 completed. Average Loss: 0.0848, LR: 9.38e-05


Epoch 162/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0568, Success=37/37]


Epoch 162 completed. Average Loss: 0.1037, LR: 9.37e-05


Epoch 163/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.5551, Success=37/37]


Epoch 163 completed. Average Loss: 0.0927, LR: 9.37e-05


Epoch 164/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.44it/s, Loss=0.0060, Success=37/37]


Epoch 164 completed. Average Loss: 0.0300, LR: 9.36e-05


Epoch 165/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0655, Success=37/37]


Epoch 165 completed. Average Loss: 0.0842, LR: 9.35e-05


Epoch 166/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.31it/s, Loss=0.0075, Success=37/37]


Epoch 166 completed. Average Loss: 0.0828, LR: 9.34e-05


Epoch 167/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.23it/s, Loss=0.0109, Success=37/37]


Epoch 167 completed. Average Loss: 0.1143, LR: 9.33e-05


Epoch 168/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.33it/s, Loss=0.0189, Success=37/37]


Epoch 168 completed. Average Loss: 0.1250, LR: 9.33e-05


Epoch 169/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.0302, Success=37/37]


Epoch 169 completed. Average Loss: 0.0495, LR: 9.32e-05


Epoch 170/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0399, Success=37/37]


Epoch 170 completed. Average Loss: 0.0463, LR: 9.31e-05


Epoch 171/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0094, Success=37/37]


Epoch 171 completed. Average Loss: 0.1325, LR: 9.30e-05


Epoch 172/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0069, Success=37/37]


Epoch 172 completed. Average Loss: 0.0547, LR: 9.29e-05


Epoch 173/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0067, Success=37/37]


Epoch 173 completed. Average Loss: 0.1445, LR: 9.29e-05


Epoch 174/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0091, Success=37/37]


Epoch 174 completed. Average Loss: 0.0732, LR: 9.28e-05


Epoch 175/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0089, Success=37/37]


Epoch 175 completed. Average Loss: 0.1130, LR: 9.27e-05


Epoch 176/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0321, Success=37/37]


Epoch 176 completed. Average Loss: 0.1271, LR: 9.26e-05


Epoch 177/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0507, Success=37/37]


Epoch 177 completed. Average Loss: 0.0503, LR: 9.25e-05


Epoch 178/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0067, Success=37/37]


Epoch 178 completed. Average Loss: 0.1205, LR: 9.25e-05


Epoch 179/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0281, Success=37/37]


Epoch 179 completed. Average Loss: 0.0504, LR: 9.24e-05


Epoch 180/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0150, Success=37/37]


Epoch 180 completed. Average Loss: 0.0882, LR: 9.23e-05


Epoch 181/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0122, Success=37/37]


Epoch 181 completed. Average Loss: 0.0623, LR: 9.22e-05


Epoch 182/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0084, Success=37/37]


Epoch 182 completed. Average Loss: 0.1014, LR: 9.21e-05


Epoch 183/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.1800, Success=37/37]


Epoch 183 completed. Average Loss: 0.1198, LR: 9.20e-05


Epoch 184/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0618, Success=37/37]


Epoch 184 completed. Average Loss: 0.1114, LR: 9.20e-05


Epoch 185/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0087, Success=37/37]


Epoch 185 completed. Average Loss: 0.1133, LR: 9.19e-05


Epoch 186/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0090, Success=37/37]


Epoch 186 completed. Average Loss: 0.0926, LR: 9.18e-05


Epoch 187/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0097, Success=37/37]


Epoch 187 completed. Average Loss: 0.1140, LR: 9.17e-05


Epoch 188/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0590, Success=37/37]


Epoch 188 completed. Average Loss: 0.0530, LR: 9.16e-05


Epoch 189/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.44it/s, Loss=0.0281, Success=37/37]


Epoch 189 completed. Average Loss: 0.0634, LR: 9.15e-05


Epoch 190/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0084, Success=37/37]


Epoch 190 completed. Average Loss: 0.0826, LR: 9.14e-05


Epoch 191/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.2838, Success=37/37]


Epoch 191 completed. Average Loss: 0.0782, LR: 9.14e-05


Epoch 192/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0201, Success=37/37]


Epoch 192 completed. Average Loss: 0.0682, LR: 9.13e-05


Epoch 193/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0607, Success=37/37]


Epoch 193 completed. Average Loss: 0.0579, LR: 9.12e-05


Epoch 194/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.1179, Success=37/37]


Epoch 194 completed. Average Loss: 0.0689, LR: 9.11e-05


Epoch 195/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0073, Success=37/37]


Epoch 195 completed. Average Loss: 0.0858, LR: 9.10e-05


Epoch 196/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0260, Success=37/37]


Epoch 196 completed. Average Loss: 0.0586, LR: 9.09e-05


Epoch 197/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0043, Success=37/37]


Epoch 197 completed. Average Loss: 0.0282, LR: 9.08e-05


Epoch 198/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0056, Success=37/37]


Epoch 198 completed. Average Loss: 0.0725, LR: 9.07e-05


Epoch 199/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0051, Success=37/37]


Epoch 199 completed. Average Loss: 0.1486, LR: 9.06e-05


Epoch 200/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.2128, Success=37/37]


Epoch 200 completed. Average Loss: 0.0883, LR: 9.05e-05
Checkpoint saved at epoch 200


Epoch 201/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0063, Success=37/37]


Epoch 201 completed. Average Loss: 0.0373, LR: 9.05e-05


Epoch 202/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0274, Success=37/37]


Epoch 202 completed. Average Loss: 0.1119, LR: 9.04e-05


Epoch 203/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0058, Success=37/37]


Epoch 203 completed. Average Loss: 0.0670, LR: 9.03e-05


Epoch 204/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0743, Success=37/37]


Epoch 204 completed. Average Loss: 0.0847, LR: 9.02e-05


Epoch 205/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0198, Success=37/37]


Epoch 205 completed. Average Loss: 0.1171, LR: 9.01e-05


Epoch 206/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0056, Success=37/37]


Epoch 206 completed. Average Loss: 0.0303, LR: 9.00e-05


Epoch 207/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0347, Success=37/37]


Epoch 207 completed. Average Loss: 0.0659, LR: 8.99e-05


Epoch 208/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0067, Success=37/37]


Epoch 208 completed. Average Loss: 0.0797, LR: 8.98e-05


Epoch 209/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0099, Success=37/37]


Epoch 209 completed. Average Loss: 0.1395, LR: 8.97e-05


Epoch 210/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0052, Success=37/37]


Epoch 210 completed. Average Loss: 0.0327, LR: 8.96e-05


Epoch 211/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0073, Success=37/37]


Epoch 211 completed. Average Loss: 0.1228, LR: 8.95e-05


Epoch 212/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0052, Success=37/37]


Epoch 212 completed. Average Loss: 0.0450, LR: 8.94e-05


Epoch 213/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0122, Success=37/37]


Epoch 213 completed. Average Loss: 0.1049, LR: 8.93e-05


Epoch 214/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0400, Success=37/37]


Epoch 214 completed. Average Loss: 0.0574, LR: 8.92e-05


Epoch 215/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.1487, Success=37/37]


Epoch 215 completed. Average Loss: 0.1218, LR: 8.91e-05


Epoch 216/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0097, Success=37/37]


Epoch 216 completed. Average Loss: 0.0674, LR: 8.90e-05


Epoch 217/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=1.0057, Success=37/37]


Epoch 217 completed. Average Loss: 0.0542, LR: 8.89e-05


Epoch 218/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0063, Success=37/37]


Epoch 218 completed. Average Loss: 0.0950, LR: 8.88e-05


Epoch 219/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0060, Success=37/37]


Epoch 219 completed. Average Loss: 0.0915, LR: 8.87e-05


Epoch 220/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0216, Success=37/37]


Epoch 220 completed. Average Loss: 0.1066, LR: 8.86e-05


Epoch 221/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.1272, Success=37/37]


Epoch 221 completed. Average Loss: 0.1106, LR: 8.85e-05


Epoch 222/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0058, Success=37/37]


Epoch 222 completed. Average Loss: 0.0512, LR: 8.84e-05


Epoch 223/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0074, Success=37/37]


Epoch 223 completed. Average Loss: 0.0550, LR: 8.83e-05


Epoch 224/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0170, Success=37/37]


Epoch 224 completed. Average Loss: 0.0528, LR: 8.82e-05


Epoch 225/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0092, Success=37/37]


Epoch 225 completed. Average Loss: 0.0851, LR: 8.81e-05


Epoch 226/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0056, Success=37/37]


Epoch 226 completed. Average Loss: 0.0438, LR: 8.80e-05


Epoch 227/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.33it/s, Loss=0.0606, Success=37/37]


Epoch 227 completed. Average Loss: 0.0597, LR: 8.79e-05


Epoch 228/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0057, Success=37/37]


Epoch 228 completed. Average Loss: 0.0618, LR: 8.78e-05


Epoch 229/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0712, Success=37/37]


Epoch 229 completed. Average Loss: 0.0358, LR: 8.77e-05


Epoch 230/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0065, Success=37/37]


Epoch 230 completed. Average Loss: 0.0541, LR: 8.76e-05


Epoch 231/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.5719, Success=37/37]


Epoch 231 completed. Average Loss: 0.0973, LR: 8.75e-05


Epoch 232/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0067, Success=37/37]


Epoch 232 completed. Average Loss: 0.0740, LR: 8.74e-05


Epoch 233/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0071, Success=37/37]


Epoch 233 completed. Average Loss: 0.0739, LR: 8.73e-05


Epoch 234/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.33it/s, Loss=0.0106, Success=37/37]


Epoch 234 completed. Average Loss: 0.0732, LR: 8.72e-05


Epoch 235/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.31it/s, Loss=0.0161, Success=37/37]


Epoch 235 completed. Average Loss: 0.0271, LR: 8.71e-05


Epoch 236/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0359, Success=37/37]


Epoch 236 completed. Average Loss: 0.0362, LR: 8.70e-05


Epoch 237/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0214, Success=37/37]


Epoch 237 completed. Average Loss: 0.0619, LR: 8.69e-05


Epoch 238/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0077, Success=37/37]


Epoch 238 completed. Average Loss: 0.1169, LR: 8.68e-05


Epoch 239/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0087, Success=37/37]


Epoch 239 completed. Average Loss: 0.0769, LR: 8.67e-05


Epoch 240/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0076, Success=37/37]


Epoch 240 completed. Average Loss: 0.0631, LR: 8.66e-05


Epoch 241/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0055, Success=37/37]


Epoch 241 completed. Average Loss: 0.1122, LR: 8.65e-05


Epoch 242/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0089, Success=37/37]


Epoch 242 completed. Average Loss: 0.0688, LR: 8.64e-05


Epoch 243/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0494, Success=37/37]


Epoch 243 completed. Average Loss: 0.0899, LR: 8.63e-05


Epoch 244/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0075, Success=37/37]


Epoch 244 completed. Average Loss: 0.1005, LR: 8.62e-05


Epoch 245/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0092, Success=37/37]


Epoch 245 completed. Average Loss: 0.0956, LR: 8.60e-05


Epoch 246/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0771, Success=37/37]


Epoch 246 completed. Average Loss: 0.0675, LR: 8.59e-05


Epoch 247/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0148, Success=37/37]


Epoch 247 completed. Average Loss: 0.0476, LR: 8.58e-05


Epoch 248/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0330, Success=37/37]


Epoch 248 completed. Average Loss: 0.0875, LR: 8.57e-05


Epoch 249/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0038, Success=37/37]


Epoch 249 completed. Average Loss: 0.0625, LR: 8.56e-05


Epoch 250/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0188, Success=37/37]


Epoch 250 completed. Average Loss: 0.0537, LR: 8.55e-05


Epoch 251/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.1762, Success=37/37]


Epoch 251 completed. Average Loss: 0.0392, LR: 8.54e-05


Epoch 252/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0078, Success=37/37]


Epoch 252 completed. Average Loss: 0.0754, LR: 8.53e-05


Epoch 253/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0083, Success=37/37]


Epoch 253 completed. Average Loss: 0.0783, LR: 8.52e-05


Epoch 254/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0158, Success=37/37]


Epoch 254 completed. Average Loss: 0.0613, LR: 8.51e-05


Epoch 255/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0272, Success=37/37]


Epoch 255 completed. Average Loss: 0.0969, LR: 8.49e-05


Epoch 256/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0049, Success=37/37]


Epoch 256 completed. Average Loss: 0.0660, LR: 8.48e-05


Epoch 257/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0047, Success=37/37]


Epoch 257 completed. Average Loss: 0.0479, LR: 8.47e-05


Epoch 258/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0057, Success=37/37]


Epoch 258 completed. Average Loss: 0.0709, LR: 8.46e-05


Epoch 259/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0046, Success=37/37]


Epoch 259 completed. Average Loss: 0.0292, LR: 8.45e-05


Epoch 260/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.34it/s, Loss=0.0041, Success=37/37]


Epoch 260 completed. Average Loss: 0.0444, LR: 8.44e-05


Epoch 261/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0132, Success=37/37]


Epoch 261 completed. Average Loss: 0.0629, LR: 8.43e-05


Epoch 262/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0108, Success=37/37]


Epoch 262 completed. Average Loss: 0.0600, LR: 8.42e-05


Epoch 263/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.2153, Success=37/37]


Epoch 263 completed. Average Loss: 0.0635, LR: 8.40e-05


Epoch 264/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0045, Success=37/37]


Epoch 264 completed. Average Loss: 0.1424, LR: 8.39e-05


Epoch 265/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0099, Success=37/37]


Epoch 265 completed. Average Loss: 0.0590, LR: 8.38e-05


Epoch 266/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.7690, Success=37/37]


Epoch 266 completed. Average Loss: 0.0880, LR: 8.37e-05


Epoch 267/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0332, Success=37/37]


Epoch 267 completed. Average Loss: 0.0235, LR: 8.36e-05


Epoch 268/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0106, Success=37/37]


Epoch 268 completed. Average Loss: 0.1124, LR: 8.35e-05


Epoch 269/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1147, Success=37/37]


Epoch 269 completed. Average Loss: 0.0652, LR: 8.34e-05


Epoch 270/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.5643, Success=37/37]


Epoch 270 completed. Average Loss: 0.0434, LR: 8.32e-05


Epoch 271/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0269, Success=37/37]


Epoch 271 completed. Average Loss: 0.0198, LR: 8.31e-05


Epoch 272/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0104, Success=37/37]


Epoch 272 completed. Average Loss: 0.0330, LR: 8.30e-05


Epoch 273/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.7898, Success=37/37]


Epoch 273 completed. Average Loss: 0.1116, LR: 8.29e-05


Epoch 274/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0061, Success=37/37]


Epoch 274 completed. Average Loss: 0.0267, LR: 8.28e-05


Epoch 275/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0100, Success=37/37]


Epoch 275 completed. Average Loss: 0.0422, LR: 8.26e-05


Epoch 276/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0714, Success=37/37]


Epoch 276 completed. Average Loss: 0.0792, LR: 8.25e-05


Epoch 277/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.2968, Success=37/37]


Epoch 277 completed. Average Loss: 0.0566, LR: 8.24e-05


Epoch 278/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0080, Success=37/37]


Epoch 278 completed. Average Loss: 0.0352, LR: 8.23e-05


Epoch 279/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0089, Success=37/37]


Epoch 279 completed. Average Loss: 0.0532, LR: 8.22e-05


Epoch 280/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.1083, Success=37/37]


Epoch 280 completed. Average Loss: 0.0265, LR: 8.21e-05


Epoch 281/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0031, Success=37/37]


Epoch 281 completed. Average Loss: 0.0499, LR: 8.19e-05


Epoch 282/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0057, Success=37/37]


Epoch 282 completed. Average Loss: 0.0852, LR: 8.18e-05


Epoch 283/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0098, Success=37/37]


Epoch 283 completed. Average Loss: 0.0896, LR: 8.17e-05


Epoch 284/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0079, Success=37/37]


Epoch 284 completed. Average Loss: 0.0274, LR: 8.16e-05


Epoch 285/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0047, Success=37/37]


Epoch 285 completed. Average Loss: 0.0366, LR: 8.14e-05


Epoch 286/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0120, Success=37/37]


Epoch 286 completed. Average Loss: 0.0618, LR: 8.13e-05


Epoch 287/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0055, Success=37/37]


Epoch 287 completed. Average Loss: 0.0251, LR: 8.12e-05


Epoch 288/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0045, Success=37/37]


Epoch 288 completed. Average Loss: 0.0473, LR: 8.11e-05


Epoch 289/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0060, Success=37/37]


Epoch 289 completed. Average Loss: 0.0307, LR: 8.10e-05


Epoch 290/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0085, Success=37/37]


Epoch 290 completed. Average Loss: 0.0676, LR: 8.08e-05


Epoch 291/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0141, Success=37/37]


Epoch 291 completed. Average Loss: 0.0178, LR: 8.07e-05


Epoch 292/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0172, Success=37/37]


Epoch 292 completed. Average Loss: 0.0770, LR: 8.06e-05


Epoch 293/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0075, Success=37/37]


Epoch 293 completed. Average Loss: 0.0371, LR: 8.05e-05


Epoch 294/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0051, Success=37/37]


Epoch 294 completed. Average Loss: 0.0407, LR: 8.03e-05


Epoch 295/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0036, Success=37/37]


Epoch 295 completed. Average Loss: 0.0213, LR: 8.02e-05


Epoch 296/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.1081, Success=37/37]


Epoch 296 completed. Average Loss: 0.0561, LR: 8.01e-05


Epoch 297/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0055, Success=37/37]


Epoch 297 completed. Average Loss: 0.0424, LR: 8.00e-05


Epoch 298/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0089, Success=37/37]


Epoch 298 completed. Average Loss: 0.0419, LR: 7.98e-05


Epoch 299/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0051, Success=37/37]


Epoch 299 completed. Average Loss: 0.0481, LR: 7.97e-05


Epoch 300/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0049, Success=37/37]


Epoch 300 completed. Average Loss: 0.1246, LR: 7.96e-05
Checkpoint saved at epoch 300


Epoch 301/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.33it/s, Loss=0.0045, Success=37/37]


Epoch 301 completed. Average Loss: 0.0377, LR: 7.95e-05


Epoch 302/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0052, Success=37/37]


Epoch 302 completed. Average Loss: 0.0765, LR: 7.93e-05


Epoch 303/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=1.0057, Success=37/37]


Epoch 303 completed. Average Loss: 0.0877, LR: 7.92e-05


Epoch 304/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0081, Success=37/37]


Epoch 304 completed. Average Loss: 0.0689, LR: 7.91e-05


Epoch 305/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0131, Success=37/37]


Epoch 305 completed. Average Loss: 0.0461, LR: 7.90e-05


Epoch 306/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.2394, Success=37/37]


Epoch 306 completed. Average Loss: 0.1270, LR: 7.88e-05


Epoch 307/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0039, Success=37/37]


Epoch 307 completed. Average Loss: 0.0174, LR: 7.87e-05


Epoch 308/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0052, Success=37/37]


Epoch 308 completed. Average Loss: 0.0883, LR: 7.86e-05


Epoch 309/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.4544, Success=37/37]


Epoch 309 completed. Average Loss: 0.0529, LR: 7.85e-05


Epoch 310/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0064, Success=37/37]


Epoch 310 completed. Average Loss: 0.1252, LR: 7.83e-05


Epoch 311/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.2585, Success=37/37]


Epoch 311 completed. Average Loss: 0.0582, LR: 7.82e-05


Epoch 312/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0059, Success=37/37]


Epoch 312 completed. Average Loss: 0.0223, LR: 7.81e-05


Epoch 313/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.1140, Success=37/37]


Epoch 313 completed. Average Loss: 0.1004, LR: 7.79e-05


Epoch 314/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0276, Success=37/37]


Epoch 314 completed. Average Loss: 0.0456, LR: 7.78e-05


Epoch 315/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0063, Success=37/37]


Epoch 315 completed. Average Loss: 0.0507, LR: 7.77e-05


Epoch 316/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0082, Success=37/37]


Epoch 316 completed. Average Loss: 0.0502, LR: 7.75e-05


Epoch 317/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0044, Success=37/37]


Epoch 317 completed. Average Loss: 0.0525, LR: 7.74e-05


Epoch 318/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0291, Success=37/37]


Epoch 318 completed. Average Loss: 0.0384, LR: 7.73e-05


Epoch 319/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0473, Success=37/37]


Epoch 319 completed. Average Loss: 0.0313, LR: 7.72e-05


Epoch 320/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0300, Success=37/37]


Epoch 320 completed. Average Loss: 0.0491, LR: 7.70e-05


Epoch 321/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0028, Success=37/37]


Epoch 321 completed. Average Loss: 0.0311, LR: 7.69e-05


Epoch 322/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0080, Success=37/37]


Epoch 322 completed. Average Loss: 0.0889, LR: 7.68e-05


Epoch 323/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0140, Success=37/37]


Epoch 323 completed. Average Loss: 0.0372, LR: 7.66e-05


Epoch 324/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0059, Success=37/37]


Epoch 324 completed. Average Loss: 0.0328, LR: 7.65e-05


Epoch 325/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0035, Success=37/37]


Epoch 325 completed. Average Loss: 0.0606, LR: 7.64e-05


Epoch 326/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0200, Success=37/37]


Epoch 326 completed. Average Loss: 0.0272, LR: 7.62e-05


Epoch 327/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0063, Success=37/37]


Epoch 327 completed. Average Loss: 0.0774, LR: 7.61e-05


Epoch 328/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0045, Success=37/37]


Epoch 328 completed. Average Loss: 0.0343, LR: 7.60e-05


Epoch 329/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0107, Success=37/37]


Epoch 329 completed. Average Loss: 0.0720, LR: 7.58e-05


Epoch 330/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0074, Success=37/37]


Epoch 330 completed. Average Loss: 0.0198, LR: 7.57e-05


Epoch 331/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0040, Success=37/37]


Epoch 331 completed. Average Loss: 0.0711, LR: 7.56e-05


Epoch 332/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0045, Success=37/37]


Epoch 332 completed. Average Loss: 0.0513, LR: 7.54e-05


Epoch 333/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.3118, Success=37/37]


Epoch 333 completed. Average Loss: 0.0686, LR: 7.53e-05


Epoch 334/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0040, Success=37/37]


Epoch 334 completed. Average Loss: 0.0479, LR: 7.52e-05


Epoch 335/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0113, Success=37/37]


Epoch 335 completed. Average Loss: 0.0241, LR: 7.50e-05


Epoch 336/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0055, Success=37/37]


Epoch 336 completed. Average Loss: 0.0206, LR: 7.49e-05


Epoch 337/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.7270, Success=37/37]


Epoch 337 completed. Average Loss: 0.0600, LR: 7.48e-05


Epoch 338/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0040, Success=37/37]


Epoch 338 completed. Average Loss: 0.0245, LR: 7.46e-05


Epoch 339/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0040, Success=37/37]


Epoch 339 completed. Average Loss: 0.0887, LR: 7.45e-05


Epoch 340/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.1146, Success=37/37]


Epoch 340 completed. Average Loss: 0.0322, LR: 7.43e-05


Epoch 341/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0090, Success=37/37]


Epoch 341 completed. Average Loss: 0.0639, LR: 7.42e-05


Epoch 342/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0097, Success=37/37]


Epoch 342 completed. Average Loss: 0.1197, LR: 7.41e-05


Epoch 343/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0062, Success=37/37]


Epoch 343 completed. Average Loss: 0.0623, LR: 7.39e-05


Epoch 344/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0034, Success=37/37]


Epoch 344 completed. Average Loss: 0.1196, LR: 7.38e-05


Epoch 345/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0046, Success=37/37]


Epoch 345 completed. Average Loss: 0.0924, LR: 7.37e-05


Epoch 346/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0131, Success=37/37]


Epoch 346 completed. Average Loss: 0.0345, LR: 7.35e-05


Epoch 347/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0043, Success=37/37]


Epoch 347 completed. Average Loss: 0.0406, LR: 7.34e-05


Epoch 348/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0046, Success=37/37]


Epoch 348 completed. Average Loss: 0.0433, LR: 7.32e-05


Epoch 349/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.30it/s, Loss=0.0106, Success=37/37]


Epoch 349 completed. Average Loss: 0.0649, LR: 7.31e-05


Epoch 350/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0079, Success=37/37]


Epoch 350 completed. Average Loss: 0.0516, LR: 7.30e-05


Epoch 351/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0073, Success=37/37]


Epoch 351 completed. Average Loss: 0.0667, LR: 7.28e-05


Epoch 352/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0066, Success=37/37]


Epoch 352 completed. Average Loss: 0.0881, LR: 7.27e-05


Epoch 353/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0060, Success=37/37]


Epoch 353 completed. Average Loss: 0.0766, LR: 7.26e-05


Epoch 354/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0043, Success=37/37]


Epoch 354 completed. Average Loss: 0.0792, LR: 7.24e-05


Epoch 355/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0233, Success=37/37]


Epoch 355 completed. Average Loss: 0.0347, LR: 7.23e-05


Epoch 356/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0030, Success=37/37]


Epoch 356 completed. Average Loss: 0.0560, LR: 7.21e-05


Epoch 357/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0682, Success=37/37]


Epoch 357 completed. Average Loss: 0.0106, LR: 7.20e-05


Epoch 358/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0065, Success=37/37]


Epoch 358 completed. Average Loss: 0.0725, LR: 7.19e-05


Epoch 359/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0074, Success=37/37]


Epoch 359 completed. Average Loss: 0.0695, LR: 7.17e-05


Epoch 360/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0036, Success=37/37]


Epoch 360 completed. Average Loss: 0.0499, LR: 7.16e-05


Epoch 361/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0779, Success=37/37]


Epoch 361 completed. Average Loss: 0.0347, LR: 7.14e-05


Epoch 362/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0151, Success=37/37]


Epoch 362 completed. Average Loss: 0.0411, LR: 7.13e-05


Epoch 363/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.4795, Success=37/37]


Epoch 363 completed. Average Loss: 0.0543, LR: 7.12e-05


Epoch 364/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0896, Success=37/37]


Epoch 364 completed. Average Loss: 0.0565, LR: 7.10e-05


Epoch 365/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.9818, Success=37/37]


Epoch 365 completed. Average Loss: 0.1011, LR: 7.09e-05


Epoch 366/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.6808, Success=37/37]


Epoch 366 completed. Average Loss: 0.0407, LR: 7.07e-05


Epoch 367/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0176, Success=37/37]


Epoch 367 completed. Average Loss: 0.0471, LR: 7.06e-05


Epoch 368/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0050, Success=37/37]


Epoch 368 completed. Average Loss: 0.0176, LR: 7.04e-05


Epoch 369/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0914, Success=37/37]


Epoch 369 completed. Average Loss: 0.0464, LR: 7.03e-05


Epoch 370/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.1428, Success=37/37]


Epoch 370 completed. Average Loss: 0.0388, LR: 7.02e-05


Epoch 371/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0213, Success=37/37]


Epoch 371 completed. Average Loss: 0.0282, LR: 7.00e-05


Epoch 372/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0029, Success=37/37]


Epoch 372 completed. Average Loss: 0.0759, LR: 6.99e-05


Epoch 373/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0103, Success=37/37]


Epoch 373 completed. Average Loss: 0.0450, LR: 6.97e-05


Epoch 374/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0024, Success=37/37]


Epoch 374 completed. Average Loss: 0.0207, LR: 6.96e-05


Epoch 375/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0125, Success=37/37]


Epoch 375 completed. Average Loss: 0.0335, LR: 6.94e-05


Epoch 376/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.4076, Success=37/37]


Epoch 376 completed. Average Loss: 0.0796, LR: 6.93e-05


Epoch 377/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0059, Success=37/37]


Epoch 377 completed. Average Loss: 0.0560, LR: 6.92e-05


Epoch 378/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0071, Success=37/37]


Epoch 378 completed. Average Loss: 0.0283, LR: 6.90e-05


Epoch 379/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0305, Success=37/37]


Epoch 379 completed. Average Loss: 0.0435, LR: 6.89e-05


Epoch 380/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0033, Success=37/37]


Epoch 380 completed. Average Loss: 0.0219, LR: 6.87e-05


Epoch 381/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0039, Success=37/37]


Epoch 381 completed. Average Loss: 0.0224, LR: 6.86e-05


Epoch 382/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.27it/s, Loss=0.2972, Success=37/37]


Epoch 382 completed. Average Loss: 0.0436, LR: 6.84e-05


Epoch 383/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.30it/s, Loss=0.0033, Success=37/37]


Epoch 383 completed. Average Loss: 0.0569, LR: 6.83e-05


Epoch 384/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.0053, Success=37/37]


Epoch 384 completed. Average Loss: 0.0973, LR: 6.81e-05


Epoch 385/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0040, Success=37/37]


Epoch 385 completed. Average Loss: 0.0550, LR: 6.80e-05


Epoch 386/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0303, Success=37/37]


Epoch 386 completed. Average Loss: 0.0894, LR: 6.79e-05


Epoch 387/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0208, Success=37/37]


Epoch 387 completed. Average Loss: 0.0236, LR: 6.77e-05


Epoch 388/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0046, Success=37/37]


Epoch 388 completed. Average Loss: 0.0326, LR: 6.76e-05


Epoch 389/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0183, Success=37/37]


Epoch 389 completed. Average Loss: 0.0152, LR: 6.74e-05


Epoch 390/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0037, Success=37/37]


Epoch 390 completed. Average Loss: 0.0562, LR: 6.73e-05


Epoch 391/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0169, Success=37/37]


Epoch 391 completed. Average Loss: 0.0898, LR: 6.71e-05


Epoch 392/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0066, Success=37/37]


Epoch 392 completed. Average Loss: 0.0779, LR: 6.70e-05


Epoch 393/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0036, Success=37/37]


Epoch 393 completed. Average Loss: 0.0701, LR: 6.68e-05


Epoch 394/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0598, Success=37/37]


Epoch 394 completed. Average Loss: 0.0651, LR: 6.67e-05


Epoch 395/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0208, Success=37/37]


Epoch 395 completed. Average Loss: 0.0905, LR: 6.65e-05


Epoch 396/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0025, Success=37/37]


Epoch 396 completed. Average Loss: 0.0195, LR: 6.64e-05


Epoch 397/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.4575, Success=37/37]


Epoch 397 completed. Average Loss: 0.0539, LR: 6.62e-05


Epoch 398/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0032, Success=37/37]


Epoch 398 completed. Average Loss: 0.0453, LR: 6.61e-05


Epoch 399/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0030, Success=37/37]


Epoch 399 completed. Average Loss: 0.0325, LR: 6.59e-05


Epoch 400/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0152, Success=37/37]


Epoch 400 completed. Average Loss: 0.0233, LR: 6.58e-05
Checkpoint saved at epoch 400


Epoch 401/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.9816, Success=37/37]


Epoch 401 completed. Average Loss: 0.0534, LR: 6.56e-05


Epoch 402/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0081, Success=37/37]


Epoch 402 completed. Average Loss: 0.0313, LR: 6.55e-05


Epoch 403/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0409, Success=37/37]


Epoch 403 completed. Average Loss: 0.0585, LR: 6.54e-05


Epoch 404/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0040, Success=37/37]


Epoch 404 completed. Average Loss: 0.0149, LR: 6.52e-05


Epoch 405/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0083, Success=37/37]


Epoch 405 completed. Average Loss: 0.0687, LR: 6.51e-05


Epoch 406/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0160, Success=37/37]


Epoch 406 completed. Average Loss: 0.0481, LR: 6.49e-05


Epoch 407/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0367, Success=37/37]


Epoch 407 completed. Average Loss: 0.1263, LR: 6.48e-05


Epoch 408/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0042, Success=37/37]


Epoch 408 completed. Average Loss: 0.0677, LR: 6.46e-05


Epoch 409/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0054, Success=37/37]


Epoch 409 completed. Average Loss: 0.0752, LR: 6.45e-05


Epoch 410/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.2330, Success=37/37]


Epoch 410 completed. Average Loss: 0.0840, LR: 6.43e-05


Epoch 411/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0139, Success=37/37]


Epoch 411 completed. Average Loss: 0.0273, LR: 6.42e-05


Epoch 412/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0890, Success=37/37]


Epoch 412 completed. Average Loss: 0.0265, LR: 6.40e-05


Epoch 413/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0050, Success=37/37]


Epoch 413 completed. Average Loss: 0.0287, LR: 6.39e-05


Epoch 414/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0182, Success=37/37]


Epoch 414 completed. Average Loss: 0.0585, LR: 6.37e-05


Epoch 415/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.5013, Success=37/37]


Epoch 415 completed. Average Loss: 0.0319, LR: 6.36e-05


Epoch 416/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0146, Success=37/37]


Epoch 416 completed. Average Loss: 0.0302, LR: 6.34e-05


Epoch 417/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0036, Success=37/37]


Epoch 417 completed. Average Loss: 0.0234, LR: 6.33e-05


Epoch 418/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0071, Success=37/37]


Epoch 418 completed. Average Loss: 0.0420, LR: 6.31e-05


Epoch 419/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0032, Success=37/37]


Epoch 419 completed. Average Loss: 0.0250, LR: 6.30e-05


Epoch 420/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0187, Success=37/37]


Epoch 420 completed. Average Loss: 0.0712, LR: 6.28e-05


Epoch 421/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0049, Success=37/37]


Epoch 421 completed. Average Loss: 0.0347, LR: 6.27e-05


Epoch 422/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0246, Success=37/37]


Epoch 422 completed. Average Loss: 0.0390, LR: 6.25e-05


Epoch 423/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0037, Success=37/37]


Epoch 423 completed. Average Loss: 0.0263, LR: 6.24e-05


Epoch 424/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0082, Success=37/37]


Epoch 424 completed. Average Loss: 0.0681, LR: 6.22e-05


Epoch 425/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0032, Success=37/37]


Epoch 425 completed. Average Loss: 0.0304, LR: 6.21e-05


Epoch 426/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0062, Success=37/37]


Epoch 426 completed. Average Loss: 0.0229, LR: 6.19e-05


Epoch 427/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.44it/s, Loss=0.0044, Success=37/37]


Epoch 427 completed. Average Loss: 0.0787, LR: 6.18e-05


Epoch 428/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0040, Success=37/37]


Epoch 428 completed. Average Loss: 0.0235, LR: 6.16e-05


Epoch 429/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0057, Success=37/37]


Epoch 429 completed. Average Loss: 0.0371, LR: 6.14e-05


Epoch 430/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.7692, Success=37/37]


Epoch 430 completed. Average Loss: 0.0537, LR: 6.13e-05


Epoch 431/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0059, Success=37/37]


Epoch 431 completed. Average Loss: 0.0552, LR: 6.11e-05


Epoch 432/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0044, Success=37/37]


Epoch 432 completed. Average Loss: 0.0661, LR: 6.10e-05


Epoch 433/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0134, Success=37/37]


Epoch 433 completed. Average Loss: 0.0389, LR: 6.08e-05


Epoch 434/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0238, Success=37/37]


Epoch 434 completed. Average Loss: 0.0122, LR: 6.07e-05


Epoch 435/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0067, Success=37/37]


Epoch 435 completed. Average Loss: 0.0255, LR: 6.05e-05


Epoch 436/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0041, Success=37/37]


Epoch 436 completed. Average Loss: 0.0496, LR: 6.04e-05


Epoch 437/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0041, Success=37/37]


Epoch 437 completed. Average Loss: 0.0538, LR: 6.02e-05


Epoch 438/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0076, Success=37/37]


Epoch 438 completed. Average Loss: 0.0646, LR: 6.01e-05


Epoch 439/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0030, Success=37/37]


Epoch 439 completed. Average Loss: 0.0308, LR: 5.99e-05


Epoch 440/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0122, Success=37/37]


Epoch 440 completed. Average Loss: 0.0241, LR: 5.98e-05


Epoch 441/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0030, Success=37/37]


Epoch 441 completed. Average Loss: 0.0380, LR: 5.96e-05


Epoch 442/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.7243, Success=37/37]


Epoch 442 completed. Average Loss: 0.0550, LR: 5.95e-05


Epoch 443/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.3909, Success=37/37]


Epoch 443 completed. Average Loss: 0.0448, LR: 5.93e-05


Epoch 444/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0029, Success=37/37]


Epoch 444 completed. Average Loss: 0.0432, LR: 5.92e-05


Epoch 445/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0060, Success=37/37]


Epoch 445 completed. Average Loss: 0.0358, LR: 5.90e-05


Epoch 446/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0026, Success=37/37]


Epoch 446 completed. Average Loss: 0.0262, LR: 5.89e-05


Epoch 447/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.3185, Success=37/37]


Epoch 447 completed. Average Loss: 0.0273, LR: 5.87e-05


Epoch 448/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0045, Success=37/37]


Epoch 448 completed. Average Loss: 0.0352, LR: 5.86e-05


Epoch 449/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0026, Success=37/37]


Epoch 449 completed. Average Loss: 0.0561, LR: 5.84e-05


Epoch 450/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0396, Success=37/37]


Epoch 450 completed. Average Loss: 0.0586, LR: 5.82e-05


Epoch 451/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0054, Success=37/37]


Epoch 451 completed. Average Loss: 0.0442, LR: 5.81e-05


Epoch 452/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0023, Success=37/37]


Epoch 452 completed. Average Loss: 0.0397, LR: 5.79e-05


Epoch 453/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0030, Success=37/37]


Epoch 453 completed. Average Loss: 0.0797, LR: 5.78e-05


Epoch 454/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0026, Success=37/37]


Epoch 454 completed. Average Loss: 0.0298, LR: 5.76e-05


Epoch 455/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0040, Success=37/37]


Epoch 455 completed. Average Loss: 0.0769, LR: 5.75e-05


Epoch 456/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0025, Success=37/37]


Epoch 456 completed. Average Loss: 0.0299, LR: 5.73e-05


Epoch 457/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0037, Success=37/37]


Epoch 457 completed. Average Loss: 0.0426, LR: 5.72e-05


Epoch 458/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0261, Success=37/37]


Epoch 458 completed. Average Loss: 0.0137, LR: 5.70e-05


Epoch 459/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.2033, Success=37/37]


Epoch 459 completed. Average Loss: 0.0257, LR: 5.69e-05


Epoch 460/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0044, Success=37/37]


Epoch 460 completed. Average Loss: 0.0131, LR: 5.67e-05


Epoch 461/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0337, Success=37/37]


Epoch 461 completed. Average Loss: 0.0417, LR: 5.65e-05


Epoch 462/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0050, Success=37/37]


Epoch 462 completed. Average Loss: 0.0448, LR: 5.64e-05


Epoch 463/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0076, Success=37/37]


Epoch 463 completed. Average Loss: 0.0826, LR: 5.62e-05


Epoch 464/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0046, Success=37/37]


Epoch 464 completed. Average Loss: 0.0646, LR: 5.61e-05


Epoch 465/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.20it/s, Loss=0.0336, Success=37/37]


Epoch 465 completed. Average Loss: 0.0450, LR: 5.59e-05


Epoch 466/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0036, Success=37/37]


Epoch 466 completed. Average Loss: 0.0324, LR: 5.58e-05


Epoch 467/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.34it/s, Loss=0.0074, Success=37/37]


Epoch 467 completed. Average Loss: 0.0168, LR: 5.56e-05


Epoch 468/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0191, Success=37/37]


Epoch 468 completed. Average Loss: 0.0334, LR: 5.55e-05


Epoch 469/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0092, Success=37/37]


Epoch 469 completed. Average Loss: 0.0119, LR: 5.53e-05


Epoch 470/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0037, Success=37/37]


Epoch 470 completed. Average Loss: 0.0397, LR: 5.52e-05


Epoch 471/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0026, Success=37/37]


Epoch 471 completed. Average Loss: 0.0507, LR: 5.50e-05


Epoch 472/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0052, Success=37/37]


Epoch 472 completed. Average Loss: 0.1528, LR: 5.48e-05


Epoch 473/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0030, Success=37/37]


Epoch 473 completed. Average Loss: 0.0459, LR: 5.47e-05


Epoch 474/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0077, Success=37/37]


Epoch 474 completed. Average Loss: 0.0260, LR: 5.45e-05


Epoch 475/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0045, Success=37/37]


Epoch 475 completed. Average Loss: 0.0457, LR: 5.44e-05


Epoch 476/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0039, Success=37/37]


Epoch 476 completed. Average Loss: 0.0491, LR: 5.42e-05


Epoch 477/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0057, Success=37/37]


Epoch 477 completed. Average Loss: 0.0228, LR: 5.41e-05


Epoch 478/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0037, Success=37/37]


Epoch 478 completed. Average Loss: 0.0157, LR: 5.39e-05


Epoch 479/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.43it/s, Loss=0.0196, Success=37/37]


Epoch 479 completed. Average Loss: 0.0198, LR: 5.38e-05


Epoch 480/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0028, Success=37/37]


Epoch 480 completed. Average Loss: 0.0327, LR: 5.36e-05


Epoch 481/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0020, Success=37/37]


Epoch 481 completed. Average Loss: 0.0252, LR: 5.35e-05


Epoch 482/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0033, Success=37/37]


Epoch 482 completed. Average Loss: 0.0608, LR: 5.33e-05


Epoch 483/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0581, Success=37/37]


Epoch 483 completed. Average Loss: 0.0190, LR: 5.31e-05


Epoch 484/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.18it/s, Loss=0.1739, Success=37/37]


Epoch 484 completed. Average Loss: 0.0422, LR: 5.30e-05


Epoch 485/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0042, Success=37/37]


Epoch 485 completed. Average Loss: 0.0304, LR: 5.28e-05


Epoch 486/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0085, Success=37/37]


Epoch 486 completed. Average Loss: 0.0064, LR: 5.27e-05


Epoch 487/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.0020, Success=37/37]


Epoch 487 completed. Average Loss: 0.0267, LR: 5.25e-05


Epoch 488/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0025, Success=37/37]


Epoch 488 completed. Average Loss: 0.0614, LR: 5.24e-05


Epoch 489/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0040, Success=37/37]


Epoch 489 completed. Average Loss: 0.0387, LR: 5.22e-05


Epoch 490/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0032, Success=37/37]


Epoch 490 completed. Average Loss: 0.0635, LR: 5.21e-05


Epoch 491/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0090, Success=37/37]


Epoch 491 completed. Average Loss: 0.0234, LR: 5.19e-05


Epoch 492/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0022, Success=37/37]


Epoch 492 completed. Average Loss: 0.0086, LR: 5.17e-05


Epoch 493/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=1.0150, Success=37/37]


Epoch 493 completed. Average Loss: 0.0511, LR: 5.16e-05


Epoch 494/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0026, Success=37/37]


Epoch 494 completed. Average Loss: 0.0318, LR: 5.14e-05


Epoch 495/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0022, Success=37/37]


Epoch 495 completed. Average Loss: 0.0930, LR: 5.13e-05


Epoch 496/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0140, Success=37/37]


Epoch 496 completed. Average Loss: 0.0741, LR: 5.11e-05


Epoch 497/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0027, Success=37/37]


Epoch 497 completed. Average Loss: 0.0165, LR: 5.10e-05


Epoch 498/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0188, Success=37/37]


Epoch 498 completed. Average Loss: 0.0213, LR: 5.08e-05


Epoch 499/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0206, Success=37/37]


Epoch 499 completed. Average Loss: 0.0295, LR: 5.07e-05


Epoch 500/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0217, Success=37/37]


Epoch 500 completed. Average Loss: 0.0386, LR: 5.05e-05
Checkpoint saved at epoch 500


Epoch 501/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.0017, Success=37/37]


Epoch 501 completed. Average Loss: 0.0138, LR: 5.03e-05


Epoch 502/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0038, Success=37/37]


Epoch 502 completed. Average Loss: 0.0528, LR: 5.02e-05


Epoch 503/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0021, Success=37/37]


Epoch 503 completed. Average Loss: 0.0332, LR: 5.00e-05


Epoch 504/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0276, Success=37/37]


Epoch 504 completed. Average Loss: 0.0200, LR: 4.99e-05


Epoch 505/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0268, Success=37/37]


Epoch 505 completed. Average Loss: 0.0656, LR: 4.97e-05


Epoch 506/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0019, Success=37/37]


Epoch 506 completed. Average Loss: 0.0187, LR: 4.96e-05


Epoch 507/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0018, Success=37/37]


Epoch 507 completed. Average Loss: 0.0190, LR: 4.94e-05


Epoch 508/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0036, Success=37/37]


Epoch 508 completed. Average Loss: 0.0393, LR: 4.93e-05


Epoch 509/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0066, Success=37/37]


Epoch 509 completed. Average Loss: 0.0221, LR: 4.91e-05


Epoch 510/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0029, Success=37/37]


Epoch 510 completed. Average Loss: 0.0210, LR: 4.89e-05


Epoch 511/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0112, Success=37/37]


Epoch 511 completed. Average Loss: 0.0252, LR: 4.88e-05


Epoch 512/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0151, Success=37/37]


Epoch 512 completed. Average Loss: 0.0074, LR: 4.86e-05


Epoch 513/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0037, Success=37/37]


Epoch 513 completed. Average Loss: 0.0453, LR: 4.85e-05


Epoch 514/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.1327, Success=37/37]


Epoch 514 completed. Average Loss: 0.0530, LR: 4.83e-05


Epoch 515/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0048, Success=37/37]


Epoch 515 completed. Average Loss: 0.0390, LR: 4.82e-05


Epoch 516/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0036, Success=37/37]


Epoch 516 completed. Average Loss: 0.0375, LR: 4.80e-05


Epoch 517/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.3531, Success=37/37]


Epoch 517 completed. Average Loss: 0.0303, LR: 4.79e-05


Epoch 518/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0023, Success=37/37]


Epoch 518 completed. Average Loss: 0.0316, LR: 4.77e-05


Epoch 519/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0027, Success=37/37]


Epoch 519 completed. Average Loss: 0.0601, LR: 4.75e-05


Epoch 520/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0034, Success=37/37]


Epoch 520 completed. Average Loss: 0.0397, LR: 4.74e-05


Epoch 521/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.5513, Success=37/37]


Epoch 521 completed. Average Loss: 0.0428, LR: 4.72e-05


Epoch 522/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0030, Success=37/37]


Epoch 522 completed. Average Loss: 0.0340, LR: 4.71e-05


Epoch 523/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0026, Success=37/37]


Epoch 523 completed. Average Loss: 0.0454, LR: 4.69e-05


Epoch 524/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0018, Success=37/37]


Epoch 524 completed. Average Loss: 0.0302, LR: 4.68e-05


Epoch 525/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.36it/s, Loss=0.0021, Success=37/37]


Epoch 525 completed. Average Loss: 0.0395, LR: 4.66e-05


Epoch 526/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0121, Success=37/37]


Epoch 526 completed. Average Loss: 0.0647, LR: 4.65e-05


Epoch 527/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0076, Success=37/37]


Epoch 527 completed. Average Loss: 0.0380, LR: 4.63e-05


Epoch 528/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0442, Success=37/37]


Epoch 528 completed. Average Loss: 0.0278, LR: 4.62e-05


Epoch 529/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0035, Success=37/37]


Epoch 529 completed. Average Loss: 0.0125, LR: 4.60e-05


Epoch 530/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0163, Success=37/37]


Epoch 530 completed. Average Loss: 0.0246, LR: 4.58e-05


Epoch 531/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0033, Success=37/37]


Epoch 531 completed. Average Loss: 0.0417, LR: 4.57e-05


Epoch 532/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0409, Success=37/37]


Epoch 532 completed. Average Loss: 0.0317, LR: 4.55e-05


Epoch 533/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0032, Success=37/37]


Epoch 533 completed. Average Loss: 0.0357, LR: 4.54e-05


Epoch 534/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.9447, Success=37/37]


Epoch 534 completed. Average Loss: 0.0612, LR: 4.52e-05


Epoch 535/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0031, Success=37/37]


Epoch 535 completed. Average Loss: 0.0144, LR: 4.51e-05


Epoch 536/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0023, Success=37/37]


Epoch 536 completed. Average Loss: 0.0255, LR: 4.49e-05


Epoch 537/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0308, Success=37/37]


Epoch 537 completed. Average Loss: 0.0175, LR: 4.48e-05


Epoch 538/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0038, Success=37/37]


Epoch 538 completed. Average Loss: 0.0903, LR: 4.46e-05


Epoch 539/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0038, Success=37/37]


Epoch 539 completed. Average Loss: 0.0286, LR: 4.45e-05


Epoch 540/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0022, Success=37/37]


Epoch 540 completed. Average Loss: 0.0673, LR: 4.43e-05


Epoch 541/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0021, Success=37/37]


Epoch 541 completed. Average Loss: 0.0421, LR: 4.41e-05


Epoch 542/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0115, Success=37/37]


Epoch 542 completed. Average Loss: 0.0670, LR: 4.40e-05


Epoch 543/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0300, Success=37/37]


Epoch 543 completed. Average Loss: 0.0107, LR: 4.38e-05


Epoch 544/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0047, Success=37/37]


Epoch 544 completed. Average Loss: 0.0095, LR: 4.37e-05


Epoch 545/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0026, Success=37/37]


Epoch 545 completed. Average Loss: 0.0060, LR: 4.35e-05


Epoch 546/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0021, Success=37/37]


Epoch 546 completed. Average Loss: 0.0519, LR: 4.34e-05


Epoch 547/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0023, Success=37/37]


Epoch 547 completed. Average Loss: 0.0538, LR: 4.32e-05


Epoch 548/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0041, Success=37/37]


Epoch 548 completed. Average Loss: 0.0229, LR: 4.31e-05


Epoch 549/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0039, Success=37/37]


Epoch 549 completed. Average Loss: 0.0108, LR: 4.29e-05


Epoch 550/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0194, Success=37/37]


Epoch 550 completed. Average Loss: 0.0176, LR: 4.28e-05


Epoch 551/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0039, Success=37/37]


Epoch 551 completed. Average Loss: 0.0116, LR: 4.26e-05


Epoch 552/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0023, Success=37/37]


Epoch 552 completed. Average Loss: 0.0619, LR: 4.24e-05


Epoch 553/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0112, Success=37/37]


Epoch 553 completed. Average Loss: 0.0286, LR: 4.23e-05


Epoch 554/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0196, Success=37/37]


Epoch 554 completed. Average Loss: 0.0150, LR: 4.21e-05


Epoch 555/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0016, Success=37/37]


Epoch 555 completed. Average Loss: 0.0152, LR: 4.20e-05


Epoch 556/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0023, Success=37/37]


Epoch 556 completed. Average Loss: 0.0144, LR: 4.18e-05


Epoch 557/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0027, Success=37/37]


Epoch 557 completed. Average Loss: 0.0181, LR: 4.17e-05


Epoch 558/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0085, Success=37/37]


Epoch 558 completed. Average Loss: 0.0196, LR: 4.15e-05


Epoch 559/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0065, Success=37/37]


Epoch 559 completed. Average Loss: 0.0182, LR: 4.14e-05


Epoch 560/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0024, Success=37/37]


Epoch 560 completed. Average Loss: 0.0948, LR: 4.12e-05


Epoch 561/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0030, Success=37/37]


Epoch 561 completed. Average Loss: 0.0145, LR: 4.11e-05


Epoch 562/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.42it/s, Loss=0.0058, Success=37/37]


Epoch 562 completed. Average Loss: 0.0139, LR: 4.09e-05


Epoch 563/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0070, Success=37/37]


Epoch 563 completed. Average Loss: 0.0180, LR: 4.08e-05


Epoch 564/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0093, Success=37/37]


Epoch 564 completed. Average Loss: 0.0277, LR: 4.06e-05


Epoch 565/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0105, Success=37/37]


Epoch 565 completed. Average Loss: 0.0275, LR: 4.05e-05


Epoch 566/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.9402, Success=37/37]


Epoch 566 completed. Average Loss: 0.0323, LR: 4.03e-05


Epoch 567/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0050, Success=37/37]


Epoch 567 completed. Average Loss: 0.0238, LR: 4.02e-05


Epoch 568/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0032, Success=37/37]


Epoch 568 completed. Average Loss: 0.0216, LR: 4.00e-05


Epoch 569/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0051, Success=37/37]


Epoch 569 completed. Average Loss: 0.0201, LR: 3.99e-05


Epoch 570/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.32it/s, Loss=0.0019, Success=37/37]


Epoch 570 completed. Average Loss: 0.0304, LR: 3.97e-05


Epoch 571/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0039, Success=37/37]


Epoch 571 completed. Average Loss: 0.0341, LR: 3.96e-05


Epoch 572/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0029, Success=37/37]


Epoch 572 completed. Average Loss: 0.0133, LR: 3.94e-05


Epoch 573/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0024, Success=37/37]


Epoch 573 completed. Average Loss: 0.0121, LR: 3.92e-05


Epoch 574/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0026, Success=37/37]


Epoch 574 completed. Average Loss: 0.0407, LR: 3.91e-05


Epoch 575/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.7784, Success=37/37]


Epoch 575 completed. Average Loss: 0.0407, LR: 3.89e-05


Epoch 576/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0020, Success=37/37]


Epoch 576 completed. Average Loss: 0.0243, LR: 3.88e-05


Epoch 577/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0297, Success=37/37]


Epoch 577 completed. Average Loss: 0.0107, LR: 3.86e-05


Epoch 578/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.34it/s, Loss=0.4749, Success=37/37]


Epoch 578 completed. Average Loss: 0.0669, LR: 3.85e-05


Epoch 579/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.38it/s, Loss=0.0022, Success=37/37]


Epoch 579 completed. Average Loss: 0.0113, LR: 3.83e-05


Epoch 580/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0172, Success=37/37]


Epoch 580 completed. Average Loss: 0.0179, LR: 3.82e-05


Epoch 581/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0113, Success=37/37]


Epoch 581 completed. Average Loss: 0.0618, LR: 3.80e-05


Epoch 582/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0037, Success=37/37]


Epoch 582 completed. Average Loss: 0.0468, LR: 3.79e-05


Epoch 583/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.34it/s, Loss=0.0134, Success=37/37]


Epoch 583 completed. Average Loss: 0.0356, LR: 3.77e-05


Epoch 584/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.37it/s, Loss=0.0051, Success=37/37]


Epoch 584 completed. Average Loss: 0.0365, LR: 3.76e-05


Epoch 585/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0022, Success=37/37]


Epoch 585 completed. Average Loss: 0.0113, LR: 3.74e-05


Epoch 586/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0018, Success=37/37]


Epoch 586 completed. Average Loss: 0.0086, LR: 3.73e-05


Epoch 587/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0025, Success=37/37]


Epoch 587 completed. Average Loss: 0.0069, LR: 3.71e-05


Epoch 588/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0017, Success=37/37]


Epoch 588 completed. Average Loss: 0.0083, LR: 3.70e-05


Epoch 589/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0217, Success=37/37]


Epoch 589 completed. Average Loss: 0.0244, LR: 3.68e-05


Epoch 590/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0085, Success=37/37]


Epoch 590 completed. Average Loss: 0.0119, LR: 3.67e-05


Epoch 591/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.35it/s, Loss=0.0053, Success=37/37]


Epoch 591 completed. Average Loss: 0.0280, LR: 3.65e-05


Epoch 592/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0019, Success=37/37]


Epoch 592 completed. Average Loss: 0.0065, LR: 3.64e-05


Epoch 593/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.33it/s, Loss=0.0070, Success=37/37]


Epoch 593 completed. Average Loss: 0.0469, LR: 3.62e-05


Epoch 594/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.40it/s, Loss=0.0944, Success=37/37]


Epoch 594 completed. Average Loss: 0.0192, LR: 3.61e-05


Epoch 595/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.39it/s, Loss=0.0701, Success=37/37]


Epoch 595 completed. Average Loss: 0.0237, LR: 3.59e-05


Epoch 596/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.36it/s, Loss=0.0063, Success=37/37]


Epoch 596 completed. Average Loss: 0.0276, LR: 3.58e-05


Epoch 597/600: 100%|███████████████████████████████████████| 37/37 [00:10<00:00,  3.41it/s, Loss=0.0033, Success=37/37]


Epoch 597 completed. Average Loss: 0.0121, LR: 3.56e-05


Epoch 598/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.29it/s, Loss=0.0024, Success=37/37]


Epoch 598 completed. Average Loss: 0.0129, LR: 3.55e-05


Epoch 599/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.21it/s, Loss=0.0021, Success=37/37]


Epoch 599 completed. Average Loss: 0.0213, LR: 3.54e-05


Epoch 600/600: 100%|███████████████████████████████████████| 37/37 [00:11<00:00,  3.34it/s, Loss=0.0056, Success=37/37]


Epoch 600 completed. Average Loss: 0.0330, LR: 3.52e-05
Checkpoint saved at epoch 600
Training completed!
✅ Training completed successfully!
Using device: cuda
🎨 Generating test images directly...
Generating: a pixel art cat
❌ Error generating 'a pixel art cat': Expected a 'cpu' device type for generator but found 'cuda'
Error type: RuntimeError
Generating: a pixel art house
❌ Error generating 'a pixel art house': Expected a 'cpu' device type for generator but found 'cuda'
Error type: RuntimeError
Generating: a pixel art tree
❌ Error generating 'a pixel art tree': Expected a 'cpu' device type for generator but found 'cuda'
Error type: RuntimeError
Generating: james graham ballard, highrise, sustainability, octane render, highly detailed
❌ Error generating 'james graham ballard, highrise, sustainability, octane render, highly detailed': Expected a 'cpu' device type for generator but found 'cuda'
Error type: RuntimeError
Generating: pixel art landscape
❌ Error generating 'pixel art lands

In [3]:
import torch
import random
import numpy as np

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
# Sinusoidal timestep embedding for diffusion steps
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.exp(
        torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
        -(torch.log(torch.tensor(10000.0)) / half_dim)
    )
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # Handle odd embedding dimensions
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb

# Residual block with time and context embeddings
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, context=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        t_proj = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_proj

        # Add context embedding if available
        if self.context_proj is not None and context is not None:
            context_pooled = context.mean(dim=1)  # [batch, context_dim]
            context_proj = self.context_proj(context_pooled)[:, :, None, None]
            h = h + context_proj

        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.shortcut(x)

# Cross-attention to integrate text embeddings
class CrossAttention(nn.Module):
    def __init__(self, channels, context_dim):
        super().__init__()
        self.channels = channels
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(context_dim, channels)
        self.value = nn.Linear(context_dim, channels)
        self.out = nn.Linear(channels, channels)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x, context):
        if context is None:
            return x

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_norm = self.norm(x_flat)

        q = self.query(x_norm)  # [B, H*W, C]
        k = self.key(context)   # [B, seq_len, C]
        v = self.value(context) # [B, seq_len, C]

        scale = (C ** -0.5)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_out = torch.bmm(attn_weights, v)
        attn_out = self.out(attn_out)

        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Self-attention block for image features
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, channels), channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        scale = (C ** -0.5)
        attn = torch.bmm(q.transpose(1, 2), k) * scale
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        return self.proj(out) + x

# U-Net model updated for 256x256 latents
class UNetConditional(nn.Module):
    def __init__(self, in_channels=4, base_channels=128, context_dim=768):
        super().__init__()
        self.time_emb_dim = base_channels * 4
        from types import SimpleNamespace
        self.config = SimpleNamespace()
        self.config._diffusers_version = "0.34.0"
        self.config.in_channels = in_channels
        self.config.out_channels = in_channels
        self.config.sample_size = 256  # Updated for 256x256 latents
        self.config.layers_per_block = 2
        self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
        self.config.attention_head_dim = 8
        self.config.cross_attention_dim = context_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(base_channels, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        )

        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
        self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
        self.cross1 = CrossAttention(base_channels * 2, context_dim)

        self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
        self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
        self.cross2 = CrossAttention(base_channels * 4, context_dim)

        self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
        self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
        self.cross3 = CrossAttention(base_channels * 8, context_dim)

        # Middle
        self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
        self.middle_attn = AttentionBlock(base_channels * 8)
        self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)

        # Decoder
        self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
        self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
        self.cross_up3 = CrossAttention(base_channels * 4, context_dim)

        self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
        self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
        self.cross_up2 = CrossAttention(base_channels * 2, context_dim)

        self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
        self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)

        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(min(32, base_channels), base_channels),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, context, cfg_scale=1.0):
        t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
        t_emb = self.time_mlp(t_emb)

        def denoise(x, t_emb, context):
            h = self.input_conv(x)

            # Encoder
            h1 = self.down1(h, t_emb, context)
            h1_cross = self.cross1(h1, context)
            h1_down = self.downsample1(h1_cross)

            h2 = self.down2(h1_down, t_emb, context)
            h2_cross = self.cross2(h2, context)
            h2_down = self.downsample2(h2_cross)

            h3 = self.down3(h2_down, t_emb, context)
            h3_cross = self.cross3(h3, context)
            h3_down = self.downsample3(h3_cross)

            # Middle
            h_mid = self.middle1(h3_down, t_emb, context)
            h_mid = self.middle_attn(h_mid)
            h_mid = self.middle2(h_mid, t_emb, context)

            # Decoder
            h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
            h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
            h = self.upsample3(h)
            h = self.cross_up3(h, context)

            h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
            h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
            h = self.upsample2(h)
            h = self.cross_up2(h, context)

            h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
            h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
            h = self.upsample1(h)

            return self.output_conv(h)

        if cfg_scale == 1.0 or context is None:
            return denoise(x, t_emb, context)

        uncond = denoise(x, t_emb, context=None)
        cond = denoise(x, t_emb, context)
        return uncond + cfg_scale * (cond - uncond)


In [4]:
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import numpy as np
from PIL import Image
# Import UNetConditional from the cell where it's defined


def generate_images_direct(unet_path="output/checkpoint_epoch_600.pt", device="cuda"):
    """Generate 256x256 images directly without using StableDiffusionPipeline"""
    seed_everything(42)
    print(f"Using device: {device}")

    # Load components
    print("Loading VAE...")
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
    vae = vae.eval().requires_grad_(False).to(device)

    print("Loading tokenizer and text encoder...")
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = text_encoder.eval().requires_grad_(False).to(device)

    # Load your trained UNet with proper device handling
    print("Loading trained UNet...")
    # Correctly initialize the UNet with base_channels=128, using the class from the other cell
    unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)
    checkpoint = torch.load(unet_path, map_location=device, weights_only=True)
    # Load only the model_state_dict from the checkpoint
    unet.load_state_dict(checkpoint['model_state_dict'])
    unet = unet.eval().to(device)

    # Create scheduler
    scheduler = DDPMScheduler(num_train_timesteps=1000)

    # Test prompts
    test_prompts = ["A futuristic smartphone"]

    print("🎨 Generating 256x256 images directly...")

    for i, prompt in enumerate(test_prompts):
        print(f"Generating: {prompt}")

        try:
            with torch.no_grad():
                # Encode prompt - ensure all tensors on correct device
                inputs = tokenizer(
                    prompt,
                    padding="max_length",
                    truncation=True,
                    max_length=77,
                    return_tensors="pt"
                )
                # Move tokenizer outputs to device
                inputs = {k: v.to(device) for k, v in inputs.items()}

                text_embeddings = text_encoder(**inputs).last_hidden_state
                print(f"Text embeddings shape: {text_embeddings.shape}, device: {text_embeddings.device}")

                # Create random latents for 256x256 output (256/8 = 32 due to VAE scaling)
                latents = torch.randn(1, 4, 32, 32, device=device, dtype=torch.float32)
                print(f"Initial latents shape: {latents.shape}, device: {latents.device}")

                # Set timesteps
                scheduler.set_timesteps(1000)

                # Denoising loop
                for t in tqdm(scheduler.timesteps, desc=f"Denoising {prompt}"):
                    # Ensure timestep is on correct device
                    t_tensor = torch.tensor([t], device=device, dtype=torch.long)

                    # Predict noise - all inputs should be on same device
                    # Pass text_embeddings as the context argument
                    noise_pred = unet(latents, t_tensor, context=text_embeddings)

                    # Remove noise
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                print(f"Final latents shape: {latents.shape}")

                # Decode latents to image
                latents = latents / 0.18215

                # No need to resize since UNet was trained on 256x256 (32x32 latents)
                images = vae.decode(latents).sample

                # Convert to PIL
                images = (images + 1) / 2  # Denormalize
                images = images.clamp(0, 1)
                images = images.cpu().permute(0, 2, 3, 1).numpy()

                image = Image.fromarray((images[0] * 255).astype(np.uint8))

                # Save
                filename = f"output/generated_256_{i+1}_{prompt.replace(' ', '_')}.png"
                image.save(filename)
                print(f"✅ Saved: {filename}")

        except Exception as e:
            print(f"❌ Error generating '{prompt}': {e}")
            print(f"Error type: {type(e).__name__}")
            # Continue with next prompt
            continue

In [11]:
generate_images_direct()


Using device: cuda
Loading VAE...
Loading tokenizer and text encoder...
Loading trained UNet...


FileNotFoundError: [Errno 2] No such file or directory: 'output/checkpoint_epoch_600.pt'

In [None]:
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler
from transformers import CLIPTokenizer, CLIPTextModel
import torch
import torch
import random
import numpy as np

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
# Sinusoidal timestep embedding for diffusion steps
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.exp(
        torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
        -(torch.log(torch.tensor(10000.0)) / half_dim)
    )
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # Handle odd embedding dimensions
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb

# Residual block with time and context embeddings
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, context=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        t_proj = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_proj

        # Add context embedding if available
        if self.context_proj is not None and context is not None:
            context_pooled = context.mean(dim=1)  # [batch, context_dim]
            context_proj = self.context_proj(context_pooled)[:, :, None, None]
            h = h + context_proj

        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.shortcut(x)

# Cross-attention to integrate text embeddings
class CrossAttention(nn.Module):
    def __init__(self, channels, context_dim):
        super().__init__()
        self.channels = channels
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(context_dim, channels)
        self.value = nn.Linear(context_dim, channels)
        self.out = nn.Linear(channels, channels)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x, context):
        if context is None:
            return x

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_norm = self.norm(x_flat)

        q = self.query(x_norm)  # [B, H*W, C]
        k = self.key(context)   # [B, seq_len, C]
        v = self.value(context) # [B, seq_len, C]

        scale = (C ** -0.5)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_out = torch.bmm(attn_weights, v)
        attn_out = self.out(attn_out)

        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Self-attention block for image features
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, channels), channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        scale = (C ** -0.5)
        attn = torch.bmm(q.transpose(1, 2), k) * scale
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        return self.proj(out) + x

# U-Net model updated for 256x256 latents
class UNetConditional(nn.Module):
    def __init__(self, in_channels=4, base_channels=128, context_dim=768):
        super().__init__()
        self.time_emb_dim = base_channels * 4
        from types import SimpleNamespace
        self.config = SimpleNamespace()
        self.config._diffusers_version = "0.34.0"
        self.config.in_channels = in_channels
        self.config.out_channels = in_channels
        self.config.sample_size = 256  # Updated for 256x256 latents
        self.config.layers_per_block = 2
        self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
        self.config.attention_head_dim = 8
        self.config.cross_attention_dim = context_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(base_channels, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        )

        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
        self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
        self.cross1 = CrossAttention(base_channels * 2, context_dim)

        self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
        self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
        self.cross2 = CrossAttention(base_channels * 4, context_dim)

        self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
        self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
        self.cross3 = CrossAttention(base_channels * 8, context_dim)

        # Middle
        self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
        self.middle_attn = AttentionBlock(base_channels * 8)
        self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)

        # Decoder
        self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
        self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
        self.cross_up3 = CrossAttention(base_channels * 4, context_dim)

        self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
        self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
        self.cross_up2 = CrossAttention(base_channels * 2, context_dim)

        self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
        self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)

        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(min(32, base_channels), base_channels),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, context, cfg_scale=1.0):
        t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
        t_emb = self.time_mlp(t_emb)

        def denoise(x, t_emb, context):
            h = self.input_conv(x)

            # Encoder
            h1 = self.down1(h, t_emb, context)
            h1_cross = self.cross1(h1, context)
            h1_down = self.downsample1(h1_cross)

            h2 = self.down2(h1_down, t_emb, context)
            h2_cross = self.cross2(h2, context)
            h2_down = self.downsample2(h2_cross)

            h3 = self.down3(h2_down, t_emb, context)
            h3_cross = self.cross3(h3, context)
            h3_down = self.downsample3(h3_cross)

            # Middle
            h_mid = self.middle1(h3_down, t_emb, context)
            h_mid = self.middle_attn(h_mid)
            h_mid = self.middle2(h_mid, t_emb, context)

            # Decoder
            h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
            h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
            h = self.upsample3(h)
            h = self.cross_up3(h, context)

            h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
            h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
            h = self.upsample2(h)
            h = self.cross_up2(h, context)

            h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
            h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
            h = self.upsample1(h)

            return self.output_conv(h)

        if cfg_scale == 1.0 or context is None:
            return denoise(x, t_emb, context)

        uncond = denoise(x, t_emb, context=None)
        cond = denoise(x, t_emb, context)
        return uncond + cfg_scale * (cond - uncond)
# Your UNetConditional and related classes (ResidualBlock, CrossAttention, AttentionBlock, etc.) go here
# ... (Your provided UNet code)

import torch
import torch.nn as nn
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import numpy as np
from PIL import Image
import random
import os



import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTokenizer, CLIPTextModel
from tqdm import tqdm
import numpy as np
from PIL import Image
import random
import os



# Set random seed
def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def create_and_save_pipeline(unet_path="output/checkpoint_epoch_600.pt", output_dir="KahabMiniGenT2Im-v1", device="cuda"):
    seed_everything(42)

    # Load components
    print("Loading VAE...")
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16)
    vae = vae.eval().requires_grad_(False).to(device)

    print("Loading tokenizer and text encoder...")
    tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder", torch_dtype=torch.float16)
    text_encoder = text_encoder.eval().requires_grad_(False).to(device)

    print("Loading scheduler...")
    scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

    print("Loading trained UNet...")
    unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)
    checkpoint = torch.load(unet_path, map_location=device, weights_only=True)
    try:
        unet.load_state_dict(checkpoint['model_state_dict'])
    except Exception as e:
        print(f"❌ Error loading UNet weights: {e}")
        print("Checkpoint keys (first 10):", list(checkpoint['model_state_dict'].keys())[:10])
        print("UNet keys (first 10):", list(unet.state_dict().keys())[:10])
        return
    unet = unet.eval().to(device, dtype=torch.float32)

    # Create diffusers UNet
    print("Creating UNet2DConditionModel for pipeline...")
    unet_config = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").config
    unet_diffusers = UNet2DConditionModel(**unet_config).to(device, dtype=torch.float32)
    try:
        unet_diffusers.load_state_dict(unet.state_dict(), strict=False)
    except Exception as e:
        print(f"❌ Error transferring UNet weights: {e}")
        print("Custom UNet keys (first 10):", list(unet.state_dict().keys())[:10])
        print("Diffusers UNet keys (first 10):", list(unet_diffusers.state_dict().keys())[:10])
        return

    # Minimal pipeline
    class CustomStableDiffusionPipeline(StableDiffusionPipeline):
        def __call__(self, prompt, num_inference_steps=50, guidance_scale=7.5, **kwargs):
            device = self.unet.device
            text_inputs = self.tokenizer(
                prompt, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt"
            )
            text_inputs = {k: v.to(device) for k, v in text_inputs.items()}
            with torch.no_grad():
                text_embeddings = self.text_encoder(text_inputs["input_ids"])[0].to(torch.float32)

            empty_inputs = self.tokenizer(
                "", padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="pt"
            )
            empty_inputs = {k: v.to(device) for k, v in empty_inputs.items()}
            with torch.no_grad():
                empty_embeddings = self.text_encoder(empty_inputs["input_ids"])[0].to(torch.float32)

            latents = torch.randn(
                (1, self.unet.config.in_channels, 32, 32), device=device, dtype=torch.float32
            )
            latents = latents * self.scheduler.init_noise_sigma

            self.scheduler.set_timesteps(num_inference_steps)

            for t in tqdm(self.scheduler.timesteps, desc=f"Denoising {prompt}"):
                latent_model_input = self.scheduler.scale_model_input(latents, t)
                t_tensor = torch.tensor([t], device=device, dtype=torch.float32)
                with torch.no_grad():
                    noise_pred = self.unet(latent_model_input, t_tensor, encoder_hidden_states=text_embeddings)['sample']
                latents = self.scheduler.step(noise_pred, t, latents).prev_sample

            latents = latents / 0.18215
            latents = latents.to(dtype=torch.float16)
            with torch.no_grad():
                images = self.vae.decode(latents).sample
            images = (images / 2 + 0.5).clamp(0, 1)
            images = images.cpu().permute(0, 2, 3, 1).numpy()
            images = [Image.fromarray((img * 255).astype(np.uint8)) for img in images]
            return {"images": images}

    # Create pipeline
    try:
        pipeline = CustomStableDiffusionPipeline(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet_diffusers,
            scheduler=scheduler,
            safety_checker=None,
            feature_extractor=None
        ).to(device)
        print("✅ Pipeline created successfully")
        print("Pipeline components:", list(pipeline.components.keys()))
    except Exception as e:
        print(f"❌ Error creating pipeline: {e}")
        return

    # Test pipeline
    try:
        print("Testing pipeline...")
        image = pipeline("A futuristic smartphone", num_inference_steps=50, guidance_scale=7.5)["images"][0]
        image.save("test_pipeline_output.png")
        print("✅ Test image generated successfully")
    except Exception as e:
        print(f"❌ Error testing pipeline: {e}")
        return

    # Save pipeline
    try:
        os.makedirs(output_dir, exist_ok=True)
        pipeline.save_pretrained(output_dir, safe_serialization=True)
        print(f"✅ Pipeline saved to {output_dir}")
    except Exception as e:
        print(f"❌ Error saving pipeline: {e}")
        return

if __name__ == "__main__":
    create_and_save_pipeline()

Loading VAE...
Loading tokenizer and text encoder...
Loading scheduler...
Loading trained UNet...
Creating UNet2DConditionModel for pipeline...


You have disabled the safety checker for <class '__main__.create_and_save_pipeline.<locals>.CustomStableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


✅ Pipeline created successfully
Pipeline components: ['vae', 'text_encoder', 'tokenizer', 'unet', 'scheduler', 'safety_checker', 'feature_extractor', 'image_encoder']
Testing pipeline...


Denoising A futuristic smartphone:  46%|█████████████████████▌                         | 23/50 [00:19<00:24,  1.12it/s]

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler
import math
import glob
import random
import json

# -----------------------
# Advanced U-Net
# -----------------------
class SinusoidalTimeEmbedding(nn.Module):
    def _init_(self, dim):
        super()._init_()
        self.linear1 = nn.Linear(dim, dim * 4)
        self.linear2 = nn.Linear(dim * 4, dim)

    def forward(self, t):
        half_dim = self.linear1.in_features // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.linear2(F.silu(self.linear1(emb)))

class CrossAttentionBlock(nn.Module):
    def _init_(self, dim, context_dim):
        super()._init_()
        self.norm = nn.LayerNorm(dim)
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(context_dim, dim)
        self.v = nn.Linear(context_dim, dim)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x, context):
        b, n, c = x.shape
        q = self.q(self.norm(x))
        k = self.k(context)
        v = self.v(context)
        attn = torch.softmax(q @ k.transpose(-2, -1) / math.sqrt(c), dim=-1)
        return self.proj(attn @ v)

class ResBlock(nn.Module):
    def _init_(self, in_ch, out_ch, time_emb_dim):
        super()._init_()
        self.norm1 = nn.GroupNorm(32, in_ch)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_emb_proj = nn.Linear(time_emb_dim, out_ch)
        self.norm2 = nn.GroupNorm(32, out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = F.silu(self.norm1(x))
        h = self.conv1(h)
        h += self.time_emb_proj(t_emb).unsqueeze(-1).unsqueeze(-1)
        h = F.silu(self.norm2(h))
        h = self.conv2(h)
        return h + self.skip(x)

class AttentionBlock(nn.Module):
    def _init_(self, dim):
        super()._init_()
        self.norm = nn.GroupNorm(32, dim)
        self.q = nn.Conv2d(dim, dim, 1)
        self.k = nn.Conv2d(dim, dim, 1)
        self.v = nn.Conv2d(dim, dim, 1)
        self.proj = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        q = self.q(h).reshape(B, C, H * W).permute(0, 2, 1)
        k = self.k(h).reshape(B, C, H * W)
        v = self.v(h).reshape(B, C, H * W).permute(0, 2, 1)
        attn = torch.softmax(q @ k / math.sqrt(C), dim=-1)
        out = attn @ v
        out = out.permute(0, 2, 1).reshape(B, C, H, W)
        return x + self.proj(out)

class AdvancedUNet(nn.Module):
    def _init_(self, in_ch=4, base_ch=320, context_dim=768, time_dim=1280):
        super()._init_()
        self.time_mlp = SinusoidalTimeEmbedding(time_dim)
        self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
        self.down1 = nn.Sequential(ResBlock(base_ch, base_ch, time_dim), AttentionBlock(base_ch), nn.AvgPool2d(2))
        self.down2 = nn.Sequential(ResBlock(base_ch, base_ch * 2, time_dim), AttentionBlock(base_ch * 2), nn.AvgPool2d(2))
        self.mid = nn.Sequential(ResBlock(base_ch * 2, base_ch * 4, time_dim), AttentionBlock(base_ch * 4), ResBlock(base_ch * 4, base_ch * 2, time_dim))
        self.cross_attn = CrossAttentionBlock(base_ch * 2, context_dim)
        self.up1 = nn.Sequential(ResBlock(base_ch * 2, base_ch, time_dim), AttentionBlock(base_ch), nn.Upsample(scale_factor=2, mode="nearest"))
        self.up2 = nn.Sequential(ResBlock(base_ch, base_ch, time_dim), AttentionBlock(base_ch), nn.Upsample(scale_factor=2, mode="nearest"))
        self.conv_out = nn.Sequential(nn.GroupNorm(32, base_ch), nn.SiLU(), nn.Conv2d(base_ch, in_ch, 3, padding=1))

    def forward(self, x, t, context):
        t_emb = self.time_mlp(t)
        x = self.conv_in(x)
        d1 = self.down1(x)
        d2 = self.down2(d1)
        mid = self.mid(d2)
        B, C, H, W = mid.shape
        context_proj = self.cross_attn(mid.view(B, C, -1).permute(0, 2, 1), context)
        mid = mid + context_proj.permute(0, 2, 1).view(B, C, H, W)
        u1 = self.up1(mid)
        u2 = self.up2(u1)
        return self.conv_out(u2)

# -----------------------
# Dataset & Preprocessing
# -----------------------
class FaceLatentDataset(Dataset):
    def _init_(self, image_dir, caption_file, vae, tokenizer, text_encoder, size=512):
        self.image_paths = glob.glob(os.path.join(image_dir, "*.jpg"))[:10]
        self.caption_data = json.load(open(caption_file))
        self.vae = vae
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])

    def _len_(self):
        return len(self.image_paths)

    def _getitem_(self, idx):
        path = self.image_paths[idx]
        image = self.transform(Image.open(path).convert("RGB")).unsqueeze(0).cuda()
        with torch.no_grad():
            latent = self.vae.encode(image).latent_dist.sample() * 0.18215
            text = self.caption_data[os.path.basename(path)]
            text_token = self.tokenizer([text], padding="max_length", max_length=77, return_tensors="pt").input_ids.cuda()
            text_embedding = self.text_encoder(text_token)[0]
        return latent.squeeze(0), text_embedding.squeeze(0)

# -----------------------
# Training Setup
# -----------------------
def train(model, dataloader, noise_scheduler, optimizer, epochs=100):
    model.train()
    for epoch in range(epochs):
        for latents, text_embeds in dataloader:
            noise = torch.randn_like(latents)
            t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (latents.shape[0],), device=latents.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, t)
            preds = model(noisy_latents, t, text_embeds)
            loss = F.mse_loss(preds, noise)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        print(f"[Epoch {epoch+1}] Loss: {loss.item():.4f}")

# -----------------------
# MAIN
# -----------------------
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-5", subfolder="vae").to(device)
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
    text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
    model = AdvancedUNet().to(device)
    dataset = FaceLatentDataset("/content/images/", "/content/exprtion.json", vae, tokenizer, text_encoder)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
    noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    train(model, dataloader, noise_scheduler, optimizer, epochs=10)
    torch.save(model.state_dict(), "my_custom_unet.pth")


main()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


OSError: CompVis/stable-diffusion-v1-5 is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo with `token` or log in with `huggingface-cli login`.

In [1]:
import os
import shutil

# Define source and destination paths
image_source_dir = '.'  # Current directory where image files are located
image_dest_dir = 'images' # Destination folder for image files
json_source_path = 'captions.json' # Source path for the JSON file
json_dest_path = '../captions.json' # Destination path for the JSON file (one level up)

# Create the destination directory for images if it doesn't exist
os.makedirs(image_dest_dir, exist_ok=True)

# Move image files
for filename in os.listdir(image_source_dir):
    if filename.endswith(('.jpg', '.jpeg', '.png', '.gif')): # Add other image extensions if needed
        shutil.move(os.path.join(image_source_dir, filename), os.path.join(image_dest_dir, filename))
        print(f"Moved {filename} to {image_dest_dir}")

# Move the JSON file
if os.path.exists(json_source_path):
    shutil.move(json_source_path, json_dest_path)
    print(f"Moved {json_source_path} to {json_dest_path}")
else:
    print(f"{json_source_path} not found.")

Moved expression_02.jpg to images
Moved expression_07.jpg to images
Moved expression_06.jpg to images
Moved expression_08.jpg to images
Moved expression_01.jpg to images
Moved expression_09.jpg to images
Moved expression_05.jpg to images
Moved expression_04.jpg to images
Moved expression_10.jpg to images
Moved expression_03.jpg to images
captions.json not found.


In [4]:
seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load components
print("Loading VAE...")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).eval().requires_grad_(False)

print("Loading tokenizer and text encoder...")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval().requires_grad_(False)


Loading VAE...
Loading tokenizer and text encoder...


In [5]:
import torch
import random
import numpy as np

def seed_everything(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(42)
# Sinusoidal timestep embedding for diffusion steps
def get_timestep_embedding(timesteps, embedding_dim):
    half_dim = embedding_dim // 2
    emb = torch.exp(
        torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
        -(torch.log(torch.tensor(10000.0)) / half_dim)
    )
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # Handle odd embedding dimensions
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb

# Residual block with time and context embeddings
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
        self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()

    def forward(self, x, t_emb, context=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)

        # Add time embedding
        t_proj = self.time_mlp(t_emb)[:, :, None, None]
        h = h + t_proj

        # Add context embedding if available
        if self.context_proj is not None and context is not None:
            context_pooled = context.mean(dim=1)  # [batch, context_dim]
            context_proj = self.context_proj(context_pooled)[:, :, None, None]
            h = h + context_proj

        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)

        return h + self.shortcut(x)

# Cross-attention to integrate text embeddings
class CrossAttention(nn.Module):
    def __init__(self, channels, context_dim):
        super().__init__()
        self.channels = channels
        self.query = nn.Linear(channels, channels)
        self.key = nn.Linear(context_dim, channels)
        self.value = nn.Linear(context_dim, channels)
        self.out = nn.Linear(channels, channels)
        self.norm = nn.LayerNorm(channels)

    def forward(self, x, context):
        if context is None:
            return x

        B, C, H, W = x.shape
        x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x_norm = self.norm(x_flat)

        q = self.query(x_norm)  # [B, H*W, C]
        k = self.key(context)   # [B, seq_len, C]
        v = self.value(context) # [B, seq_len, C]

        scale = (C ** -0.5)
        attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_out = torch.bmm(attn_weights, v)
        attn_out = self.out(attn_out)

        attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
        return x + attn_out

# Self-attention block for image features
class AttentionBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(min(32, channels), channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h).reshape(B, 3, C, H * W)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]

        scale = (C ** -0.5)
        attn = torch.bmm(q.transpose(1, 2), k) * scale
        attn = F.softmax(attn, dim=-1)

        out = torch.bmm(v, attn.transpose(1, 2))
        out = out.reshape(B, C, H, W)
        return self.proj(out) + x

# U-Net model updated for 256x256 latents
class UNetConditional(nn.Module):
    def __init__(self, in_channels=4, base_channels=128, context_dim=768):
        super().__init__()
        self.time_emb_dim = base_channels * 4
        from types import SimpleNamespace
        self.config = SimpleNamespace()
        self.config._diffusers_version = "0.34.0"
        self.config.in_channels = in_channels
        self.config.out_channels = in_channels
        self.config.sample_size = 256  # Updated for 256x256 latents
        self.config.layers_per_block = 2
        self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
        self.config.attention_head_dim = 8
        self.config.cross_attention_dim = context_dim

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(base_channels, self.time_emb_dim),
            nn.SiLU(),
            nn.Linear(self.time_emb_dim, self.time_emb_dim),
        )

        # Input projection
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Encoder
        self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
        self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
        self.cross1 = CrossAttention(base_channels * 2, context_dim)

        self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
        self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
        self.cross2 = CrossAttention(base_channels * 4, context_dim)

        self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
        self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
        self.cross3 = CrossAttention(base_channels * 8, context_dim)

        # Middle
        self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
        self.middle_attn = AttentionBlock(base_channels * 8)
        self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)

        # Decoder
        self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
        self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
        self.cross_up3 = CrossAttention(base_channels * 4, context_dim)

        self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
        self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
        self.cross_up2 = CrossAttention(base_channels * 2, context_dim)

        self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
        self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)

        # Output
        self.output_conv = nn.Sequential(
            nn.GroupNorm(min(32, base_channels), base_channels),
            nn.SiLU(),
            nn.Conv2d(base_channels, in_channels, 3, padding=1)
        )

    def forward(self, x, t, context, cfg_scale=1.0):
        t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
        t_emb = self.time_mlp(t_emb)

        def denoise(x, t_emb, context):
            h = self.input_conv(x)

            # Encoder
            h1 = self.down1(h, t_emb, context)
            h1_cross = self.cross1(h1, context)
            h1_down = self.downsample1(h1_cross)

            h2 = self.down2(h1_down, t_emb, context)
            h2_cross = self.cross2(h2, context)
            h2_down = self.downsample2(h2_cross)

            h3 = self.down3(h2_down, t_emb, context)
            h3_cross = self.cross3(h3, context)
            h3_down = self.downsample3(h3_cross)

            # Middle
            h_mid = self.middle1(h3_down, t_emb, context)
            h_mid = self.middle_attn(h_mid)
            h_mid = self.middle2(h_mid, t_emb, context)

            # Decoder
            h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
            h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
            h = self.upsample3(h)
            h = self.cross_up3(h, context)

            h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
            h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
            h = self.upsample2(h)
            h = self.cross_up2(h, context)

            h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
            h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
            h = self.upsample1(h)

            return self.output_conv(h)

        if cfg_scale == 1.0 or context is None:
            return denoise(x, t_emb, context)

        uncond = denoise(x, t_emb, context=None)
        cond = denoise(x, t_emb, context)
        return uncond + cfg_scale * (cond - uncond)
import torch
from diffusers import AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from PIL import Image
import numpy as np
from tqdm import tqdm
import argparse
import sys



def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

def generate_images_direct(unet_path="output/KahabMinGenT2Im-v1.pt", device="cuda", output_dir="output", prompt=None):
    """Generate 256x256 images with a custom UNet and user-specified text prompt"""
    seed_everything(42)
    print(f"Using device: {device}")

 
    print("Loading trained UNet...")
    unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)
    checkpoint = torch.load(unet_path, map_location=device, weights_only=True)
    unet.load_state_dict(checkpoint['model_state_dict'])
    unet = unet.to(device).eval()

    # Create scheduler
    scheduler = DDPMScheduler(num_train_timesteps=1000)

    # Get prompt from user if not provided
    if prompt is None:
        # Check if running in Jupyter
        if 'ipykernel' in sys.modules:
            prompt = input("Enter your text prompt (e.g., 'A friendly dragon'): ").strip()
        else:
            prompt = ""  # Will be handled by argparse default or user input
        if not prompt:
            prompt = "A friendly dragon"  # Default prompt if empty

    test_prompts = [prompt]

    print("🎨 Generating 256x256 images...")
    for i, prompt in enumerate(test_prompts):
        print(f"Generating: {prompt}")
        try:
            with torch.no_grad():
                # Encode prompt
                inputs = tokenizer(
                    prompt,
                    padding="max_length",
                    truncation=True,
                    max_length=77,
                    return_tensors="pt"
                )
                inputs = {k: v.to(device) for k, v in inputs.items()}
                text_embeddings = text_encoder(**inputs).last_hidden_state
                print(f"Text embeddings shape: {text_embeddings.shape}, device: {text_embeddings.device}")

                # Create random latents for 256x256 output (256/8 = 32 due to VAE scaling)
                latents = torch.randn(1, 4, 32, 32, device=device, dtype=torch.float32)
                print(f"Initial latents shape: {latents.shape}, device: {latents.device}")

                # Set timesteps
                scheduler.set_timesteps(500)

                # Denoising loop
                for t in tqdm(scheduler.timesteps, desc=f"Denoising {prompt}"):
                    t_tensor = torch.tensor([t], device=device, dtype=torch.long)
                    noise_pred = unet(latents, t_tensor, context=text_embeddings)
                    latents = scheduler.step(noise_pred, t, latents).prev_sample

                print(f"Final latents shape: {latents.shape}")

                # Decode latents to image
                latents = latents / 0.18215
                images = vae.decode(latents).sample
                images = (images / 2 + 0.5).clamp(0, 1)  # Denormalize
                images = images.cpu().permute(0, 2, 3, 1).numpy()
                image = Image.fromarray((images[0] * 255).astype(np.uint8))

                # Save
                filename = f"{output_dir}/generated_256_.png"
                image.save(filename)
                print(f"✅ Saved: {filename}")

        except Exception as e:
            print(f"❌ Error generating '{prompt}': {e}")
            print(f"Error type: {type(e).__name__}")
            continue

def main():
    # Check if running in Jupyter
    if 'ipykernel' in sys.modules:
        generate_images_direct(
            unet_path="output/KahabMinGenT2Im-v1.pt",
            device="cuda" if torch.cuda.is_available() else "cpu",
            output_dir="output",
            prompt=None
        )
    else:
        parser = argparse.ArgumentParser(description="Generate images with custom UNet and text prompt")
        parser.add_argument("--unet_path", type=str, default="output/KahabMinGenT2Im-v1.pt", help="Path to UNet checkpoint")
        parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use (cuda or cpu)")
        parser.add_argument("--output_dir", type=str, default="output", help="Output directory for generated images")
        parser.add_argument("--prompt", type=str, default=None, help="Text prompt for image generation")
        args = parser.parse_args()

        generate_images_direct(
            unet_path=args.unet_path,
            device=args.device,
            output_dir=args.output_dir,
            prompt=args.prompt
        )



In [6]:
generate_images_direct(prompt="a dog plyaing")

Using device: cuda
Loading trained UNet...


FileNotFoundError: [Errno 2] No such file or directory: 'output/KahabMinGenT2Im-v1.pt'