In [1]:
## conditional diffusion model

In [2]:
from typing import Dict, Tuple
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader,Dataset
from torchvision import models, transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from PIL import Image
import pandas as pd

In [3]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_input3 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_input(x))
        h_       = self.LeakyReLU(self.FC_input2(h_))
        h_       = self.LeakyReLU(self.FC_input3(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                       #             (i.e., parateters of simple tractable normal distribution "q"
        
        return mean, log_var

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_hidden3 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_hidden(x))
        h     = self.LeakyReLU(self.FC_hidden2(h))
        h     = self.LeakyReLU(self.FC_hidden3(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat

class VAE(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat            = self.Decoder(z)
        
        return x_hat, mean, log_var

In [4]:
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        '''
        standard ResNet style convolutional block
        '''
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
        '''
        process and downscale the image feature maps
        '''
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
        '''
        process and upscale the image feature maps
        '''
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
            ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x


class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        '''
        generic one layer FC NN for embedding things  
        '''
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


In [5]:

def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

In [6]:
class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=10):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        self.up0 = nn.Sequential(
            # nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 7, 7), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t, context_mask):
        # x is (noisy) image, c is context label, t is timestep, 
        # context_mask says which samples to block the context on

        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)
        hiddenvec = self.to_vec(down2)

        # convert context to one hot embedding
        # c = nn.functional.one_hot(c, num_classes=self.n_classes).type(torch.float)

        ## c should be of the form <number of context * context length
        
        # mask out context if context_mask == 1
        context_mask = context_mask[:, None]
        context_mask = context_mask.repeat(1,self.n_classes)
        context_mask = (-1*(1-context_mask)) # need to flip 0 <-> 1
        c = c * context_mask
        
        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        # could concatenate the context embedding here instead of adaGN
        # hiddenvec = torch.cat((hiddenvec, temb1, cemb1), 1)

        up1 = self.up0(hiddenvec)
        # up2 = self.up1(up1, down2) # if want to avoid add and multiply embeddings
        up2 = self.up1(cemb1*up1+ temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+ temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out


In [7]:

class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, x, c):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(1, self.n_T+1, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None, None, None] * x
            + self.sqrtmab[_ts, None, None, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # dropout context with some probability
        context_mask = torch.bernoulli(torch.zeros(c.size()[0])+self.drop_prob).to(self.device)
        
        # return MSE between added noise, and our predicted noise
        return self.loss_mse(noise, self.nn_model(x_t, c, _ts / self.n_T, context_mask))

    def sample(self, c_i, n_sample, size, device, guide_w = 0.0):
        # we follow the guidance sampling scheme described in 'Classifier-Free Diffusion Guidance'
        # to make the fwd passes efficient, we concat two versions of the dataset,
        # one with context_mask=0 and the other context_mask=1
        # we then mix the outputs with the guidance scale, w
        # where w>0 means more guidance

        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        # c_i = torch.arange(0,n_sample).to(device) # context for us just cycles throught the mnist labels
        # c_i = c_i.repeat(int(n_sample/c_i.shape[0]))

        ## create a copy of the context vector and append
        c_i = torch.cat((c_i, c_i), dim=0).to(device)

        # don't drop context at test time
        context_mask = torch.zeros(c_i.size()[0]).to(device)

        # double the batch
        # c_i = c_i.repeat(2)
        # context_mask = context_mask.repeat(2)
        context_mask[n_sample:] = 1. # makes second half of batch context free

        x_i_store = [] # keep track of generated steps in case want to plot something 
        print()
        for i in range(self.n_T, 0, -1):
            print(f'sampling timestep {i}',end='\r')
            t_is = torch.tensor([i / self.n_T]).to(device)
            t_is = t_is.repeat(n_sample,1,1,1)

            # double batch
            x_i = x_i.repeat(2,1,1,1)
            t_is = t_is.repeat(2,1,1,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, c_i, t_is, context_mask)
            eps1 = eps[:n_sample]
            eps2 = eps[n_sample:]
            eps = (1+guide_w)*eps1 - guide_w*eps2
            x_i = x_i[:n_sample]
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())
        
        x_i_store = np.array(x_i_store)
        return x_i, x_i_store

In [8]:

def train_mnist():

    # hardcoding these here
    n_epoch = 20000
    batch_size = 32
    n_T = 400 # 500
    device = "cuda"
    n_classes = 768 # text emb length
    n_feat = 256 # text emb projection layer
    lrate = 1e-4
    save_model = True
    save_dir = './diffusion_outputs/'
    ws_test = [0.5] # 0, .5, 2

    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)

    # ddpm.load_state_dict(torch.load("./diffusion_outputs/model.pth"))

    tf = transforms.Compose([transforms.Grayscale(), transforms.Resize((28,28)), transforms.ToTensor()]) # mnist is already normalised 0 to 1

    # dataset = MNIST("./data", train=True, download=True, transform=tf)
    dataset = CustomDataset(pd.read_csv("./captions.csv"), transform=tf)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=5)
    optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)

    for ep in range(n_epoch):
        print(f'epoch {ep}')
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lrate*(1-ep/n_epoch)

        pbar = tqdm(dataloader)
        loss_ema = None
        for x, c in pbar:
            optim.zero_grad()
            x = x.to(device)
            c = c.to(device)
            loss = ddpm(x, c)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            pbar.set_description(f"loss: {loss_ema:.4f}")
            optim.step()
        
        # for eval, save an image of currently generated samples (top rows)
        # followed by real images (bottom rows)
        if ep % 1000 == 0 or ep == n_epoch-1:
            ddpm.eval()
            with torch.no_grad():
                
                for w_i, w in enumerate(ws_test):
                    context_sample = torch.cat((torch.load("./pipeline/text/55.pt").unsqueeze(0),
                                    torch.load("./pipeline/text/134.pt").unsqueeze(0),
                                    torch.load("./pipeline/text/407.pt").unsqueeze(0)), dim=0) # sample: pizza, photo, birthday
                    n_sample = context_sample.size()[0]
                    x_gen, x_gen_store = ddpm.sample(context_sample, n_sample, (1, 28, 28), device, guide_w=w)
        
                    grid = make_grid(x_gen*-1 + 1, nrow=n_sample)
                    save_image(grid, save_dir + f"image_ep{ep}_w{w}_sub.png")
                    print('saved image at ' + save_dir + f"image_ep{ep}_w{w}_sub.png")
    
        if save_model and ep == n_epoch-1:
            torch.save(ddpm.state_dict(), save_dir + f"model.pth")
            print('saved model at ' + save_dir + f"model.pth")
                

