# Music 103 diffusion version

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm
from os.path import exists
from os import remove, chdir
import pickle

In [14]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        max_len = x.size(1)
        pe = torch.zeros(max_len, self.d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).to(x.device)
        return x + pe

    
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x

class EmbedHead(nn.Module):
    def __init__(
        self,
        input_dim,
        inner_dim_1,
        inner_dim_2,
        out_dim
    ):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, inner_dim_1)
        self.linear2 = nn.Linear(inner_dim_1, inner_dim_2)
        self.linear3 = nn.Linear(inner_dim_2, out_dim)
        self.activation_fn = nn.functional.gelu

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        x = self.linear2(x)
        x = self.activation_fn(x)
        x = self.linear3(x)
        return x
    

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super().__init__()
        self.encoder_embedding = EmbedHead(src_vocab_size + tgt_vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.time_embeddings = nn.ModuleList([EmbedFC(1, d_model) for _ in range(num_layers)])

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.output_layer = nn.Sequential(PositionWiseFeedForward(d_model, d_ff), nn.Linear(d_model, tgt_vocab_size))
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, x, time):
        embedding = self.dropout(self.positional_encoding(self.encoder_embedding(torch.cat([src, x], dim=-1))))

        for i, enc_layer in enumerate(self.encoder_layers):
            time_embedding = self.time_embeddings[i](time).unsqueeze(1)
            embedding = enc_layer(embedding + time_embedding, None)
        
        output = self.output_layer(embedding)
        
        return output
    


