## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms 
from torch.utils.data import DataLoader, Dataset
from torch import optim
from torch.cuda.amp import autocast, GradScaler
import copy
import numpy as np
import math
from PIL import Image
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm
import os

try:
    import cPickle as pickle
except ModuleNotFoundError:
    import pickle

## Utils

In [2]:
def plot_images(images):
    plt.figure(figsize=(4, 4))
    plt.imshow(torch.cat([torch.cat([i for i in images.cpu()], dim=-1),], dim=-2).permute(1, 2, 0).cpu())
    plt.show()


def save_images(images, path, **kwargs):
    grid = torchvision.utils.make_grid(images, **kwargs)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    im = Image.fromarray(ndarr)
    im.save(path)

## Dataset

In [3]:
class MinMaxScaler:
    def __init__(self, codebook:str, normalize = True, min_val=None, max_val=None):
        if min_val is None or max_val is None:
            if codebook == 'top':
                self.max_val = 0.5855448842048645
                self.min_val = -0.5102538466453552
            elif codebook == 'bot':
                self.max_val = 2.88156795501709
                self.min_val = -3.064223527908325 
        else:
            self.min_val = min_val
            self.max_val = max_val
        self.normalize = normalize

    def __call__(self, batch):
        out = (batch - self.min_val) / (self.max_val - self.min_val)
        if self.normalize:
            out = (out - 0.5) * 2
        return out
    
class MinMaxScalerCollate:
    def __init__(self):
        self.scaler_t = MinMaxScaler("top")
        self.scaler_b = MinMaxScaler("bot")
    
    def __call__(self, batch):
        batch_b, batch_t = [], []
        for k in batch:
            batch_t.append(k[0])
            batch_b.append(k[1])
        batch_t = torch.stack(batch_t)
        batch_b = torch.stack(batch_b)
        batch_t = self.scaler_t(batch_t)
        batch_b = self.scaler_b(batch_b)
        return batch_b, batch_t
    
class InverseMinMaxScaler:
    def __init__(self, codebook:str, normalize = True, min_val = None, max_val = None):
        if min_val is None or max_val is None:
            if codebook == 'top':
                self.max_val = 0.5855448842048645
                self.min_val = -0.5102538466453552
            elif codebook == 'bot':
                self.max_val = 2.88156795501709
                self.min_val = -3.064223527908325 
        else:
            self.min_val = min_val
            self.max_val = max_val
        self.normalize = normalize

    def __call__(self, batch):
        if self.normalize:
            out = (batch / 2) + 0.5
        out = (out * (self.max_val - self.min_val)) + self.min_val
        return out
    
class TensorFolderDataset(Dataset):
    def __init__(self, folder_path):
        self.folder_path = folder_path
        self.tensor_files = sorted(os.listdir(folder_path))

    def __getitem__(self, index):
        filepath = os.path.join(self.folder_path, self.tensor_files[index])
        tensor = torch.load(filepath)
        tensor_t = tensor[3:, :32, :32]
        tensor_b = tensor[:3, :, :]
        return tensor_t, tensor_b

    def __len__(self):
        return len(self.tensor_files)
    
def get_data(config):
    
    dataset = TensorFolderDataset(config.dataset_path)
    collate_fn = MinMaxScalerCollate()
    dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

    return dataloader

## Modules

In [4]:
class EMA():
    def __init__(self, beta, step = 0):
        super().__init__()
        self.beta = beta
        self.step = step

    def update_model_average(self, ema_model, current_model):
        for current_params, ema_params in zip(current_model.parameters(), ema_model.parameters()):
            old_weight, up_weight = ema_params.data, current_params.data
            ema_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())
        
        