In [9]:
class CustomDataset(Dataset):

    def __init__(self, df, transform=None):
        value_counts = df['id'].value_counts()
        values_to_keep = value_counts[(value_counts >= 75) & (value_counts <= 100)].index
        df = df[df['id'].isin(values_to_keep)]
        df['id'] = pd.factorize(df['caption'])[0]
        self.df = df
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        # label = self.df['id'].iloc[idx]
        emb = self.df["emb"].iloc[idx]
        emb = torch.load("./pipeline/text/"+emb).squeeze()
        
        img = Image.open("./cocodata/"+str(self.df['image'].iloc[idx]))
        if self.transform:
            img = self.transform(img)

        return img, emb

In [None]:
train_mnist()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df['id'] = pd.factorize(df['caption'])[0]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 0


loss: 0.3722: 100%|██████████| 30/30 [00:02<00:00, 14.01it/s]



sampling timestep 490

  0%|          | 0/30 [00:00<?, ?it/s]

saved image at ./diffusion_outputs/image_ep0_w0.5_sub.png
epoch 1


loss: 0.0962: 100%|██████████| 30/30 [00:01<00:00, 23.20it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 2


loss: 0.0857: 100%|██████████| 30/30 [00:01<00:00, 23.03it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 3


loss: 0.0624: 100%|██████████| 30/30 [00:01<00:00, 23.15it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 4


loss: 0.0940: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 5


loss: 0.0595: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 6


loss: 0.0624: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 7


loss: 0.0655: 100%|██████████| 30/30 [00:01<00:00, 23.27it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 8


loss: 0.0596: 100%|██████████| 30/30 [00:01<00:00, 23.35it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 9


loss: 0.0638: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 10


loss: 0.0531: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 11


loss: 0.0625: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 12


loss: 0.0679: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 13


loss: 0.0600: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 14


loss: 0.0807: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 15


loss: 0.0565: 100%|██████████| 30/30 [00:01<00:00, 23.33it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 16


loss: 0.0613: 100%|██████████| 30/30 [00:01<00:00, 23.13it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 17


loss: 0.0613: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 18


loss: 0.0614: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 19


loss: 0.0618: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 20


loss: 0.0517: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 21


loss: 0.0536: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 22


loss: 0.0540: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 23


loss: 0.0506: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 24


loss: 0.0582: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 25


loss: 0.0621: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 26


loss: 0.0550: 100%|██████████| 30/30 [00:01<00:00, 22.40it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 27


loss: 0.0581: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 28


loss: 0.0519: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 29


loss: 0.0587: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 30


loss: 0.0567: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 31


loss: 0.0433: 100%|██████████| 30/30 [00:01<00:00, 22.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 32


loss: 0.0512: 100%|██████████| 30/30 [00:01<00:00, 22.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 33


loss: 0.0500: 100%|██████████| 30/30 [00:01<00:00, 23.13it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 34


loss: 0.0637: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 35


loss: 0.0615: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 36


loss: 0.0548: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 37


loss: 0.0499: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 38


loss: 0.0462: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 39


loss: 0.0477: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 40


loss: 0.0466: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 41


loss: 0.0507: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 42


loss: 0.0561: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 43


loss: 0.0560: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 44


loss: 0.0540: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 45


loss: 0.0692: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 46


loss: 0.0646: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 47


loss: 0.0471: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 48


loss: 0.0547: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 49


loss: 0.0519: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 50


loss: 0.0549: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 51


loss: 0.0526: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 52


loss: 0.0478: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 53


loss: 0.0530: 100%|██████████| 30/30 [00:01<00:00, 23.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 54


loss: 0.0460: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 55


loss: 0.0485: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 56


loss: 0.0492: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 57


loss: 0.0494: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 58


loss: 0.1061: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 59


loss: 0.0746: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 60


loss: 0.0601: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 61


loss: 0.0534: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 62


loss: 0.0429: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 63


loss: 0.0629: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 64


loss: 0.0843: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 65


loss: 0.0588: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 66


loss: 0.0455: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 67


loss: 0.0584: 100%|██████████| 30/30 [00:01<00:00, 22.61it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 68


loss: 0.0861: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 69


loss: 0.0541: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 70


loss: 0.0485: 100%|██████████| 30/30 [00:01<00:00, 22.55it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 71


loss: 0.0539: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 72


loss: 0.0529: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 73


loss: 0.0484: 100%|██████████| 30/30 [00:01<00:00, 22.61it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 74


loss: 0.0432: 100%|██████████| 30/30 [00:01<00:00, 22.61it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 75


loss: 0.0484: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 76


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 77


loss: 0.0472: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 78


loss: 0.0528: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 79


loss: 0.0519: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 80


loss: 0.0457: 100%|██████████| 30/30 [00:01<00:00, 22.55it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 81


loss: 0.0487: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 82


loss: 0.0418: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 83


loss: 0.0422: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 84


loss: 0.0502: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 85


loss: 0.0473: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 86


loss: 0.0495: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 87


loss: 0.0403: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 88


loss: 0.0455: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 89


loss: 0.0522: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 90


loss: 0.0472: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 91


loss: 0.0472: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 92


loss: 0.0452: 100%|██████████| 30/30 [00:01<00:00, 22.49it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 93


loss: 0.0563: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 94


loss: 0.0410: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 95


loss: 0.0506: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 96


loss: 0.0550: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 97


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 98


loss: 0.0450: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 99


loss: 0.0518: 100%|██████████| 30/30 [00:01<00:00, 22.54it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 100


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 101


loss: 0.0432: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 102


loss: 0.0474: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 103


loss: 0.0414: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 104


loss: 0.0417: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 105


loss: 0.0427: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 106


loss: 0.0446: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 107


loss: 0.0484: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 108


loss: 0.0395: 100%|██████████| 30/30 [00:01<00:00, 22.63it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 109


loss: 0.0704: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 110


loss: 0.0653: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 111


loss: 0.0496: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 112


loss: 0.0620: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 113


loss: 0.0456: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 114


loss: 0.0873: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 115


loss: 0.0503: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 116


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 117


loss: 0.0426: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 118


loss: 0.0657: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 119


loss: 0.0553: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 120


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 121


loss: 0.0767: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 122


loss: 0.0519: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 123


loss: 0.0590: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 124


loss: 0.0410: 100%|██████████| 30/30 [00:01<00:00, 22.48it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 125


loss: 0.0433: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 126


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 127


loss: 0.0562: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 128


loss: 0.0520: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 129


loss: 0.0438: 100%|██████████| 30/30 [00:01<00:00, 23.08it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 130


loss: 0.0469: 100%|██████████| 30/30 [00:01<00:00, 22.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 131


loss: 0.0445: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 132


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 133


loss: 0.0475: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 134


loss: 0.0477: 100%|██████████| 30/30 [00:01<00:00, 23.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 135


loss: 0.0500: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 136


loss: 0.0530: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 137


loss: 0.0423: 100%|██████████| 30/30 [00:01<00:00, 22.54it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 138


loss: 0.0482: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 139


loss: 0.0408: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 140


loss: 0.0474: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 141


loss: 0.0491: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 142


loss: 0.0447: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 143


loss: 0.0432: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 144


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 145


loss: 0.0576: 100%|██████████| 30/30 [00:01<00:00, 22.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 146


loss: 0.0430: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 147


loss: 0.0451: 100%|██████████| 30/30 [00:01<00:00, 23.15it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 148


loss: 0.0413: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 149


loss: 0.0452: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 150


loss: 0.0530: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 151


loss: 0.0525: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 152


loss: 0.0533: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 153


loss: 0.0435: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 154


loss: 0.0458: 100%|██████████| 30/30 [00:01<00:00, 23.08it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 155


loss: 0.0414: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 156


loss: 0.0424: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 157


loss: 0.0453: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 158


loss: 0.0399: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 159


loss: 0.0332: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 160


loss: 0.0461: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 161


loss: 0.0465: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 162


loss: 0.0498: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 163


loss: 0.0459: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 164


loss: 0.0512: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 165


loss: 0.0659: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 166


loss: 0.0500: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 167


loss: 0.0436: 100%|██████████| 30/30 [00:01<00:00, 23.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 168


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 169


loss: 0.0395: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 170


loss: 0.0566: 100%|██████████| 30/30 [00:01<00:00, 23.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 171


loss: 0.0440: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 172


loss: 0.0408: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 173


loss: 0.0496: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 174


loss: 0.0444: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 175


loss: 0.0444: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 176


loss: 0.0519: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 177


loss: 0.0452: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 178


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 179


loss: 0.0430: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 180


loss: 0.0418: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 181


loss: 0.0417: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 182


loss: 0.0455: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 183


loss: 0.0449: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 184


loss: 0.0443: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 185


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 22.48it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 186


loss: 0.0424: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 187


loss: 0.0415: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 188


loss: 0.0483: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 189


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 190


loss: 0.0387: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 191


loss: 0.0518: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 192


loss: 0.0553: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 193


loss: 0.0419: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 194


loss: 0.0429: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 195


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 196


loss: 0.0479: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 197


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 22.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 198


loss: 0.0431: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 199


loss: 0.0469: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 200


loss: 0.0424: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 201


loss: 0.0442: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 202


loss: 0.0468: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 203


loss: 0.0392: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 204


loss: 0.0478: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 205


loss: 0.0472: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 206


loss: 0.0458: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 207


loss: 0.0666: 100%|██████████| 30/30 [00:01<00:00, 22.51it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 208


loss: 0.0458: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 209


loss: 0.0415: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 210


loss: 0.0490: 100%|██████████| 30/30 [00:01<00:00, 23.07it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 211


loss: 0.0466: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 212


loss: 0.0470: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 213


loss: 0.0596: 100%|██████████| 30/30 [00:01<00:00, 22.45it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 214


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 215


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 216


loss: 0.0431: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 217


loss: 0.0430: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 218


loss: 0.0530: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 219


loss: 0.0460: 100%|██████████| 30/30 [00:01<00:00, 22.57it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 220


loss: 0.0481: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 221


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 22.55it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 222


loss: 0.0521: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 223


loss: 0.0429: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 224


loss: 0.0489: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 225


loss: 0.0459: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 226


loss: 0.0400: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 227


loss: 0.0389: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 228


loss: 0.0459: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 229


loss: 0.0370: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 230


loss: 0.0421: 100%|██████████| 30/30 [00:01<00:00, 23.03it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 231


loss: 0.0431: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 232


loss: 0.0422: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 233


loss: 0.0429: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 234


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 235


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 236


loss: 0.0458: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 237


loss: 0.0442: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 238


loss: 0.0465: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 239


loss: 0.0603: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 240


loss: 0.0480: 100%|██████████| 30/30 [00:01<00:00, 22.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 241


loss: 0.0429: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 242


loss: 0.0412: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 243


loss: 0.0550: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 244


loss: 0.0347: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 245


loss: 0.0483: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 246


loss: 0.0405: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 247


loss: 0.0453: 100%|██████████| 30/30 [00:01<00:00, 22.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 248


loss: 0.0345: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 249


loss: 0.0407: 100%|██████████| 30/30 [00:01<00:00, 22.42it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 250


loss: 0.0412: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 251


loss: 0.0425: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 252


loss: 0.0479: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 253


loss: 0.0390: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 254


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 255


loss: 0.0368: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 256


loss: 0.0370: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 257


loss: 0.0432: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 258


loss: 0.0463: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 259


loss: 0.0385: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 260


loss: 0.0373: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 261


loss: 0.0453: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 262


loss: 0.0475: 100%|██████████| 30/30 [00:01<00:00, 22.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 263


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 264


loss: 0.0529: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 265


loss: 0.0343: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 266


loss: 0.0345: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 267


loss: 0.0426: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 268


loss: 0.0402: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 269


loss: 0.0446: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 270


loss: 0.0448: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 271


loss: 0.0628: 100%|██████████| 30/30 [00:01<00:00, 22.37it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 272


loss: 0.0520: 100%|██████████| 30/30 [00:01<00:00, 22.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 273


loss: 0.0481: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 274


loss: 0.0461: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 275


loss: 0.0520: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 276


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 277


loss: 0.0416: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 278


loss: 0.0609: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 279


loss: 0.0427: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 280


loss: 0.0482: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 281


loss: 0.0381: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 282


loss: 0.0399: 100%|██████████| 30/30 [00:01<00:00, 23.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 283


loss: 0.0388: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 284


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 285


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 286


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 287


loss: 0.0453: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 288


loss: 0.0407: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 289


loss: 0.0347: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 290


loss: 0.0375: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 291


loss: 0.0392: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 292


loss: 0.0390: 100%|██████████| 30/30 [00:01<00:00, 23.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 293


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 294


loss: 0.0477: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 295


loss: 0.0419: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 296


loss: 0.0493: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 297


loss: 0.0399: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 298


loss: 0.0488: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 299


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 300


loss: 0.0390: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 301


loss: 0.0493: 100%|██████████| 30/30 [00:01<00:00, 22.54it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 302


loss: 0.0593: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 303


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 23.02it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 304


loss: 0.0377: 100%|██████████| 30/30 [00:01<00:00, 22.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 305


loss: 0.0430: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 306


loss: 0.0467: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 307


loss: 0.0474: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 308


loss: 0.0405: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 309


loss: 0.0414: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 310


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 311


loss: 0.0381: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 312


loss: 0.0406: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 313


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 314


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 21.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 315


loss: 0.0431: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 316


loss: 0.0374: 100%|██████████| 30/30 [00:01<00:00, 21.12it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 317


loss: 0.0419: 100%|██████████| 30/30 [00:01<00:00, 20.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 318


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 319


loss: 0.0425: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 320


loss: 0.0439: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 321


loss: 0.0388: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 322


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 323


loss: 0.0450: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 324


loss: 0.0553: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 325


loss: 0.0408: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 326


loss: 0.0471: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 327


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 328


loss: 0.0321: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 329


loss: 0.0356: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 330


loss: 0.0511: 100%|██████████| 30/30 [00:01<00:00, 23.14it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 331


loss: 0.0377: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 332


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 333


loss: 0.0414: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 334


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 335


loss: 0.0405: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 336


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 337


loss: 0.0494: 100%|██████████| 30/30 [00:01<00:00, 23.24it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 338


loss: 0.0405: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 339


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 340


loss: 0.0431: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 341


loss: 0.0391: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 342


loss: 0.0463: 100%|██████████| 30/30 [00:01<00:00, 22.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 343


loss: 0.0395: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 344


loss: 0.0346: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 345


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 346


loss: 0.0456: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 347


loss: 0.0434: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 348


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 349


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 350


loss: 0.0405: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 351


loss: 0.0483: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 352


loss: 0.0393: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 353


loss: 0.0412: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 354


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 355


loss: 0.0622: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 356


loss: 0.0363: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 357


loss: 0.0389: 100%|██████████| 30/30 [00:01<00:00, 22.39it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 358


loss: 0.0409: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 359


loss: 0.0306: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 360


loss: 0.0322: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 361


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.54it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 362


loss: 0.0427: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 363


loss: 0.0332: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 364


loss: 0.0424: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 365


loss: 0.0388: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 366


loss: 0.0406: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 367


loss: 0.0348: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 368


loss: 0.0373: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 369


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 370


loss: 0.0454: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 371


loss: 0.0408: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 372


loss: 0.0406: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 373


loss: 0.0349: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 374


loss: 0.0372: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 375


loss: 0.0372: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 376


loss: 0.0336: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 377


loss: 0.0433: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 378


loss: 0.0572: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 379


loss: 0.0337: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 380


loss: 0.0461: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 381


loss: 0.0386: 100%|██████████| 30/30 [00:01<00:00, 23.16it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 382


loss: 0.0467: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 383


loss: 0.0346: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 384


loss: 0.0407: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 385


loss: 0.0526: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 386


loss: 0.0404: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 387


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 388


loss: 0.0383: 100%|██████████| 30/30 [00:01<00:00, 22.38it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 389


loss: 0.0441: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 390


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 23.08it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 391


loss: 0.0394: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 392


loss: 0.0493: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 393


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 394


loss: 0.0407: 100%|██████████| 30/30 [00:01<00:00, 22.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 395


loss: 0.0470: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 396


loss: 0.0349: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 397


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 398


loss: 0.0472: 100%|██████████| 30/30 [00:01<00:00, 22.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 399


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 400


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 401


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 402


loss: 0.0380: 100%|██████████| 30/30 [00:01<00:00, 22.60it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 403


loss: 0.0375: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 404


loss: 0.0412: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 405


loss: 0.0390: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 406


loss: 0.0723: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 407


loss: 0.0402: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 408


loss: 0.0394: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 409


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 410


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 411


loss: 0.0355: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 412


loss: 0.0428: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 413


loss: 0.0359: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 414


loss: 0.0445: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 415


loss: 0.0466: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 416


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 417


loss: 0.0363: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 418


loss: 0.0368: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 419


loss: 0.0314: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 420


loss: 0.0306: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 421


loss: 0.0421: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 422


loss: 0.0362: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 423


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 424


loss: 0.0337: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 425


loss: 0.0399: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 426


loss: 0.0408: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 427


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 428


loss: 0.0344: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 429


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 430


loss: 0.0374: 100%|██████████| 30/30 [00:01<00:00, 22.63it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 431


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 432


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 433


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 434


loss: 0.0400: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 435


loss: 0.0314: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 436


loss: 0.0330: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 437


loss: 0.0325: 100%|██████████| 30/30 [00:01<00:00, 22.45it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 438


loss: 0.0311: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 439


loss: 0.0400: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 440


loss: 0.0760: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 441


loss: 0.0395: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 442


loss: 0.0417: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 443


loss: 0.0447: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 444


loss: 0.0351: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 445


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 446


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 447


loss: 0.0345: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 448


loss: 0.0346: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 449


loss: 0.0332: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 450


loss: 0.0370: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 451


loss: 0.0373: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 452


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 453


loss: 0.0400: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 454


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 455


loss: 0.0348: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 456


loss: 0.0350: 100%|██████████| 30/30 [00:01<00:00, 22.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 457


loss: 0.0312: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 458


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 459


loss: 0.0285: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 460


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 461


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 462


loss: 0.0357: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 463


loss: 0.0335: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 464


loss: 0.0479: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 465


loss: 0.0362: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 466


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 467


loss: 0.0418: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 468


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 469


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.30it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 470


loss: 0.0368: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 471


loss: 0.0286: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 472


loss: 0.0365: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 473


loss: 0.0366: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 474


loss: 0.0350: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 475


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 23.02it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 476


loss: 0.0514: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 477


loss: 0.0356: 100%|██████████| 30/30 [00:01<00:00, 22.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 478


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 23.02it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 479


loss: 0.0384: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 480


loss: 0.0336: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 481


loss: 0.0396: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 482


loss: 0.0333: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 483


loss: 0.0360: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 484


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 485


loss: 0.0360: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 486


loss: 0.0383: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 487


loss: 0.0353: 100%|██████████| 30/30 [00:01<00:00, 22.53it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 488


loss: 0.0410: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 489


loss: 0.0370: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 490


loss: 0.0321: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 491


loss: 0.0438: 100%|██████████| 30/30 [00:01<00:00, 23.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 492


loss: 0.0411: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 493


loss: 0.0414: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 494


loss: 0.0380: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 495


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 496


loss: 0.0360: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 497


loss: 0.0358: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 498


loss: 0.0360: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 499


loss: 0.0374: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 500


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 501


loss: 0.0365: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 502


loss: 0.0338: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 503


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 504


loss: 0.0382: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 505


loss: 0.0328: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 506


loss: 0.0346: 100%|██████████| 30/30 [00:01<00:00, 22.60it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 507


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 508


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 22.63it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 509


loss: 0.0325: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 510


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 511


loss: 0.0352: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 512


loss: 0.0350: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 513


loss: 0.0355: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 514


loss: 0.0316: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 515


loss: 0.0288: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 516


loss: 0.0343: 100%|██████████| 30/30 [00:01<00:00, 22.65it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 517


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 518


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 519


loss: 0.0353: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 520


loss: 0.0298: 100%|██████████| 30/30 [00:01<00:00, 23.07it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 521


loss: 0.0419: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 522


loss: 0.0394: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 523


loss: 0.0372: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 524


loss: 0.0433: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 525


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.30it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 526


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 527


loss: 0.0292: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 528


loss: 0.0342: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 529


loss: 0.0355: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 530


loss: 0.0325: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 531


loss: 0.0362: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 532


loss: 0.0284: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 533


loss: 0.0345: 100%|██████████| 30/30 [00:01<00:00, 23.06it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 534


loss: 0.0342: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 535


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 536


loss: 0.0310: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 537


loss: 0.0337: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 538


loss: 0.0335: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 539


loss: 0.0330: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 540


loss: 0.0831: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 541


loss: 0.0450: 100%|██████████| 30/30 [00:01<00:00, 22.43it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 542


loss: 0.0348: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 543


loss: 0.0365: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 544


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 545


loss: 0.0326: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 546


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 547


loss: 0.0389: 100%|██████████| 30/30 [00:01<00:00, 23.11it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 548


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 549


loss: 0.0311: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 550


loss: 0.0323: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 551


loss: 0.0439: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 552


loss: 0.0317: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 553


loss: 0.0756: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 554


loss: 0.0438: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 555


loss: 0.0291: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 556


loss: 0.0353: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 557


loss: 0.0415: 100%|██████████| 30/30 [00:01<00:00, 22.33it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 558


loss: 0.0306: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 559


loss: 0.0274: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 560


loss: 0.0265: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 561


loss: 0.0487: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 562


loss: 0.0348: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 563


loss: 0.0336: 100%|██████████| 30/30 [00:01<00:00, 23.02it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 564


loss: 0.0340: 100%|██████████| 30/30 [00:01<00:00, 22.39it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 565


loss: 0.0340: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 566


loss: 0.0265: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 567


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 568


loss: 0.0398: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 569


loss: 0.0313: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 570


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 571


loss: 0.0462: 100%|██████████| 30/30 [00:01<00:00, 22.42it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 572


loss: 0.0356: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 573


loss: 0.0338: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 574


loss: 0.0371: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 575


loss: 0.0434: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 576


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 577


loss: 0.0349: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 578


loss: 0.0332: 100%|██████████| 30/30 [00:01<00:00, 22.59it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 579


loss: 0.0293: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 580


loss: 0.0376: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 581


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 582


loss: 0.0338: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 583


loss: 0.0406: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 584


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 585


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.95it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 586


loss: 0.0321: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 587


loss: 0.0828: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 588


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 589


loss: 0.0364: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 590


loss: 0.0423: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 591


loss: 0.0356: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 592


loss: 0.0365: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 593


loss: 0.0381: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 594


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 595


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 596


loss: 0.0318: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 597


loss: 0.0315: 100%|██████████| 30/30 [00:01<00:00, 22.99it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 598


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 599


loss: 0.0325: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 600


loss: 0.0316: 100%|██████████| 30/30 [00:01<00:00, 22.94it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 601


loss: 0.0350: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 602


loss: 0.0316: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 603


loss: 0.0314: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 604


loss: 0.0257: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 605


loss: 0.0286: 100%|██████████| 30/30 [00:01<00:00, 22.61it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 606


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.59it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 607


loss: 0.0352: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 608


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 609


loss: 0.0295: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 610


loss: 0.0303: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 611


loss: 0.0357: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 612


loss: 0.0304: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 613


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 614


loss: 0.0291: 100%|██████████| 30/30 [00:01<00:00, 23.00it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 615


loss: 0.0354: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 616


loss: 0.0246: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 617


loss: 0.0312: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 618


loss: 0.0338: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 619


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 620


loss: 0.0397: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 621


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 622


loss: 0.0313: 100%|██████████| 30/30 [00:01<00:00, 22.63it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 623


loss: 0.0299: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 624


loss: 0.0348: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 625


loss: 0.0309: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 626


loss: 0.0383: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 627


loss: 0.0332: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 628


loss: 0.0282: 100%|██████████| 30/30 [00:01<00:00, 22.43it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 629


loss: 0.0281: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 630


loss: 0.0279: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 631


loss: 0.0287: 100%|██████████| 30/30 [00:01<00:00, 22.46it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 632


loss: 0.0247: 100%|██████████| 30/30 [00:01<00:00, 22.91it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 633


loss: 0.0267: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 634


loss: 0.0282: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 635


loss: 0.0331: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 636


loss: 0.0491: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 637


loss: 0.0286: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 638


loss: 0.0281: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 639


loss: 0.0300: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 640


loss: 0.0304: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 641


loss: 0.0284: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 642


loss: 0.0289: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 643


loss: 0.0401: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 644


loss: 0.0373: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 645


loss: 0.0575: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 646


loss: 0.0286: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 647


loss: 0.0295: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 648


loss: 0.0272: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 649


loss: 0.0347: 100%|██████████| 30/30 [00:01<00:00, 23.15it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 650


loss: 0.0423: 100%|██████████| 30/30 [00:01<00:00, 22.63it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 651


loss: 0.0285: 100%|██████████| 30/30 [00:01<00:00, 22.49it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 652


loss: 0.0299: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 653


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 654


loss: 0.0335: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 655


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 656


loss: 0.0355: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 657


loss: 0.0388: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 658


loss: 0.0279: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 659


loss: 0.0360: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 660


loss: 0.0318: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 661


loss: 0.0326: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 662


loss: 0.0281: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 663


loss: 0.0281: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 664


loss: 0.0352: 100%|██████████| 30/30 [00:01<00:00, 22.62it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 665


loss: 0.0318: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 666


loss: 0.0301: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 667


loss: 0.0287: 100%|██████████| 30/30 [00:01<00:00, 23.14it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 668


loss: 0.0266: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 669


loss: 0.0238: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 670


loss: 0.0269: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 671


loss: 0.0244: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 672


loss: 0.0378: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 673


loss: 0.0338: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 674


loss: 0.0357: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 675


loss: 0.0306: 100%|██████████| 30/30 [00:01<00:00, 22.89it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 676


loss: 0.0273: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 677


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 678


loss: 0.0329: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 679


loss: 0.0289: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 680


loss: 0.0282: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 681


loss: 0.0277: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 682


loss: 0.0256: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 683


loss: 0.0252: 100%|██████████| 30/30 [00:01<00:00, 23.03it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 684


loss: 0.0278: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 685


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.66it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 686


loss: 0.0366: 100%|██████████| 30/30 [00:01<00:00, 22.86it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 687


loss: 0.0303: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 688


loss: 0.0290: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 689


loss: 0.0308: 100%|██████████| 30/30 [00:01<00:00, 23.10it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 690


loss: 0.0309: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 691


loss: 0.0379: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 692


loss: 0.0257: 100%|██████████| 30/30 [00:01<00:00, 23.05it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 693


loss: 0.0274: 100%|██████████| 30/30 [00:01<00:00, 22.81it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 694


loss: 0.0303: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 695


loss: 0.0294: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 696


loss: 0.0336: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 697


loss: 0.0320: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 698


loss: 0.0234: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 699


loss: 0.0267: 100%|██████████| 30/30 [00:01<00:00, 22.56it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 700


loss: 0.0288: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 701


loss: 0.0280: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 702


loss: 0.0283: 100%|██████████| 30/30 [00:01<00:00, 22.50it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 703


loss: 0.0255: 100%|██████████| 30/30 [00:01<00:00, 23.22it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 704


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 23.08it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 705


loss: 0.0389: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 706


loss: 0.0297: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 707


loss: 0.0277: 100%|██████████| 30/30 [00:01<00:00, 22.67it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 708


loss: 0.0301: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 709


loss: 0.0330: 100%|██████████| 30/30 [00:01<00:00, 22.92it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 710


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.70it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 711


loss: 0.0291: 100%|██████████| 30/30 [00:01<00:00, 22.41it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 712


loss: 0.0292: 100%|██████████| 30/30 [00:01<00:00, 22.74it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 713


loss: 0.0284: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 714


loss: 0.0363: 100%|██████████| 30/30 [00:01<00:00, 22.73it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 715


loss: 0.0272: 100%|██████████| 30/30 [00:01<00:00, 22.93it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 716


loss: 0.0291: 100%|██████████| 30/30 [00:01<00:00, 22.88it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 717


loss: 0.0263: 100%|██████████| 30/30 [00:01<00:00, 22.64it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 718


loss: 0.0253: 100%|██████████| 30/30 [00:01<00:00, 22.77it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 719


loss: 0.0235: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 720


loss: 0.0220: 100%|██████████| 30/30 [00:01<00:00, 22.83it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 721


loss: 0.0233: 100%|██████████| 30/30 [00:01<00:00, 23.04it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 722


loss: 0.0259: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 723


loss: 0.0241: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 724


loss: 0.0301: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 725


loss: 0.0294: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 726


loss: 0.0256: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 727


loss: 0.0262: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 728


loss: 0.0289: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 729


loss: 0.0247: 100%|██████████| 30/30 [00:01<00:00, 22.71it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 730


loss: 0.0259: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 731


loss: 0.0269: 100%|██████████| 30/30 [00:01<00:00, 22.97it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 732


loss: 0.0254: 100%|██████████| 30/30 [00:01<00:00, 22.90it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 733


loss: 0.0287: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 734


loss: 0.0267: 100%|██████████| 30/30 [00:01<00:00, 22.47it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 735


loss: 0.0260: 100%|██████████| 30/30 [00:01<00:00, 22.84it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 736


loss: 0.0294: 100%|██████████| 30/30 [00:01<00:00, 22.75it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 737


loss: 0.0269: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 738


loss: 0.0278: 100%|██████████| 30/30 [00:01<00:00, 22.96it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 739


loss: 0.0296: 100%|██████████| 30/30 [00:01<00:00, 22.58it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 740


loss: 0.0268: 100%|██████████| 30/30 [00:01<00:00, 22.68it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 741


loss: 0.0265: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 742


loss: 0.0242: 100%|██████████| 30/30 [00:01<00:00, 22.79it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 743


loss: 0.0335: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 744


loss: 0.0510: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 745


loss: 0.0249: 100%|██████████| 30/30 [00:01<00:00, 22.85it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 746


loss: 0.0455: 100%|██████████| 30/30 [00:01<00:00, 23.01it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 747


loss: 0.0302: 100%|██████████| 30/30 [00:01<00:00, 22.82it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 748


loss: 0.0292: 100%|██████████| 30/30 [00:01<00:00, 22.78it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 749


loss: 0.0293: 100%|██████████| 30/30 [00:01<00:00, 23.09it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 750


loss: 0.0246: 100%|██████████| 30/30 [00:01<00:00, 22.87it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 751


loss: 0.0273: 100%|██████████| 30/30 [00:01<00:00, 22.98it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 752


loss: 0.0361: 100%|██████████| 30/30 [00:01<00:00, 22.76it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 753


loss: 0.0324: 100%|██████████| 30/30 [00:01<00:00, 22.80it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 754


loss: 0.0285: 100%|██████████| 30/30 [00:01<00:00, 22.69it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 755


loss: 0.0255: 100%|██████████| 30/30 [00:01<00:00, 22.52it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 756


loss: 0.0303: 100%|██████████| 30/30 [00:01<00:00, 22.72it/s]
  0%|          | 0/30 [00:00<?, ?it/s]

epoch 757


loss: 0.0292:  63%|██████▎   | 19/30 [00:00<00:00, 24.12it/s]

In [None]:
def inference():
    n_epoch = 10000
    batch_size = 32
    n_T = 400 # 500
    device = "cuda"
    n_classes = 768 # text emb length
    n_feat = 256 # text emb projection layer
    lrate = 1e-4
    save_model = True
    
    save_dir = './diffusion_outputs/'
    ws_test = [0.5] # 0, .5, 2

    ddpm = DDPM(nn_model=ContextUnet(in_channels=1, n_feat=n_feat, n_classes=n_classes), betas=(1e-4, 0.02), n_T=n_T, device=device, drop_prob=0.1)
    ddpm.to(device)
    ddpm.load_state_dict(torch.load("./diffusion_outputs/model.pth"))
    context_sample = torch.cat((torch.load("./pipeline/text/55.pt").unsqueeze(0),
                                    torch.load("./pipeline/text/134.pt").unsqueeze(0),
                                    torch.load("./pipeline/text/407.pt").unsqueeze(0),
                               torch.zeros((1,768)),
                               torch.zeros((1,768)),
                               torch.zeros((1,768))),dim=0)
    n_sample = context_sample.size()[0]
    with torch.no_grad():
        x_gen, x_gen_store = ddpm.sample(context_sample, n_sample, (1, 28, 28), device, guide_w=0.5)

    grid = make_grid(x_gen*-1 + 1, nrow=n_sample)
    save_image(grid, save_dir + f"inference_sub.png")
    print('saved image at ' + save_dir + f"inference_sub.png")

In [None]:
inference()