# Music 103 diffusion version

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm
from os.path import exists
from os import remove, chdir
import pickle

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        max_len = x.size(1)
        pe = torch.zeros(max_len, self.d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).to(x.device)
        return x + pe

    
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class EmbedHead(nn.Module):
    def __init__(
        self,
        input_dim,
        inner_dim_1,
        inner_dim_2,
        out_dim
    ):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, inner_dim_1)
        self.linear2 = nn.Linear(inner_dim_1, inner_dim_2)
        self.linear3 = nn.Linear(inner_dim_2, out_dim)
        self.activation_fn = nn.functional.gelu

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        x = self.linear2(x)
        x = self.activation_fn(x)
        x = self.linear3(x)
        return x
    

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder_embedding = EmbedHead(src_vocab_size + tgt_vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.time_embeddings = nn.ModuleList([EmbedFC(1, d_model) for _ in range(num_layers)])

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_layer = nn.Sequential(PositionWiseFeedForward(d_model, d_ff), nn.Linear(d_model, tgt_vocab_size))
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, x, time):
        embedding = self.dropout(self.positional_encoding(self.encoder_embedding(torch.cat([src, x], dim=-1))))

        for i, enc_layer in enumerate(self.encoder_layers):
            time_embedding = self.time_embeddings[i](time).unsqueeze(1)
            embedding = enc_layer(embedding + time_embedding, None)
        
        output = self.output_layer(embedding)
        
        return output
    


