# Music 103 diffusion version

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm
from os.path import exists
from os import remove, chdir
import pickle

In [35]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        max_len = x.size(1)
        pe = torch.zeros(max_len, self.d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).to(x.device)
        return x + pe


class DecoderPositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, x, tgt):
        # max_len = x.size(1)
        tgt_one_hot = tgt[:, :, 12:]
        tgt_class = torch.argmax(tgt_one_hot, dim=-1)
        pe = torch.zeros_like(x)
        position = torch.cumsum(tgt_class, dim=1).unsqueeze(-1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model)).to(position.device)
        
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return x + pe

    
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class EmbedHead(nn.Module):
    def __init__(
        self,
        input_dim,
        inner_dim_1,
        inner_dim_2,
        out_dim
    ):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, inner_dim_1)
        self.linear2 = nn.Linear(inner_dim_1, inner_dim_2)
        self.linear3 = nn.Linear(inner_dim_2, out_dim)
        self.activation_fn = nn.functional.gelu

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        x = self.linear2(x)
        x = self.activation_fn(x)
        x = self.linear3(x)
        return x
    

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder_embedding = EmbedHead(src_vocab_size, d_model, d_model, d_model)
        self.decoder_embedding = EmbedHead(tgt_vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.time_embeddings = nn.ModuleList([EmbedFC(1, d_model) for _ in range(num_layers)])
        self.decoder_positional_encoding = PositionalEncoding(d_model)

        self.encoder_layers_src = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_layers_noise = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_layers_tgt = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_layer = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, x, time):
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        x_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(x)))

        enc_src = src_embedded
        for enc_layer in self.encoder_layers_src:
            enc_src = enc_layer(enc_src, None)

        enc_noise = x_embedded
        for enc_layer in self.encoder_layers_noise:
            enc_noise = enc_layer(enc_noise, None)

        enc_tgt = enc_noise
        for i, enc_layer in enumerate(self.encoder_layers_tgt):
            time_embedding = self.time_embeddings[i](time).unsqueeze(1)
            enc_tgt = enc_tgt + enc_src + time_embedding
            enc_tgt = enc_layer(enc_tgt, None)
        
        output = self.output_layer(enc_tgt)
        
        return output
    


# **1. DDPM**


# a. Building Blocks

# b. DDPM Schedules

In [36]:
def ddpm_schedules(beta1, beta2, T):
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    ##################
    ### Problem 1 (a): Implement ddpm_schedules()
    beta_t = torch.linspace(beta1, beta2, T).float()

    alpha_t = 1 - beta_t
    oneover_sqrta = 1 / torch.sqrt(alpha_t)
    sqrt_beta_t = torch.sqrt(beta_t)
    alphabar_t = torch.cumprod(alpha_t, dim=0)
    sqrtab = torch.sqrt(alphabar_t)
    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / torch.sqrt(1 - alphabar_t)
    ##################
    ##################

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

# c. DDPM Main Module



Here the noise $\sigma_t^2=\beta_t$