class HeadAttention(nn.Module):
    def __init__(self, channels, size, heads):
        super(HeadAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, heads, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
    


class SkipAttention(nn.Module):
    def __init__(self, f_in_g, f_in_x, f_out):
        super().__init__()
        
        self.w_g = nn.Sequential(
            nn.Conv2d(f_in_g, f_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(f_out)
        )
        
        self.w_x = nn.Sequential(
            nn.Conv2d(f_in_x, f_out, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(f_out)
        )
        
        self.relu = nn.ReLU(inplace=True)

        self.psi = nn.Sequential(
            nn.Conv2d(f_out, 1, kernel_size = 1, stride = 1, padding = 0),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
         
    def forward(self, g, x):
        g1 = self.w_g(g)
        x1 = self.w_x(x)
        g1 = F.interpolate(g1, size=x.size()[2:], mode='bilinear')
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        return psi*x
    
    
    
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)
        
        
class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=256):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )


    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, gated_attention, emb_dim=256):
        super().__init__()

        self.gated_attention = gated_attention

        self.attention = SkipAttention(in_channels // 2, in_channels // 2, in_channels // 2)

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )


    def forward(self, x, skip_x, t):
        if self.gated_attention:
            skip_x = self.attention(g=x, x=skip_x)
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb
    
    
class UNet(nn.Module):
    def __init__(self, device, batch_size, heads = 4, gated_attention = False, c_in=6, c_out=3, time_dim=256):
        super().__init__()
        self.channels = [128*k for k in range(1, 5)]
        self.device = device
        self.batch_size = batch_size
        self.time_dim = time_dim
        # self.cond_emb = nn.Linear(3*32*32, time_dim)
        self.inc = DoubleConv(c_in, self.channels[0])
        self.down1 = Down(self.channels[0], self.channels[1])
        self.sa1 = HeadAttention(self.channels[1], 32, heads)
        self.down2 = Down(self.channels[1], self.channels[2])
        self.sa2 = HeadAttention(self.channels[2], 16, heads)
        self.down3 = Down(self.channels[2], self.channels[2])
        self.sa3 = HeadAttention(self.channels[2], 8, heads)

        self.bot1 = DoubleConv(self.channels[2], self.channels[3])
        self.bot2 = DoubleConv(self.channels[3], self.channels[3])
        self.bot3 = DoubleConv(self.channels[3], self.channels[2])

        self.up1 = Up(2*self.channels[2], self.channels[1], gated_attention)
        self.sa4 = HeadAttention(self.channels[1], 16, heads)
        self.up2 = Up(2*self.channels[1], self.channels[0], gated_attention)
        self.sa5 = HeadAttention(self.channels[0], 32, heads)
        self.up3 = Up(2*self.channels[0], self.channels[0], gated_attention)
        self.sa6 = HeadAttention(self.channels[0], 64, heads)
        self.outc = nn.Conv2d(self.channels[0], c_out, kernel_size=1)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x2 = self.sa1(x2)
        x3 = self.down2(x2, t)
        x3 = self.sa2(x3)
        x4 = self.down3(x3, t)
        x4 = self.sa3(x4)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.sa4(x)
        x = self.up2(x, x2, t)
        x = self.sa5(x)
        x = self.up3(x, x1, t)
        x = self.sa6(x)
        output = self.outc(x)
        return output

## Diffusion

In [5]:
class Diffusion:
    def __init__(self, config):
        self.img_size = config.generate_img_size
        self.noise_steps = config.noise_steps
        self.device = config.device
        self.img_size = config.generate_img_size

        self.beta = self.prepare_noise_schedule(schedule=config.schedule, noise_steps=self.noise_steps).to(self.device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    @staticmethod
    def prepare_noise_schedule(schedule, noise_steps, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):

        if schedule == "linear":
            betas = torch.linspace(linear_start, linear_end, noise_steps)

        # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
        elif schedule == "cosine":
            timesteps = (torch.arange(noise_steps + 1) / noise_steps + cosine_s)
            alphas = timesteps / (1 + cosine_s) * math.pi / 2
            alphas = torch.cos(alphas).pow(2)
            alphas = alphas / alphas[0]
            betas = 1 - alphas[1:] / alphas[:-1]
            betas = betas.clamp(max=0.999)

        return betas

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    @torch.no_grad()
    def sample(self, title, model, n, epoch):

        model.eval()

        x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
        for i in reversed(range(1, self.noise_steps)):
            t = (torch.ones(n) * i).long().to(self.device)
            predicted_noise = model(x, t)
            alpha = self.alpha[t][:, None, None, None]
            alpha_hat = self.alpha_hat[t][:, None, None, None]
            beta = self.beta[t][:, None, None, None]
            if i > 1:
                noise = torch.randn_like(x)
            else:
                noise = torch.zeros_like(x)
            x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)

        save_images(x, f"sample{epoch}_{title}.jpg")
        plot_images(x)
        
        model.train()
        
        return x