In [3]:
class VQVAE(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, dropout, codebook_size, d_codebook):
        super().__init__()
        self.encoder_embedding = EmbedHead(vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_output = nn.Linear(d_model, d_codebook)
        self.codebook = nn.Embedding(codebook_size, d_codebook)
        self.codebook.weight.data.uniform_(-1/d_codebook, 1/d_codebook)
        self.decoder_embedding = EmbedHead(d_codebook, d_model, d_model, d_model)
        self.decoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_output = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encode(self, x):
        embedding = self.dropout(self.positional_encoding(self.encoder_embedding(x)))
        for i, enc_layer in enumerate(self.encoder_layers):
            embedding = enc_layer(embedding, None)
        return self.encoder_output(embedding)
    
    def vq(self, z):
        # z: [batch_size, seq_length, d_codebook]
        distance = (z.unsqueeze(2) - self.codebook.weight.unsqueeze(0).unsqueeze(0)).pow(2).mean(dim=-1)
        _, indices = torch.min(distance, dim=-1)
        # print(indices)
        return self.codebook(indices)

    def decode(self, z):
        embedding = self.dropout(self.positional_encoding(self.decoder_embedding(z)))
        for i, dec_layer in enumerate(self.decoder_layers):
            embedding = dec_layer(embedding, None)
        return torch.sigmoid(self.decoder_output(embedding))
    
    def forward(self, x):
        # x: [batch_size, seq_length, vocab_size]
        z = self.encode(x)
        z_vq = self.vq(z)
        z_straight_through = (z_vq - z).detach() + z
        x_recon = self.decode(z_straight_through)
        recon_loss = nn.functional.binary_cross_entropy(x_recon, x)
        embed_loss = nn.functional.mse_loss(z_vq, z.detach())
        commit_loss = nn.functional.mse_loss(z, z_vq.detach())
        return x_recon, recon_loss, embed_loss, commit_loss


In [4]:
from tqdm import tqdm

def train_VQVAE(vqvae, optim, trainset, validset, lr, n_epoch, device, patience, model_path, alpha=0.5, beta=1):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        vqvae.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        loss_ema = None
        # train
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            _, recon_loss, embed_loss, commit_loss = vqvae(tgt)
            loss = recon_loss + beta * embed_loss + alpha * commit_loss
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            optim.step()
            
        # validation
        vqvae.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                _, recon_loss, embed_loss, commit_loss = vqvae(tgt)
                loss = recon_loss
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(vqvae.state_dict(), model_path)
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

In [5]:
# hardcoding these here
n_epoch = 1000
n_T = 1000
n_feat = 128
lr = 1e-4
ws_test = [0.0, 0.5, 2.0]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

src_vocab_size = 12
tgt_vocab_size = 12
d_model = 512
num_heads = 8
num_layers = 4
d_ff = 4096//8
max_seq_length = 2400
dropout = 0.1
batchsize = 16
mode = "train"


if exists("trainset_w.pkl") and exists("validset_w.pkl") and exists("testset_w.pkl"):
    print("splitted dataset found!")
    with open("trainset_w.pkl", "rb") as f:
        trainset = pickle.load(f)
    with open("validset_w.pkl", "rb") as f:
        validset = pickle.load(f)
    with open("testset_w.pkl", "rb") as f:
        testset = pickle.load(f)
else:
    print("?")

def collate_fn(batch):
    # Unpack batch into individual components
    idx, src_data, tgt_data, w = zip(*batch)
    #print(len(rates[0]), len(tgt_data[0]), len(src_data[0]))
    
    # Convert `src_data`, `tgt_data`, and `rates` to tensors if they are not already
    src_data = [torch.tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s for s in src_data]
    tgt_data = [torch.tensor(t, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t for t in tgt_data]

    src_data = [torch.cat([s], dim=-1) for s in src_data]
    tgt_data = [torch.cat([t], dim=-1) for t in tgt_data]

    # Pad src_data
    src_data = nn.utils.rnn.pad_sequence(src_data, batch_first=True, padding_value=0.).to(DEVICE)

    # Pad tgt_data
    tgt_data = nn.utils.rnn.pad_sequence(tgt_data, batch_first=True, padding_value=0).to(DEVICE)

    # Extract the last dimension and one-hot encode it
    return idx, src_data, tgt_data


trainset = data.DataLoader(trainset, batch_size=batchsize, collate_fn=collate_fn)
validset = data.DataLoader(validset, batch_size=1, collate_fn=collate_fn)
testset = data.DataLoader(testset, batch_size=1, collate_fn=collate_fn)

splitted dataset found!


In [11]:
lr = 1e-3
vqvae = VQVAE(tgt_vocab_size, 512, num_heads, 1, d_ff, dropout, 128, 12).to(DEVICE)
optim = torch.optim.Adam(vqvae.parameters(), lr=lr)
train_VQVAE(vqvae, optim, trainset, validset, lr, n_epoch, DEVICE, 20, "model_best_vqvae_128.pt")

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 1/1000 [00:02<38:58,  2.34s/it]

epoch 0, train_loss: 0.5377, valid loss: 0.5933


  0%|          | 2/1000 [00:04<38:51,  2.34s/it]

epoch 1, train_loss: 0.8322, valid loss: 0.4648


  0%|          | 3/1000 [00:07<38:53,  2.34s/it]

epoch 2, train_loss: 0.6637, valid loss: 0.4245


  0%|          | 4/1000 [00:09<38:47,  2.34s/it]

epoch 3, train_loss: 0.7065, valid loss: 0.3715


  0%|          | 5/1000 [00:11<38:40,  2.33s/it]

epoch 4, train_loss: 0.5143, valid loss: 0.3536


  1%|          | 6/1000 [00:13<38:34,  2.33s/it]

epoch 5, train_loss: 0.4842, valid loss: 0.3082


  1%|          | 7/1000 [00:16<38:30,  2.33s/it]

epoch 6, train_loss: 0.3760, valid loss: 0.3610, min_valid_loss: 0.3082, wait: 0 / 20


  1%|          | 8/1000 [00:18<38:25,  2.32s/it]

epoch 7, train_loss: 0.3826, valid loss: 0.3064


  1%|          | 9/1000 [00:20<38:21,  2.32s/it]

epoch 8, train_loss: 0.3057, valid loss: 0.2831


  1%|          | 10/1000 [00:23<38:23,  2.33s/it]

epoch 9, train_loss: 0.3264, valid loss: 0.2871, min_valid_loss: 0.2831, wait: 0 / 20


  1%|          | 11/1000 [00:25<38:21,  2.33s/it]

epoch 10, train_loss: 0.2794, valid loss: 0.2352


  1%|          | 12/1000 [00:27<38:17,  2.33s/it]

epoch 11, train_loss: 0.2392, valid loss: 0.2494, min_valid_loss: 0.2352, wait: 0 / 20


  1%|▏         | 13/1000 [00:30<38:12,  2.32s/it]

epoch 12, train_loss: 0.2022, valid loss: 0.2151


  1%|▏         | 14/1000 [00:32<38:09,  2.32s/it]

epoch 13, train_loss: 0.1828, valid loss: 0.2024


  2%|▏         | 15/1000 [00:34<38:02,  2.32s/it]

epoch 14, train_loss: 0.1793, valid loss: 0.1921


  2%|▏         | 16/1000 [00:37<37:41,  2.30s/it]

epoch 15, train_loss: 0.1560, valid loss: 0.1981, min_valid_loss: 0.1921, wait: 0 / 20


  2%|▏         | 17/1000 [00:39<37:35,  2.29s/it]

epoch 16, train_loss: 0.1519, valid loss: 0.1718


  2%|▏         | 18/1000 [00:41<37:30,  2.29s/it]

epoch 17, train_loss: 0.1329, valid loss: 0.1504


  2%|▏         | 19/1000 [00:43<37:24,  2.29s/it]

epoch 18, train_loss: 0.1224, valid loss: 0.1453


  2%|▏         | 20/1000 [00:46<37:18,  2.28s/it]

epoch 19, train_loss: 0.1132, valid loss: 0.1341


  2%|▏         | 21/1000 [00:48<37:25,  2.29s/it]

epoch 20, train_loss: 0.1060, valid loss: 0.1222


  2%|▏         | 22/1000 [00:50<37:11,  2.28s/it]

epoch 21, train_loss: 0.1066, valid loss: 0.1301, min_valid_loss: 0.1222, wait: 0 / 20


  2%|▏         | 23/1000 [00:53<37:11,  2.28s/it]

epoch 22, train_loss: 0.1046, valid loss: 0.1094


  2%|▏         | 24/1000 [00:55<37:01,  2.28s/it]

epoch 23, train_loss: 0.0928, valid loss: 0.1094, min_valid_loss: 0.1094, wait: 0 / 20


  2%|▎         | 25/1000 [00:57<36:59,  2.28s/it]

epoch 24, train_loss: 0.0872, valid loss: 0.1067


  3%|▎         | 26/1000 [00:59<37:00,  2.28s/it]

epoch 25, train_loss: 0.0833, valid loss: 0.1038


  3%|▎         | 27/1000 [01:02<36:59,  2.28s/it]

epoch 26, train_loss: 0.0816, valid loss: 0.0971


  3%|▎         | 28/1000 [01:04<36:57,  2.28s/it]

epoch 27, train_loss: 0.0728, valid loss: 0.0942


  3%|▎         | 29/1000 [01:06<36:52,  2.28s/it]

epoch 28, train_loss: 0.0710, valid loss: 0.0977, min_valid_loss: 0.0942, wait: 0 / 20


  3%|▎         | 30/1000 [01:09<36:56,  2.29s/it]

epoch 29, train_loss: 0.0708, valid loss: 0.0900


  3%|▎         | 31/1000 [01:11<36:44,  2.27s/it]

epoch 30, train_loss: 0.0778, valid loss: 0.1082, min_valid_loss: 0.0900, wait: 0 / 20


  3%|▎         | 32/1000 [01:13<36:36,  2.27s/it]

epoch 31, train_loss: 0.0772, valid loss: 0.0902, min_valid_loss: 0.0900, wait: 1 / 20


  3%|▎         | 33/1000 [01:15<36:30,  2.27s/it]

epoch 32, train_loss: 0.0717, valid loss: 0.0912, min_valid_loss: 0.0900, wait: 2 / 20


  3%|▎         | 34/1000 [01:18<36:33,  2.27s/it]

epoch 33, train_loss: 0.0688, valid loss: 0.0872


  4%|▎         | 35/1000 [01:20<36:33,  2.27s/it]

epoch 34, train_loss: 0.0666, valid loss: 0.0873, min_valid_loss: 0.0872, wait: 0 / 20


  4%|▎         | 36/1000 [01:22<36:54,  2.30s/it]

epoch 35, train_loss: 0.0656, valid loss: 0.0817


  4%|▎         | 37/1000 [01:25<36:41,  2.29s/it]

epoch 36, train_loss: 0.0618, valid loss: 0.0821, min_valid_loss: 0.0817, wait: 0 / 20


  4%|▍         | 38/1000 [01:27<36:39,  2.29s/it]

epoch 37, train_loss: 0.0583, valid loss: 0.0744


  4%|▍         | 39/1000 [01:29<36:28,  2.28s/it]

epoch 38, train_loss: 0.0558, valid loss: 0.0745, min_valid_loss: 0.0744, wait: 0 / 20


  4%|▍         | 40/1000 [01:31<36:20,  2.27s/it]

epoch 39, train_loss: 0.0580, valid loss: 0.0763, min_valid_loss: 0.0744, wait: 1 / 20


  4%|▍         | 41/1000 [01:34<36:12,  2.27s/it]

epoch 40, train_loss: 0.0588, valid loss: 0.0751, min_valid_loss: 0.0744, wait: 2 / 20


  4%|▍         | 42/1000 [01:36<36:14,  2.27s/it]

epoch 41, train_loss: 0.0561, valid loss: 0.0688


  4%|▍         | 43/1000 [01:38<36:08,  2.27s/it]

epoch 42, train_loss: 0.0558, valid loss: 0.0756, min_valid_loss: 0.0688, wait: 0 / 20


  4%|▍         | 44/1000 [01:40<36:02,  2.26s/it]

epoch 43, train_loss: 0.0574, valid loss: 0.0757, min_valid_loss: 0.0688, wait: 1 / 20


  4%|▍         | 45/1000 [01:43<36:27,  2.29s/it]

epoch 44, train_loss: 0.0557, valid loss: 0.0658


  5%|▍         | 46/1000 [01:45<36:33,  2.30s/it]

epoch 45, train_loss: 0.0535, valid loss: 0.0724, min_valid_loss: 0.0658, wait: 0 / 20


  5%|▍         | 47/1000 [01:47<36:37,  2.31s/it]

epoch 46, train_loss: 0.0524, valid loss: 0.0687, min_valid_loss: 0.0658, wait: 1 / 20


  5%|▍         | 48/1000 [01:50<36:20,  2.29s/it]

epoch 47, train_loss: 0.0516, valid loss: 0.0674, min_valid_loss: 0.0658, wait: 2 / 20


  5%|▍         | 49/1000 [01:52<36:07,  2.28s/it]

epoch 48, train_loss: 0.0508, valid loss: 0.0768, min_valid_loss: 0.0658, wait: 3 / 20


  5%|▌         | 50/1000 [01:54<35:55,  2.27s/it]

epoch 49, train_loss: 0.0517, valid loss: 0.0757, min_valid_loss: 0.0658, wait: 4 / 20


  5%|▌         | 51/1000 [01:56<35:49,  2.26s/it]

epoch 50, train_loss: 0.0517, valid loss: 0.0736, min_valid_loss: 0.0658, wait: 5 / 20


  5%|▌         | 52/1000 [01:59<35:51,  2.27s/it]

epoch 51, train_loss: 0.0512, valid loss: 0.0657


  5%|▌         | 53/1000 [02:01<35:45,  2.27s/it]

epoch 52, train_loss: 0.0509, valid loss: 0.0679, min_valid_loss: 0.0657, wait: 0 / 20


  5%|▌         | 54/1000 [02:03<36:02,  2.29s/it]

epoch 53, train_loss: 0.0540, valid loss: 0.0720, min_valid_loss: 0.0657, wait: 1 / 20


  6%|▌         | 55/1000 [02:06<36:03,  2.29s/it]

epoch 54, train_loss: 0.0516, valid loss: 0.0639


  6%|▌         | 56/1000 [02:08<36:00,  2.29s/it]

epoch 55, train_loss: 0.0483, valid loss: 0.0658, min_valid_loss: 0.0639, wait: 0 / 20


  6%|▌         | 57/1000 [02:10<35:48,  2.28s/it]

epoch 56, train_loss: 0.0511, valid loss: 0.0672, min_valid_loss: 0.0639, wait: 1 / 20


  6%|▌         | 58/1000 [02:12<35:40,  2.27s/it]

epoch 57, train_loss: 0.0481, valid loss: 0.0670, min_valid_loss: 0.0639, wait: 2 / 20


  6%|▌         | 59/1000 [02:15<35:44,  2.28s/it]

epoch 58, train_loss: 0.0493, valid loss: 0.0622


  6%|▌         | 60/1000 [02:17<35:50,  2.29s/it]

epoch 59, train_loss: 0.0467, valid loss: 0.0619


  6%|▌         | 61/1000 [02:19<35:39,  2.28s/it]

epoch 60, train_loss: 0.0448, valid loss: 0.0673, min_valid_loss: 0.0619, wait: 0 / 20


  6%|▌         | 62/1000 [02:21<35:31,  2.27s/it]

epoch 61, train_loss: 0.0460, valid loss: 0.0631, min_valid_loss: 0.0619, wait: 1 / 20


  6%|▋         | 63/1000 [02:24<35:50,  2.30s/it]

epoch 62, train_loss: 0.0467, valid loss: 0.0727, min_valid_loss: 0.0619, wait: 2 / 20


  6%|▋         | 64/1000 [02:26<36:03,  2.31s/it]

epoch 63, train_loss: 0.0441, valid loss: 0.0752, min_valid_loss: 0.0619, wait: 3 / 20


  6%|▋         | 65/1000 [02:28<36:07,  2.32s/it]

epoch 64, train_loss: 0.0466, valid loss: 0.0689, min_valid_loss: 0.0619, wait: 4 / 20


  7%|▋         | 66/1000 [02:31<36:07,  2.32s/it]

epoch 65, train_loss: 0.0472, valid loss: 0.0706, min_valid_loss: 0.0619, wait: 5 / 20


  7%|▋         | 67/1000 [02:33<36:15,  2.33s/it]

epoch 66, train_loss: 0.0460, valid loss: 0.0594


  7%|▋         | 68/1000 [02:36<36:12,  2.33s/it]

epoch 67, train_loss: 0.0468, valid loss: 0.0628, min_valid_loss: 0.0594, wait: 0 / 20


  7%|▋         | 69/1000 [02:38<36:10,  2.33s/it]

epoch 68, train_loss: 0.0423, valid loss: 0.0666, min_valid_loss: 0.0594, wait: 1 / 20


  7%|▋         | 70/1000 [02:40<36:07,  2.33s/it]

epoch 69, train_loss: 0.0425, valid loss: 0.0614, min_valid_loss: 0.0594, wait: 2 / 20


  7%|▋         | 71/1000 [02:42<36:05,  2.33s/it]

epoch 70, train_loss: 0.0389, valid loss: 0.0645, min_valid_loss: 0.0594, wait: 3 / 20


  7%|▋         | 72/1000 [02:45<36:03,  2.33s/it]

epoch 71, train_loss: 0.0369, valid loss: 0.0608, min_valid_loss: 0.0594, wait: 4 / 20


  7%|▋         | 73/1000 [02:47<36:01,  2.33s/it]

epoch 72, train_loss: 0.0367, valid loss: 0.0697, min_valid_loss: 0.0594, wait: 5 / 20


  7%|▋         | 74/1000 [02:49<35:59,  2.33s/it]

epoch 73, train_loss: 0.0381, valid loss: 0.0676, min_valid_loss: 0.0594, wait: 6 / 20


  8%|▊         | 75/1000 [02:52<35:57,  2.33s/it]

epoch 74, train_loss: 0.0415, valid loss: 0.0665, min_valid_loss: 0.0594, wait: 7 / 20


  8%|▊         | 76/1000 [02:54<35:54,  2.33s/it]

epoch 75, train_loss: 0.0395, valid loss: 0.0650, min_valid_loss: 0.0594, wait: 8 / 20


  8%|▊         | 77/1000 [02:56<35:50,  2.33s/it]

epoch 76, train_loss: 0.0423, valid loss: 0.0623, min_valid_loss: 0.0594, wait: 9 / 20


  8%|▊         | 78/1000 [02:59<35:54,  2.34s/it]

epoch 77, train_loss: 0.0415, valid loss: 0.0578


  8%|▊         | 79/1000 [03:01<35:49,  2.33s/it]

epoch 78, train_loss: 0.0389, valid loss: 0.0597, min_valid_loss: 0.0578, wait: 0 / 20


  8%|▊         | 80/1000 [03:03<35:47,  2.33s/it]

epoch 79, train_loss: 0.0358, valid loss: 0.0670, min_valid_loss: 0.0578, wait: 1 / 20


  8%|▊         | 81/1000 [03:06<35:53,  2.34s/it]

epoch 80, train_loss: 0.0354, valid loss: 0.0558


  8%|▊         | 82/1000 [03:08<35:46,  2.34s/it]

epoch 81, train_loss: 0.0344, valid loss: 0.0617, min_valid_loss: 0.0558, wait: 0 / 20


  8%|▊         | 83/1000 [03:11<35:41,  2.34s/it]

epoch 82, train_loss: 0.0349, valid loss: 0.0578, min_valid_loss: 0.0558, wait: 1 / 20


  8%|▊         | 84/1000 [03:13<35:40,  2.34s/it]

epoch 83, train_loss: 0.0337, valid loss: 0.0657, min_valid_loss: 0.0558, wait: 2 / 20


  8%|▊         | 85/1000 [03:15<35:37,  2.34s/it]

epoch 84, train_loss: 0.0339, valid loss: 0.0610, min_valid_loss: 0.0558, wait: 3 / 20


  9%|▊         | 86/1000 [03:18<35:35,  2.34s/it]

epoch 85, train_loss: 0.0340, valid loss: 0.0628, min_valid_loss: 0.0558, wait: 4 / 20


  9%|▊         | 87/1000 [03:20<35:32,  2.34s/it]

epoch 86, train_loss: 0.0356, valid loss: 0.0593, min_valid_loss: 0.0558, wait: 5 / 20


  9%|▉         | 88/1000 [03:22<35:29,  2.33s/it]

epoch 87, train_loss: 0.0374, valid loss: 0.0621, min_valid_loss: 0.0558, wait: 6 / 20


  9%|▉         | 89/1000 [03:25<35:25,  2.33s/it]

epoch 88, train_loss: 0.0354, valid loss: 0.0668, min_valid_loss: 0.0558, wait: 7 / 20


  9%|▉         | 90/1000 [03:27<35:21,  2.33s/it]

epoch 89, train_loss: 0.0352, valid loss: 0.0609, min_valid_loss: 0.0558, wait: 8 / 20


  9%|▉         | 91/1000 [03:29<35:20,  2.33s/it]

epoch 90, train_loss: 0.0329, valid loss: 0.0672, min_valid_loss: 0.0558, wait: 9 / 20


  9%|▉         | 92/1000 [03:32<35:18,  2.33s/it]

epoch 91, train_loss: 0.0316, valid loss: 0.0590, min_valid_loss: 0.0558, wait: 10 / 20


  9%|▉         | 93/1000 [03:34<35:14,  2.33s/it]

epoch 92, train_loss: 0.0319, valid loss: 0.0615, min_valid_loss: 0.0558, wait: 11 / 20


  9%|▉         | 94/1000 [03:36<35:12,  2.33s/it]

epoch 93, train_loss: 0.0309, valid loss: 0.0566, min_valid_loss: 0.0558, wait: 12 / 20


 10%|▉         | 95/1000 [03:39<35:09,  2.33s/it]

epoch 94, train_loss: 0.0322, valid loss: 0.0618, min_valid_loss: 0.0558, wait: 13 / 20


 10%|▉         | 96/1000 [03:41<35:06,  2.33s/it]

epoch 95, train_loss: 0.0330, valid loss: 0.0634, min_valid_loss: 0.0558, wait: 14 / 20


 10%|▉         | 97/1000 [03:43<35:03,  2.33s/it]

epoch 96, train_loss: 0.0321, valid loss: 0.0642, min_valid_loss: 0.0558, wait: 15 / 20


 10%|▉         | 98/1000 [03:45<35:02,  2.33s/it]

epoch 97, train_loss: 0.0343, valid loss: 0.0667, min_valid_loss: 0.0558, wait: 16 / 20


 10%|▉         | 99/1000 [03:48<34:58,  2.33s/it]

epoch 98, train_loss: 0.0320, valid loss: 0.0637, min_valid_loss: 0.0558, wait: 17 / 20


 10%|█         | 100/1000 [03:50<34:57,  2.33s/it]

epoch 99, train_loss: 0.0351, valid loss: 0.0650, min_valid_loss: 0.0558, wait: 18 / 20


 10%|█         | 100/1000 [03:52<34:56,  2.33s/it]

epoch 100, train_loss: 0.0387, valid loss: 0.0749, min_valid_loss: 0.0558, wait: 19 / 20





In [7]:
def eval_vqvae(vqvae, checkpoint, testset):
    vqvae.load_state_dict(torch.load(checkpoint))
    vqvae.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 10:
                break
            x_gen, _, _, _ = vqvae(tgt)
            x_gen = (x_gen >= 0.5).long()
            
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

In [12]:
# vqvae = VQVAE(tgt_vocab_size, 256, num_heads, 2, d_ff, dropout, 128, 12).to(DEVICE)
eval_vqvae(vqvae, "model_best_vqvae_128.pt", testset)

  0%|          | 0/100 [00:00<?, ?it/s]

 11%|█         | 11/100 [00:00<00:00, 236.93it/s]


# **1. DDPM**


# a. Building Blocks

# b. DDPM Schedules

In [6]:
def ddpm_schedules(beta1, beta2, T):
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    ##################
    ### Problem 1 (a): Implement ddpm_schedules()
    beta_t = torch.linspace(beta1, beta2, T).float()

    alpha_t = 1 - beta_t
    oneover_sqrta = 1 / torch.sqrt(alpha_t)
    sqrt_beta_t = torch.sqrt(beta_t)
    alphabar_t = torch.cumprod(alpha_t, dim=0)
    sqrtab = torch.sqrt(alphabar_t)
    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / torch.sqrt(1 - alphabar_t)
    ##################
    ##################

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

# c. DDPM Main Module



Here the noise $\sigma_t^2=\beta_t$

In [7]:
class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, n_inference=None, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)
        
        self.n_T = n_T
        self.n_inference = n_inference if n_inference else n_T 
        
        for k, v in ddpm_schedules(betas[0], betas[1], self.n_inference).items():
            self.register_buffer(k+'_KAIMING', v)

        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, src, tgt):
        ##################
        ### Problem 1 (b): Implement forward()
        t = torch.randint(0, self.n_T, (tgt.size(0),), device=self.device)
        sqrtab_t, sqrtmab_t = self.sqrtab[t].view(-1, 1, 1), \
            self.sqrtmab[t].view(-1, 1, 1)

        noise = torch.randn_like(tgt).to(self.device)  # Define noise tensor
        x_t = sqrtab_t * tgt + sqrtmab_t * noise

        # mask out with probability
        # context_mask = torch.bernoulli(torch.zeros(src.shape[0])+self.drop_prob).unsqueeze(-1).unsqueeze(-1).to(self.device)

        pred_noise = self.nn_model(src, x_t, t / (self.n_T - 1))
        loss = self.loss_mse(pred_noise, noise) 
        ##################
        ##################

        return loss

    @torch.no_grad()
    def sample(self, src, guide_w=0.0):
        n_sample = src.shape[0]
        x_i = torch.randn(*src.shape).to(self.device)

        for i in range(int(self.n_inference), 0, -1):

            ##################
            ### Problem 1 (c): Implement sample()
            t = torch.full((n_sample,), (i - 1) / (self.n_inference - 1)).to(self.device).float()
            t_i = t.view(-1, 1, 1)

            # double batch

            z = torch.randn(*src.shape).to(self.device) if i > 1 else 0 # if last step, no noise
            # classifier-free guidance
            pred_full = self.nn_model(src, x_i, t_i)
            x_i = self.oneover_sqrta_KAIMING[i - 1] * (x_i - pred_full * self.mab_over_sqrtmab_KAIMING[i - 1])\
                + self.sqrt_beta_t_KAIMING[i - 1] * z
        return x_i