In [37]:
class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, n_inference=None, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)
        
        self.n_T = n_T
        self.n_inference = n_inference if n_inference else n_T 
        
        for k, v in ddpm_schedules(betas[0], betas[1], self.n_inference).items():
            self.register_buffer(k+'_KAIMING', v)

        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, src, tgt):
        ##################
        ### Problem 1 (b): Implement forward()
        t = torch.randint(0, self.n_T, (tgt.size(0),), device=self.device)
        sqrtab_t, sqrtmab_t = self.sqrtab[t].view(-1, 1, 1), \
            self.sqrtmab[t].view(-1, 1, 1)

        noise = torch.randn_like(tgt).to(self.device)  # Define noise tensor
        x_t = sqrtab_t * tgt + sqrtmab_t * noise

        # mask out with probability
        context_mask = torch.bernoulli(torch.zeros(src.shape[0])+1 - self.drop_prob).unsqueeze(-1).unsqueeze(-1).to(self.device)

        pred_noise = self.nn_model(src * context_mask, x_t, t / (self.n_T - 1))
        loss = self.loss_mse(pred_noise, noise) 
        ##################
        ##################

        return loss

    @torch.no_grad()
    def sample(self, src, guide_w=0.0):
        n_sample = src.shape[0]
        x_i = torch.randn(*src.shape).to(self.device)
        c_i = src.to(self.device).clone()
        c_i = c_i.repeat(2, 1, 1)
        context_mask = torch.zeros_like(c_i).to(self.device)
        context_mask[:n_sample] = 1.0  # second half context-free

        for i in range(int(self.n_inference), 0, -1):

            ##################
            ### Problem 1 (c): Implement sample()
            t = torch.full((n_sample,), (i - 1) / (self.n_inference - 1)).to(self.device).float()
            t_i = t.view(-1, 1, 1)

            # double batch
            x_i = x_i.repeat(2, 1, 1)
            t_i = t_i.repeat(2, 1, 1)

            z = torch.randn(*src.shape).to(self.device) if i > 1 else 0 # if last step, no noise
            # classifier-free guidance
            pred_full = self.nn_model(c_i * context_mask, x_i, t_i)
            pred_1, pred_2 = pred_full[:n_sample], pred_full[n_sample:]
            pred_noise = (1 + guide_w) * pred_1 - guide_w * pred_2
            x_i = x_i[:n_sample]
            x_i = self.oneover_sqrta_KAIMING[i - 1] * (x_i - pred_noise * self.mab_over_sqrtmab_KAIMING[i - 1])\
                + self.sqrt_beta_t_KAIMING[i - 1] * z
        return x_i

# c. Training Function

In [38]:
from tqdm import tqdm

def train_main_loop(ddpm, optim, trainset, validset, testset, lr, n_epoch, device, guide_w, patience):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        loss_ema = None
        # train
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            loss = ddpm(src, tgt)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            optim.step()
            
        # validation
        ddpm.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                loss = ddpm(src, tgt)
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(ddpm.nn_model.state_dict(), f"model_best_diffusion.pt")
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

    # # eval
    # ddpm.eval()
    # x_gens = []
    # count = 0
    # with torch.no_grad():
    #     for idx, src, tgt in tqdm(testset, total=len(testset)):
    #         if count > 3:
    #             break
    #         x_gens.append((idx, (ddpm.sample(src, guide_w) >= 0.5).long()))
    #         count += 1

    # torch.save(x_gens, "song_test_music103.pt")

def eval_main_loop(ddpm, checkpoint, testset, device, guide_w, rate=0.5):
    ddpm.nn_model.load_state_dict(torch.load(checkpoint))
    ddpm.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 3:
                break
            x_gen = ddpm.sample(src, guide_w)
            x_gen = ((x_gen >= torch.quantile(x_gen, 0.66, dim=-1, keepdim=True)) & (x_gen >= rate)).long()
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

# e. Training


In [39]:
# hardcoding these here
n_epoch = 200
n_T = 1000
n_feat = 128
lr = 1e-4
ws_test = [0.0, 0.5, 2.0]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

src_vocab_size = 12
tgt_vocab_size = 12
d_model = 512
num_heads = 8
num_layers = 4
d_ff = 4096//8
max_seq_length = 2400
dropout = 0.1
batchsize = 16
mode = "train"


if exists("trainset_w.pkl") and exists("validset_w.pkl") and exists("testset_w.pkl"):
    print("splitted dataset found!")
    with open("trainset_w.pkl", "rb") as f:
        trainset = pickle.load(f)
    with open("validset_w.pkl", "rb") as f:
        validset = pickle.load(f)
    with open("testset_w.pkl", "rb") as f:
        testset = pickle.load(f)
else:
    print("?")