def train(config, scheduler_path = None, path = None):
    
    dataloader = get_data(config)
    print("loaded")
    
    model = UNet(config.device, config.batch_size, config.heads, config.gated_attention).to(config.device)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    optimizer = optim.AdamW(model.parameters(), lr=config.lr)
    mse = nn.MSELoss()

    if config.sched == "cycle" and scheduler_path == None:
        scheduler = CycleScheduler(
            optimizer,
            config.lr,
            n_iter=len(dataloader) * config.epochs,
            momentum=None,
            warmup_proportion=0.05,
        )
    
    losses = []
    start_epoch = 0
    
    if path is not None:
        checkpoint = torch.load(path)
        model.load_state_dict(checkpoint['model_state_dict'])
        config_dict = checkpoint["config"]
        config.__dict__.update(config_dict)
        start_epoch = checkpoint['epoch']
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        losses = checkpoint['loss']
        if config.sched == "cycle" and scheduler_path == None:
            with open(scheduler_path, "rb") as file:
                scheduler = pickle.load(file)
        
        steps_ema = len(dataloader)*start_epoch
        ema = EMA(0.995, steps_ema)
        ema_model = copy.deepcopy(model).eval().requires_grad_(False)
        ema_model.load_state_dict(checkpoint['ema_model_state_dict'])
        
        del checkpoint
        
    else:
        ema = EMA(0.995)
        ema_model = copy.deepcopy(model).eval().requires_grad_(False)
        
    
    diffusion = Diffusion(config)

 
    if config.mix_precision:
        scaler = GradScaler()

    for epoch in range(start_epoch+1, config.epochs + 1):
        
        epoch_loss = []
        
        model.train()

        for k, (images, condition) in enumerate(tqdm(dataloader)):
            
            optimizer.zero_grad()

            images = images.to(config.device)
            condition = condition.to(config.device)
            condition = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)(condition)
            
            if config.mix_precision:
                with autocast():
                    t = diffusion.sample_timesteps(images.shape[0]).to(config.device)
                    x_t, noise = diffusion.noise_images(images, t)
                    predicted_noise = model(x=x_t, t=t, y=condition)
                    loss = mse(noise, predicted_noise)           
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                with autocast():
                    ema.step_ema(ema_model, model)
            else:
                t = diffusion.sample_timesteps(images.shape[0]).to(config.device)
                x_t, noise = diffusion.noise_images(images, t)
                conditioned_x_t = torch.cat([x_t, condition], dim=1)
                predicted_noise = model(conditioned_x_t, t)
                loss = mse(noise, predicted_noise)           
                loss.backward()
                if scheduler_path is not None:
                    scheduler.step()
                optimizer.step()
                ema.step_ema(ema_model, model)
            
            if k%100 == 0:
                print(f"e{epoch}  |  b{k}  |  MSE{loss.item()}")
                
            epoch_loss.append(loss.item())
            
        epoch_loss = np.array(epoch_loss).mean()
        print(f"Epoch {epoch} loss: {epoch_loss}")
        losses.append(epoch_loss)

        if epoch%2 == 0:

            path = f'my_diff_e{str(epoch)}.pth'

            if torch.cuda.device_count() > 1:
                torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.module.state_dict(),
                        'ema_model_state_dict': ema_model.module.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': losses,
                        'config': config.__dict__
                        }, path)
            else:
                torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'ema_model_state_dict': ema_model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': losses,
                        'config': config.__dict__
                        }, path)

            if scheduler_path is not None:
                with open(f"diff_sch_e{str(epoch)}.pkl", "wb") as file:
                    pickle.dump(scheduler, file, -1)

    