# c. Training Function

In [8]:
from tqdm import tqdm

def train_main_loop(ddpm, vqvae, optim, trainset, validset, lr, n_epoch, device, guide_w, patience):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        loss_ema = None
        # train
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            tgt_enc = vqvae.encode(tgt)
            loss = ddpm(src, tgt_enc)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            optim.step()
            
        # validation
        ddpm.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                tgt_enc = vqvae.encode(tgt)
                loss = ddpm(src, tgt_enc)
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(ddpm.nn_model.state_dict(), f"model_best_diffusion.pt")
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

    # # eval
    # ddpm.eval()
    # x_gens = []
    # count = 0
    # with torch.no_grad():
    #     for idx, src, tgt in tqdm(testset, total=len(testset)):
    #         if count > 3:
    #             break
    #         x_gens.append((idx, (ddpm.sample(src, guide_w) >= 0.5).long()))
    #         count += 1

    # torch.save(x_gens, "song_test_music103.pt")

def eval_main_loop(ddpm, vqvae, checkpoint, testset, device, guide_w, rate=0.5):
    ddpm.nn_model.load_state_dict(torch.load(checkpoint))
    ddpm.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 3:
                break
            x_gen = ddpm.sample(src, guide_w)
            x_gen = vqvae.decode(vqvae.vq(x_gen))
            x_gen = (x_gen >= rate).long()
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