def collate_fn(batch):
    # Unpack batch into individual components
    idx, src_data, tgt_data, w = zip(*batch)
    #print(len(rates[0]), len(tgt_data[0]), len(src_data[0]))
    
    # Convert `src_data`, `tgt_data`, and `rates` to tensors if they are not already
    src_data = [torch.tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s for s in src_data]
    tgt_data = [torch.tensor(t, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t for t in tgt_data]

    src_data = [torch.cat([s], dim=-1) for s in src_data]
    tgt_data = [torch.cat([t], dim=-1) for t in tgt_data]

    # Pad src_data
    src_data = nn.utils.rnn.pad_sequence(src_data, batch_first=True, padding_value=0.).to(DEVICE)

    # Pad tgt_data
    tgt_data = nn.utils.rnn.pad_sequence(tgt_data, batch_first=True, padding_value=0).to(DEVICE)

    # Extract the last dimension and one-hot encode it
    return idx, src_data, tgt_data


trainset = data.DataLoader(trainset, batch_size=batchsize, collate_fn=collate_fn)
validset = data.DataLoader(validset, batch_size=1, collate_fn=collate_fn)
testset = data.DataLoader(testset, batch_size=1, collate_fn=collate_fn)


splitted dataset found!


In [40]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout).to(DEVICE)
ddpm = DDPM(nn_model=transformer, betas=(1e-4, 0.02), n_T=n_T, \
    device=DEVICE, n_inference=1000, drop_prob=0.1)
ddpm.to(DEVICE)
optim = torch.optim.Adam(ddpm.parameters(), lr=lr)
train_main_loop(ddpm, optim, trainset, validset, testset, lr, n_epoch, DEVICE, 0, 20)


  0%|          | 1/200 [00:11<36:54, 11.13s/it]

epoch 0, train_loss: 1.0514, valid loss: 0.9746


  1%|          | 2/200 [00:22<36:15, 10.99s/it]

epoch 1, train_loss: 0.5408, valid loss: 0.2044


  2%|▏         | 3/200 [00:32<35:58, 10.95s/it]

epoch 2, train_loss: 0.2086, valid loss: 0.1763


  2%|▏         | 4/200 [00:43<35:44, 10.94s/it]

epoch 3, train_loss: 0.1666, valid loss: 0.1193


  2%|▎         | 5/200 [00:54<35:22, 10.89s/it]

epoch 4, train_loss: 0.1453, valid loss: 0.1404, min_valid_loss: 0.1193, wait: 0 / 20


  3%|▎         | 6/200 [01:05<35:15, 10.91s/it]

epoch 5, train_loss: 0.1470, valid loss: 0.1182


  4%|▎         | 7/200 [01:16<35:09, 10.93s/it]

epoch 6, train_loss: 0.1257, valid loss: 0.0824


  4%|▍         | 8/200 [01:27<34:51, 10.89s/it]

epoch 7, train_loss: 0.1086, valid loss: 0.1606, min_valid_loss: 0.0824, wait: 0 / 20


  4%|▍         | 9/200 [01:38<34:36, 10.87s/it]

epoch 8, train_loss: 0.1225, valid loss: 0.1599, min_valid_loss: 0.0824, wait: 1 / 20


  5%|▌         | 10/200 [01:49<34:22, 10.85s/it]

epoch 9, train_loss: 0.1018, valid loss: 0.0985, min_valid_loss: 0.0824, wait: 2 / 20


  6%|▌         | 11/200 [01:59<34:10, 10.85s/it]

epoch 10, train_loss: 0.0986, valid loss: 0.1377, min_valid_loss: 0.0824, wait: 3 / 20


  6%|▌         | 12/200 [02:10<33:58, 10.85s/it]

epoch 11, train_loss: 0.0827, valid loss: 0.1493, min_valid_loss: 0.0824, wait: 4 / 20


  6%|▋         | 13/200 [02:21<33:47, 10.84s/it]

epoch 12, train_loss: 0.0847, valid loss: 0.0839, min_valid_loss: 0.0824, wait: 5 / 20


  7%|▋         | 14/200 [02:32<33:36, 10.84s/it]

epoch 13, train_loss: 0.0801, valid loss: 0.0922, min_valid_loss: 0.0824, wait: 6 / 20


  8%|▊         | 15/200 [02:43<33:38, 10.91s/it]

epoch 14, train_loss: 0.0766, valid loss: 0.0739


  8%|▊         | 16/200 [02:54<33:23, 10.89s/it]

epoch 15, train_loss: 0.0726, valid loss: 0.0985, min_valid_loss: 0.0739, wait: 0 / 20


  8%|▊         | 17/200 [03:05<33:10, 10.88s/it]

epoch 16, train_loss: 0.0632, valid loss: 0.0821, min_valid_loss: 0.0739, wait: 1 / 20


  9%|▉         | 18/200 [03:15<32:58, 10.87s/it]

