# Music 103 diffusion version

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import numpy as np
import copy
import pandas as pd
from tqdm import tqdm
from os.path import exists
from os import remove, chdir
import pickle

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = torch.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output
        
    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1, 2)
        
    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        return x.transpose(1, 2).contiguous().view(batch_size, seq_length, self.d_model)
        
    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q))
        K = self.split_heads(self.W_k(K))
        V = self.split_heads(self.W_v(V))
        
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))
        return output
    
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
    
class PositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super(PositionalEncoding, self).__init__()
        self.d_model = d_model

    def forward(self, x):
        max_len = x.size(1)
        pe = torch.zeros(max_len, self.d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        
        pe = pe.unsqueeze(0).to(x.device)
        return x + pe


class DecoderPositionalEncoding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, x, tgt):
        # max_len = x.size(1)
        tgt_one_hot = tgt[:, :, 12:]
        tgt_class = torch.argmax(tgt_one_hot, dim=-1)
        pe = torch.zeros_like(x)
        position = torch.cumsum(tgt_class, dim=1).unsqueeze(-1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * -(math.log(10000.0) / self.d_model)).to(position.device)
        
        pe[:, :, 0::2] = torch.sin(position * div_term)
        pe[:, :, 1::2] = torch.cos(position * div_term)
        return x + pe

    
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))
        return x
    
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.self_attn(x, x, x, tgt_mask)
        x = self.norm1(x + self.dropout(attn_output))
        attn_output = self.cross_attn(x, enc_output, enc_output, src_mask)
        x = self.norm2(x + self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm3(x + self.dropout(ff_output))
        return x

class EmbedHead(nn.Module):
    def __init__(
        self,
        input_dim,
        inner_dim_1,
        inner_dim_2,
        out_dim
    ):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, inner_dim_1)
        self.linear2 = nn.Linear(inner_dim_1, inner_dim_2)
        self.linear3 = nn.Linear(inner_dim_2, out_dim)
        self.activation_fn = nn.functional.gelu

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)
        x = self.linear2(x)
        x = self.activation_fn(x)
        x = self.linear3(x)
        return x
    

class EmbedFC(nn.Module):
    def __init__(self, input_dim, emb_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, emb_dim),
            nn.GELU(),
            nn.Linear(emb_dim, emb_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout):
        super(Transformer, self).__init__()
        self.encoder_embedding = EmbedHead(src_vocab_size, d_model, d_model, d_model)
        self.decoder_embedding = EmbedHead(tgt_vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (torch.sum(src, dim=2) > 0).unsqueeze(1).unsqueeze(2)
        tgt_mask = (torch.sum(tgt, dim=2) > 0).unsqueeze(1).unsqueeze(3)
        seq_length = tgt.size(1)
        
        d = 1
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal=d)).bool().to(src.device)
        tgt_mask = tgt_mask & nopeak_mask
        return src_mask, tgt_mask

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        if src_mask is None or tgt_mask is None:
            src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embedded = self.dropout(self.positional_encoding(self.encoder_embedding(src)))
        tgt_embedded = self.dropout(self.positional_encoding(self.decoder_embedding(tgt)))

        enc_output = src_embedded
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embedded
        for i, dec_layer in enumerate(self.decoder_layers):
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = self.fc(dec_output)
        return output



In [4]:
class VQVAE(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, dropout, codebook_size, d_codebook):
        super().__init__()
        self.codebook_size = codebook_size
        self.encoder_embedding = EmbedHead(vocab_size, d_model, d_model, d_model)
        self.positional_encoding = PositionalEncoding(d_model)
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.encoder_output = nn.Linear(d_model, d_codebook)
        self.codebook = nn.Embedding(codebook_size, d_codebook)
        self.codebook.weight.data.uniform_(-1/d_codebook, 1/d_codebook)
        self.decoder_embedding = EmbedHead(d_codebook, d_model, d_model, d_model)
        self.decoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])
        self.decoder_output = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(dropout)

    def encode(self, x):
        embedding = self.dropout(self.positional_encoding(self.encoder_embedding(x)))
        for i, enc_layer in enumerate(self.encoder_layers):
            embedding = enc_layer(embedding, None)
        return self.encoder_output(embedding)
    
    def vq_indices(self, z):
        distance = (z.unsqueeze(2) - self.codebook.weight.unsqueeze(0).unsqueeze(0)).pow(2).mean(dim=-1)
        _, indices = torch.min(distance, dim=-1)
        return indices
    
    def vq_one_hot(self, z):
        indices = self.vq_indices(z)
        one_hot = torch.nn.functional.one_hot(indices, num_classes=self.codebook_size).float()
        return one_hot

    def vq(self, z):
        return self.codebook(self.vq_indices(z))

    def decode(self, z):
        embedding = self.dropout(self.positional_encoding(self.decoder_embedding(z)))
        for i, dec_layer in enumerate(self.decoder_layers):
            embedding = dec_layer(embedding, None)
        return torch.sigmoid(self.decoder_output(embedding))
    
    def forward(self, x):
        # x: [batch_size, seq_length, vocab_size]
        z = self.encode(x)
        z_vq = self.vq(z)
        z_straight_through = (z_vq - z).detach() + z
        x_recon = self.decode(z_straight_through)
        recon_loss = nn.functional.binary_cross_entropy(x_recon, x)
        embed_loss = nn.functional.mse_loss(z_vq, z.detach())
        commit_loss = nn.functional.mse_loss(z, z_vq.detach())
        return x_recon, recon_loss, embed_loss, commit_loss