In [19]:
class VQVAE(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, dropout, codebook_size, d_codebook):
        super().__init__()
        self.encoder_embedding = EmbedHead(vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_output = nn.Linear(d_model, d_codebook)
        self.codebook = nn.Embedding(codebook_size, d_codebook)
        self.decoder_embedding = EmbedHead(d_codebook, d_model, d_model, d_model)
        self.decoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_output = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encode(self, x):
        embedding = self.dropout(self.positional_encoding(self.encoder_embedding(x)))
        for i, enc_layer in enumerate(self.encoder_layers):
            embedding = enc_layer(embedding, None)
        return self.encoder_output(embedding)
    
    def vq(self, z):
        # z: [batch_size, seq_length, d_codebook]
        distance = (z.unsqueeze(2) - self.codebook.weight.unsqueeze(0).unsqueeze(0)).pow(2).mean(dim=-1)
        _, indices = torch.min(distance, dim=-1)
        return self.codebook(indices)

    def decode(self, z):
        embedding = self.dropout(self.positional_encoding(self.decoder_embedding(z)))
        for i, dec_layer in enumerate(self.decoder_layers):
            embedding = dec_layer(embedding, None)
        return torch.sigmoid(self.decoder_output(embedding))
    
    def forward(self, x):
        # x: [batch_size, seq_length, vocab_size]
        z = self.encode(x)
        z_vq = self.vq(z)
        z_straight_through = (z_vq - z).detach() + z
        x_recon = self.decode(z_straight_through)
        recon_loss = nn.functional.mse_loss(x_recon, x)
        embed_loss = nn.functional.mse_loss(z_vq, z.detach())
        commit_loss = nn.functional.mse_loss(z, z_vq.detach())
        return x_recon, recon_loss, embed_loss, commit_loss


In [20]:
from tqdm import tqdm

def train_VQVAE(vqvae, optim, trainset, validset, lr, n_epoch, device, patience, alpha=0.5,beta=1):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        vqvae.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        loss_ema = None
        # train
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            _, recon_loss, embed_loss, commit_loss = vqvae(tgt)
            loss = recon_loss + beta * embed_loss + alpha * commit_loss
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            optim.step()
            
        # validation
        vqvae.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                _, recon_loss, embed_loss, commit_loss = vqvae(tgt)
                loss = recon_loss
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(vqvae.state_dict(), f"model_best_vqvae.pt")
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

In [17]:
# hardcoding these here
n_epoch = 200
n_T = 1000
n_feat = 128
lr = 1e-4
ws_test = [0.0, 0.5, 2.0]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

src_vocab_size = 12
tgt_vocab_size = 12
d_model = 512
num_heads = 8
num_layers = 4
d_ff = 4096//8
max_seq_length = 2400
dropout = 0.1
batchsize = 16
mode = "train"


if exists("trainset_w.pkl") and exists("validset_w.pkl") and exists("testset_w.pkl"):
    print("splitted dataset found!")
    with open("trainset_w.pkl", "rb") as f:
        trainset = pickle.load(f)
    with open("validset_w.pkl", "rb") as f:
        validset = pickle.load(f)
    with open("testset_w.pkl", "rb") as f:
        testset = pickle.load(f)
else:
    print("?")

def collate_fn(batch):
    # Unpack batch into individual components
    idx, src_data, tgt_data, w = zip(*batch)
    #print(len(rates[0]), len(tgt_data[0]), len(src_data[0]))
    
    # Convert `src_data`, `tgt_data`, and `rates` to tensors if they are not already
    src_data = [torch.tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s for s in src_data]
    tgt_data = [torch.tensor(t, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t for t in tgt_data]

    src_data = [torch.cat([s], dim=-1) for s in src_data]
    tgt_data = [torch.cat([t], dim=-1) for t in tgt_data]

    # Pad src_data
    src_data = nn.utils.rnn.pad_sequence(src_data, batch_first=True, padding_value=0.).to(DEVICE)

    # Pad tgt_data
    tgt_data = nn.utils.rnn.pad_sequence(tgt_data, batch_first=True, padding_value=0).to(DEVICE)

    # Extract the last dimension and one-hot encode it
    return idx, src_data, tgt_data


trainset = data.DataLoader(trainset, batch_size=batchsize, collate_fn=collate_fn)
validset = data.DataLoader(validset, batch_size=1, collate_fn=collate_fn)
testset = data.DataLoader(testset, batch_size=1, collate_fn=collate_fn)

splitted dataset found!


In [22]:
vqvae = VQVAE(tgt_vocab_size, d_model, num_heads, 4, d_ff, dropout, 256, 12).to(DEVICE)
optim = torch.optim.Adam(vqvae.parameters(), lr=lr)
train_VQVAE(vqvae, optim, trainset, validset, lr, n_epoch, DEVICE, 20)


  0%|          | 1/200 [00:19<1:03:57, 19.28s/it]

epoch 0, train_loss: 0.2588, valid loss: 0.1958


  1%|          | 2/200 [00:38<1:02:42, 19.00s/it]

epoch 1, train_loss: 0.1498, valid loss: 0.1937


  2%|▏         | 3/200 [00:57<1:03:14, 19.26s/it]

epoch 2, train_loss: 0.1429, valid loss: 0.1929


  2%|▏         | 4/200 [01:16<1:02:43, 19.20s/it]

epoch 3, train_loss: 0.1395, valid loss: 0.1923


  2%|▎         | 5/200 [01:34<1:01:14, 18.84s/it]

epoch 4, train_loss: 0.1369, valid loss: 0.1920


  3%|▎         | 6/200 [01:53<1:00:15, 18.64s/it]

epoch 5, train_loss: 0.1348, valid loss: 0.1919


  4%|▎         | 7/200 [02:11<59:30, 18.50s/it]  

epoch 6, train_loss: 0.1329, valid loss: 0.1918


  4%|▍         | 8/200 [02:29<59:10, 18.49s/it]

epoch 7, train_loss: 0.1312, valid loss: 0.1918, min_valid_loss: 0.1918, wait: 0 / 20


  4%|▍         | 9/200 [02:48<58:36, 18.41s/it]

epoch 8, train_loss: 0.1300, valid loss: 0.1918, min_valid_loss: 0.1918, wait: 1 / 20


  5%|▌         | 10/200 [03:06<58:36, 18.51s/it]

epoch 9, train_loss: 0.1290, valid loss: 0.1919, min_valid_loss: 0.1918, wait: 2 / 20


  6%|▌         | 11/200 [03:25<58:00, 18.41s/it]

epoch 10, train_loss: 0.1284, valid loss: 0.1920, min_valid_loss: 0.1918, wait: 3 / 20


  6%|▌         | 12/200 [03:43<57:33, 18.37s/it]

epoch 11, train_loss: 0.1280, valid loss: 0.1922, min_valid_loss: 0.1918, wait: 4 / 20


  6%|▋         | 13/200 [04:01<57:20, 18.40s/it]

epoch 12, train_loss: 0.1276, valid loss: 0.1923, min_valid_loss: 0.1918, wait: 5 / 20


  7%|▋         | 14/200 [04:19<56:49, 18.33s/it]

epoch 13, train_loss: 0.1274, valid loss: 0.1925, min_valid_loss: 0.1918, wait: 6 / 20


  8%|▊         | 15/200 [04:37<56:06, 18.20s/it]

epoch 14, train_loss: 0.1272, valid loss: 0.1927, min_valid_loss: 0.1918, wait: 7 / 20


  8%|▊         | 16/200 [04:56<55:48, 18.20s/it]

epoch 15, train_loss: 0.1271, valid loss: 0.1928, min_valid_loss: 0.1918, wait: 8 / 20


  8%|▊         | 17/200 [05:14<56:01, 18.37s/it]

epoch 16, train_loss: 0.1270, valid loss: 0.1930, min_valid_loss: 0.1918, wait: 9 / 20


  9%|▉         | 18/200 [05:32<55:16, 18.22s/it]

epoch 17, train_loss: 0.1269, valid loss: 0.1931, min_valid_loss: 0.1918, wait: 10 / 20


 10%|▉         | 19/200 [05:51<55:05, 18.26s/it]

epoch 18, train_loss: 0.1268, valid loss: 0.1934, min_valid_loss: 0.1918, wait: 11 / 20


 10%|█         | 20/200 [06:09<55:09, 18.39s/it]

epoch 19, train_loss: 0.1267, valid loss: 0.1935, min_valid_loss: 0.1918, wait: 12 / 20


 10%|█         | 21/200 [06:27<54:25, 18.25s/it]

epoch 20, train_loss: 0.1266, valid loss: 0.1937, min_valid_loss: 0.1918, wait: 13 / 20


 11%|█         | 22/200 [06:45<53:46, 18.13s/it]

epoch 21, train_loss: 0.1266, valid loss: 0.1939, min_valid_loss: 0.1918, wait: 14 / 20


 12%|█▏        | 23/200 [07:03<53:30, 18.14s/it]

epoch 22, train_loss: 0.1265, valid loss: 0.1941, min_valid_loss: 0.1918, wait: 15 / 20


 12%|█▏        | 23/200 [07:07<54:49, 18.59s/it]


KeyboardInterrupt: 

In [37]:
def eval_vqvae(vqvae, checkpoint, testset):
    vqvae.load_state_dict(torch.load(checkpoint))
    vqvae.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 3:
                break
            x_gen, _, _, _ = vqvae(tgt)
            print(x_gen)
            x_gen = (x_gen >= torch.quantile(x_gen, 0.66, dim=-1, keepdim=True)).long()
            
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

In [38]:
eval_vqvae(vqvae, "model_best_vqvae.pt", testset)

  4%|▍         | 4/100 [00:00<00:02, 45.08it/s]

tensor([[[0.1133, 0.1141, 0.0905,  ..., 0.0930, 0.0840, 0.1269],
         [0.1173, 0.1253, 0.0941,  ..., 0.1006, 0.0864, 0.1244],
         [0.1398, 0.1426, 0.1165,  ..., 0.1231, 0.1110, 0.1434],
         ...,
         [0.0379, 0.0384, 0.0307,  ..., 0.0345, 0.0251, 0.0398],
         [0.0400, 0.0372, 0.0323,  ..., 0.0339, 0.0257, 0.0404],
         [0.0420, 0.0360, 0.0330,  ..., 0.0323, 0.0260, 0.0412]]])
tensor([[[0.1137, 0.1135, 0.0914,  ..., 0.0937, 0.0842, 0.1270],
         [0.1177, 0.1246, 0.0951,  ..., 0.1013, 0.0866, 0.1245],
         [0.1403, 0.1419, 0.1178,  ..., 0.1241, 0.1114, 0.1437],
         ...,
         [0.0376, 0.0361, 0.0275,  ..., 0.0345, 0.0242, 0.0386],
         [0.0379, 0.0356, 0.0278,  ..., 0.0329, 0.0248, 0.0388],
         [0.0376, 0.0357, 0.0279,  ..., 0.0312, 0.0249, 0.0392]]])
tensor([[[0.1456, 0.1365, 0.1331,  ..., 0.1405, 0.1132, 0.1604],
         [0.1503, 0.1469, 0.1378,  ..., 0.1503, 0.1164, 0.1581],
         [0.1683, 0.1581, 0.1584,  ..., 0.1675, 0.1398, 0.




# **1. DDPM**


# a. Building Blocks

# b. DDPM Schedules

In [45]:
def ddpm_schedules(beta1, beta2, T):
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    ##################
    ### Problem 1 (a): Implement ddpm_schedules()
    beta_t = torch.linspace(beta1, beta2, T).float()

    alpha_t = 1 - beta_t
    oneover_sqrta = 1 / torch.sqrt(alpha_t)
    sqrt_beta_t = torch.sqrt(beta_t)
    alphabar_t = torch.cumprod(alpha_t, dim=0)
    sqrtab = torch.sqrt(alphabar_t)
    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / torch.sqrt(1 - alphabar_t)
    ##################
    ##################

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }

# c. DDPM Main Module



Here the noise $\sigma_t^2=\beta_t$

In [46]:
class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device, n_inference=None, drop_prob=0.1):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)
        
        self.n_T = n_T
        self.n_inference = n_inference if n_inference else n_T 
        
        for k, v in ddpm_schedules(betas[0], betas[1], self.n_inference).items():
            self.register_buffer(k+'_KAIMING', v)

        self.device = device
        self.drop_prob = drop_prob
        self.loss_mse = nn.MSELoss()

    def forward(self, src, tgt):
        ##################
        ### Problem 1 (b): Implement forward()
        t = torch.randint(0, self.n_T, (tgt.size(0),), device=self.device)
        sqrtab_t, sqrtmab_t = self.sqrtab[t].view(-1, 1, 1), \
            self.sqrtmab[t].view(-1, 1, 1)

        noise = torch.randn_like(tgt).to(self.device)  # Define noise tensor
        x_t = sqrtab_t * tgt + sqrtmab_t * noise

        # mask out with probability
        # context_mask = torch.bernoulli(torch.zeros(src.shape[0])+self.drop_prob).unsqueeze(-1).unsqueeze(-1).to(self.device)

        pred_noise = self.nn_model(src, x_t, t / (self.n_T - 1))
        loss = self.loss_mse(pred_noise, noise) 
        ##################
        ##################

        return loss

    @torch.no_grad()
    def sample(self, src, guide_w=0.0):
        n_sample = src.shape[0]
        x_i = torch.randn(*src.shape).to(self.device)

        for i in range(int(self.n_inference), 0, -1):

            ##################
            ### Problem 1 (c): Implement sample()
            t = torch.full((n_sample,), (i - 1) / (self.n_inference - 1)).to(self.device).float()
            t_i = t.view(-1, 1, 1)

            # double batch

            z = torch.randn(*src.shape).to(self.device) if i > 1 else 0 # if last step, no noise
            # classifier-free guidance
            pred_full = self.nn_model(src, x_i, t_i)
            x_i = self.oneover_sqrta_KAIMING[i - 1] * (x_i - pred_full * self.mab_over_sqrtmab_KAIMING[i - 1])\
                + self.sqrt_beta_t_KAIMING[i - 1] * z
        return x_i

# c. Training Function

In [47]:
from tqdm import tqdm

def train_main_loop(ddpm, optim, trainset, validset, testset, lr, n_epoch, device, guide_w, patience):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        ddpm.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        loss_ema = None
        # train
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            loss = ddpm(src, tgt)
            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
            optim.step()
            
        # validation
        ddpm.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                loss = ddpm(src, tgt)
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(ddpm.nn_model.state_dict(), f"model_best_diffusion.pt")
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_ema:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

    # # eval
    # ddpm.eval()
    # x_gens = []
    # count = 0
    # with torch.no_grad():
    #     for idx, src, tgt in tqdm(testset, total=len(testset)):
    #         if count > 3:
    #             break
    #         x_gens.append((idx, (ddpm.sample(src, guide_w) >= 0.5).long()))
    #         count += 1

    # torch.save(x_gens, "song_test_music103.pt")

def eval_main_loop(ddpm, checkpoint, testset, device, guide_w, rate=0.5):
    ddpm.nn_model.load_state_dict(torch.load(checkpoint))
    ddpm.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 3:
                break
            x_gen = ddpm.sample(src, guide_w)
            x_gen = ((x_gen >= torch.quantile(x_gen, 0.66, dim=-1, keepdim=True)) & (x_gen >= rate)).long()
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

# e. Training


splitted dataset found!


In [49]:
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout).to(DEVICE)
ddpm = DDPM(nn_model=transformer, betas=(1e-4, 0.02), n_T=n_T, \
    device=DEVICE, n_inference=1000, drop_prob=0.1)
ddpm.to(DEVICE)
optim = torch.optim.Adam(ddpm.parameters(), lr=lr)
train_main_loop(ddpm, optim, trainset, validset, testset, lr, n_epoch, DEVICE, 0, 20)


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 1/200 [00:03<13:12,  3.98s/it]

epoch 0, train_loss: 0.7961, valid loss: 0.4862


  1%|          | 2/200 [00:07<13:09,  3.99s/it]

epoch 1, train_loss: 0.2770, valid loss: 0.1736


  2%|▏         | 3/200 [00:11<13:05,  3.99s/it]

epoch 2, train_loss: 0.1961, valid loss: 0.1186


  2%|▏         | 4/200 [00:15<13:01,  3.99s/it]

epoch 3, train_loss: 0.1822, valid loss: 0.1173


  2%|▎         | 5/200 [00:19<12:54,  3.97s/it]

epoch 4, train_loss: 0.1443, valid loss: 0.1291, min_valid_loss: 0.1173, wait: 0 / 20


  3%|▎         | 6/200 [00:23<12:47,  3.96s/it]

epoch 5, train_loss: 0.1629, valid loss: 0.1384, min_valid_loss: 0.1173, wait: 1 / 20


  4%|▎         | 7/200 [00:27<12:42,  3.95s/it]

epoch 6, train_loss: 0.1139, valid loss: 0.1499, min_valid_loss: 0.1173, wait: 2 / 20


  4%|▍         | 8/200 [00:31<12:41,  3.96s/it]

epoch 7, train_loss: 0.1186, valid loss: 0.0975


  4%|▍         | 9/200 [00:35<12:35,  3.96s/it]

epoch 8, train_loss: 0.1358, valid loss: 0.1037, min_valid_loss: 0.0975, wait: 0 / 20


  5%|▌         | 10/200 [00:39<12:31,  3.95s/it]

epoch 9, train_loss: 0.1357, valid loss: 0.1358, min_valid_loss: 0.0975, wait: 1 / 20


  6%|▌         | 11/200 [00:43<12:29,  3.97s/it]

epoch 10, train_loss: 0.0975, valid loss: 0.0846


  6%|▌         | 12/200 [00:47<12:27,  3.98s/it]

epoch 11, train_loss: 0.0953, valid loss: 0.0795


  6%|▋         | 13/200 [00:51<12:22,  3.97s/it]

epoch 12, train_loss: 0.0891, valid loss: 0.1167, min_valid_loss: 0.0795, wait: 0 / 20


  7%|▋         | 14/200 [00:55<12:16,  3.96s/it]

epoch 13, train_loss: 0.0764, valid loss: 0.0943, min_valid_loss: 0.0795, wait: 1 / 20


  8%|▊         | 15/200 [00:59<12:16,  3.98s/it]

epoch 14, train_loss: 0.0756, valid loss: 0.1029, min_valid_loss: 0.0795, wait: 2 / 20


  8%|▊         | 16/200 [01:03<12:14,  3.99s/it]

epoch 15, train_loss: 0.0921, valid loss: 0.0758


  8%|▊         | 17/200 [01:07<12:08,  3.98s/it]

epoch 16, train_loss: 0.0868, valid loss: 0.0844, min_valid_loss: 0.0758, wait: 0 / 20


  9%|▉         | 18/200 [01:11<12:02,  3.97s/it]

epoch 17, train_loss: 0.0673, valid loss: 0.0866, min_valid_loss: 0.0758, wait: 1 / 20


 10%|▉         | 19/200 [01:15<11:56,  3.96s/it]

epoch 18, train_loss: 0.0808, valid loss: 0.0853, min_valid_loss: 0.0758, wait: 2 / 20


 10%|█         | 20/200 [01:19<11:54,  3.97s/it]

epoch 19, train_loss: 0.0734, valid loss: 0.0664


In [34]:
eval_main_loop(ddpm, "model_best_diffusion.pt", testset, DEVICE, 0.1, 0.5)

  0%|          | 0/100 [00:00<?, ?it/s]

  4%|▍         | 4/100 [00:19<07:57,  4.98s/it]