class ModelConfig:
    def __init__(self, batch_size=10, image_size=64, epochs=100, lr=1e-4, 
                 device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                 mix_precision = False,
                 gated_attention = False,
                 schedule = 'linear',
                 noise_steps = 1000,
                 generate_img_size = 64,
                 sched = None,
                 heads = 4,
                 dataset_path = "/kaggle/input/new-tensors-mse-00059/New tensors"):
        
        self.batch_size = batch_size
        self.image_size = image_size
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.mix_precision = mix_precision
        self.gated_attention = gated_attention
        self.schedule = schedule
        self.noise_steps = noise_steps
        self.generate_img_size = generate_img_size
        self.sched = sched
        self.heads = heads
        self.dataset_path = dataset_path

def launch(path = None, sched_path = None):
    config = ModelConfig()
    train(config, sched_path, path)

In [6]:
net = UNet(device="cpu", batch_size=10, heads=4, gated_attention=False)
print(sum([p.numel() for p in net.parameters()]))

47597964


In [7]:
!nvidia-smi

Sun May  7 18:39:56 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   35C    P0    28W / 250W |      2MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [8]:
launch("/kaggle/input/ldm-64-new-conditioned/my_diff_e90.pth")

loaded


  0%|          | 0/2500 [00:00<?, ?it/s]

e91  |  b0  |  MSE0.0016600837698206306
e91  |  b100  |  MSE0.013996480032801628
e91  |  b200  |  MSE0.004101743921637535
e91  |  b300  |  MSE0.002570695709437132
e91  |  b400  |  MSE0.0024590680841356516
e91  |  b500  |  MSE0.01387505792081356
e91  |  b600  |  MSE0.00813427846878767
e91  |  b700  |  MSE0.04360262304544449
e91  |  b800  |  MSE0.004691008944064379
e91  |  b900  |  MSE0.04914599284529686
e91  |  b1000  |  MSE0.04231433942914009
e91  |  b1100  |  MSE0.009078649803996086
e91  |  b1200  |  MSE0.04605858400464058
e91  |  b1300  |  MSE0.011329641565680504
e91  |  b1400  |  MSE0.009625199250876904
e91  |  b1500  |  MSE0.03258274123072624
e91  |  b1600  |  MSE0.0831092894077301
e91  |  b1700  |  MSE0.007993682287633419
e91  |  b1800  |  MSE0.00899260863661766
e91  |  b1900  |  MSE0.010972564108669758
e91  |  b2000  |  MSE0.00965025182813406
e91  |  b2100  |  MSE0.0050200046971440315
e91  |  b2200  |  MSE0.009164582006633282
e91  |  b2300  |  MSE0.054727137088775635
e91  |  b240

  0%|          | 0/2500 [00:00<?, ?it/s]

e92  |  b0  |  MSE0.01631573960185051
e92  |  b100  |  MSE0.08069062978029251
e92  |  b200  |  MSE0.005535321310162544
e92  |  b300  |  MSE0.006869533099234104
e92  |  b400  |  MSE0.011161626316606998
e92  |  b500  |  MSE0.003999573644250631
e92  |  b600  |  MSE0.005114093888550997
e92  |  b700  |  MSE0.030538232997059822
e92  |  b800  |  MSE0.12216545641422272
e92  |  b900  |  MSE0.07401785999536514
e92  |  b1000  |  MSE0.01377419289201498
e92  |  b1100  |  MSE0.0026143412105739117
e92  |  b1200  |  MSE0.012114267796278
e92  |  b1300  |  MSE0.001726000802591443
e92  |  b1400  |  MSE0.0011142196599394083
e92  |  b1500  |  MSE0.005878890864551067
e92  |  b1600  |  MSE0.001563701662234962
e92  |  b1700  |  MSE0.03746391460299492
e92  |  b1800  |  MSE0.043976180255413055
e92  |  b1900  |  MSE0.035863786935806274
e92  |  b2000  |  MSE0.008406654000282288
e92  |  b2100  |  MSE0.09165630489587784
e92  |  b2200  |  MSE0.009387283585965633
e92  |  b2300  |  MSE0.004545432515442371
e92  |  b240

  0%|          | 0/2500 [00:00<?, ?it/s]

