# Import Libraries

In [None]:
### Importing basic libraries
import os   
import sys
import torch
import glob
import numpy as np
import matplotlib.pyplot as plt
import healpy as hp
from tqdm.auto import tqdm
from packaging import version
import datetime
from collections import namedtuple
from functools import wraps, partial
from inspect import isfunction

In [None]:
### Importing torch libraries
import torch
import torch.nn as nn
from torch import einsum
import torch.nn.functional as F
import torch.utils.data as data
import pytorch_lightning as pl
from torchvision.transforms import Compose, ToTensor, Normalize
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger


In [None]:
### Importing needed libraries
from einops import rearrange
import pickle

# Prefix Parameters

In [None]:
pl.seed_everything(1234)

In [None]:
### training params
num_epochs = 300
batch_size = 2
learning_rate = 1e-5
learning_rate_decay = 0.99

### diffusion params
timesteps = 1000
beta_start = 0.0001
beta_end = 0.02

### data params
LR_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid32_CIC_noRSD/"
HR_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid128_CIC_noRSD/"
mid_dir = "/gpfs02/work/tanimura/ana/UNet/data/dens_magneticum_snap25_Box128_grid64_CIC_noRSD/"

n_maps=100
rate_train =0.8

### model parameters
save_dir = './tests/'

In [None]:
### model parameters
save_dir = '/gpfs02/work/akira.tokiwa/gpgpu/instruction/'
logger = TensorBoardLogger(save_dir=save_dir, name='diffusion')

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Diffusion Code

## Diffusion step
Code snippets ported from:
https://huggingface.co/blog/annotated-diffusion

In [None]:
### Functions for diffusion
def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def cast_tuple(t, length = 1):
    if isinstance(t, tuple):
        return t
    return ((t,) * length)

def identity(t, *args, **kwargs):
    return t

def cycle(dl):
    while True:
        for data in dl:
            yield data

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    shape = (batch_size, *((1,) * (len(x_shape) - 1)))
    return out.reshape(shape).to(t.device)

In [None]:
class Diffusion():
    def __init__(self, betas):
        #NOTE: you're choosing to make loss-type an argument of p_losses and not the diffusion model class
        #you should then save it separately when saving a run
        # calculations for diffusion q(x_t | x_0) and others
        self.betas = betas
        self.alphas = 1. - betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)  # alpha_bar
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
        # x_t = sqrt_alphas_cumprod* x_0 + sqrt_one_minus_alphas_cumprod * eps_t
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.timesteps = len(self.betas)

    # forward diffusion (using the nice property)
    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x_start.shape
        )
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, denoise_model, x_start, t, noise=None, loss_type="l1", labels=None):
        # L_CE <= L_VLB ~ Sum[eps_t - MODEL(x_t(x_0, eps_t), t) ]
        if noise is None:
            noise = torch.randn_like(x_start)

        x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = denoise_model(x_t, t, labels)

        if loss_type == 'l1':
            loss = F.l1_loss(noise, predicted_noise)
        elif loss_type == 'l2':
            loss = F.mse_loss(noise, predicted_noise)
        elif loss_type == "huber":
            loss = F.smooth_l1_loss(noise, predicted_noise)
        else:
            raise NotImplementedError()
        return loss

    @torch.no_grad()
    def timewise_loss(self, denoise_model, x_start, t, noise=None, loss_type="l1", labels=None):
        if noise is None:
            noise = torch.randn_like(x_start)

        x_t = self.q_sample(x_start=x_start, t=t, noise=noise)
        predicted_noise = denoise_model(x_t, t, labels)
        if loss_type == 'l1':
            loss = F.l1_loss(noise, predicted_noise, reduction='none')
        elif loss_type == 'l2':
            loss = F.mse_loss(noise, predicted_noise, reduction='none')
        elif loss_type == "huber":
            loss = F.smooth_l1_loss(noise, predicted_noise, reduction='none')
        else:
            raise NotImplementedError()
        loss = torch.mean(loss, dim=[-3, -2, -1]) #mean over all spatial dims
        return loss

    @torch.no_grad()
    def p_sample(self, model, x, t, t_index, label=None):
        betas_t = extract(self.betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(
            self.sqrt_one_minus_alphas_cumprod, t, x.shape
        )
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)

        # Equation 11 in the paper
        # Use our model (noise predictor) to predict the mean
        model_output = model(x, t) if label is None else model(x, t, label)
        model_mean = sqrt_recip_alphas_t * (
                x - betas_t * model_output / sqrt_one_minus_alphas_cumprod_t
        )

        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(self.posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            # Algorithm 2 line 4:
            return model_mean + torch.sqrt(posterior_variance_t) * noise

    @torch.no_grad()
    def p_sample_loop(self, model, shape, labels=None):
        device = next(model.parameters()).device
        print('sample device', device)
        b = shape[0]
        # start from pure noise (for each example in the batch)
        img = torch.randn(shape, device=device)
        imgs = []
        if labels is not None:
            assert labels.shape[0] == shape[0]

        for i in tqdm(reversed(range(0, self.timesteps)), desc='sampling loop time step', total=self.timesteps):
            img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i, labels)
            imgs.append(img.cpu().numpy())
        return imgs

    @torch.no_grad()
    def sample(self, model, image_size, batch_size=16, channels=1, labels=None):
        return self.p_sample_loop(model, shape=(batch_size, image_size, channels), labels=labels)

## Diffusion scheduler
based on https://github.com/openai/improved-diffusion/blob/783b6740edb79fdb7d063250db2c51cc9545dcd1/improved_diffusion/resample.py

In [None]:
### Functions for scheduler

def linear_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps, beta_start, beta_end):
    #beta_start = 0.0001
    #beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

def cosine_beta_schedule(timesteps, s=0.015):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0, 0.999) 

In [None]:
class TimestepSampler():
    def __init__(self, sampler_type='uniform', history=None, nstart=None, timesteps=None, uniweight=None, device='cuda'):
        self.type= sampler_type
        print('Sampler type', self.type)
        if self.type not in ['uniform', 'loss_aware']:
            raise NotImplementedError()
        if self.type=='loss_aware':
            self.sqloss_history = torch.ones((history, timesteps), device=device)*np.nan #L^2[b, t]
            self.nstart = nstart
            self.uniweight = 1/timesteps if uniweight is None else uniweight
            print('Nstart', nstart, type(nstart))
            print('History', history, type(history))
            print('Uniweight', self.uniweight, type(self.uniweight))

        self.timesteps = timesteps
        self.uniform = 1/timesteps
        self.device=device
        self.history_per_term = torch.zeros(timesteps, device=device, dtype=int)
        self.not_enough_history = True

    def get_weights(self, batch_size, iteration):
        if (iteration<self.nstart) or self.not_enough_history:
            return np.ones(self.timesteps)*self.uniform
        else:
            laweights = torch.sqrt(torch.mean(self.sqloss_history**2, dim=0))
            laweights /= laweights.sum()
            laweights *= (1-self.uniweight)
            laweights += self.uniweight/self.timesteps
            return laweights
        #fast way to evaluate / store the loss for different timesteps??
        #do you need a different sampler for storing history?

    def update_history(self, tl, loss_timewise):
        if self.not_enough_history:  # not-full loss history array
            for (t, tloss) in zip(tl, loss_timewise):
                if self.history_per_term[t] == self.sqloss_history.shape[0]:  # enough history
                    self.sqloss_history[:-1, t] = self.sqloss_history[1:, t]
                    self.sqloss_history[-1, t] = tloss
                else:
                    self.sqloss_history[self.history_per_term[t], t] = tloss
                    self.history_per_term[t] += 1
                    if self.history_per_term.min()==self.sqloss_history.shape[0]:
                        self.not_enough_history = False
                        print('Enough history for all')
        else:#enough history for all terms
            #test if this works fine
            self.sqloss_history[:-1, tl] = self.sqloss_history[1:, tl]
            self.sqloss_history[-1, tl] = loss_timewise
        return

    def get_timesteps(self, batch_size, iteration):
        if self.type=='uniform':
            return torch.randint(0, self.timesteps, (batch_size,), device=self.device).long()
        elif self.type=='loss_aware':
            weights = self.get_weights(batch_size, iteration)
            return torch.tensor(list(torch.utils.data.WeightedRandomSampler(weights, batch_size, replacement=True)), device=self.device).long()
        else:
            raise NotImplementedError()

# U-net Architecture 

In [None]:
def Upsample(dim, dim_out = None, kernel_size=3):
    return nn.Sequential(
        nn.Upsample(scale_factor = 2, mode='trilinear'),
        nn.Conv3d(dim, default(dim_out, dim), kernel_size, padding=1)
    )

"""
def Downsample(dim, dim_out = None, kernel_size=1):
    return nn.Sequential(
        partial(rearrange, pattern='b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w', p1 = 2, p2 = 2, p3 = 2),
        nn.Conv3d(dim * 4, default(dim_out, dim), kernel_size)
    )
"""

class Downsample(nn.Module):
    def __init__(self, dim, dim_out = None, kernel_size=1):
        super().__init__()
        self.conv = nn.Conv3d(dim*4, default(dim_out, dim), kernel_size)

    def forward(self, x):
        x = rearrange(x, 'b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w', p1 = 2, p2 = 2, p3 = 2)
        x = self.conv(x)
        return x
    

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.g = nn.Parameter(torch.ones(1, dim, 1, 1))

    def forward(self, x):
        return F.normalize(x, dim = 1) * self.g * (x.shape[1] ** 0.5)