# **1. DDPM**


# a. Building Blocks

# b. DDPM Schedules

# c. DDPM Main Module



Here the noise $\sigma_t^2=\beta_t$

# c. Training Function

In [5]:
def train_main_loop(transformer, vqvae, optim, trainset, validset, lr, n_epoch, device, patience):
    wait = 0
    min_valid_loss = float('inf')
    for ep in tqdm(range(n_epoch)):
        transformer.train()

        # linear lrate decay
        optim.param_groups[0]['lr'] = lr*(1-ep/n_epoch)
        # train
        criterion = nn.CrossEntropyLoss()
        for idx, src, tgt in trainset:
            optim.zero_grad()
            tgt = tgt.to(device)
            src = src.to(device)
            src_mask, tgt_mask = transformer.generate_mask(src, tgt[:, :-1, :])
            tgt_indices = vqvae.vq_indices(vqvae.encode(tgt))
            tgt_one_hot = torch.nn.functional.one_hot(tgt_indices, num_classes=vqvae.codebook_size).float()
            output = transformer(src, tgt_one_hot[:, :-1, :])
            loss = criterion(output.contiguous().view(-1, vqvae.codebook_size), tgt_indices[:, 1:].contiguous().view(-1))
            loss_train = loss.item()
            loss.backward()
            optim.step()
            
        # validation
        transformer.eval()
        total_loss = 0
        with torch.no_grad():
            for idx, src, tgt in validset:
                tgt = tgt.to(device)
                src = src.to(device)
                src_mask, tgt_mask = transformer.generate_mask(src, tgt[:, :-1, :])
                tgt_indices = vqvae.vq_indices(vqvae.encode(tgt))
                tgt_one_hot = torch.nn.functional.one_hot(tgt_indices, num_classes=vqvae.codebook_size).float()
                output = transformer(src, tgt_one_hot[:, :-1, :], src_mask, tgt_mask)
                loss = criterion(output.contiguous().view(-1, vqvae.codebook_size), tgt_indices[:, 1:].contiguous().view(-1))
                total_loss += loss.item()
        avg_valid_loss = total_loss / len(validset)

        # early stopping
        if avg_valid_loss < min_valid_loss:
            min_valid_loss = avg_valid_loss
            torch.save(transformer.state_dict(), f"model_best_autoreg.pt")
            print(f'epoch {ep}, train_loss: {loss_train:.4f}, valid loss: {avg_valid_loss:.4f}')
            wait = 0
        else:
            print(f'epoch {ep}, train_loss: {loss_train:.4f}, valid loss: {avg_valid_loss:.4f}, min_valid_loss: {min_valid_loss:.4f}, wait: {wait} / {patience}')
            wait += 1
        if wait >= patience:
            break

def eval_main_loop(transformer, vqvae, checkpoint, testset, device, guide_w, rate=0.5):
    transformer.load_state_dict(torch.load(checkpoint))
    transformer.eval()
    x_gens = []
    count = 0
    with torch.no_grad():
        for idx, src, tgt in tqdm(testset, total=len(testset)):
            if count > 10:
                break
            
            tgt_enc = vqvae.vq_one_hot(vqvae.encode(tgt))
            sampled_indices = []
            current_tgt_enc = tgt_enc[:, :1, :]
            for t in range(1, tgt_enc.size(1)):
                output = transformer(src, current_tgt_enc).detach()
                tgt_enc_prediction = torch.softmax(output[:, -1, :], dim=-1)
                # sample from categorical distribution
                sampled_index = torch.multinomial(tgt_enc_prediction, 1)
                sampled_indices.append(sampled_index)
                sampled_one_hot = torch.nn.functional.one_hot(sampled_index, num_classes=vqvae.codebook_size).float()
                current_tgt_enc = torch.cat([current_tgt_enc, sampled_one_hot], dim=1)
            x_gen = vqvae.decode(vqvae.codebook(torch.cat(sampled_indices, dim=1)))
            x_gen = (x_gen >= rate).long()
            x_gens.append((idx, x_gen))
            count += 1

    torch.save(x_gens, "song_test_music103.pt")