e93  |  b0  |  MSE0.004904171451926231
e93  |  b100  |  MSE0.057635772973299026
e93  |  b200  |  MSE0.12445355206727982
e93  |  b300  |  MSE0.005430819932371378
e93  |  b400  |  MSE0.019907329231500626
e93  |  b500  |  MSE0.009311903268098831
e93  |  b600  |  MSE0.002519810339435935
e93  |  b700  |  MSE0.023742027580738068
e93  |  b800  |  MSE0.003766910871490836
e93  |  b900  |  MSE0.0439077652990818
e93  |  b1000  |  MSE0.011690675280988216
e93  |  b1100  |  MSE0.01639682799577713
e93  |  b1200  |  MSE0.010645115748047829
e93  |  b1300  |  MSE0.005297385156154633
e93  |  b1400  |  MSE0.063979871571064
e93  |  b1500  |  MSE0.005002194084227085
e93  |  b1600  |  MSE0.0027137182187289
e93  |  b1700  |  MSE0.026743818074464798
e93  |  b1800  |  MSE0.01479863841086626
e93  |  b1900  |  MSE0.009153533726930618
e93  |  b2000  |  MSE0.06687797605991364
e93  |  b2100  |  MSE0.0012269000289961696
e93  |  b2200  |  MSE0.013020189478993416
e93  |  b2300  |  MSE0.07013069093227386
e93  |  b2400  

  0%|          | 0/2500 [00:00<?, ?it/s]

e94  |  b0  |  MSE0.010333588346838951
e94  |  b100  |  MSE0.019148031249642372
e94  |  b200  |  MSE0.005387489218264818
e94  |  b300  |  MSE0.04220897704362869
e94  |  b400  |  MSE0.11665863543748856
e94  |  b500  |  MSE0.011614315211772919
e94  |  b600  |  MSE0.007063842378556728
e94  |  b700  |  MSE0.03760123625397682
e94  |  b800  |  MSE0.013138115406036377
e94  |  b900  |  MSE0.01110171154141426
e94  |  b1000  |  MSE0.10317249596118927
e94  |  b1100  |  MSE0.017572572454810143
e94  |  b1200  |  MSE0.039016854017972946
e94  |  b1300  |  MSE0.04990663751959801
e94  |  b1400  |  MSE0.0027651567943394184
e94  |  b1500  |  MSE0.006258490961045027
e94  |  b1600  |  MSE0.038836345076560974
e94  |  b1700  |  MSE0.073277547955513
e94  |  b1800  |  MSE0.004116969183087349
e94  |  b1900  |  MSE0.016454845666885376
e94  |  b2000  |  MSE0.01587468385696411
e94  |  b2100  |  MSE0.0025086302775889635
e94  |  b2200  |  MSE0.01427058968693018
e94  |  b2300  |  MSE0.0037138930056244135
e94  |  b240

  0%|          | 0/2500 [00:00<?, ?it/s]