In [None]:
AttentionConfig = namedtuple('AttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
def once(fn):
    called = False
    @wraps(fn)
    def inner(x):
        nonlocal called
        if called:
            return
        called = True
        return fn(x)
    return inner

print_once = once(print)
class Attend(nn.Module):
    def __init__(
        self,
        dropout = 0.,
        flash = False
    ):
        super().__init__()
        self.dropout = dropout
        self.attn_dropout = nn.Dropout(dropout)

        self.flash = flash
        assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

        # determine efficient attention configs for cuda and cpu

        self.cpu_config = AttentionConfig(True, True, True)
        self.cuda_config = None

        if not torch.cuda.is_available() or not flash:
            return

        device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

        if device_properties.major == 8 and device_properties.minor == 0:
            print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(True, False, False)
        else:
            print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
            self.cuda_config = AttentionConfig(False, True, True)

    def flash_attn(self, q, k, v):
        _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

        q, k, v = map(lambda t: t.contiguous(), (q, k, v))

        # Check if there is a compatible device for flash attention

        config = self.cuda_config if is_cuda else self.cpu_config

        # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

        with torch.backends.cuda.sdp_kernel(**config._asdict()):
            out = F.scaled_dot_product_attention(
                q, k, v,
                dropout_p = self.dropout if self.training else 0.
            )

        return out

    def forward(self, q, k, v):
        """
        einstein notation
        b - batch
        h - heads
        n, i, j - sequence length (base sequence length, source, target)
        d - feature dimension
        """

        q_len, k_len, device = q.shape[-2], k.shape[-2], q.device

        if self.flash:
            return self.flash_attn(q, k, v)

        scale = q.shape[-1] ** -0.5

        # similarity

        sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale

        # attention

        attn = sim.softmax(dim = -1)
        attn = self.attn_dropout(attn)

        # aggregate values

        out = einsum(f"b h i j, b h j d -> b h i d", attn, v)

        return out

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    #embeds time in the phase
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = np.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None].float() * embeddings[None, :] #t1: [40, 1], t2: [1, 32]. Works on cpu, not on mps
        #^ is matmul: torch.allclose(res, torch.matmul(t1.float(), t2)): True when cpu
        #NM: added float for mps
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings #Bx64