# e. Training


In [16]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout).to(DEVICE)
lr = 1e-4
ddpm = DDPM(nn_model=transformer, betas=(1e-4, 0.02), n_T=n_T, \
    device=DEVICE, n_inference=1000, drop_prob=0.1)
ddpm.to(DEVICE)
optim = torch.optim.Adam(ddpm.parameters(), lr=lr)
vqvae = VQVAE(tgt_vocab_size, 256, num_heads, 1, d_ff, dropout, 512, 12).to(DEVICE)
vqvae.load_state_dict(torch.load("model_best_vqvae.pt"))
train_main_loop(ddpm, vqvae,optim, trainset, validset, lr, n_epoch, DEVICE, 0, 50)


  0%|          | 1/1000 [00:04<1:21:13,  4.88s/it]

epoch 0, train_loss: 0.7965, valid loss: 0.5216


  0%|          | 2/1000 [00:09<1:18:12,  4.70s/it]

epoch 1, train_loss: 0.3239, valid loss: 0.1868


  0%|          | 3/1000 [00:14<1:18:01,  4.70s/it]

epoch 2, train_loss: 0.1764, valid loss: 0.1527


  0%|          | 4/1000 [00:18<1:17:10,  4.65s/it]

epoch 3, train_loss: 0.1859, valid loss: 0.1160


  0%|          | 5/1000 [00:23<1:16:28,  4.61s/it]

epoch 4, train_loss: 0.1413, valid loss: 0.1256, min_valid_loss: 0.1160, wait: 0 / 50


  1%|          | 6/1000 [00:27<1:16:23,  4.61s/it]

epoch 5, train_loss: 0.1030, valid loss: 0.1151


  1%|          | 7/1000 [00:32<1:17:00,  4.65s/it]

epoch 6, train_loss: 0.0934, valid loss: 0.1149


  1%|          | 8/1000 [00:37<1:16:44,  4.64s/it]

epoch 7, train_loss: 0.1044, valid loss: 0.1204, min_valid_loss: 0.1149, wait: 0 / 50


  1%|          | 9/1000 [00:41<1:16:38,  4.64s/it]

epoch 8, train_loss: 0.0907, valid loss: 0.0776


  1%|          | 10/1000 [00:46<1:16:17,  4.62s/it]

epoch 9, train_loss: 0.0853, valid loss: 0.0863, min_valid_loss: 0.0776, wait: 0 / 50


  1%|          | 11/1000 [00:51<1:16:04,  4.62s/it]

epoch 10, train_loss: 0.0818, valid loss: 0.0916, min_valid_loss: 0.0776, wait: 1 / 50


  1%|          | 12/1000 [00:55<1:16:12,  4.63s/it]

epoch 11, train_loss: 0.0821, valid loss: 0.0776


  1%|▏         | 13/1000 [01:00<1:15:57,  4.62s/it]

epoch 12, train_loss: 0.0775, valid loss: 0.1048, min_valid_loss: 0.0776, wait: 0 / 50


  1%|▏         | 14/1000 [01:04<1:15:33,  4.60s/it]

epoch 13, train_loss: 0.0731, valid loss: 0.0915, min_valid_loss: 0.0776, wait: 1 / 50


  2%|▏         | 15/1000 [01:09<1:15:30,  4.60s/it]

epoch 14, train_loss: 0.0753, valid loss: 0.0687


  2%|▏         | 16/1000 [01:14<1:15:17,  4.59s/it]

epoch 15, train_loss: 0.0686, valid loss: 0.0769, min_valid_loss: 0.0687, wait: 0 / 50


  2%|▏         | 17/1000 [01:18<1:15:20,  4.60s/it]

epoch 16, train_loss: 0.0704, valid loss: 0.0580


  2%|▏         | 18/1000 [01:23<1:15:05,  4.59s/it]

epoch 17, train_loss: 0.0569, valid loss: 0.0855, min_valid_loss: 0.0580, wait: 0 / 50


  2%|▏         | 19/1000 [01:27<1:15:12,  4.60s/it]

epoch 18, train_loss: 0.0649, valid loss: 0.0556


  2%|▏         | 20/1000 [01:32<1:15:00,  4.59s/it]

epoch 19, train_loss: 0.0671, valid loss: 0.0798, min_valid_loss: 0.0556, wait: 0 / 50


  2%|▏         | 21/1000 [01:37<1:14:54,  4.59s/it]

epoch 20, train_loss: 0.0687, valid loss: 0.0675, min_valid_loss: 0.0556, wait: 1 / 50


  2%|▏         | 22/1000 [01:41<1:15:18,  4.62s/it]

epoch 21, train_loss: 0.0623, valid loss: 0.0822, min_valid_loss: 0.0556, wait: 2 / 50


  2%|▏         | 23/1000 [01:46<1:15:04,  4.61s/it]