e95  |  b0  |  MSE0.002391395391896367
e95  |  b100  |  MSE0.005774557124823332
e95  |  b200  |  MSE0.031436264514923096
e95  |  b300  |  MSE0.009776837192475796
e95  |  b400  |  MSE0.0010712685761973262
e95  |  b500  |  MSE0.002417827257886529
e95  |  b600  |  MSE0.004604398272931576
e95  |  b700  |  MSE0.006040562409907579
e95  |  b800  |  MSE0.004440543241798878
e95  |  b900  |  MSE0.01861765794456005
e95  |  b1000  |  MSE0.0015076693380251527
e95  |  b1100  |  MSE0.005664126947522163
e95  |  b1200  |  MSE0.004344700835645199
e95  |  b1300  |  MSE0.00455993227660656
e95  |  b1400  |  MSE0.00353458384051919
e95  |  b1500  |  MSE0.04836656153202057
e95  |  b1600  |  MSE0.04776766523718834
e95  |  b1700  |  MSE0.033142343163490295
e95  |  b1800  |  MSE0.0067013828083872795
e95  |  b1900  |  MSE0.03746282309293747
e95  |  b2000  |  MSE0.04379573091864586
e95  |  b2100  |  MSE0.008409343659877777
e95  |  b2200  |  MSE0.05009991675615311
e95  |  b2300  |  MSE0.006963175255805254
e95  |  b

  0%|          | 0/2500 [00:00<?, ?it/s]

e96  |  b0  |  MSE0.015847887843847275
e96  |  b100  |  MSE0.012496777810156345
e96  |  b200  |  MSE0.011461284942924976
e96  |  b300  |  MSE0.021755022928118706
e96  |  b400  |  MSE0.0713651105761528
e96  |  b500  |  MSE0.0640735775232315
e96  |  b600  |  MSE0.022111471742391586
e96  |  b700  |  MSE0.006327719893306494
e96  |  b800  |  MSE0.03866693750023842
e96  |  b900  |  MSE0.013163959607481956
e96  |  b1000  |  MSE0.06016137823462486
e96  |  b1100  |  MSE0.015919890254735947
e96  |  b1200  |  MSE0.011107627302408218
e96  |  b1300  |  MSE0.007112967316061258
e96  |  b1400  |  MSE0.0015363091370090842
e96  |  b1500  |  MSE0.008252931758761406
e96  |  b1600  |  MSE0.008502128534018993
e96  |  b1700  |  MSE0.02289471961557865
e96  |  b1800  |  MSE0.08979067951440811
e96  |  b1900  |  MSE0.004561217036098242
e96  |  b2000  |  MSE0.0012915144907310605
e96  |  b2100  |  MSE0.02837110124528408
e96  |  b2200  |  MSE0.002412390196695924
e96  |  b2300  |  MSE0.0113490866497159
e96  |  b2400

  0%|          | 0/2500 [00:00<?, ?it/s]

e97  |  b0  |  MSE0.03319690749049187
e97  |  b100  |  MSE0.038936011493206024
e97  |  b200  |  MSE0.002226564334705472
e97  |  b300  |  MSE0.06855230778455734
e97  |  b400  |  MSE0.00375755806453526
e97  |  b500  |  MSE0.003858183976262808
e97  |  b600  |  MSE0.07793399691581726
e97  |  b700  |  MSE0.006548517383635044
e97  |  b800  |  MSE0.016921445727348328
e97  |  b900  |  MSE0.0025590830482542515
e97  |  b1000  |  MSE0.00993636529892683
e97  |  b1100  |  MSE0.009172116406261921
e97  |  b1200  |  MSE0.05770222470164299
e97  |  b1300  |  MSE0.007694622501730919
e97  |  b1400  |  MSE0.06423182040452957
e97  |  b1500  |  MSE0.03549901023507118
e97  |  b1600  |  MSE0.04904976859688759
e97  |  b1700  |  MSE0.0006023462046869099
e97  |  b1800  |  MSE0.06496770679950714
e97  |  b1900  |  MSE0.028650548309087753
e97  |  b2000  |  MSE0.041657883673906326
e97  |  b2100  |  MSE0.00744887487962842
e97  |  b2200  |  MSE0.010148005560040474
e97  |  b2300  |  MSE0.010120412334799767
e97  |  b2400

  0%|          | 0/2500 [00:00<?, ?it/s]