epoch 17, train_loss: 0.0688, valid loss: 0.0891, min_valid_loss: 0.0739, wait: 2 / 20


 10%|▉         | 19/200 [03:26<32:53, 10.90s/it]

epoch 18, train_loss: 0.0688, valid loss: 0.0716


 10%|█         | 20/200 [03:37<32:49, 10.94s/it]

epoch 19, train_loss: 0.0596, valid loss: 0.0638


 10%|█         | 21/200 [03:48<32:41, 10.96s/it]

epoch 20, train_loss: 0.0635, valid loss: 0.0629


 11%|█         | 22/200 [03:59<32:27, 10.94s/it]

epoch 21, train_loss: 0.0681, valid loss: 0.0709, min_valid_loss: 0.0629, wait: 0 / 20


 12%|█▏        | 23/200 [04:10<32:12, 10.92s/it]

epoch 22, train_loss: 0.0597, valid loss: 0.0824, min_valid_loss: 0.0629, wait: 1 / 20


 12%|█▏        | 24/200 [04:21<31:58, 10.90s/it]

epoch 23, train_loss: 0.0544, valid loss: 0.0668, min_valid_loss: 0.0629, wait: 2 / 20


 12%|█▎        | 25/200 [04:32<31:51, 10.93s/it]

epoch 24, train_loss: 0.0573, valid loss: 0.0486


 13%|█▎        | 26/200 [04:43<31:37, 10.90s/it]

epoch 25, train_loss: 0.0576, valid loss: 0.0765, min_valid_loss: 0.0486, wait: 0 / 20


 14%|█▎        | 27/200 [04:54<31:23, 10.89s/it]

epoch 26, train_loss: 0.0518, valid loss: 0.0598, min_valid_loss: 0.0486, wait: 1 / 20


 14%|█▍        | 28/200 [05:05<31:12, 10.89s/it]

epoch 27, train_loss: 0.0535, valid loss: 0.0547, min_valid_loss: 0.0486, wait: 2 / 20


 14%|█▍        | 29/200 [05:16<31:00, 10.88s/it]

epoch 28, train_loss: 0.0559, valid loss: 0.0623, min_valid_loss: 0.0486, wait: 3 / 20


 15%|█▌        | 30/200 [05:27<30:54, 10.91s/it]

epoch 29, train_loss: 0.0533, valid loss: 0.0413


 16%|█▌        | 31/200 [05:37<30:45, 10.92s/it]

epoch 30, train_loss: 0.0527, valid loss: 0.0600, min_valid_loss: 0.0413, wait: 0 / 20


 16%|█▌        | 32/200 [05:48<30:31, 10.90s/it]

epoch 31, train_loss: 0.0633, valid loss: 0.0662, min_valid_loss: 0.0413, wait: 1 / 20


 16%|█▋        | 33/200 [05:59<30:19, 10.89s/it]

epoch 32, train_loss: 0.0576, valid loss: 0.0569, min_valid_loss: 0.0413, wait: 2 / 20


 17%|█▋        | 34/200 [06:10<30:07, 10.89s/it]

epoch 33, train_loss: 0.0560, valid loss: 0.0494, min_valid_loss: 0.0413, wait: 3 / 20


 18%|█▊        | 35/200 [06:21<29:55, 10.88s/it]

epoch 34, train_loss: 0.0474, valid loss: 0.0591, min_valid_loss: 0.0413, wait: 4 / 20


 18%|█▊        | 36/200 [06:32<29:44, 10.88s/it]

epoch 35, train_loss: 0.0473, valid loss: 0.0530, min_valid_loss: 0.0413, wait: 5 / 20


 18%|█▊        | 37/200 [06:43<29:32, 10.87s/it]

epoch 36, train_loss: 0.0518, valid loss: 0.0629, min_valid_loss: 0.0413, wait: 6 / 20


 19%|█▉        | 38/200 [06:54<29:20, 10.87s/it]

epoch 37, train_loss: 0.0443, valid loss: 0.0513, min_valid_loss: 0.0413, wait: 7 / 20


 20%|█▉        | 39/200 [07:04<29:10, 10.87s/it]