In [None]:
class Block(nn.Module):
    """
    Basic building block for the Unet architecture.
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, num_groups=8):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding)
        self.norm = nn.GroupNorm(num_groups, out_channels)
        self.act = nn.LeakyReLU(0.1) if out_channels > 1 else nn.Identity()

    def forward(self, x, scale_shift = None):
        x = self.conv(x)
        x = self.norm(x) 

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

In [None]:
class ResnetBlock(nn.Module):
    """
    Residual block composed of two basic blocks. https://arxiv.org/abs/1512.03385
    """
    def __init__(self, in_channels, out_channels, time_emb_dim, kernel_size=3, padding=1, groups = 8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_channels * 2))
            if exists(time_emb_dim)
            else None
        )
        self.block1 = Block(in_channels, out_channels, kernel_size, padding, groups)
        self.block2 = Block(out_channels, out_channels, kernel_size, padding, groups) 
        self.res_conv = nn.Conv3d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()


    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
            scale_shift = time_emb.chunk(2, dim = 1)

        h = self.block1(x, scale_shift = scale_shift)

        h = self.block2(h)

        return h + self.res_conv(x)

In [None]:
class LinearAttention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32
    ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)
        self.to_qkv = nn.Conv3d(dim, hidden_dim * 3, 1, bias = False)

        self.to_out = nn.Sequential(
            nn.Conv3d(hidden_dim, dim, 1),
            RMSNorm(dim)
        )

    def forward(self, x):
        b, c, d, h, w = x.shape

        #x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z -> b h c (x y z)', h = self.heads), qkv)

        q = q.softmax(dim = -2)
        k = k.softmax(dim = -1)

        q = q * self.scale

        context = torch.einsum('b h d n, b h e n -> b h d e', k, v)

        out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
        out = rearrange(out, 'b h c (x y z) -> b (h c) x y z', h = self.heads, x = d, y = h, z = w)
        return self.to_out(out)

In [None]:
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        heads = 4,
        dim_head = 32,
        flash = False
    ):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads

        self.norm = RMSNorm(dim)
        self.attend = Attend(flash = flash)

        self.to_qkv = nn.Conv3d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv3d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, d, h, w = x.shape

        #x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z -> b h (x y z) c', h = self.heads), qkv)

        out = self.attend(q, k, v)

        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = d, y = h, z = w, h = self.heads)
        return self.to_out(out)

In [None]:
class Unet(pl.LightningModule):
    """
    Full Unet architecture composed of an encoder (downsampler), a bottleneck, and a decoder (upsampler).
    """
    def __init__(
        self,
        dim,
        init_dim = None,
        out_dim = None,
        dim_mults = (1, 2, 4),
        channels = 3,
        self_condition = False,
        resnet_block_groups = 8,
        full_attn = (False, False, True),
        flash_attn = False
    ):
        super().__init__()

        # determine dimensions

        self.channels = channels
        self.self_condition = self_condition
        input_channels =  channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv3d(input_channels, init_dim, 7, padding = 3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups = resnet_block_groups)

        # time embeddings

        time_dim = dim * 4
        fourier_dim = dim

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(fourier_dim),
            nn.Linear(fourier_dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        # attention

        full_attn = cast_tuple(full_attn, length = len(dim_mults))
        assert len(full_attn) == len(dim_mults)

        FullAttention = partial(Attention, flash = flash_attn)

        # layers

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(in_out, full_attn)):
            is_last = ind >= (num_resolutions - 1)

            attn_klass = FullAttention if layer_full_attn else LinearAttention

            self.downs.append(nn.ModuleList([
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                block_klass(dim_in, dim_in, time_emb_dim = time_dim),
                attn_klass(dim_in),
                Downsample(dim_in, dim_out) if not is_last else nn.Conv3d(dim_in, dim_out, 3, padding = 1)
            ]))

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)
        self.mid_attn = FullAttention(mid_dim)
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim)

        for ind, ((dim_in, dim_out), layer_full_attn) in enumerate(zip(reversed(in_out), reversed(full_attn))):
            is_last = ind == (len(in_out) - 1)

            attn_klass = FullAttention if layer_full_attn else LinearAttention

            self.ups.append(nn.ModuleList([
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim),
                attn_klass(dim_out),
                Upsample(dim_out, dim_in) if not is_last else  nn.Conv3d(dim_out, dim_in, 3, padding = 1)
            ]))

        default_out_dim = channels * 2
        self.out_dim = default(out_dim, default_out_dim)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim)
        self.final_conv = nn.Conv3d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x, x_self_cond), dim = 1)
            
        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            print("x_block1", x.shape)
            h.append(x)

            x = block2(x, t)
            print("x_block2", x.shape)
            #tmp_attn = attn(x)
            #print("tmp_attn", tmp_attn.shape)
            #x = tmp_attn + x
            #print("x_attn", x.shape)
            h.append(x)

            x = downsample(x)
            print("x_downsample", x.shape)

        x = self.mid_block1(x, t)
        print("x_mid_block1", x.shape)
        #x = self.mid_attn(x) + x
        x = self.mid_block2(x, t)
        print("x_mid_block2", x.shape)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim = 1)
            print("x_cat", x.shape)
            x = block1(x, t)
            print("x_block1", x.shape)

            x = torch.cat((x, h.pop()), dim = 1)
            print("x_cat", x.shape)
            x = block2(x, t)
            print("x_block2", x.shape)
            #x = attn(x) + x

            x = upsample(x)
            print("x_upsample", x.shape)

        x = torch.cat((x, r), dim = 1)

        x = self.final_res_block(x, t)
        return self.final_conv(x)

In [None]:
class Unet_pl(pl.LightningModule):
    def __init__(self, 
                channels = 1,  
                dim = 64,
                init_dim = None,
                out_dim = None,
                batch_size = 64,
                learning_rate = 1e-5,
                learning_rate_decay = 0.99,
                num_epochs = 300,
                timesteps = 1000,
                beta_start = 0.0001,
                beta_end = 0.02,
                loss_type="huber", 
                sampler=None, 
                conditional=False):
        super().__init__()
        self.model = Unet(dim=dim, channels=dim, init_dim=init_dim, out_dim=out_dim, self_condition=conditional)
        self.batch_size = batch_size
        self.lr_init = learning_rate
        self.learning_rate_decay = learning_rate_decay
        self.num_epochs = num_epochs

        betas = linear_beta_schedule(timesteps, beta_start, beta_end)
        self.diffusion = Diffusion(betas)

        self.init_conv = nn.Conv3d(channels, dim, 7, padding = 3)
        self.init_conv_condition = nn.Sequential(
            Upsample(channels, dim),
            Upsample(dim, dim)
        )

        self.loss_type = loss_type
        self.sampler = sampler
        self.conditional = conditional
        self.loss_spike_flg = 0

    def training_step(self, batch, batch_idx):
        if self.conditional:
            hr, lr = batch 
            hr = self.init_conv(hr)
            lr = self.init_conv_condition(lr)
            x = hr -lr
            labels = lr
        else:
            x = batch

        t = self.sampler.get_timesteps(x.shape[0], self.current_epoch)
        loss = self.diffusion.p_losses(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
        self.log('train_loss', loss)

        if self.sampler.type == 'loss_aware':
            loss_timewise = self.diffusion.timewise_loss(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
            self.sampler.update_history(t, loss_timewise)

        if loss.item() > 0.1 and self.current_epoch > 300 and (self.loss_spike_flg < 2):
            badbdict = {'batch': batch.detach().cpu().numpy(), 'itn': self.current_epoch, 't': t.detach().cpu().numpy(), 'loss': loss.item()}
            pickle.dump(badbdict, open(f'largeloss_{self.current_epoch}.pkl', 'wb'))
            self.loss_spike_flg += 1
        return loss
    
    def validation_step(self, batch, batch_idx):
        if self.conditional:
            hr, lr = batch 
            lr = self.init_conv_condition(lr)
            x = hr - lr
            labels = lr
        else:
            x = batch

        t = self.sampler.get_timesteps(x.shape[0], self.current_epoch)
        loss = self.diffusion.p_losses(self.model, x, t, loss_type=self.loss_type, labels=labels if self.conditional else None)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr_init)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=self.learning_rate_decay)
        return {"optimizer": optimizer, "lr_scheduler": scheduler}  

# prepare data

## Functions

In [None]:
def get_minmax_transform(rangemin, rangemax):
    """
    Function to get a pair of transforms that normalize and denormalize tensors.
    """
    transform = Compose([lambda t: (t - rangemin) / (rangemax - rangemin) * 2 - 1])
    inverse_transform = Compose([lambda t: (t + 1) / 2 * (rangemax - rangemin) + rangemin])
    
    return transform, inverse_transform

In [None]:
class MapDataset(data.Dataset):
    """
    Class for the map dataset.
    """
    def __init__(self, mapdir, n_maps=100):
        self.maps = sorted(glob.glob(f'{mapdir}*.npy'))[:n_maps]

    def __getitem__(self, index):
        dmaps = np.array([np.load(map) for map in self.maps])  
        tensor_map = torch.from_numpy(dmaps).float()
        
        return tensor_map

## Data

In [24]:
data_hr = MapDataset(HR_dir, n_maps=10).__getitem__(0)
data_lr = MapDataset(LR_dir, n_maps=10).__getitem__(0)

In [25]:
RANGE_MIN, RANGE_MAX = data_hr.min().clone().detach(), data_hr.max().clone().detach()
transforms, inverse_transforms = get_minmax_transform(RANGE_MIN, RANGE_MAX)
combined_dataset = data.TensorDataset(transforms(data_hr).unsqueeze(1), transforms(data_lr).unsqueeze(1))

In [26]:
len_train = int(rate_train * len(data_lr))
len_val = len(data_lr) - len_train
train, val = data.random_split(combined_dataset, [len_train, len_val])
loaders = {x: data.DataLoader(ds, batch_size=batch_size, shuffle=x=='train', num_workers=os.cpu_count()) for x, ds in zip(('train', 'val'), (train, val))}

train_loader, val_loader= loaders['train'], loaders['val']

# Train Model

In [None]:
!pwd

In [27]:
### Functions for training
def setup_trainer(num_epochs, logger=None, patience=10):
    early_stop_callback = EarlyStopping(
        monitor="val_loss",
        patience=patience,
        verbose=0,
        mode="min"
    )

    dt = datetime.datetime.now()
    name = dt.strftime('Run_%m-%d_%H-%M')

    checkpoint_callback = ModelCheckpoint(
        filename= name + "{epoch:02d}-{val_loss:.2f}",
        save_top_k=1,
        monitor="val_loss",
        save_last=True,
        mode="min"
    )

    trainer = pl.Trainer(
        max_epochs=num_epochs,
        callbacks=[checkpoint_callback, early_stop_callback],
        num_sanity_val_steps=0,
        accelerator='gpu', 
        devices=1,
        logger=logger
    )
    return trainer

In [28]:
sampler = TimestepSampler(sampler_type='uniform', timesteps=timesteps, device=device)

Sampler type uniform


In [29]:
model = Unet_pl(dim = 64,
                init_dim = 64,
                out_dim = 1,
                batch_size = 2,
                learning_rate = 1e-5,
                learning_rate_decay = 0.99,
                num_epochs = 300,
                timesteps = 1000,
                beta_start = 0.0001,
                beta_end = 0.02,
                loss_type="huber", 
                sampler=sampler, 
                conditional=True).to(device)

In [32]:
model

Unet_pl(
  (model): Unet(
    (init_conv): Conv3d(128, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3))
    (time_mlp): Sequential(
      (0): SinusoidalPositionEmbeddings()
      (1): Linear(in_features=64, out_features=256, bias=True)
      (2): GELU(approximate='none')
      (3): Linear(in_features=256, out_features=256, bias=True)
    )
    (downs): ModuleList(
      (0): ModuleList(
        (0-1): 2 x ResnetBlock(
          (mlp): Sequential(
            (0): SiLU()
            (1): Linear(in_features=256, out_features=128, bias=True)
          )
          (block1): Block(
            (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): GroupNorm(8, 64, eps=1e-05, affine=True)
            (act): LeakyReLU(negative_slope=0.1)
          )
          (block2): Block(
            (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            (norm): GroupNorm(8, 64, eps=1e-05, affine=True)


In [31]:
trainer = setup_trainer(num_epochs, logger=logger)
trainer.fit(model, train_loader, val_loader)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]

  | Name                | Type       | Params
---------------------------------------------------
0 | model               | Unet       | 28.6 M
1 | init_conv           | Conv3d     | 22.0 K
2 | init_conv_condition | Sequential | 112 K 
---------------------------------------------------
28.8 M    Trainable params
0         Non-trainable params
28.8 M    Total params
115.045   Total estimated model params size (MB)
  rank_zero_warn(


Epoch 0:   0%|          | 0/5 [00:00<?, ?it/s] x_block1 torch.Size([2, 64, 128, 128, 128])
x_block2 torch.Size([2, 64, 128, 128, 128])


RuntimeError: Given groups=1, weight of size [64, 256, 1, 1, 1], expected input[2, 512, 64, 64, 64] to have 256 channels, but got 512 channels instead

# Load Model

In [None]:
map_dir = '/gpfs02/work/akira.tokiwa/gpgpu/instruction/maps/'
if not os.path.exists(map_dir):
    os.makedirs(map_dir)

In [None]:
ckpt_path = trainer.checkpoint_callback.best_model_path

In [None]:
model.load_state_dict(torch.load(ckpt_path)['state_dict'], strict=False)

_IncompatibleKeys(missing_keys=['model.init_conv.laplacian', 'model.init_conv_lr.laplacian', 'model.down_blocks.0.0.block1.conv.laplacian', 'model.down_blocks.0.0.block2.conv.laplacian', 'model.down_blocks.0.1.block1.conv.laplacian', 'model.down_blocks.0.1.block2.conv.laplacian', 'model.down_blocks.1.0.block1.conv.laplacian', 'model.down_blocks.1.0.block2.conv.laplacian', 'model.down_blocks.1.0.res_conv.laplacian', 'model.down_blocks.1.1.block1.conv.laplacian', 'model.down_blocks.1.1.block2.conv.laplacian', 'model.down_blocks.2.0.block1.conv.laplacian', 'model.down_blocks.2.0.block2.conv.laplacian', 'model.down_blocks.2.0.res_conv.laplacian', 'model.down_blocks.2.1.block1.conv.laplacian', 'model.down_blocks.2.1.block2.conv.laplacian', 'model.down_blocks.3.0.block1.conv.laplacian', 'model.down_blocks.3.0.block2.conv.laplacian', 'model.down_blocks.3.0.res_conv.laplacian', 'model.down_blocks.3.1.block1.conv.laplacian', 'model.down_blocks.3.1.block2.conv.laplacian', 'model.mid_block1.block

In [None]:
model = model.to(device)

In [None]:
i = 0
tmp_sample =combined_dataset.tensors[0][i].to(device)
tmp_lr = combined_dataset.tensors[1][i].to(device)
q_sample = model.diffusion.q_sample(tmp_sample, torch.full((1,), timesteps-1, device=device))
img = torch.randn(tmp_sample.shape, device=device)
for j in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
    t = torch.full((1,), j, device=device, dtype=torch.long)
    model.eval()
    with torch.inference_mode():
        loss = model.diffusion.p_losses(model.model, tmp_sample, t, tmp_lr)
        img = model.diffusion.p_sample(model.model, img, t, tmp_lr, j)
    print('Step {}, Loss {}'.format(j, loss), flush=True)
    if j % 10 == 0:
        diffmap = img.detach().cpu().numpy()
        np.save(map_dir+"diffused_step{}.npy".format(j), diffmap)
        print('Saved step {}'.format(j), flush=True)



sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s][A[A

Step 999, Loss 0.341176837682724



[A

Step 998, Loss 0.33622345328330994



[A

Step 997, Loss 0.33832472562789917



[A

Step 996, Loss 0.3400003910064697



[A

Step 995, Loss 0.3383110463619232



[A

Step 994, Loss 0.3426855802536011



[A

Step 993, Loss 0.34216225147247314



[A

Step 992, Loss 0.3374362587928772



[A

Step 991, Loss 0.34147679805755615



[A

Step 990, Loss 0.3398386240005493



[A

Step 989, Loss 0.33934730291366577



[A

Step 988, Loss 0.3400684595108032



[A

Step 987, Loss 0.3446684777736664



[A

Step 986, Loss 0.34072956442832947



[A

Step 985, Loss 0.3388824462890625



[A

Step 984, Loss 0.34096765518188477



[A

Step 983, Loss 0.3392423093318939



[A

Step 982, Loss 0.3387162387371063



[A

Step 981, Loss 0.34063833951950073



[A

Step 980, Loss 0.3432810306549072



[A

Step 979, Loss 0.33851179480552673



[A

Step 978, Loss 0.3386378288269043



[A

Step 977, Loss 0.3411961495876312



[A

Step 976, Loss 0.3417409658432007



[A

Step 975, Loss 0.3418332040309906



[A

Step 974, Loss 0.34300553798675537



[A

Step 973, Loss 0.33862385153770447



[A

Step 972, Loss 0.338980495929718



[A

Step 971, Loss 0.3384357690811157



[A

Step 970, Loss 0.3433971107006073



[A

Step 969, Loss 0.342424213886261



[A

Step 968, Loss 0.34020060300827026



[A

Step 967, Loss 0.3384147584438324



[A

Step 966, Loss 0.3361092507839203



[A

Step 965, Loss 0.3396255373954773



[A

Step 964, Loss 0.34149226546287537



[A

Step 963, Loss 0.3404814600944519



[A

Step 962, Loss 0.3462550640106201



[A

Step 961, Loss 0.3471677005290985



[A

Step 960, Loss 0.3461998403072357



[A

Step 959, Loss 0.34591126441955566



[A

Step 958, Loss 0.34132057428359985



[A

Step 957, Loss 0.3411778211593628



[A

Step 956, Loss 0.3387313783168793



[A

Step 955, Loss 0.33856701850891113



[A

Step 954, Loss 0.33774709701538086



[A

Step 953, Loss 0.3360164761543274



[A

Step 952, Loss 0.3410879075527191



[A

Step 951, Loss 0.3436765670776367



[A

Step 950, Loss 0.34556031227111816



[A

Step 949, Loss 0.3396762013435364



[A

Step 948, Loss 0.3379553258419037



[A

Step 947, Loss 0.3404843807220459



[A

Step 946, Loss 0.3389289379119873



[A

Step 945, Loss 0.3396189212799072



[A

Step 944, Loss 0.3406997621059418



[A

Step 943, Loss 0.3403078019618988



[A

Step 942, Loss 0.3346518874168396



[A

Step 941, Loss 0.3406488597393036



[A

Step 940, Loss 0.3402664065361023



[A

Step 939, Loss 0.3430248200893402



[A

Step 938, Loss 0.3407890796661377



[A

Step 937, Loss 0.3482699990272522



[A

Step 936, Loss 0.3491508960723877



[A

Step 935, Loss 0.3527534604072571



[A

Step 934, Loss 0.34452497959136963



[A

Step 933, Loss 0.33697009086608887



[A

Step 932, Loss 0.33765971660614014



[A

Step 931, Loss 0.3402830958366394



[A

Step 930, Loss 0.34370145201683044



[A

Step 929, Loss 0.34100717306137085



[A

Step 928, Loss 0.34054097533226013



[A

Step 927, Loss 0.3396420478820801



[A

Step 926, Loss 0.3391225337982178



[A

Step 925, Loss 0.33968544006347656



[A

Step 924, Loss 0.3419671952724457



[A

Step 923, Loss 0.3476843237876892



[A

Step 922, Loss 0.3493131995201111



[A

Step 921, Loss 0.3496004343032837



[A

Step 920, Loss 0.3454461097717285



[A

Step 919, Loss 0.3439213037490845



[A

Step 918, Loss 0.34073710441589355



[A

Step 917, Loss 0.33633753657341003



[A

Step 916, Loss 0.33505257964134216



[A

Step 915, Loss 0.33627891540527344



[A

Step 914, Loss 0.3374924063682556



[A

Step 913, Loss 0.34092292189598083



[A

Step 912, Loss 0.33878397941589355



[A

Step 911, Loss 0.3393421769142151



[A

Step 910, Loss 0.33812254667282104



[A

Step 909, Loss 0.3370034694671631



[A

Step 908, Loss 0.3380507826805115



[A

Step 907, Loss 0.3363484740257263



[A

Step 906, Loss 0.3388884961605072



[A

Step 905, Loss 0.3450305461883545



[A

Step 904, Loss 0.3412679135799408



[A

Step 903, Loss 0.34071168303489685



[A

Step 902, Loss 0.3395170569419861



[A

Step 901, Loss 0.34002935886383057



[A

Step 900, Loss 0.3388676643371582



[A

Step 899, Loss 0.339433491230011



[A

Step 898, Loss 0.33822101354599



[A

Step 897, Loss 0.33663225173950195



[A

Step 896, Loss 0.3365822732448578



[A

Step 895, Loss 0.3362085521221161



[A

Step 894, Loss 0.3390604853630066



[A

Step 893, Loss 0.3423558473587036



[A

Step 892, Loss 0.33855676651000977



[A

Step 891, Loss 0.337406724691391



[A

Step 890, Loss 0.33731478452682495



[A

Step 889, Loss 0.3372548818588257



[A

Step 888, Loss 0.34385859966278076



[A

Step 887, Loss 0.3377331495285034



[A

Step 886, Loss 0.34204035997390747



[A

Step 885, Loss 0.3380011022090912



[A

Step 884, Loss 0.3356381058692932



[A

Step 883, Loss 0.33567190170288086



[A

Step 882, Loss 0.3361177146434784



[A

Step 881, Loss 0.3391723036766052



[A

Step 880, Loss 0.3435870409011841



[A

Step 879, Loss 0.33768290281295776



[A

Step 878, Loss 0.3409147262573242



[A

Step 877, Loss 0.3378428816795349



[A

Step 876, Loss 0.335845410823822



[A

Step 875, Loss 0.33814844489097595



[A

Step 874, Loss 0.3344738483428955



[A

Step 873, Loss 0.337877482175827



[A

Step 872, Loss 0.3387138843536377



[A

Step 871, Loss 0.3381115794181824



[A

Step 870, Loss 0.33780455589294434



[A

Step 869, Loss 0.3380551338195801



[A

Step 868, Loss 0.34178489446640015



[A

Step 867, Loss 0.3395804166793823



[A

Step 866, Loss 0.336057186126709



[A

Step 865, Loss 0.33697545528411865



[A

Step 864, Loss 0.34075912833213806



[A

Step 863, Loss 0.3352383077144623



[A

Step 862, Loss 0.33893272280693054



[A

Step 861, Loss 0.34625524282455444



[A

Step 860, Loss 0.3459366261959076



[A

Step 859, Loss 0.3446551263332367



[A

Step 858, Loss 0.34497612714767456



[A

Step 857, Loss 0.3421752452850342



[A

Step 856, Loss 0.3407459259033203



[A

Step 855, Loss 0.33676955103874207



[A

Step 854, Loss 0.33641165494918823



[A

Step 853, Loss 0.33517593145370483



[A

Step 852, Loss 0.33433806896209717



[A

Step 851, Loss 0.33514994382858276



[A

Step 850, Loss 0.33539408445358276



[A

Step 849, Loss 0.33578357100486755



[A

Step 848, Loss 0.3334846496582031



[A

Step 847, Loss 0.33441147208213806



[A

Step 846, Loss 0.3335418701171875



[A

Step 845, Loss 0.3354220390319824



[A

Step 844, Loss 0.34046047925949097



[A

Step 843, Loss 0.339881032705307



[A

Step 842, Loss 0.3373025357723236



[A

Step 841, Loss 0.33646059036254883



[A

Step 840, Loss 0.3374571204185486



[A

Step 839, Loss 0.341079443693161



[A

Step 838, Loss 0.3392149806022644



[A

Step 837, Loss 0.3403313159942627



[A

Step 836, Loss 0.33804023265838623



[A

Step 835, Loss 0.3367322087287903



[A

Step 834, Loss 0.33811816573143005



[A

Step 833, Loss 0.3379662036895752



[A

Step 832, Loss 0.33879756927490234



[A

Step 831, Loss 0.33863377571105957



[A

Step 830, Loss 0.3364807069301605



[A

Step 829, Loss 0.33678120374679565



[A

Step 828, Loss 0.3367980122566223



[A

Step 827, Loss 0.3370339870452881



[A

Step 826, Loss 0.336214542388916



[A

Step 825, Loss 0.33572185039520264



[A

Step 824, Loss 0.33633124828338623



[A

Step 823, Loss 0.3364127576351166



[A

Step 822, Loss 0.3379003405570984



[A

Step 821, Loss 0.33751600980758667



[A

Step 820, Loss 0.33801981806755066



[A

Step 819, Loss 0.3334479331970215



[A

Step 818, Loss 0.34034380316734314



[A

Step 817, Loss 0.3402749001979828



[A

Step 816, Loss 0.3392210006713867



[A

Step 815, Loss 0.34085482358932495



[A

Step 814, Loss 0.3413517475128174



[A

Step 813, Loss 0.33828848600387573



[A

Step 812, Loss 0.3349343538284302



[A

Step 811, Loss 0.3360392451286316



[A

Step 810, Loss 0.3353254795074463



[A

Step 809, Loss 0.3360246419906616



[A

Step 808, Loss 0.33440670371055603



[A

Step 807, Loss 0.33838605880737305



[A

Step 806, Loss 0.3399772644042969



[A

Step 805, Loss 0.3352971076965332



[A

Step 804, Loss 0.33756789565086365



[A

Step 803, Loss 0.33623838424682617



[A

Step 802, Loss 0.33747124671936035



[A

Step 801, Loss 0.3343507647514343



[A

Step 800, Loss 0.33514404296875



[A

Step 799, Loss 0.33383429050445557



[A

Step 798, Loss 0.33674880862236023



[A

Step 797, Loss 0.33924493193626404



[A

Step 796, Loss 0.33858275413513184



[A

Step 795, Loss 0.34372419118881226



[A

Step 794, Loss 0.332633376121521



[A

Step 793, Loss 0.3335488736629486



[A

Step 792, Loss 0.33333173394203186



[A

Step 791, Loss 0.33234232664108276



[A

Step 790, Loss 0.33505138754844666



[A

Step 789, Loss 0.3349991738796234



[A

Step 788, Loss 0.33439335227012634



[A

Step 787, Loss 0.3332110047340393



[A

Step 786, Loss 0.3333038091659546



[A

Step 785, Loss 0.33477306365966797



[A

Step 784, Loss 0.33177971839904785



[A

Step 783, Loss 0.3362722396850586



[A

Step 782, Loss 0.3356945514678955



[A

Step 781, Loss 0.3358950614929199



[A

Step 780, Loss 0.3323124051094055



[A

Step 779, Loss 0.3313346207141876



[A

Step 778, Loss 0.3317425847053528



[A

Step 777, Loss 0.33251819014549255



[A

Step 776, Loss 0.3334903419017792



[A

Step 775, Loss 0.3338756561279297



[A

Step 774, Loss 0.337207168340683



[A

Step 773, Loss 0.33331820368766785



[A

Step 772, Loss 0.33441227674484253



[A

Step 771, Loss 0.33484625816345215



[A

Step 770, Loss 0.3364109396934509



[A

Step 769, Loss 0.3321871757507324



[A

Step 768, Loss 0.33478015661239624



[A

Step 767, Loss 0.3355928659439087



[A

Step 766, Loss 0.33529338240623474



[A

Step 765, Loss 0.33365321159362793



[A

Step 764, Loss 0.33413296937942505



[A

Step 763, Loss 0.3348512351512909



[A

Step 762, Loss 0.33888256549835205



[A

Step 761, Loss 0.3380705714225769



[A

Step 760, Loss 0.33890223503112793



[A

Step 759, Loss 0.33765825629234314



[A

Step 758, Loss 0.3356688618659973



[A

Step 757, Loss 0.33484411239624023



[A

Step 756, Loss 0.3344026505947113



[A

Step 755, Loss 0.33745336532592773



[A

Step 754, Loss 0.34220609068870544



[A

Step 753, Loss 0.34464138746261597



[A

Step 752, Loss 0.35065650939941406



[A

Step 751, Loss 0.34714555740356445



[A

Step 750, Loss 0.3460221290588379



[A

Step 749, Loss 0.33593565225601196



[A

Step 748, Loss 0.33344554901123047



[A

Step 747, Loss 0.33705782890319824



[A

Step 746, Loss 0.3351998031139374



[A

Step 745, Loss 0.3369563817977905



[A

Step 744, Loss 0.3333759903907776



[A

Step 743, Loss 0.33375489711761475



[A

Step 742, Loss 0.3342721462249756



[A

Step 741, Loss 0.3345921039581299



[A

Step 740, Loss 0.33575379848480225



[A

Step 739, Loss 0.3335968852043152



[A

Step 738, Loss 0.33645230531692505



[A

Step 737, Loss 0.3352808952331543



[A

Step 736, Loss 0.33388084173202515



[A

Step 735, Loss 0.3338603377342224



[A

Step 734, Loss 0.33576077222824097



[A

Step 733, Loss 0.338758647441864



[A

Step 732, Loss 0.33359643816947937



[A

Step 731, Loss 0.3347809314727783



[A

Step 730, Loss 0.3328700065612793



[A

Step 729, Loss 0.3354741930961609



[A

Step 728, Loss 0.3341636061668396



[A

Step 727, Loss 0.33819088339805603



[A

Step 726, Loss 0.33537954092025757



[A

Step 725, Loss 0.3348996639251709



[A

Step 724, Loss 0.33331459760665894



[A

Step 723, Loss 0.33624300360679626



[A

Step 722, Loss 0.33572137355804443



[A

Step 721, Loss 0.33276480436325073



[A

Step 720, Loss 0.3377866744995117



[A

Step 719, Loss 0.33624088764190674



[A

Step 718, Loss 0.3348715901374817



[A

Step 717, Loss 0.3349688947200775



[A

Step 716, Loss 0.33413586020469666



[A

Step 715, Loss 0.33634087443351746



[A

Step 714, Loss 0.33830317854881287



[A

Step 713, Loss 0.3349853456020355



[A

Step 712, Loss 0.3354746103286743



[A

Step 711, Loss 0.3350636959075928



[A

Step 710, Loss 0.33531898260116577



[A

Step 709, Loss 0.33537325263023376



[A

Step 708, Loss 0.3380258083343506



[A

Step 707, Loss 0.33440348505973816



[A

Step 706, Loss 0.3350193500518799



[A

Step 705, Loss 0.33548030257225037



[A

Step 704, Loss 0.3339807987213135



[A

Step 703, Loss 0.3381747603416443



[A

Step 702, Loss 0.33581265807151794



[A

Step 701, Loss 0.33431100845336914



[A

Step 700, Loss 0.3361392915248871



[A

Step 699, Loss 0.3395026624202728



[A

Step 698, Loss 0.3452836275100708



[A

Step 697, Loss 0.33740782737731934



[A

Step 696, Loss 0.33654269576072693



[A

Step 695, Loss 0.33832883834838867



[A

Step 694, Loss 0.34123682975769043



[A

Step 693, Loss 0.34032031893730164



[A

Step 692, Loss 0.33514827489852905



[A

Step 691, Loss 0.33443817496299744



[A

Step 690, Loss 0.333039253950119



[A

Step 689, Loss 0.3375348448753357



[A

Step 688, Loss 0.3363809287548065



[A

Step 687, Loss 0.33865150809288025



[A

Step 686, Loss 0.34110674262046814



[A

Step 685, Loss 0.34315335750579834



[A

Step 684, Loss 0.3440260589122772



[A

Step 683, Loss 0.3401539921760559



[A

Step 682, Loss 0.3360365033149719



[A

Step 681, Loss 0.33776193857192993



[A

Step 680, Loss 0.33929765224456787



[A

Step 679, Loss 0.3395094871520996



[A

Step 678, Loss 0.3455495834350586



[A

Step 677, Loss 0.3487134575843811



[A

Step 676, Loss 0.34495383501052856



[A

Step 675, Loss 0.339169979095459



[A

Step 674, Loss 0.33897802233695984



[A

Step 673, Loss 0.3345661163330078



[A

Step 672, Loss 0.33644363284111023



[A

Step 671, Loss 0.33134588599205017



[A

Step 670, Loss 0.33605870604515076



[A

Step 669, Loss 0.33545932173728943



[A

Step 668, Loss 0.33641475439071655



[A

Step 667, Loss 0.3407156467437744



[A

Step 666, Loss 0.3380245566368103



[A

Step 665, Loss 0.3344898223876953



[A

Step 664, Loss 0.3339179456233978



[A

Step 663, Loss 0.3338772654533386



[A

Step 662, Loss 0.333406537771225



[A

Step 661, Loss 0.33751219511032104



[A

Step 660, Loss 0.3397075831890106



[A

Step 659, Loss 0.33524996042251587



[A

Step 658, Loss 0.3357071280479431



[A

Step 657, Loss 0.33494263887405396



[A

Step 656, Loss 0.33749669790267944



[A

Step 655, Loss 0.3397529125213623



[A

Step 654, Loss 0.3400517404079437



[A

Step 653, Loss 0.3338031768798828



[A

Step 652, Loss 0.3321954905986786



[A

Step 651, Loss 0.3356492519378662



[A

Step 650, Loss 0.3348531126976013



[A

Step 649, Loss 0.3383382558822632



[A

Step 648, Loss 0.3342732787132263



[A

Step 647, Loss 0.33215686678886414



[A

Step 646, Loss 0.3341992497444153



[A

Step 645, Loss 0.33632808923721313



[A

Step 644, Loss 0.3354756534099579



[A

Step 643, Loss 0.33751147985458374



[A

Step 642, Loss 0.3444758355617523



[A

Step 641, Loss 0.34064924716949463



[A

Step 640, Loss 0.3366036117076874



[A

Step 639, Loss 0.3346540927886963



[A

Step 638, Loss 0.3339359760284424



[A

Step 637, Loss 0.33362090587615967



[A

Step 636, Loss 0.3396528363227844



[A

Step 635, Loss 0.33559805154800415



[A

Step 634, Loss 0.3369120955467224



[A

Step 633, Loss 0.3364107012748718



[A

Step 632, Loss 0.3355782926082611



[A

Step 631, Loss 0.34091028571128845



[A

Step 630, Loss 0.34033870697021484



[A

Step 629, Loss 0.3354896008968353



[A

Step 628, Loss 0.33565449714660645



[A

Step 627, Loss 0.33644723892211914



[A

Step 626, Loss 0.33706969022750854



[A

Step 625, Loss 0.3369307518005371



[A

Step 624, Loss 0.3382767140865326



[A

Step 623, Loss 0.3422735631465912



[A

Step 622, Loss 0.3364444673061371



[A

Step 621, Loss 0.3342389464378357



[A

Step 620, Loss 0.33527088165283203



[A

Step 619, Loss 0.33589452505111694



[A

Step 618, Loss 0.3355110287666321



[A

Step 617, Loss 0.3379163444042206



[A

Step 616, Loss 0.3395838141441345



[A

Step 615, Loss 0.33805108070373535



[A

Step 614, Loss 0.3369009494781494



[A

Step 613, Loss 0.33643922209739685



[A

Step 612, Loss 0.33772867918014526



[A

Step 611, Loss 0.3337039351463318



[A

Step 610, Loss 0.3311755657196045



[A

Step 609, Loss 0.33603712916374207



[A

Step 608, Loss 0.3383285701274872



[A

Step 607, Loss 0.3376418948173523



[A

Step 606, Loss 0.33413803577423096



[A

Step 605, Loss 0.33450016379356384



[A

Step 604, Loss 0.3344636857509613



[A

Step 603, Loss 0.33510148525238037



[A

Step 602, Loss 0.3338049650192261



[A

Step 601, Loss 0.33947262167930603



[A

Step 600, Loss 0.3337740898132324



[A

Step 599, Loss 0.3400443196296692



[A

Step 598, Loss 0.3421010375022888



[A

Step 597, Loss 0.33965420722961426



[A

Step 596, Loss 0.3336929678916931



[A

Step 595, Loss 0.33502936363220215



[A

Step 594, Loss 0.3332076072692871



[A

Step 593, Loss 0.3340836763381958



[A

Step 592, Loss 0.33190494775772095



[A

Step 591, Loss 0.33295920491218567



[A

Step 590, Loss 0.3338705003261566



[A

Step 589, Loss 0.3316090703010559



[A

Step 588, Loss 0.33422374725341797



[A

Step 587, Loss 0.33608531951904297



[A

Step 586, Loss 0.3362273573875427



[A

Step 585, Loss 0.33035650849342346



[A

Step 584, Loss 0.3318663239479065



[A

Step 583, Loss 0.3313625752925873



[A

Step 582, Loss 0.3336865305900574



[A

Step 581, Loss 0.3349255919456482



[A

Step 580, Loss 0.33094996213912964



[A

Step 579, Loss 0.3364056944847107



[A

Step 578, Loss 0.3320040702819824



[A

Step 577, Loss 0.33116409182548523



[A

Step 576, Loss 0.33363184332847595



[A

Step 575, Loss 0.332075834274292



[A

Step 574, Loss 0.33443722128868103



[A

Step 573, Loss 0.33484071493148804



[A

Step 572, Loss 0.3342137336730957



[A

Step 571, Loss 0.3340397775173187



[A

Step 570, Loss 0.33443665504455566



[A

Step 569, Loss 0.3356831669807434



[A

Step 568, Loss 0.33721303939819336



[A

Step 567, Loss 0.3343338966369629



[A

Step 566, Loss 0.3293558955192566



[A

Step 565, Loss 0.3355075716972351



[A

Step 564, Loss 0.33517563343048096



[A

Step 563, Loss 0.3329065442085266



[A

Step 562, Loss 0.3317659795284271



[A

Step 561, Loss 0.33295565843582153



[A

Step 560, Loss 0.32984405755996704



[A

Step 559, Loss 0.3313829302787781



[A

Step 558, Loss 0.33044958114624023



[A

Step 557, Loss 0.33768269419670105



[A

Step 556, Loss 0.3403933048248291



[A

Step 555, Loss 0.3431636691093445



[A

Step 554, Loss 0.34614962339401245



[A

Step 553, Loss 0.3406497836112976



[A

Step 552, Loss 0.3386705815792084



[A

Step 551, Loss 0.3385505676269531



[A

Step 550, Loss 0.33741068840026855



[A

Step 549, Loss 0.33772191405296326



[A

Step 548, Loss 0.3361828923225403



[A

Step 547, Loss 0.3337903916835785



[A

Step 546, Loss 0.33457398414611816



[A

Step 545, Loss 0.3362330198287964



[A

Step 544, Loss 0.3358094394207001



[A

Step 543, Loss 0.3334617018699646



[A

Step 542, Loss 0.3346775472164154



[A

Step 541, Loss 0.33710917830467224



[A

Step 540, Loss 0.3294790983200073



[A

Step 539, Loss 0.33057332038879395



[A

Step 538, Loss 0.3317612409591675



[A

Step 537, Loss 0.3335246741771698



[A

Step 536, Loss 0.3333459794521332



[A

Step 535, Loss 0.3335551619529724



[A

Step 534, Loss 0.3390898108482361



[A

Step 533, Loss 0.3334153890609741



[A

Step 532, Loss 0.33726435899734497



[A

Step 531, Loss 0.3337087035179138



[A

Step 530, Loss 0.33055853843688965



[A

Step 529, Loss 0.3306553065776825



[A

Step 528, Loss 0.3283814489841461



[A

Step 527, Loss 0.32995492219924927



[A

Step 526, Loss 0.3287634551525116



[A

Step 525, Loss 0.3312620520591736



[A

Step 524, Loss 0.32875245809555054



[A

Step 523, Loss 0.3292257487773895



[A

Step 522, Loss 0.3311108946800232



[A

Step 521, Loss 0.3332168757915497



[A

Step 520, Loss 0.3302452266216278



[A

Step 519, Loss 0.33118170499801636



[A

Step 518, Loss 0.3308018147945404



[A

Step 517, Loss 0.3316985070705414



[A

Step 516, Loss 0.33317169547080994



[A

Step 515, Loss 0.33091312646865845



[A

Step 514, Loss 0.3319382667541504



[A

Step 513, Loss 0.33309781551361084



[A

Step 512, Loss 0.3333108127117157



[A

Step 511, Loss 0.3318144381046295



[A

Step 510, Loss 0.332832396030426



[A

Step 509, Loss 0.3382225036621094



[A

Step 508, Loss 0.33804044127464294



[A

Step 507, Loss 0.3346274793148041



[A

Step 506, Loss 0.33685392141342163



[A

Step 505, Loss 0.3323262631893158



[A

Step 504, Loss 0.3293450176715851



[A

Step 503, Loss 0.3326628506183624



[A

Step 502, Loss 0.3326198160648346



[A

Step 501, Loss 0.33717411756515503



[A

Step 500, Loss 0.333661288022995



[A

Step 499, Loss 0.3338153064250946



[A

Step 498, Loss 0.32814717292785645



[A

Step 497, Loss 0.3315688371658325



[A

Step 496, Loss 0.3333587646484375



[A

Step 495, Loss 0.334867000579834



[A

Step 494, Loss 0.3404940068721771



[A

Step 493, Loss 0.34423351287841797



[A

Step 492, Loss 0.3414914608001709



[A

Step 491, Loss 0.3412584662437439



[A

Step 490, Loss 0.34374141693115234



[A

Step 489, Loss 0.34252190589904785



[A

Step 488, Loss 0.33896422386169434



[A

Step 487, Loss 0.33795151114463806



[A

Step 486, Loss 0.33334916830062866



[A

Step 485, Loss 0.331037700176239



[A

Step 484, Loss 0.32832056283950806



[A

Step 483, Loss 0.33096206188201904



[A

Step 482, Loss 0.3323609232902527



[A

Step 481, Loss 0.32691776752471924



[A

Step 480, Loss 0.327556312084198



[A

Step 479, Loss 0.3309246301651001



[A

Step 478, Loss 0.3301544785499573



[A

Step 477, Loss 0.33182093501091003



[A

Step 476, Loss 0.33557552099227905



[A

Step 475, Loss 0.33336424827575684



[A

Step 474, Loss 0.3343530297279358



[A

Step 473, Loss 0.3312820792198181



[A

Step 472, Loss 0.3277028203010559



[A

Step 471, Loss 0.32922083139419556



[A

Step 470, Loss 0.33127689361572266



[A

Step 469, Loss 0.3319932222366333



[A

Step 468, Loss 0.33354780077934265



[A

Step 467, Loss 0.3312177360057831



[A

Step 466, Loss 0.3313240706920624



[A

Step 465, Loss 0.3308277130126953



[A

Step 464, Loss 0.3300251364707947



[A

Step 463, Loss 0.33274418115615845



[A

Step 462, Loss 0.3282617926597595



[A

Step 461, Loss 0.33083048462867737



[A

Step 460, Loss 0.331047385931015



[A

Step 459, Loss 0.33224746584892273



[A

Step 458, Loss 0.3344336450099945



[A

Step 457, Loss 0.3311740458011627



[A

Step 456, Loss 0.33237048983573914



[A

Step 455, Loss 0.32958292961120605



[A

Step 454, Loss 0.3278998136520386



[A

Step 453, Loss 0.33056995272636414



[A

Step 452, Loss 0.32817184925079346



[A

Step 451, Loss 0.3300541043281555



[A

Step 450, Loss 0.3330804705619812



[A

Step 449, Loss 0.330452024936676



[A

Step 448, Loss 0.3315065801143646



[A

Step 447, Loss 0.3283970057964325



[A

Step 446, Loss 0.3318241536617279



[A

Step 445, Loss 0.33115506172180176



[A

Step 444, Loss 0.33525794744491577



[A

Step 443, Loss 0.3324061334133148



[A

Step 442, Loss 0.32870107889175415



[A

Step 441, Loss 0.32927650213241577



[A

Step 440, Loss 0.3320741355419159



[A

Step 439, Loss 0.3283154368400574



[A

Step 438, Loss 0.33051586151123047



[A

Step 437, Loss 0.3307558000087738



[A

Step 436, Loss 0.3282122015953064



[A

Step 435, Loss 0.3299342393875122



[A

Step 434, Loss 0.32790055871009827



[A

Step 433, Loss 0.3279605209827423



[A

Step 432, Loss 0.327930212020874



[A

Step 431, Loss 0.328918993473053



[A

Step 430, Loss 0.3305087685585022



[A

Step 429, Loss 0.33203983306884766



[A

Step 428, Loss 0.3304564654827118



[A

Step 427, Loss 0.3286920487880707



[A

Step 426, Loss 0.3272891342639923



[A

Step 425, Loss 0.32746565341949463



[A

Step 424, Loss 0.3315585255622864



[A

Step 423, Loss 0.33312058448791504



[A

Step 422, Loss 0.336760938167572



[A

Step 421, Loss 0.3311271369457245



[A

Step 420, Loss 0.3285459876060486



[A

Step 419, Loss 0.3280845582485199



[A

Step 418, Loss 0.3282880485057831



[A

Step 417, Loss 0.3297218382358551



[A

Step 416, Loss 0.32850539684295654



[A

Step 415, Loss 0.3319805860519409



[A

Step 414, Loss 0.3285368084907532



[A

Step 413, Loss 0.3314141035079956



[A

Step 412, Loss 0.3303731083869934



[A

Step 411, Loss 0.33225929737091064



[A

Step 410, Loss 0.3272860050201416



[A

Step 409, Loss 0.32549357414245605



[A

Step 408, Loss 0.32686662673950195



[A

Step 407, Loss 0.3265332877635956



[A

Step 406, Loss 0.33074840903282166



[A

Step 405, Loss 0.3283050060272217



[A

Step 404, Loss 0.3278920650482178



[A

Step 403, Loss 0.32734447717666626



[A

Step 402, Loss 0.3265657424926758



[A

Step 401, Loss 0.32677698135375977



[A

Step 400, Loss 0.32830581068992615



[A

Step 399, Loss 0.32827651500701904



[A

Step 398, Loss 0.32966944575309753



[A

Step 397, Loss 0.32969069480895996



[A

Step 396, Loss 0.3304564654827118



[A

Step 395, Loss 0.32726824283599854



[A

Step 394, Loss 0.3296462297439575



[A

Step 393, Loss 0.3290029466152191



[A

Step 392, Loss 0.3283516764640808



[A

Step 391, Loss 0.32836371660232544



[A

Step 390, Loss 0.3294311761856079



[A

Step 389, Loss 0.33298251032829285



[A

Step 388, Loss 0.3280172646045685



[A

Step 387, Loss 0.33001309633255005



[A

Step 386, Loss 0.32822686433792114



[A

Step 385, Loss 0.32599350810050964



[A

Step 384, Loss 0.3261011838912964



[A

Step 383, Loss 0.32688242197036743



[A

Step 382, Loss 0.33150139451026917



[A

Step 381, Loss 0.33008038997650146



[A

Step 380, Loss 0.32682865858078003



[A

Step 379, Loss 0.32755047082901



[A

Step 378, Loss 0.3304448127746582



[A

Step 377, Loss 0.32802248001098633



[A

Step 376, Loss 0.3274053931236267



[A

Step 375, Loss 0.33395954966545105



[A

Step 374, Loss 0.33555182814598083



[A

Step 373, Loss 0.32734933495521545



[A

Step 372, Loss 0.33158132433891296



[A

Step 371, Loss 0.33140015602111816



[A

Step 370, Loss 0.3271428942680359



[A

Step 369, Loss 0.3282403349876404



[A

Step 368, Loss 0.331589937210083



[A

Step 367, Loss 0.3285702168941498



[A

Step 366, Loss 0.32747477293014526



[A

Step 365, Loss 0.32912755012512207



[A

Step 364, Loss 0.32913923263549805



[A

Step 363, Loss 0.33133310079574585



[A

Step 362, Loss 0.3319506049156189



[A

Step 361, Loss 0.33076703548431396



[A

Step 360, Loss 0.3350111246109009



[A

Step 359, Loss 0.3311135768890381



[A

Step 358, Loss 0.32908034324645996



[A

Step 357, Loss 0.3313828408718109



[A

Step 356, Loss 0.33305251598358154



[A

Step 355, Loss 0.3381170928478241



[A

Step 354, Loss 0.34454768896102905



[A

Step 353, Loss 0.34108442068099976



[A

Step 352, Loss 0.3403007984161377



[A

Step 351, Loss 0.3354158401489258



[A

Step 350, Loss 0.33266526460647583



[A

Step 349, Loss 0.3364734351634979



[A

Step 348, Loss 0.3384665250778198



[A

Step 347, Loss 0.3396714925765991



[A

Step 346, Loss 0.34596771001815796



[A

Step 345, Loss 0.3364120423793793



[A

Step 344, Loss 0.33179253339767456



[A

Step 343, Loss 0.3357923924922943



[A

Step 342, Loss 0.32795965671539307



[A

Step 341, Loss 0.32860442996025085



[A

Step 340, Loss 0.3314055800437927



[A

Step 339, Loss 0.33421969413757324



[A

Step 338, Loss 0.33221814036369324



[A

Step 337, Loss 0.3316372036933899



[A

Step 336, Loss 0.33267492055892944



[A

Step 335, Loss 0.33616209030151367



[A

Step 334, Loss 0.3327793776988983



[A

Step 333, Loss 0.3339190185070038



[A

Step 332, Loss 0.3355534076690674



[A

Step 331, Loss 0.3332248330116272



[A

Step 330, Loss 0.32954075932502747



[A

Step 329, Loss 0.3312954902648926



[A

Step 328, Loss 0.3342779576778412



[A

Step 327, Loss 0.33315321803092957



[A

Step 326, Loss 0.3305068016052246



[A

Step 325, Loss 0.3295707702636719



[A

Step 324, Loss 0.33196431398391724



[A

Step 323, Loss 0.3321410119533539



[A

Step 322, Loss 0.3304833769798279



[A

Step 321, Loss 0.33307528495788574



[A

Step 320, Loss 0.3325116038322449



[A

Step 319, Loss 0.33200591802597046



[A

Step 318, Loss 0.3316163420677185



[A

Step 317, Loss 0.3325035572052002



[A

Step 316, Loss 0.3335912823677063



[A

Step 315, Loss 0.3367273807525635



[A

Step 314, Loss 0.338029146194458



[A

Step 313, Loss 0.3364661633968353



[A

Step 312, Loss 0.33487123250961304



[A

Step 311, Loss 0.3337620198726654



[A

Step 310, Loss 0.3310790956020355



[A

Step 309, Loss 0.333326518535614



[A

Step 308, Loss 0.33105212450027466



[A

Step 307, Loss 0.33160924911499023



[A

Step 306, Loss 0.33470794558525085



[A

Step 305, Loss 0.33315348625183105



[A

Step 304, Loss 0.33059966564178467



[A

Step 303, Loss 0.33358773589134216



[A

Step 302, Loss 0.3359981179237366



[A

Step 301, Loss 0.33657845854759216



[A

Step 300, Loss 0.33225947618484497



[A

Step 299, Loss 0.33353838324546814



[A

Step 298, Loss 0.332461416721344



[A

Step 297, Loss 0.333845853805542



[A

Step 296, Loss 0.33489760756492615



[A

Step 295, Loss 0.33374160528182983



[A

Step 294, Loss 0.3329122066497803



[A

Step 293, Loss 0.33480024337768555



[A

Step 292, Loss 0.33360394835472107



[A

Step 291, Loss 0.33387064933776855



[A

Step 290, Loss 0.3364585340023041



[A

Step 289, Loss 0.33909446001052856



[A

Step 288, Loss 0.3400702476501465



[A

Step 287, Loss 0.3395504653453827



[A

Step 286, Loss 0.34147095680236816



[A

Step 285, Loss 0.33785754442214966



[A

Step 284, Loss 0.3407513499259949



[A

Step 283, Loss 0.33990252017974854



[A

Step 282, Loss 0.33620327711105347



[A

Step 281, Loss 0.34007394313812256



[A

Step 280, Loss 0.339913010597229



[A

Step 279, Loss 0.33716195821762085



[A

Step 278, Loss 0.33720049262046814



[A

Step 277, Loss 0.34163951873779297



[A

Step 276, Loss 0.3390040397644043



[A

Step 275, Loss 0.33822101354599



[A

Step 274, Loss 0.34104102849960327



[A

Step 273, Loss 0.3390609920024872



[A

Step 272, Loss 0.33865249156951904



[A

Step 271, Loss 0.34369534254074097



[A

Step 270, Loss 0.34890902042388916



[A

Step 269, Loss 0.3453628718852997



[A

Step 268, Loss 0.3422169089317322



[A

Step 267, Loss 0.3388223648071289



[A

Step 266, Loss 0.34429726004600525



[A

Step 265, Loss 0.34175729751586914



[A

Step 264, Loss 0.3467656970024109



[A

Step 263, Loss 0.3434700667858124



[A

Step 262, Loss 0.3431296944618225



[A

Step 261, Loss 0.3413029611110687



[A

Step 260, Loss 0.3403480648994446



[A

Step 259, Loss 0.3428420424461365



[A

Step 258, Loss 0.34456154704093933



[A

Step 257, Loss 0.3449002206325531



[A

Step 256, Loss 0.3442048728466034



[A

Step 255, Loss 0.3455236852169037



[A

Step 254, Loss 0.3461833894252777



[A

Step 253, Loss 0.35016965866088867



[A

Step 252, Loss 0.3485791087150574



[A

Step 251, Loss 0.3488113284111023



[A

Step 250, Loss 0.34826549887657166



[A

Step 249, Loss 0.35177820920944214



[A

Step 248, Loss 0.35028934478759766



[A

Step 247, Loss 0.3528895378112793



[A

Step 246, Loss 0.35068297386169434



[A

Step 245, Loss 0.34819817543029785



[A

Step 244, Loss 0.3497353196144104



[A

Step 243, Loss 0.3491672873497009



[A

Step 242, Loss 0.3482566773891449



[A

Step 241, Loss 0.3499225974082947



[A

Step 240, Loss 0.34977826476097107



[A

Step 239, Loss 0.34927988052368164



[A

Step 238, Loss 0.3512624204158783



[A

Step 237, Loss 0.35509395599365234



[A

Step 236, Loss 0.35460853576660156



[A

Step 235, Loss 0.35504743456840515



[A

Step 234, Loss 0.3612787127494812



[A

Step 233, Loss 0.35973918437957764



[A

Step 232, Loss 0.35438400506973267



[A

Step 231, Loss 0.35543093085289



[A

Step 230, Loss 0.35749518871307373



[A

Step 229, Loss 0.3564293086528778



[A

Step 228, Loss 0.3567688465118408



[A

Step 227, Loss 0.36274445056915283



[A

Step 226, Loss 0.36398574709892273



[A

Step 225, Loss 0.35912322998046875



[A

Step 224, Loss 0.3566930890083313



[A

Step 223, Loss 0.36036252975463867



[A

Step 222, Loss 0.35897600650787354



[A

Step 221, Loss 0.3573196530342102



[A

Step 220, Loss 0.35856637358665466



[A

Step 219, Loss 0.36002835631370544



[A

Step 218, Loss 0.35823389887809753



[A

Step 217, Loss 0.3612877130508423



[A

Step 216, Loss 0.3598073422908783



[A

Step 215, Loss 0.36239275336265564



[A

Step 214, Loss 0.3652436435222626



[A

Step 213, Loss 0.3614407479763031



[A

Step 212, Loss 0.36575520038604736



[A

Step 211, Loss 0.3699125647544861



[A

Step 210, Loss 0.36963075399398804



[A

Step 209, Loss 0.36939629912376404



[A

Step 208, Loss 0.37289461493492126



[A

Step 207, Loss 0.37043988704681396



[A

Step 206, Loss 0.37361621856689453



[A

Step 205, Loss 0.3718486726284027



[A

Step 204, Loss 0.36971062421798706



[A

Step 203, Loss 0.36881232261657715



[A

Step 202, Loss 0.3732272982597351



[A

Step 201, Loss 0.37245890498161316



[A

Step 200, Loss 0.36913514137268066



[A

Step 199, Loss 0.3740355968475342



[A

Step 198, Loss 0.3754473626613617



[A

Step 197, Loss 0.37162312865257263



[A

Step 196, Loss 0.3754384517669678



[A

Step 195, Loss 0.37975165247917175



[A

Step 194, Loss 0.37612542510032654



[A

Step 193, Loss 0.37838584184646606



[A

Step 192, Loss 0.3788759112358093



[A

Step 191, Loss 0.3785652816295624



[A

Step 190, Loss 0.3774833679199219



[A

Step 189, Loss 0.3781750202178955



[A

Step 188, Loss 0.3790382742881775



[A

Step 187, Loss 0.3797778785228729



[A

Step 186, Loss 0.38374584913253784



[A

Step 185, Loss 0.38499176502227783



[A

Step 184, Loss 0.38097497820854187



[A

Step 183, Loss 0.38241466879844666



[A

Step 182, Loss 0.38473570346832275



[A

Step 181, Loss 0.3934783339500427



[A

Step 180, Loss 0.39405497908592224



[A

Step 179, Loss 0.3962433338165283



[A

Step 178, Loss 0.3899444341659546



[A

Step 177, Loss 0.38606390357017517



[A

Step 176, Loss 0.38980334997177124



[A

Step 175, Loss 0.39011678099632263



[A

Step 174, Loss 0.39427828788757324



[A

Step 173, Loss 0.3978552222251892



[A

Step 172, Loss 0.39522624015808105



[A

Step 171, Loss 0.3963155150413513



[A

Step 170, Loss 0.3922078013420105



[A

Step 169, Loss 0.39358216524124146



[A

Step 168, Loss 0.3952823579311371



[A

Step 167, Loss 0.3979932963848114



[A

Step 166, Loss 0.39810681343078613



[A

Step 165, Loss 0.39671215415000916



[A

Step 164, Loss 0.3979277014732361



[A

Step 163, Loss 0.40320849418640137



[A

Step 162, Loss 0.4041416645050049



[A

Step 161, Loss 0.40588805079460144



[A

Step 160, Loss 0.40650293231010437



[A

Step 159, Loss 0.405381441116333



[A

Step 158, Loss 0.4102528691291809



[A

Step 157, Loss 0.4112686514854431



[A

Step 156, Loss 0.41238078474998474



[A

Step 155, Loss 0.41672763228416443



[A

Step 154, Loss 0.421801894903183



[A

Step 153, Loss 0.4121135175228119



[A

Step 152, Loss 0.41755884885787964



[A

Step 151, Loss 0.418557733297348



[A

Step 150, Loss 0.4230058193206787



[A

Step 149, Loss 0.4178488254547119



[A

Step 148, Loss 0.4215964674949646



[A

Step 147, Loss 0.4226348400115967



[A

Step 146, Loss 0.4272482395172119



[A

Step 145, Loss 0.42789626121520996



[A

Step 144, Loss 0.43120431900024414



[A

Step 143, Loss 0.43043240904808044



[A

Step 142, Loss 0.4323112964630127



[A

Step 141, Loss 0.43185752630233765



[A

Step 140, Loss 0.43464744091033936



[A

Step 139, Loss 0.4342653453350067



[A

Step 138, Loss 0.433014839887619



[A

Step 137, Loss 0.43684589862823486



[A

Step 136, Loss 0.4397760331630707



[A

Step 135, Loss 0.44241863489151



[A

Step 134, Loss 0.444198340177536



[A

Step 133, Loss 0.44332337379455566



[A

Step 132, Loss 0.4443429708480835



[A

Step 131, Loss 0.44728517532348633



[A

Step 130, Loss 0.4490160346031189



[A

Step 129, Loss 0.4474831223487854



[A

Step 128, Loss 0.4475143849849701



[A

Step 127, Loss 0.4544057250022888



[A

Step 126, Loss 0.4529035687446594



[A

Step 125, Loss 0.4533970355987549



[A

Step 124, Loss 0.45625561475753784



[A

Step 123, Loss 0.4601375460624695



[A

Step 122, Loss 0.46121737360954285



[A

Step 121, Loss 0.4698983430862427



[A

Step 120, Loss 0.47535407543182373



[A

Step 119, Loss 0.4692092835903168



[A

Step 118, Loss 0.4748442769050598



[A

Step 117, Loss 0.47559961676597595



[A

Step 116, Loss 0.47622808814048767



[A

Step 115, Loss 0.48081696033477783



[A

Step 114, Loss 0.48558053374290466



[A

Step 113, Loss 0.4885500371456146



[A

Step 112, Loss 0.4906456470489502



[A

Step 111, Loss 0.4940055012702942



[A

Step 110, Loss 0.4926561117172241



[A

Step 109, Loss 0.48871278762817383



[A

Step 108, Loss 0.48990485072135925



[A

Step 107, Loss 0.4937896132469177



[A

Step 106, Loss 0.49299296736717224



[A

Step 105, Loss 0.4948795437812805



[A

Step 104, Loss 0.4960366487503052



[A

Step 103, Loss 0.4961708188056946



[A

Step 102, Loss 0.502675473690033



[A

Step 101, Loss 0.5024235248565674



[A

Step 100, Loss 0.5065786242485046



[A

Step 99, Loss 0.5108429193496704



[A

Step 98, Loss 0.5113998651504517



[A

Step 97, Loss 0.5190960764884949



[A

Step 96, Loss 0.5199183225631714



[A

Step 95, Loss 0.5184253454208374



[A

Step 94, Loss 0.5204247832298279



[A

Step 93, Loss 0.5263943076133728



[A

Step 92, Loss 0.5301787257194519



[A

Step 91, Loss 0.534026026725769



[A

Step 90, Loss 0.5351280570030212



[A

Step 89, Loss 0.5317753553390503



[A

Step 88, Loss 0.5355220437049866



[A

Step 87, Loss 0.5376803874969482



[A

Step 86, Loss 0.5369122624397278



[A

Step 85, Loss 0.5411431789398193



[A

Step 84, Loss 0.5455016493797302



[A

Step 83, Loss 0.5481473803520203



[A

Step 82, Loss 0.5494570732116699



[A

Step 81, Loss 0.5550726652145386



[A

Step 80, Loss 0.5577758550643921



[A

Step 79, Loss 0.5642640590667725



[A

Step 78, Loss 0.5650903582572937



[A

Step 77, Loss 0.5677951574325562



[A

Step 76, Loss 0.563805103302002



[A

Step 75, Loss 0.568446159362793



[A

Step 74, Loss 0.5702685713768005



[A

Step 73, Loss 0.5756192803382874



[A

Step 72, Loss 0.5791404247283936



[A

Step 71, Loss 0.5818095803260803



[A

Step 70, Loss 0.5927035808563232



[A

Step 69, Loss 0.6016161441802979



[A

Step 68, Loss 0.5993934869766235



[A

Step 67, Loss 0.5956279635429382



[A

Step 66, Loss 0.5957114696502686



[A

Step 65, Loss 0.5953987836837769



[A

Step 64, Loss 0.6020625829696655



[A

Step 63, Loss 0.6116011738777161



[A

Step 62, Loss 0.6138014793395996



[A

Step 61, Loss 0.6128713488578796



[A

Step 60, Loss 0.6104363203048706



[A

Step 59, Loss 0.6196527481079102



[A

Step 58, Loss 0.6240250468254089



[A

Step 57, Loss 0.627185583114624



[A

Step 56, Loss 0.6330695152282715



[A

Step 55, Loss 0.6353245973587036



[A

Step 54, Loss 0.6401259303092957



[A

Step 53, Loss 0.6417509913444519



[A

Step 52, Loss 0.6453143954277039



[A

Step 51, Loss 0.6479880809783936



[A

Step 50, Loss 0.6535022258758545



[A

Step 49, Loss 0.6585737466812134



[A

Step 48, Loss 0.6645470857620239



[A

Step 47, Loss 0.663236677646637



[A

Step 46, Loss 0.6692778468132019



[A

Step 45, Loss 0.6704959273338318



[A

Step 44, Loss 0.6743940114974976



[A

Step 43, Loss 0.6781694889068604



[A

Step 42, Loss 0.6813501119613647



[A

Step 41, Loss 0.6841152906417847



[A

Step 40, Loss 0.6890108585357666



[A

Step 39, Loss 0.6939224004745483



[A

Step 38, Loss 0.7016837000846863



[A

Step 37, Loss 0.7053886651992798



[A

Step 36, Loss 0.7078296542167664



[A

Step 35, Loss 0.7190321683883667



[A

Step 34, Loss 0.720808207988739



[A

Step 33, Loss 0.7222224473953247



[A

Step 32, Loss 0.7251603603363037



[A

Step 31, Loss 0.7315663695335388



[A

Step 30, Loss 0.7317788004875183



[A

Step 29, Loss 0.7338924407958984



[A

Step 28, Loss 0.7408813238143921



[A

Step 27, Loss 0.7437421083450317



[A

Step 26, Loss 0.7439473867416382



[A

Step 25, Loss 0.7513437271118164



[A

Step 24, Loss 0.7501746416091919



[A

Step 23, Loss 0.7608731985092163



[A

Step 22, Loss 0.7600375413894653



[A

Step 21, Loss 0.7653735876083374



[A

Step 20, Loss 0.7696933746337891



[A

Step 19, Loss 0.7731221914291382



[A

Step 18, Loss 0.7797813415527344



[A

Step 17, Loss 0.7842082381248474



[A

Step 16, Loss 0.7866182327270508



[A

Step 15, Loss 0.7927641868591309



[A

Step 14, Loss 0.7999862432479858



[A

Step 13, Loss 0.8053455352783203



[A

Step 12, Loss 0.8089886903762817



[A

Step 11, Loss 0.8130039572715759



[A

Step 10, Loss 0.8177706003189087



[A

Step 9, Loss 0.8226457834243774



[A

Step 8, Loss 0.8270779848098755



[A

Step 7, Loss 0.8284484148025513



[A

Step 6, Loss 0.8344659209251404



[A

Step 5, Loss 0.8389419317245483



[A

Step 4, Loss 0.8432345390319824



[A

Step 3, Loss 0.8446857333183289



[A

Step 2, Loss 0.8473138809204102



[A

Step 1, Loss 0.856032133102417



[A

Step 0, Loss 0.866870641708374



sampling loop time step: 100%|██████████| 1000/1000 [08:23<00:00,  1.99it/s]