epoch 22, train_loss: 0.0578, valid loss: 0.0727, min_valid_loss: 0.0556, wait: 3 / 50


  2%|▏         | 24/1000 [01:50<1:14:52,  4.60s/it]

epoch 23, train_loss: 0.0582, valid loss: 0.0942, min_valid_loss: 0.0556, wait: 4 / 50


  2%|▎         | 25/1000 [01:55<1:14:45,  4.60s/it]

epoch 24, train_loss: 0.0600, valid loss: 0.0749, min_valid_loss: 0.0556, wait: 5 / 50


  3%|▎         | 26/1000 [02:00<1:14:41,  4.60s/it]

epoch 25, train_loss: 0.0599, valid loss: 0.0844, min_valid_loss: 0.0556, wait: 6 / 50


  3%|▎         | 27/1000 [02:04<1:14:34,  4.60s/it]

epoch 26, train_loss: 0.0668, valid loss: 0.0653, min_valid_loss: 0.0556, wait: 7 / 50


  3%|▎         | 28/1000 [02:09<1:14:30,  4.60s/it]

epoch 27, train_loss: 0.0681, valid loss: 0.0647, min_valid_loss: 0.0556, wait: 8 / 50


  3%|▎         | 29/1000 [02:13<1:14:25,  4.60s/it]

epoch 28, train_loss: 0.0523, valid loss: 0.0676, min_valid_loss: 0.0556, wait: 9 / 50


  3%|▎         | 30/1000 [02:18<1:14:23,  4.60s/it]

epoch 29, train_loss: 0.0566, valid loss: 0.0705, min_valid_loss: 0.0556, wait: 10 / 50


  3%|▎         | 31/1000 [02:23<1:14:17,  4.60s/it]

epoch 30, train_loss: 0.0644, valid loss: 0.0788, min_valid_loss: 0.0556, wait: 11 / 50


  3%|▎         | 32/1000 [02:27<1:14:10,  4.60s/it]

epoch 31, train_loss: 0.0565, valid loss: 0.0849, min_valid_loss: 0.0556, wait: 12 / 50


  3%|▎         | 33/1000 [02:32<1:14:07,  4.60s/it]

epoch 32, train_loss: 0.0611, valid loss: 0.0760, min_valid_loss: 0.0556, wait: 13 / 50


  3%|▎         | 34/1000 [02:36<1:14:02,  4.60s/it]

epoch 33, train_loss: 0.0646, valid loss: 0.0628, min_valid_loss: 0.0556, wait: 14 / 50


  4%|▎         | 35/1000 [02:41<1:13:56,  4.60s/it]

epoch 34, train_loss: 0.0523, valid loss: 0.0660, min_valid_loss: 0.0556, wait: 15 / 50


  4%|▎         | 36/1000 [02:46<1:13:53,  4.60s/it]

epoch 35, train_loss: 0.0561, valid loss: 0.0662, min_valid_loss: 0.0556, wait: 16 / 50


  4%|▎         | 37/1000 [02:50<1:13:48,  4.60s/it]

epoch 36, train_loss: 0.0655, valid loss: 0.0632, min_valid_loss: 0.0556, wait: 17 / 50


  4%|▍         | 38/1000 [02:55<1:13:49,  4.60s/it]

epoch 37, train_loss: 0.0587, valid loss: 0.0655, min_valid_loss: 0.0556, wait: 18 / 50


  4%|▍         | 39/1000 [02:59<1:13:45,  4.61s/it]

epoch 38, train_loss: 0.0593, valid loss: 0.0701, min_valid_loss: 0.0556, wait: 19 / 50


  4%|▍         | 40/1000 [03:04<1:13:45,  4.61s/it]

epoch 39, train_loss: 0.0502, valid loss: 0.0649, min_valid_loss: 0.0556, wait: 20 / 50


  4%|▍         | 41/1000 [03:09<1:13:42,  4.61s/it]

epoch 40, train_loss: 0.0607, valid loss: 0.0641, min_valid_loss: 0.0556, wait: 21 / 50


  4%|▍         | 42/1000 [03:13<1:13:37,  4.61s/it]

epoch 41, train_loss: 0.0622, valid loss: 0.0730, min_valid_loss: 0.0556, wait: 22 / 50


  4%|▍         | 43/1000 [03:18<1:13:39,  4.62s/it]

epoch 42, train_loss: 0.0527, valid loss: 0.0986, min_valid_loss: 0.0556, wait: 23 / 50


  4%|▍         | 44/1000 [03:22<1:13:34,  4.62s/it]

epoch 43, train_loss: 0.0638, valid loss: 0.0617, min_valid_loss: 0.0556, wait: 24 / 50


  4%|▍         | 45/1000 [03:27<1:13:44,  4.63s/it]

epoch 44, train_loss: 0.0508, valid loss: 0.0582, min_valid_loss: 0.0556, wait: 25 / 50


  5%|▍         | 46/1000 [03:32<1:13:34,  4.63s/it]

epoch 45, train_loss: 0.0561, valid loss: 0.0673, min_valid_loss: 0.0556, wait: 26 / 50


  5%|▍         | 47/1000 [03:36<1:13:27,  4.63s/it]

epoch 46, train_loss: 0.0494, valid loss: 0.0631, min_valid_loss: 0.0556, wait: 27 / 50


  5%|▍         | 48/1000 [03:41<1:13:19,  4.62s/it]

epoch 47, train_loss: 0.0535, valid loss: 0.0586, min_valid_loss: 0.0556, wait: 28 / 50


  5%|▍         | 49/1000 [03:46<1:13:11,  4.62s/it]

epoch 48, train_loss: 0.0532, valid loss: 0.0610, min_valid_loss: 0.0556, wait: 29 / 50


  5%|▌         | 50/1000 [03:50<1:13:04,  4.62s/it]

epoch 49, train_loss: 0.0463, valid loss: 0.0777, min_valid_loss: 0.0556, wait: 30 / 50


  5%|▌         | 51/1000 [03:55<1:12:59,  4.62s/it]

epoch 50, train_loss: 0.0552, valid loss: 0.0632, min_valid_loss: 0.0556, wait: 31 / 50


  5%|▌         | 52/1000 [03:59<1:13:10,  4.63s/it]

epoch 51, train_loss: 0.0502, valid loss: 0.0501


  5%|▌         | 53/1000 [04:04<1:13:00,  4.63s/it]

epoch 52, train_loss: 0.0590, valid loss: 0.0640, min_valid_loss: 0.0501, wait: 0 / 50


  5%|▌         | 54/1000 [04:09<1:12:53,  4.62s/it]

epoch 53, train_loss: 0.0657, valid loss: 0.0640, min_valid_loss: 0.0501, wait: 1 / 50


  6%|▌         | 55/1000 [04:13<1:12:49,  4.62s/it]

epoch 54, train_loss: 0.0537, valid loss: 0.0705, min_valid_loss: 0.0501, wait: 2 / 50


  6%|▌         | 56/1000 [04:18<1:12:46,  4.63s/it]

epoch 55, train_loss: 0.0562, valid loss: 0.0666, min_valid_loss: 0.0501, wait: 3 / 50


  6%|▌         | 57/1000 [04:23<1:12:41,  4.63s/it]

epoch 56, train_loss: 0.0505, valid loss: 0.0760, min_valid_loss: 0.0501, wait: 4 / 50


  6%|▌         | 58/1000 [04:27<1:12:33,  4.62s/it]

epoch 57, train_loss: 0.0522, valid loss: 0.0627, min_valid_loss: 0.0501, wait: 5 / 50


  6%|▌         | 59/1000 [04:32<1:12:29,  4.62s/it]

epoch 58, train_loss: 0.0607, valid loss: 0.0650, min_valid_loss: 0.0501, wait: 6 / 50


  6%|▌         | 60/1000 [04:36<1:12:25,  4.62s/it]

epoch 59, train_loss: 0.0510, valid loss: 0.0613, min_valid_loss: 0.0501, wait: 7 / 50


  6%|▌         | 61/1000 [04:41<1:12:21,  4.62s/it]

epoch 60, train_loss: 0.0537, valid loss: 0.0817, min_valid_loss: 0.0501, wait: 8 / 50


  6%|▌         | 62/1000 [04:46<1:12:15,  4.62s/it]

epoch 61, train_loss: 0.0617, valid loss: 0.0683, min_valid_loss: 0.0501, wait: 9 / 50


  6%|▋         | 63/1000 [04:50<1:12:11,  4.62s/it]

epoch 62, train_loss: 0.0494, valid loss: 0.0544, min_valid_loss: 0.0501, wait: 10 / 50


  6%|▋         | 64/1000 [04:55<1:12:04,  4.62s/it]