epoch 38, train_loss: 0.0555, valid loss: 0.0542, min_valid_loss: 0.0413, wait: 8 / 20


 20%|██        | 40/200 [07:15<28:58, 10.87s/it]

epoch 39, train_loss: 0.0583, valid loss: 0.0523, min_valid_loss: 0.0413, wait: 9 / 20


 20%|██        | 41/200 [07:26<28:54, 10.91s/it]

epoch 40, train_loss: 0.0488, valid loss: 0.0409


 21%|██        | 42/200 [07:37<28:41, 10.89s/it]

epoch 41, train_loss: 0.0467, valid loss: 0.0476, min_valid_loss: 0.0409, wait: 0 / 20


 22%|██▏       | 43/200 [07:48<28:28, 10.88s/it]

epoch 42, train_loss: 0.0434, valid loss: 0.0514, min_valid_loss: 0.0409, wait: 1 / 20


 22%|██▏       | 44/200 [07:59<28:16, 10.88s/it]

epoch 43, train_loss: 0.0414, valid loss: 0.0539, min_valid_loss: 0.0409, wait: 2 / 20


 22%|██▎       | 45/200 [08:10<28:04, 10.87s/it]

epoch 44, train_loss: 0.0449, valid loss: 0.0511, min_valid_loss: 0.0409, wait: 3 / 20


 23%|██▎       | 46/200 [08:21<27:53, 10.86s/it]

epoch 45, train_loss: 0.0439, valid loss: 0.0440, min_valid_loss: 0.0409, wait: 4 / 20


 24%|██▎       | 47/200 [08:31<27:41, 10.86s/it]

epoch 46, train_loss: 0.0442, valid loss: 0.0431, min_valid_loss: 0.0409, wait: 5 / 20


 24%|██▍       | 48/200 [08:42<27:31, 10.86s/it]

epoch 47, train_loss: 0.0549, valid loss: 0.0659, min_valid_loss: 0.0409, wait: 6 / 20


 24%|██▍       | 49/200 [08:53<27:19, 10.86s/it]

epoch 48, train_loss: 0.0472, valid loss: 0.0518, min_valid_loss: 0.0409, wait: 7 / 20


 25%|██▌       | 50/200 [09:04<27:08, 10.86s/it]

epoch 49, train_loss: 0.0471, valid loss: 0.0694, min_valid_loss: 0.0409, wait: 8 / 20


 26%|██▌       | 51/200 [09:15<26:58, 10.86s/it]

epoch 50, train_loss: 0.0415, valid loss: 0.0469, min_valid_loss: 0.0409, wait: 9 / 20


 26%|██▌       | 52/200 [09:26<26:47, 10.86s/it]

epoch 51, train_loss: 0.0448, valid loss: 0.0481, min_valid_loss: 0.0409, wait: 10 / 20


 26%|██▋       | 53/200 [09:37<26:42, 10.90s/it]

epoch 52, train_loss: 0.0495, valid loss: 0.0380


 27%|██▋       | 54/200 [09:48<26:29, 10.89s/it]

epoch 53, train_loss: 0.0475, valid loss: 0.0570, min_valid_loss: 0.0380, wait: 0 / 20


 28%|██▊       | 55/200 [09:58<26:17, 10.88s/it]

epoch 54, train_loss: 0.0401, valid loss: 0.0510, min_valid_loss: 0.0380, wait: 1 / 20


 28%|██▊       | 56/200 [10:09<26:05, 10.87s/it]

epoch 55, train_loss: 0.0483, valid loss: 0.0589, min_valid_loss: 0.0380, wait: 2 / 20


 28%|██▊       | 57/200 [10:20<26:02, 10.92s/it]

epoch 56, train_loss: 0.0421, valid loss: 0.0393, min_valid_loss: 0.0380, wait: 3 / 20


 29%|██▉       | 58/200 [10:31<25:57, 10.97s/it]

epoch 57, train_loss: 0.0400, valid loss: 0.0570, min_valid_loss: 0.0380, wait: 4 / 20


 30%|██▉       | 59/200 [10:42<25:41, 10.93s/it]