# e. Training


In [6]:
# hardcoding these here
n_epoch = 200
n_T = 1000
n_feat = 128
lr = 1e-4
ws_test = [0.0, 0.5, 2.0]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

src_vocab_size = 12
tgt_vocab_size = 128
d_model = 512
num_heads = 8
num_layers = 4
d_ff = 4096//8
max_seq_length = 2400
dropout = 0.1
batchsize = 16
mode = "train"


if exists("trainset_w.pkl") and exists("validset_w.pkl") and exists("testset_w.pkl"):
    print("splitted dataset found!")
    with open("trainset_w.pkl", "rb") as f:
        trainset = pickle.load(f)
    with open("validset_w.pkl", "rb") as f:
        validset = pickle.load(f)
    with open("testset_w.pkl", "rb") as f:
        testset = pickle.load(f)
else:
    print("?")

def collate_fn(batch):
    # Unpack batch into individual components
    idx, src_data, tgt_data, w = zip(*batch)
    #print(len(rates[0]), len(tgt_data[0]), len(src_data[0]))
    
    # Convert `src_data`, `tgt_data`, and `rates` to tensors if they are not already
    src_data = [torch.tensor(s, dtype=torch.float32) if not isinstance(s, torch.Tensor) else s for s in src_data]
    tgt_data = [torch.tensor(t, dtype=torch.float32) if not isinstance(t, torch.Tensor) else t for t in tgt_data]

    tgt_data = [torch.cat([torch.zeros(1, 12), t], dim=0) for t in tgt_data]

    # Pad src_data
    src_data = nn.utils.rnn.pad_sequence(src_data, batch_first=True, padding_value=0.).to(DEVICE)

    # Pad tgt_data
    tgt_data = nn.utils.rnn.pad_sequence(tgt_data, batch_first=True, padding_value=0).to(DEVICE)

    # Extract the last dimension and one-hot encode it
    return idx, src_data, tgt_data


trainset = data.DataLoader(trainset, batch_size=batchsize, collate_fn=collate_fn)
validset = data.DataLoader(validset, batch_size=1, collate_fn=collate_fn)
testset = data.DataLoader(testset, batch_size=1, collate_fn=collate_fn)


splitted dataset found!


In [7]:
lr = 1e-4
transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, 3, d_ff, max_seq_length, dropout).to(DEVICE)
optim = torch.optim.Adam(transformer.parameters(), lr=lr)
vqvae = VQVAE(12, 512, num_heads, 1, d_ff, dropout, 128, 12).to(DEVICE)
vqvae.load_state_dict(torch.load("model_best_vqvae_128.pt"))
train_main_loop(transformer, vqvae, optim, trainset, validset, lr, n_epoch, DEVICE, 20)


  0%|          | 1/200 [00:08<29:11,  8.80s/it]

epoch 0, train_loss: 2.7433, valid loss: 3.4339


  1%|          | 2/200 [00:17<28:46,  8.72s/it]

epoch 1, train_loss: 2.5741, valid loss: 3.3317


  2%|▏         | 3/200 [00:25<28:17,  8.62s/it]

epoch 2, train_loss: 2.2084, valid loss: 2.9922


  2%|▏         | 4/200 [00:34<28:00,  8.58s/it]

epoch 3, train_loss: 1.8246, valid loss: 2.6139


  2%|▎         | 5/200 [00:43<28:00,  8.62s/it]

epoch 4, train_loss: 1.6812, valid loss: 2.5380


  3%|▎         | 6/200 [00:51<27:44,  8.58s/it]

epoch 5, train_loss: 1.5437, valid loss: 2.3626


  4%|▎         | 7/200 [01:00<27:32,  8.56s/it]

epoch 6, train_loss: 1.3831, valid loss: 2.1932


  4%|▍         | 8/200 [01:08<27:30,  8.60s/it]

epoch 7, train_loss: 1.2340, valid loss: 1.9960


  4%|▍         | 9/200 [01:17<27:20,  8.59s/it]

epoch 8, train_loss: 1.0965, valid loss: 1.8918


  5%|▌         | 10/200 [01:26<27:10,  8.58s/it]