epoch 63, train_loss: 0.0527, valid loss: 0.0612, min_valid_loss: 0.0501, wait: 11 / 50


  6%|▋         | 65/1000 [05:00<1:12:01,  4.62s/it]

epoch 64, train_loss: 0.0517, valid loss: 0.0574, min_valid_loss: 0.0501, wait: 12 / 50


  7%|▋         | 66/1000 [05:04<1:12:29,  4.66s/it]

epoch 65, train_loss: 0.0518, valid loss: 0.0636, min_valid_loss: 0.0501, wait: 13 / 50


  7%|▋         | 67/1000 [05:09<1:13:08,  4.70s/it]

epoch 66, train_loss: 0.0474, valid loss: 0.0458


  7%|▋         | 68/1000 [05:14<1:12:52,  4.69s/it]

epoch 67, train_loss: 0.0541, valid loss: 0.0612, min_valid_loss: 0.0458, wait: 0 / 50


  7%|▋         | 69/1000 [05:18<1:12:25,  4.67s/it]

epoch 68, train_loss: 0.0481, valid loss: 0.0684, min_valid_loss: 0.0458, wait: 1 / 50


  7%|▋         | 70/1000 [05:23<1:12:04,  4.65s/it]

epoch 69, train_loss: 0.0484, valid loss: 0.0531, min_valid_loss: 0.0458, wait: 2 / 50


  7%|▋         | 71/1000 [05:28<1:12:04,  4.65s/it]

epoch 70, train_loss: 0.0501, valid loss: 0.0642, min_valid_loss: 0.0458, wait: 3 / 50


  7%|▋         | 72/1000 [05:32<1:12:11,  4.67s/it]

epoch 71, train_loss: 0.0449, valid loss: 0.0543, min_valid_loss: 0.0458, wait: 4 / 50


  7%|▋         | 73/1000 [05:37<1:11:53,  4.65s/it]

epoch 72, train_loss: 0.0522, valid loss: 0.0677, min_valid_loss: 0.0458, wait: 5 / 50


  7%|▋         | 74/1000 [05:42<1:11:35,  4.64s/it]

epoch 73, train_loss: 0.0489, valid loss: 0.0632, min_valid_loss: 0.0458, wait: 6 / 50


  8%|▊         | 75/1000 [05:46<1:11:22,  4.63s/it]

epoch 74, train_loss: 0.0425, valid loss: 0.0575, min_valid_loss: 0.0458, wait: 7 / 50


  8%|▊         | 76/1000 [05:51<1:11:10,  4.62s/it]

epoch 75, train_loss: 0.0462, valid loss: 0.0627, min_valid_loss: 0.0458, wait: 8 / 50


  8%|▊         | 77/1000 [05:55<1:11:01,  4.62s/it]

epoch 76, train_loss: 0.0483, valid loss: 0.0500, min_valid_loss: 0.0458, wait: 9 / 50


  8%|▊         | 78/1000 [06:00<1:10:53,  4.61s/it]

epoch 77, train_loss: 0.0497, valid loss: 0.0519, min_valid_loss: 0.0458, wait: 10 / 50


  8%|▊         | 79/1000 [06:05<1:10:50,  4.62s/it]

epoch 78, train_loss: 0.0560, valid loss: 0.0572, min_valid_loss: 0.0458, wait: 11 / 50


  8%|▊         | 80/1000 [06:09<1:10:39,  4.61s/it]

epoch 79, train_loss: 0.0503, valid loss: 0.0571, min_valid_loss: 0.0458, wait: 12 / 50


  8%|▊         | 81/1000 [06:14<1:10:30,  4.60s/it]

epoch 80, train_loss: 0.0554, valid loss: 0.0628, min_valid_loss: 0.0458, wait: 13 / 50


  8%|▊         | 82/1000 [06:18<1:10:22,  4.60s/it]

epoch 81, train_loss: 0.0493, valid loss: 0.0709, min_valid_loss: 0.0458, wait: 14 / 50


  8%|▊         | 83/1000 [06:23<1:10:16,  4.60s/it]

epoch 82, train_loss: 0.0558, valid loss: 0.0678, min_valid_loss: 0.0458, wait: 15 / 50


  8%|▊         | 84/1000 [06:28<1:10:09,  4.60s/it]

epoch 83, train_loss: 0.0503, valid loss: 0.0613, min_valid_loss: 0.0458, wait: 16 / 50


  8%|▊         | 85/1000 [06:32<1:10:06,  4.60s/it]

epoch 84, train_loss: 0.0532, valid loss: 0.0561, min_valid_loss: 0.0458, wait: 17 / 50


  9%|▊         | 86/1000 [06:37<1:10:01,  4.60s/it]

epoch 85, train_loss: 0.0521, valid loss: 0.0570, min_valid_loss: 0.0458, wait: 18 / 50


  9%|▊         | 87/1000 [06:41<1:09:59,  4.60s/it]

epoch 86, train_loss: 0.0467, valid loss: 0.0622, min_valid_loss: 0.0458, wait: 19 / 50


  9%|▉         | 88/1000 [06:46<1:09:57,  4.60s/it]

epoch 87, train_loss: 0.0500, valid loss: 0.0608, min_valid_loss: 0.0458, wait: 20 / 50


  9%|▉         | 89/1000 [06:51<1:09:50,  4.60s/it]

epoch 88, train_loss: 0.0444, valid loss: 0.0593, min_valid_loss: 0.0458, wait: 21 / 50


  9%|▉         | 90/1000 [06:55<1:09:45,  4.60s/it]

epoch 89, train_loss: 0.0459, valid loss: 0.0650, min_valid_loss: 0.0458, wait: 22 / 50


  9%|▉         | 91/1000 [07:00<1:09:40,  4.60s/it]

epoch 90, train_loss: 0.0440, valid loss: 0.0612, min_valid_loss: 0.0458, wait: 23 / 50


  9%|▉         | 92/1000 [07:04<1:09:35,  4.60s/it]

epoch 91, train_loss: 0.0444, valid loss: 0.0596, min_valid_loss: 0.0458, wait: 24 / 50


  9%|▉         | 93/1000 [07:09<1:09:29,  4.60s/it]

epoch 92, train_loss: 0.0555, valid loss: 0.0670, min_valid_loss: 0.0458, wait: 25 / 50


  9%|▉         | 94/1000 [07:14<1:09:26,  4.60s/it]

epoch 93, train_loss: 0.0519, valid loss: 0.0597, min_valid_loss: 0.0458, wait: 26 / 50


 10%|▉         | 95/1000 [07:18<1:09:21,  4.60s/it]

epoch 94, train_loss: 0.0481, valid loss: 0.0674, min_valid_loss: 0.0458, wait: 27 / 50


 10%|▉         | 96/1000 [07:23<1:09:17,  4.60s/it]

epoch 95, train_loss: 0.0490, valid loss: 0.0721, min_valid_loss: 0.0458, wait: 28 / 50


 10%|▉         | 97/1000 [07:27<1:09:13,  4.60s/it]

epoch 96, train_loss: 0.0423, valid loss: 0.0623, min_valid_loss: 0.0458, wait: 29 / 50


 10%|▉         | 98/1000 [07:32<1:09:10,  4.60s/it]

epoch 97, train_loss: 0.0464, valid loss: 0.0600, min_valid_loss: 0.0458, wait: 30 / 50


 10%|▉         | 99/1000 [07:37<1:09:04,  4.60s/it]

epoch 98, train_loss: 0.0564, valid loss: 0.0717, min_valid_loss: 0.0458, wait: 31 / 50


 10%|█         | 100/1000 [07:41<1:08:59,  4.60s/it]

epoch 99, train_loss: 0.0531, valid loss: 0.0576, min_valid_loss: 0.0458, wait: 32 / 50


 10%|█         | 101/1000 [07:46<1:08:54,  4.60s/it]

epoch 100, train_loss: 0.0476, valid loss: 0.0659, min_valid_loss: 0.0458, wait: 33 / 50


 10%|█         | 102/1000 [07:50<1:08:50,  4.60s/it]

epoch 101, train_loss: 0.0605, valid loss: 0.0585, min_valid_loss: 0.0458, wait: 34 / 50


 10%|█         | 103/1000 [07:55<1:08:48,  4.60s/it]

epoch 102, train_loss: 0.0490, valid loss: 0.0601, min_valid_loss: 0.0458, wait: 35 / 50


 10%|█         | 104/1000 [08:00<1:08:42,  4.60s/it]

epoch 103, train_loss: 0.0472, valid loss: 0.0612, min_valid_loss: 0.0458, wait: 36 / 50


 10%|█         | 105/1000 [08:04<1:08:38,  4.60s/it]

epoch 104, train_loss: 0.0501, valid loss: 0.0551, min_valid_loss: 0.0458, wait: 37 / 50


 11%|█         | 106/1000 [08:09<1:08:30,  4.60s/it]