epoch 58, train_loss: 0.0423, valid loss: 0.0553, min_valid_loss: 0.0380, wait: 5 / 20


 30%|███       | 60/200 [10:53<25:27, 10.91s/it]

epoch 59, train_loss: 0.0393, valid loss: 0.0517, min_valid_loss: 0.0380, wait: 6 / 20


 30%|███       | 61/200 [11:04<25:13, 10.89s/it]

epoch 60, train_loss: 0.0410, valid loss: 0.0596, min_valid_loss: 0.0380, wait: 7 / 20


 31%|███       | 62/200 [11:15<25:01, 10.88s/it]

epoch 61, train_loss: 0.0405, valid loss: 0.0410, min_valid_loss: 0.0380, wait: 8 / 20


 32%|███▏      | 63/200 [11:26<24:49, 10.87s/it]

epoch 62, train_loss: 0.0459, valid loss: 0.0640, min_valid_loss: 0.0380, wait: 9 / 20


 32%|███▏      | 64/200 [11:37<24:37, 10.87s/it]

epoch 63, train_loss: 0.0383, valid loss: 0.0382, min_valid_loss: 0.0380, wait: 10 / 20


 32%|███▎      | 65/200 [11:47<24:26, 10.87s/it]

epoch 64, train_loss: 0.0387, valid loss: 0.0504, min_valid_loss: 0.0380, wait: 11 / 20


 33%|███▎      | 66/200 [11:58<24:21, 10.91s/it]

epoch 65, train_loss: 0.0370, valid loss: 0.0375


 34%|███▎      | 67/200 [12:09<24:08, 10.89s/it]

epoch 66, train_loss: 0.0415, valid loss: 0.0487, min_valid_loss: 0.0375, wait: 0 / 20


 34%|███▍      | 68/200 [12:20<23:55, 10.88s/it]

epoch 67, train_loss: 0.0380, valid loss: 0.0429, min_valid_loss: 0.0375, wait: 1 / 20


 34%|███▍      | 69/200 [12:31<23:44, 10.87s/it]

epoch 68, train_loss: 0.0394, valid loss: 0.0432, min_valid_loss: 0.0375, wait: 2 / 20


 35%|███▌      | 70/200 [12:42<23:32, 10.87s/it]

epoch 69, train_loss: 0.0420, valid loss: 0.0429, min_valid_loss: 0.0375, wait: 3 / 20


 36%|███▌      | 71/200 [12:53<23:29, 10.93s/it]

epoch 70, train_loss: 0.0439, valid loss: 0.0462, min_valid_loss: 0.0375, wait: 4 / 20


 36%|███▌      | 72/200 [13:04<23:23, 10.97s/it]

epoch 71, train_loss: 0.0385, valid loss: 0.0400, min_valid_loss: 0.0375, wait: 5 / 20


 36%|███▋      | 73/200 [13:15<23:24, 11.06s/it]

epoch 72, train_loss: 0.0377, valid loss: 0.0348


 37%|███▋      | 74/200 [13:26<23:13, 11.06s/it]

epoch 73, train_loss: 0.0398, valid loss: 0.0391, min_valid_loss: 0.0348, wait: 0 / 20


 38%|███▊      | 75/200 [13:37<23:03, 11.07s/it]

epoch 74, train_loss: 0.0409, valid loss: 0.0366, min_valid_loss: 0.0348, wait: 1 / 20


 38%|███▊      | 76/200 [13:48<22:48, 11.04s/it]

epoch 75, train_loss: 0.0391, valid loss: 0.0488, min_valid_loss: 0.0348, wait: 2 / 20


 38%|███▊      | 77/200 [13:59<22:39, 11.05s/it]

epoch 76, train_loss: 0.0404, valid loss: 0.0385, min_valid_loss: 0.0348, wait: 3 / 20


 39%|███▉      | 78/200 [14:11<22:36, 11.12s/it]

epoch 77, train_loss: 0.0444, valid loss: 0.0317


 40%|███▉      | 79/200 [14:22<22:24, 11.11s/it]

epoch 78, train_loss: 0.0464, valid loss: 0.0445, min_valid_loss: 0.0317, wait: 0 / 20


 40%|████      | 80/200 [14:33<22:12, 11.10s/it]