e98  |  b0  |  MSE0.024672450497746468
e98  |  b100  |  MSE0.00604369817301631
e98  |  b200  |  MSE0.01262713223695755
e98  |  b300  |  MSE0.01997826062142849
e98  |  b400  |  MSE0.04035003483295441
e98  |  b500  |  MSE0.015922075137495995
e98  |  b600  |  MSE0.04232373461127281
e98  |  b700  |  MSE0.021628519520163536
e98  |  b800  |  MSE0.011372942477464676
e98  |  b900  |  MSE0.014498580247163773
e98  |  b1000  |  MSE0.003439299063757062
e98  |  b1100  |  MSE0.026975581422448158
e98  |  b1200  |  MSE0.030014993622899055
e98  |  b1300  |  MSE0.015670549124479294
e98  |  b1400  |  MSE0.08163461834192276
e98  |  b1500  |  MSE0.0019107525004073977
e98  |  b1600  |  MSE0.0006006022449582815
e98  |  b1700  |  MSE0.058070264756679535
e98  |  b1800  |  MSE0.042048435658216476
e98  |  b1900  |  MSE0.007251713890582323
e98  |  b2000  |  MSE0.01822226494550705
e98  |  b2100  |  MSE0.006880205124616623
e98  |  b2200  |  MSE0.0803195908665657
e98  |  b2300  |  MSE0.0047015841118991375
e98  |  b2

  0%|          | 0/2500 [00:00<?, ?it/s]

e99  |  b0  |  MSE0.0019995432812720537
e99  |  b100  |  MSE0.015234863385558128
e99  |  b200  |  MSE0.042410146445035934
e99  |  b300  |  MSE0.013651612214744091
e99  |  b400  |  MSE0.008308667689561844
e99  |  b500  |  MSE0.056378960609436035
e99  |  b600  |  MSE0.018912531435489655
e99  |  b700  |  MSE0.021324411034584045
e99  |  b800  |  MSE0.0023425689432770014
e99  |  b900  |  MSE0.041789259761571884
e99  |  b1000  |  MSE0.040779728442430496
e99  |  b1100  |  MSE0.008366849273443222
e99  |  b1200  |  MSE0.010743838734924793
e99  |  b1300  |  MSE0.0039884187281131744
e99  |  b1400  |  MSE0.01345829013735056
e99  |  b1500  |  MSE0.053536608815193176
e99  |  b1600  |  MSE0.011358126997947693
e99  |  b1700  |  MSE0.0008698191959410906
e99  |  b1800  |  MSE0.027095012366771698
e99  |  b1900  |  MSE0.0031756998505443335
e99  |  b2000  |  MSE0.05837380513548851
e99  |  b2100  |  MSE0.013920065946877003
e99  |  b2200  |  MSE0.006082416977733374
e99  |  b2300  |  MSE0.01663113385438919
e9

  0%|          | 0/2500 [00:00<?, ?it/s]

e100  |  b0  |  MSE0.00863195862621069
e100  |  b100  |  MSE0.02022739127278328
e100  |  b200  |  MSE0.040217868983745575
e100  |  b300  |  MSE0.051056452095508575
e100  |  b400  |  MSE0.01661369577050209
e100  |  b500  |  MSE0.07109842449426651
e100  |  b600  |  MSE0.03213052451610565
e100  |  b700  |  MSE0.018043940886855125
e100  |  b800  |  MSE0.0185320395976305
e100  |  b900  |  MSE0.02219568006694317
e100  |  b1000  |  MSE0.00757207116112113
e100  |  b1100  |  MSE0.007556573022156954
e100  |  b1200  |  MSE0.021535594016313553
e100  |  b1300  |  MSE0.007548393215984106
e100  |  b1400  |  MSE0.0033771938178688288
e100  |  b1500  |  MSE0.06775306910276413
e100  |  b1600  |  MSE0.003959372639656067
e100  |  b1700  |  MSE0.0027387377340346575
e100  |  b1800  |  MSE0.006385491229593754
e100  |  b1900  |  MSE0.021538782864809036
e100  |  b2000  |  MSE0.030510157346725464
e100  |  b2100  |  MSE0.007773086428642273
e100  |  b2200  |  MSE0.10043734312057495
e100  |  b2300  |  MSE0.02468210