epoch 105, train_loss: 0.0390, valid loss: 0.0587, min_valid_loss: 0.0458, wait: 38 / 50


 11%|█         | 107/1000 [08:13<1:08:25,  4.60s/it]

epoch 106, train_loss: 0.0447, valid loss: 0.0707, min_valid_loss: 0.0458, wait: 39 / 50


 11%|█         | 108/1000 [08:18<1:08:23,  4.60s/it]

epoch 107, train_loss: 0.0468, valid loss: 0.0578, min_valid_loss: 0.0458, wait: 40 / 50


 11%|█         | 109/1000 [08:23<1:08:20,  4.60s/it]

epoch 108, train_loss: 0.0524, valid loss: 0.0678, min_valid_loss: 0.0458, wait: 41 / 50


 11%|█         | 110/1000 [08:27<1:08:16,  4.60s/it]

epoch 109, train_loss: 0.0457, valid loss: 0.0482, min_valid_loss: 0.0458, wait: 42 / 50


 11%|█         | 111/1000 [08:32<1:08:12,  4.60s/it]

epoch 110, train_loss: 0.0474, valid loss: 0.0677, min_valid_loss: 0.0458, wait: 43 / 50


 11%|█         | 112/1000 [08:36<1:08:30,  4.63s/it]

epoch 111, train_loss: 0.0511, valid loss: 0.0437


 11%|█▏        | 113/1000 [08:41<1:08:26,  4.63s/it]

epoch 112, train_loss: 0.0449, valid loss: 0.0615, min_valid_loss: 0.0437, wait: 0 / 50


 11%|█▏        | 114/1000 [08:46<1:08:13,  4.62s/it]

epoch 113, train_loss: 0.0488, valid loss: 0.0634, min_valid_loss: 0.0437, wait: 1 / 50


 12%|█▏        | 115/1000 [08:50<1:08:03,  4.61s/it]

epoch 114, train_loss: 0.0429, valid loss: 0.0733, min_valid_loss: 0.0437, wait: 2 / 50


 12%|█▏        | 116/1000 [08:55<1:07:54,  4.61s/it]

epoch 115, train_loss: 0.0453, valid loss: 0.0538, min_valid_loss: 0.0437, wait: 3 / 50


 12%|█▏        | 117/1000 [09:00<1:07:48,  4.61s/it]

epoch 116, train_loss: 0.0477, valid loss: 0.0457, min_valid_loss: 0.0437, wait: 4 / 50


 12%|█▏        | 118/1000 [09:04<1:07:42,  4.61s/it]

epoch 117, train_loss: 0.0444, valid loss: 0.0566, min_valid_loss: 0.0437, wait: 5 / 50


 12%|█▏        | 119/1000 [09:09<1:07:43,  4.61s/it]

epoch 118, train_loss: 0.0468, valid loss: 0.0700, min_valid_loss: 0.0437, wait: 6 / 50


 12%|█▏        | 120/1000 [09:13<1:07:35,  4.61s/it]

epoch 119, train_loss: 0.0469, valid loss: 0.0598, min_valid_loss: 0.0437, wait: 7 / 50


 12%|█▏        | 121/1000 [09:18<1:07:34,  4.61s/it]

epoch 120, train_loss: 0.0433, valid loss: 0.0550, min_valid_loss: 0.0437, wait: 8 / 50


 12%|█▏        | 122/1000 [09:23<1:07:27,  4.61s/it]

epoch 121, train_loss: 0.0527, valid loss: 0.0621, min_valid_loss: 0.0437, wait: 9 / 50


 12%|█▏        | 123/1000 [09:27<1:07:19,  4.61s/it]

epoch 122, train_loss: 0.0480, valid loss: 0.0664, min_valid_loss: 0.0437, wait: 10 / 50


 12%|█▏        | 124/1000 [09:32<1:07:16,  4.61s/it]

epoch 123, train_loss: 0.0413, valid loss: 0.0753, min_valid_loss: 0.0437, wait: 11 / 50


 12%|█▎        | 125/1000 [09:36<1:07:10,  4.61s/it]

epoch 124, train_loss: 0.0466, valid loss: 0.0644, min_valid_loss: 0.0437, wait: 12 / 50


 13%|█▎        | 126/1000 [09:41<1:07:01,  4.60s/it]

epoch 125, train_loss: 0.0433, valid loss: 0.0567, min_valid_loss: 0.0437, wait: 13 / 50


 13%|█▎        | 127/1000 [09:46<1:06:56,  4.60s/it]

epoch 126, train_loss: 0.0486, valid loss: 0.0685, min_valid_loss: 0.0437, wait: 14 / 50


 13%|█▎        | 128/1000 [09:50<1:06:51,  4.60s/it]

epoch 127, train_loss: 0.0462, valid loss: 0.0503, min_valid_loss: 0.0437, wait: 15 / 50


 13%|█▎        | 129/1000 [09:55<1:06:48,  4.60s/it]

epoch 128, train_loss: 0.0479, valid loss: 0.0614, min_valid_loss: 0.0437, wait: 16 / 50


 13%|█▎        | 130/1000 [09:59<1:06:42,  4.60s/it]

epoch 129, train_loss: 0.0455, valid loss: 0.0637, min_valid_loss: 0.0437, wait: 17 / 50


 13%|█▎        | 131/1000 [10:04<1:06:37,  4.60s/it]

epoch 130, train_loss: 0.0413, valid loss: 0.0512, min_valid_loss: 0.0437, wait: 18 / 50


 13%|█▎        | 132/1000 [10:09<1:06:32,  4.60s/it]

epoch 131, train_loss: 0.0472, valid loss: 0.0584, min_valid_loss: 0.0437, wait: 19 / 50


 13%|█▎        | 133/1000 [10:13<1:06:26,  4.60s/it]

epoch 132, train_loss: 0.0572, valid loss: 0.0558, min_valid_loss: 0.0437, wait: 20 / 50


 13%|█▎        | 134/1000 [10:18<1:06:22,  4.60s/it]

epoch 133, train_loss: 0.0464, valid loss: 0.0532, min_valid_loss: 0.0437, wait: 21 / 50


 14%|█▎        | 135/1000 [10:22<1:06:33,  4.62s/it]

epoch 134, train_loss: 0.0462, valid loss: 0.0414


 14%|█▎        | 136/1000 [10:27<1:06:23,  4.61s/it]

epoch 135, train_loss: 0.0450, valid loss: 0.0658, min_valid_loss: 0.0414, wait: 0 / 50


 14%|█▎        | 137/1000 [10:32<1:06:19,  4.61s/it]

epoch 136, train_loss: 0.0432, valid loss: 0.0649, min_valid_loss: 0.0414, wait: 1 / 50


 14%|█▍        | 138/1000 [10:36<1:06:14,  4.61s/it]

epoch 137, train_loss: 0.0438, valid loss: 0.0638, min_valid_loss: 0.0414, wait: 2 / 50


 14%|█▍        | 139/1000 [10:41<1:06:08,  4.61s/it]

epoch 138, train_loss: 0.0464, valid loss: 0.0579, min_valid_loss: 0.0414, wait: 3 / 50


 14%|█▍        | 140/1000 [10:45<1:06:03,  4.61s/it]

epoch 139, train_loss: 0.0449, valid loss: 0.0470, min_valid_loss: 0.0414, wait: 4 / 50


 14%|█▍        | 141/1000 [10:50<1:05:58,  4.61s/it]

epoch 140, train_loss: 0.0444, valid loss: 0.0533, min_valid_loss: 0.0414, wait: 5 / 50


 14%|█▍        | 142/1000 [10:55<1:05:52,  4.61s/it]

epoch 141, train_loss: 0.0432, valid loss: 0.0585, min_valid_loss: 0.0414, wait: 6 / 50


 14%|█▍        | 143/1000 [10:59<1:05:49,  4.61s/it]

epoch 142, train_loss: 0.0520, valid loss: 0.0460, min_valid_loss: 0.0414, wait: 7 / 50


 14%|█▍        | 144/1000 [11:04<1:05:46,  4.61s/it]

epoch 143, train_loss: 0.0462, valid loss: 0.0770, min_valid_loss: 0.0414, wait: 8 / 50


 14%|█▍        | 145/1000 [11:09<1:05:45,  4.61s/it]

epoch 144, train_loss: 0.0463, valid loss: 0.0620, min_valid_loss: 0.0414, wait: 9 / 50


 15%|█▍        | 146/1000 [11:13<1:06:10,  4.65s/it]

epoch 145, train_loss: 0.0444, valid loss: 0.0622, min_valid_loss: 0.0414, wait: 10 / 50


 15%|█▍        | 147/1000 [11:18<1:05:52,  4.63s/it]