epoch 79, train_loss: 0.0364, valid loss: 0.0383, min_valid_loss: 0.0317, wait: 1 / 20


 40%|████      | 81/200 [14:44<22:01, 11.10s/it]

epoch 80, train_loss: 0.0386, valid loss: 0.0461, min_valid_loss: 0.0317, wait: 2 / 20


 41%|████      | 82/200 [14:55<21:49, 11.09s/it]

epoch 81, train_loss: 0.0397, valid loss: 0.0427, min_valid_loss: 0.0317, wait: 3 / 20


 42%|████▏     | 83/200 [15:06<21:37, 11.09s/it]

epoch 82, train_loss: 0.0430, valid loss: 0.0410, min_valid_loss: 0.0317, wait: 4 / 20


 42%|████▏     | 84/200 [15:17<21:26, 11.09s/it]

epoch 83, train_loss: 0.0438, valid loss: 0.0401, min_valid_loss: 0.0317, wait: 5 / 20


 42%|████▎     | 85/200 [15:28<21:15, 11.10s/it]

epoch 84, train_loss: 0.0449, valid loss: 0.0414, min_valid_loss: 0.0317, wait: 6 / 20


 43%|████▎     | 86/200 [15:39<21:05, 11.10s/it]

epoch 85, train_loss: 0.0534, valid loss: 0.0477, min_valid_loss: 0.0317, wait: 7 / 20


 44%|████▎     | 87/200 [15:51<20:54, 11.10s/it]

epoch 86, train_loss: 0.0435, valid loss: 0.0532, min_valid_loss: 0.0317, wait: 8 / 20


 44%|████▍     | 88/200 [16:02<20:43, 11.10s/it]

epoch 87, train_loss: 0.0378, valid loss: 0.0443, min_valid_loss: 0.0317, wait: 9 / 20


 44%|████▍     | 89/200 [16:13<20:32, 11.11s/it]

epoch 88, train_loss: 0.0339, valid loss: 0.0470, min_valid_loss: 0.0317, wait: 10 / 20


 45%|████▌     | 90/200 [16:24<20:21, 11.10s/it]

epoch 89, train_loss: 0.0351, valid loss: 0.0566, min_valid_loss: 0.0317, wait: 11 / 20


 46%|████▌     | 91/200 [16:35<20:09, 11.10s/it]

epoch 90, train_loss: 0.0364, valid loss: 0.0443, min_valid_loss: 0.0317, wait: 12 / 20


 46%|████▌     | 92/200 [16:46<19:58, 11.09s/it]

epoch 91, train_loss: 0.0369, valid loss: 0.0550, min_valid_loss: 0.0317, wait: 13 / 20


 46%|████▋     | 93/200 [16:57<19:46, 11.09s/it]

epoch 92, train_loss: 0.0337, valid loss: 0.0408, min_valid_loss: 0.0317, wait: 14 / 20


 47%|████▋     | 94/200 [17:08<19:35, 11.09s/it]

epoch 93, train_loss: 0.0369, valid loss: 0.0425, min_valid_loss: 0.0317, wait: 15 / 20


 48%|████▊     | 95/200 [17:19<19:23, 11.08s/it]

epoch 94, train_loss: 0.0419, valid loss: 0.0349, min_valid_loss: 0.0317, wait: 16 / 20


 48%|████▊     | 96/200 [17:30<19:12, 11.08s/it]

epoch 95, train_loss: 0.0424, valid loss: 0.0476, min_valid_loss: 0.0317, wait: 17 / 20


 48%|████▊     | 97/200 [17:41<19:01, 11.08s/it]

epoch 96, train_loss: 0.0396, valid loss: 0.0509, min_valid_loss: 0.0317, wait: 18 / 20


 48%|████▊     | 97/200 [17:52<18:59, 11.06s/it]

epoch 97, train_loss: 0.0360, valid loss: 0.0580, min_valid_loss: 0.0317, wait: 19 / 20





In [43]:
eval_main_loop(ddpm, "model_best_diffusion.pt", testset, DEVICE, 0.5, 0.5)

  4%|▍         | 4/100 [00:40<16:18, 10.19s/it]