epoch 9, train_loss: 1.0914, valid loss: 1.8285


  6%|▌         | 11/200 [01:34<27:03,  8.59s/it]

epoch 10, train_loss: 1.0263, valid loss: 1.7840


  6%|▌         | 12/200 [01:43<26:54,  8.59s/it]

epoch 11, train_loss: 1.0329, valid loss: 1.7537


  6%|▋         | 13/200 [01:51<26:46,  8.59s/it]

epoch 12, train_loss: 1.0117, valid loss: 1.7240


  7%|▋         | 14/200 [02:00<26:39,  8.60s/it]

epoch 13, train_loss: 0.9887, valid loss: 1.7049


  8%|▊         | 15/200 [02:09<26:31,  8.60s/it]

epoch 14, train_loss: 0.9485, valid loss: 1.6894


  8%|▊         | 16/200 [02:17<26:26,  8.62s/it]

epoch 15, train_loss: 0.9634, valid loss: 1.6775


  8%|▊         | 17/200 [02:26<26:18,  8.63s/it]

epoch 16, train_loss: 0.9576, valid loss: 1.6661


  9%|▉         | 18/200 [02:34<26:11,  8.63s/it]

epoch 17, train_loss: 0.9456, valid loss: 1.6604


 10%|▉         | 19/200 [02:43<26:02,  8.63s/it]

epoch 18, train_loss: 0.9341, valid loss: 1.6496


 10%|█         | 20/200 [02:52<25:53,  8.63s/it]

epoch 19, train_loss: 0.9351, valid loss: 1.6385


 10%|█         | 21/200 [03:00<25:41,  8.61s/it]

epoch 20, train_loss: 0.9209, valid loss: 1.6392, min_valid_loss: 1.6385, wait: 0 / 20


 11%|█         | 22/200 [03:09<25:34,  8.62s/it]

epoch 21, train_loss: 0.9149, valid loss: 1.6268


 12%|█▏        | 23/200 [03:18<25:27,  8.63s/it]

epoch 22, train_loss: 0.9004, valid loss: 1.6213


 12%|█▏        | 24/200 [03:26<25:20,  8.64s/it]

epoch 23, train_loss: 0.8916, valid loss: 1.6172


 12%|█▎        | 25/200 [03:35<25:13,  8.65s/it]

epoch 24, train_loss: 0.8806, valid loss: 1.6112


 13%|█▎        | 26/200 [03:43<24:59,  8.62s/it]

epoch 25, train_loss: 0.8858, valid loss: 1.6121, min_valid_loss: 1.6112, wait: 0 / 20


 14%|█▎        | 27/200 [03:52<24:53,  8.63s/it]

epoch 26, train_loss: 0.8750, valid loss: 1.5995


 14%|█▍        | 28/200 [04:01<24:40,  8.61s/it]

epoch 27, train_loss: 0.8691, valid loss: 1.5998, min_valid_loss: 1.5995, wait: 0 / 20


 14%|█▍        | 29/200 [04:09<24:29,  8.59s/it]

epoch 28, train_loss: 0.8571, valid loss: 1.5997, min_valid_loss: 1.5995, wait: 1 / 20


 15%|█▌        | 30/200 [04:18<24:24,  8.61s/it]

epoch 29, train_loss: 0.8459, valid loss: 1.5962


 16%|█▌        | 31/200 [04:27<24:16,  8.62s/it]

epoch 30, train_loss: 0.8578, valid loss: 1.5909


 16%|█▌        | 32/200 [04:35<24:10,  8.63s/it]

epoch 31, train_loss: 0.8261, valid loss: 1.5904


 16%|█▋        | 33/200 [04:44<23:57,  8.61s/it]

epoch 32, train_loss: 0.8005, valid loss: 1.5955, min_valid_loss: 1.5904, wait: 0 / 20


 17%|█▋        | 34/200 [04:52<23:46,  8.59s/it]

epoch 33, train_loss: 0.8209, valid loss: 1.6047, min_valid_loss: 1.5904, wait: 1 / 20


 18%|█▊        | 35/200 [05:01<23:36,  8.58s/it]

epoch 34, train_loss: 0.8192, valid loss: 1.6046, min_valid_loss: 1.5904, wait: 2 / 20


 18%|█▊        | 36/200 [05:09<23:26,  8.58s/it]

epoch 35, train_loss: 0.8192, valid loss: 1.5925, min_valid_loss: 1.5904, wait: 3 / 20


 18%|█▊        | 37/200 [05:18<23:21,  8.60s/it]