epoch 146, train_loss: 0.0457, valid loss: 0.0626, min_valid_loss: 0.0414, wait: 11 / 50


 15%|█▍        | 148/1000 [11:22<1:05:40,  4.63s/it]

epoch 147, train_loss: 0.0466, valid loss: 0.0574, min_valid_loss: 0.0414, wait: 12 / 50


 15%|█▍        | 149/1000 [11:27<1:05:31,  4.62s/it]

epoch 148, train_loss: 0.0399, valid loss: 0.0512, min_valid_loss: 0.0414, wait: 13 / 50


 15%|█▌        | 150/1000 [11:32<1:05:24,  4.62s/it]

epoch 149, train_loss: 0.0396, valid loss: 0.0734, min_valid_loss: 0.0414, wait: 14 / 50


 15%|█▌        | 151/1000 [11:36<1:05:17,  4.61s/it]

epoch 150, train_loss: 0.0426, valid loss: 0.0533, min_valid_loss: 0.0414, wait: 15 / 50


 15%|█▌        | 152/1000 [11:41<1:05:19,  4.62s/it]

epoch 151, train_loss: 0.0485, valid loss: 0.0478, min_valid_loss: 0.0414, wait: 16 / 50


 15%|█▌        | 153/1000 [11:46<1:05:10,  4.62s/it]

epoch 152, train_loss: 0.0440, valid loss: 0.0536, min_valid_loss: 0.0414, wait: 17 / 50


 15%|█▌        | 154/1000 [11:50<1:05:01,  4.61s/it]

epoch 153, train_loss: 0.0465, valid loss: 0.0481, min_valid_loss: 0.0414, wait: 18 / 50


 16%|█▌        | 155/1000 [11:55<1:04:57,  4.61s/it]

epoch 154, train_loss: 0.0388, valid loss: 0.0645, min_valid_loss: 0.0414, wait: 19 / 50


 16%|█▌        | 156/1000 [11:59<1:05:24,  4.65s/it]

epoch 155, train_loss: 0.0436, valid loss: 0.0523, min_valid_loss: 0.0414, wait: 20 / 50


 16%|█▌        | 157/1000 [12:04<1:05:40,  4.67s/it]

epoch 156, train_loss: 0.0489, valid loss: 0.0597, min_valid_loss: 0.0414, wait: 21 / 50


 16%|█▌        | 158/1000 [12:09<1:05:52,  4.69s/it]

epoch 157, train_loss: 0.0424, valid loss: 0.0562, min_valid_loss: 0.0414, wait: 22 / 50


 16%|█▌        | 159/1000 [12:14<1:05:58,  4.71s/it]

epoch 158, train_loss: 0.0429, valid loss: 0.0567, min_valid_loss: 0.0414, wait: 23 / 50


 16%|█▌        | 160/1000 [12:18<1:06:00,  4.71s/it]

epoch 159, train_loss: 0.0410, valid loss: 0.0560, min_valid_loss: 0.0414, wait: 24 / 50


 16%|█▌        | 161/1000 [12:23<1:05:25,  4.68s/it]

epoch 160, train_loss: 0.0430, valid loss: 0.0555, min_valid_loss: 0.0414, wait: 25 / 50


 16%|█▌        | 162/1000 [12:28<1:05:14,  4.67s/it]

epoch 161, train_loss: 0.0414, valid loss: 0.0570, min_valid_loss: 0.0414, wait: 26 / 50


 16%|█▋        | 163/1000 [12:32<1:04:52,  4.65s/it]

epoch 162, train_loss: 0.0450, valid loss: 0.0523, min_valid_loss: 0.0414, wait: 27 / 50


 16%|█▋        | 164/1000 [12:37<1:04:36,  4.64s/it]

epoch 163, train_loss: 0.0422, valid loss: 0.0539, min_valid_loss: 0.0414, wait: 28 / 50


 16%|█▋        | 165/1000 [12:41<1:04:25,  4.63s/it]

epoch 164, train_loss: 0.0523, valid loss: 0.0556, min_valid_loss: 0.0414, wait: 29 / 50


 17%|█▋        | 166/1000 [12:46<1:04:14,  4.62s/it]

epoch 165, train_loss: 0.0489, valid loss: 0.0591, min_valid_loss: 0.0414, wait: 30 / 50


 17%|█▋        | 167/1000 [12:51<1:04:08,  4.62s/it]

epoch 166, train_loss: 0.0443, valid loss: 0.0601, min_valid_loss: 0.0414, wait: 31 / 50


 17%|█▋        | 168/1000 [12:55<1:04:01,  4.62s/it]

epoch 167, train_loss: 0.0419, valid loss: 0.0627, min_valid_loss: 0.0414, wait: 32 / 50


 17%|█▋        | 169/1000 [13:00<1:03:55,  4.62s/it]

epoch 168, train_loss: 0.0448, valid loss: 0.0480, min_valid_loss: 0.0414, wait: 33 / 50


 17%|█▋        | 170/1000 [13:05<1:03:49,  4.61s/it]

epoch 169, train_loss: 0.0533, valid loss: 0.0666, min_valid_loss: 0.0414, wait: 34 / 50


 17%|█▋        | 171/1000 [13:09<1:03:45,  4.61s/it]

epoch 170, train_loss: 0.0425, valid loss: 0.0655, min_valid_loss: 0.0414, wait: 35 / 50


 17%|█▋        | 172/1000 [13:14<1:03:40,  4.61s/it]

epoch 171, train_loss: 0.0444, valid loss: 0.0470, min_valid_loss: 0.0414, wait: 36 / 50


 17%|█▋        | 173/1000 [13:18<1:03:37,  4.62s/it]

epoch 172, train_loss: 0.0453, valid loss: 0.0493, min_valid_loss: 0.0414, wait: 37 / 50


 17%|█▋        | 174/1000 [13:23<1:03:31,  4.61s/it]

epoch 173, train_loss: 0.0422, valid loss: 0.0645, min_valid_loss: 0.0414, wait: 38 / 50


 18%|█▊        | 175/1000 [13:28<1:03:25,  4.61s/it]

epoch 174, train_loss: 0.0430, valid loss: 0.0482, min_valid_loss: 0.0414, wait: 39 / 50


 18%|█▊        | 176/1000 [13:32<1:03:20,  4.61s/it]

epoch 175, train_loss: 0.0453, valid loss: 0.0541, min_valid_loss: 0.0414, wait: 40 / 50


 18%|█▊        | 177/1000 [13:37<1:03:14,  4.61s/it]

epoch 176, train_loss: 0.0435, valid loss: 0.0630, min_valid_loss: 0.0414, wait: 41 / 50


 18%|█▊        | 178/1000 [13:41<1:03:04,  4.60s/it]

epoch 177, train_loss: 0.0419, valid loss: 0.0611, min_valid_loss: 0.0414, wait: 42 / 50


 18%|█▊        | 179/1000 [13:46<1:02:58,  4.60s/it]

epoch 178, train_loss: 0.0435, valid loss: 0.0496, min_valid_loss: 0.0414, wait: 43 / 50


 18%|█▊        | 180/1000 [13:51<1:02:51,  4.60s/it]

epoch 179, train_loss: 0.0448, valid loss: 0.0775, min_valid_loss: 0.0414, wait: 44 / 50


 18%|█▊        | 181/1000 [13:55<1:02:46,  4.60s/it]

epoch 180, train_loss: 0.0462, valid loss: 0.0435, min_valid_loss: 0.0414, wait: 45 / 50


 18%|█▊        | 182/1000 [14:00<1:02:43,  4.60s/it]

epoch 181, train_loss: 0.0416, valid loss: 0.0520, min_valid_loss: 0.0414, wait: 46 / 50


 18%|█▊        | 183/1000 [14:04<1:02:40,  4.60s/it]

epoch 182, train_loss: 0.0412, valid loss: 0.0433, min_valid_loss: 0.0414, wait: 47 / 50


 18%|█▊        | 184/1000 [14:09<1:02:38,  4.61s/it]

epoch 183, train_loss: 0.0409, valid loss: 0.0543, min_valid_loss: 0.0414, wait: 48 / 50


 18%|█▊        | 184/1000 [14:14<1:03:07,  4.64s/it]

epoch 184, train_loss: 0.0459, valid loss: 0.0806, min_valid_loss: 0.0414, wait: 49 / 50





In [17]:
eval_main_loop(ddpm, vqvae, "model_best_diffusion.pt", testset, DEVICE, 0.1, 0.5)

  0%|          | 0/100 [00:00<?, ?it/s]

  4%|▍         | 4/100 [00:15<06:06,  3.82s/it]


: 