epoch 36, train_loss: 0.8235, valid loss: 1.5869


 19%|█▉        | 38/200 [05:27<23:10,  8.58s/it]

epoch 37, train_loss: 0.8168, valid loss: 1.6000, min_valid_loss: 1.5869, wait: 0 / 20


 20%|█▉        | 39/200 [05:35<23:00,  8.58s/it]

epoch 38, train_loss: 0.8006, valid loss: 1.6185, min_valid_loss: 1.5869, wait: 1 / 20


 20%|██        | 40/200 [05:44<22:51,  8.57s/it]

epoch 39, train_loss: 0.7942, valid loss: 1.6259, min_valid_loss: 1.5869, wait: 2 / 20


 20%|██        | 41/200 [05:52<22:43,  8.57s/it]

epoch 40, train_loss: 0.7988, valid loss: 1.6140, min_valid_loss: 1.5869, wait: 3 / 20


 21%|██        | 42/200 [06:01<22:34,  8.57s/it]

epoch 41, train_loss: 0.7832, valid loss: 1.6147, min_valid_loss: 1.5869, wait: 4 / 20


 22%|██▏       | 43/200 [06:09<22:25,  8.57s/it]

epoch 42, train_loss: 0.7985, valid loss: 1.6306, min_valid_loss: 1.5869, wait: 5 / 20


 22%|██▏       | 44/200 [06:18<22:16,  8.57s/it]

epoch 43, train_loss: 0.7705, valid loss: 1.6282, min_valid_loss: 1.5869, wait: 6 / 20


 22%|██▎       | 45/200 [06:27<22:07,  8.57s/it]

epoch 44, train_loss: 0.7760, valid loss: 1.6192, min_valid_loss: 1.5869, wait: 7 / 20


 23%|██▎       | 46/200 [06:35<22:00,  8.57s/it]

epoch 45, train_loss: 0.7570, valid loss: 1.6376, min_valid_loss: 1.5869, wait: 8 / 20


 24%|██▎       | 47/200 [06:44<21:51,  8.57s/it]

epoch 46, train_loss: 0.7523, valid loss: 1.6511, min_valid_loss: 1.5869, wait: 9 / 20


 24%|██▍       | 48/200 [06:52<21:43,  8.57s/it]

epoch 47, train_loss: 0.7562, valid loss: 1.6582, min_valid_loss: 1.5869, wait: 10 / 20


 24%|██▍       | 49/200 [07:01<21:34,  8.57s/it]

epoch 48, train_loss: 0.7478, valid loss: 1.6467, min_valid_loss: 1.5869, wait: 11 / 20


 25%|██▌       | 50/200 [07:10<21:32,  8.62s/it]

epoch 49, train_loss: 0.7477, valid loss: 1.6485, min_valid_loss: 1.5869, wait: 12 / 20


 26%|██▌       | 51/200 [07:18<21:20,  8.60s/it]

epoch 50, train_loss: 0.7217, valid loss: 1.6360, min_valid_loss: 1.5869, wait: 13 / 20


 26%|██▌       | 52/200 [07:27<21:10,  8.58s/it]

epoch 51, train_loss: 0.7084, valid loss: 1.6399, min_valid_loss: 1.5869, wait: 14 / 20


 26%|██▋       | 53/200 [07:35<21:00,  8.58s/it]

epoch 52, train_loss: 0.7216, valid loss: 1.6456, min_valid_loss: 1.5869, wait: 15 / 20


 27%|██▋       | 54/200 [07:44<20:51,  8.57s/it]

epoch 53, train_loss: 0.7226, valid loss: 1.6604, min_valid_loss: 1.5869, wait: 16 / 20


 28%|██▊       | 55/200 [07:52<20:41,  8.56s/it]

epoch 54, train_loss: 0.7140, valid loss: 1.6735, min_valid_loss: 1.5869, wait: 17 / 20


 28%|██▊       | 56/200 [08:01<20:32,  8.56s/it]

epoch 55, train_loss: 0.7185, valid loss: 1.6504, min_valid_loss: 1.5869, wait: 18 / 20


 28%|██▊       | 56/200 [08:10<21:00,  8.75s/it]

epoch 56, train_loss: 0.7161, valid loss: 1.6682, min_valid_loss: 1.5869, wait: 19 / 20





In [9]:
eval_main_loop(transformer, vqvae,"model_best_autoreg.pt", testset, DEVICE, 2, 0.5)

  0%|          | 0/100 [00:00<?, ?it/s]

 11%|█         | 11/100 [00:29<03:57,  2.67s/it]
