# Импорт библиотек

In [None]:
import json
import os
import sys
from collections import OrderedDict
from datetime import datetime
from typing import Sequence, overload, Union, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [None]:
import functools
import itertools
import os
import sys
from fractions import Fraction

from music21 import chord, duration, instrument, note, stream
from fractions import Fraction

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
DATA_DIR = 'dataset/generation_music_dataset/processed_midi'
VOCAB_DIR = 'vocabs'
DATASET_DIR = 'dataset'
CHECKPOINT_DIR = 'checkpoints/VAE'
GENERATED_DIR = 'generated/VAE'

PROCESS_MIDI_FILES = False

SPLIT_DATASET = False
BUILD_DATASET = False

CHANNEL_NAMES = ('note', 'chord', 'duration', 'shift')

MAX_SEQ_LEN = 128
BATCH_SIZE = 256

BASE_LEARNING_RATE = 1e-5
LEARNING_RATE = BASE_LEARNING_RATE * BATCH_SIZE / 512

In [None]:
START_EPOCH = 107
EPOCH_COUNT = 200

# Объявление базовых методов

In [None]:
def split_dataset(path: str, val_prop=0.1):
    df = pd.DataFrame([x for x in os.listdir(path) if x.endswith('abc')], columns=['filename'])
    df['split'] = 'train'

    df.loc[np.random.default_rng(42).choice(len(df), int(val_prop * len(df)), replace=False), 'split'] = 'val'
    df.to_csv(path + '.csv', index=False)

In [None]:
def load(path: str, max_seq_len, drop_last, filename: str):
        path2abc = os.path.join(path, filename)
        df = pd.read_csv(path2abc, names=CHANNEL_NAMES, dtype=str)

        if max_seq_len < 0:
            return [df.values]
        else:
            slices = []
            for start in range(0, len(df), max_seq_len):
                end = start + max_seq_len
                if end > len(df):
                    if drop_last:
                        break

                    end = len(df)

                slices.append(df.iloc[start:end].values)

            return slices

def load_processed_data(path: str, *, split='train', max_seq_len: int = -1, drop_last: bool = True) -> list:
    index = pd.read_csv(path + '.csv')
    index = index[index.split == split]['filename'].values

    data = list()
    from multiprocessing import Pool, cpu_count
    with Pool(processes=cpu_count()*4) as pool:
        for slices in tqdm(pool.imap_unordered(functools.partial(load, path, max_seq_len, drop_last), index), total=len(index)):
            for val in slices:
                data.append(val)

    return data

In [None]:
def make_midi(x: torch.Tensor, vocab_map: OrderedDict):
    midi_stream = stream.Stream()

    x = x.detach().cpu().numpy()
    next_offset = None
    for line in x:
        if np.any(line <= 1):
            continue

        p, c, d, s = (v[int(x)] for v, x in zip(vocab_map.values(), line))

        if p == '<rest>':
            element = note.Rest()
        else:
            if c == '0':
                element = note.Note(p)
            else:
                p = note.pitch.Pitch(p).midi
                note_list = [
                    note.Note(p + int(x)) for x in c.split('-')
                ]
                element = chord.Chord(note_list)

            element.storedInstrument = instrument.Piano()

        element.duration = duration.Duration(Fraction(d))

        midi_stream.append(element)
        if next_offset is not None:
            element.offset = next_offset
            next_offset = None

        if s != 'f':
            next_offset = Fraction(element.offset) + Fraction(s)

    return midi_stream

In [None]:
def get_frequency_elements(elements: list, threshold: int = 50) -> list:
    vocab = {}

    for row in elements:
        vocab[row] = vocab.get(row, 0) + 1

    return list(x for _, x in sorted((-f, x) for x, f in vocab.items() if f >= threshold))

In [None]:
class Vocab:
    def __init__(self, elements: Optional[Sequence[str]], *, unk: str = '<unk>', num_special: Optional[int] = None):
        self.int_to_element = tuple(elements)
        self.element_to_int = dict(zip(self.int_to_element, range(len(self.int_to_element))))
        self.unk_id = self.element_to_int[unk]
        if num_special is None:
            num_special = self.unk_id + 1

        self.num_special = num_special

    @overload
    def __getitem__(self, key: int) -> str:
        pass

    @overload
    def __getitem__(self, key: str) -> int:
        pass

    def __getitem__(self, key):
        if isinstance(key, int):
            return self.int_to_element[key]
        else:
            return self.element_to_int.get(key, self.unk_id)

    @classmethod
    def load(cls, path):
        with open(path, 'r') as file:
            data = json.load(file)

        return cls(data['elements'], unk=data['unk'], num_special=data['num_special'])

    def save(self, path) -> None:
        data = dict(
            elements=self.int_to_element,
            unk=self[self.unk_id],
            num_special=self.num_special
        )
        with open(path, 'w') as file:
            json.dump(data, file)
            file.close()

    def __len__(self):
        return len(self.element_to_int)

In [None]:
def build_frequency_dictionary(elements):
    element_names = get_frequency_elements(elements)

    element_names.insert(0, '<start>')
    element_names.insert(1, '<unk>')

    return Vocab(element_names)

In [None]:
def build_note_dictionary(elements):
    element_names = get_frequency_elements(elements)

    min_midi = None
    max_midi = None

    from music21 import note
    for name in element_names:
        if name == '<rest>':
            continue
        midi = note.Pitch(name).midi
        if min_midi is None:
            min_midi = max_midi = midi
        else:
            min_midi = min(min_midi, midi)
            max_midi = max(max_midi, midi)

    element_names = ['<start>', '<rest>']
    for midi in range(min_midi, max_midi + 1):
        element_names.append(note.Pitch(midi).nameWithOctave)

    return Vocab(element_names, unk='<rest>')

In [None]:
def get_vocab_builder(channel_name: str):
    if channel_name == 'note':
        return build_note_dictionary

    return build_frequency_dictionary

In [None]:
def build_dataset(data, vocab_map):
    X_list = list()
    for idx, vocab in enumerate(vocab_map.values()):
        X = torch.nn.utils.rnn.pad_sequence(
            [torch.from_numpy(np.asarray(
                [vocab[x[idx]] for x in s],
                dtype=np.int64
            )) for s in data],
            batch_first=True
        )
        X_list.append(X)

    return torch.stack(X_list, dim=-1)

# Подготовка данных

In [None]:
if SPLIT_DATASET:
    split_dataset(DATA_DIR)

In [None]:
vocab_map = OrderedDict()

if BUILD_DATASET:
    print("Loading training dataset")
    data = load_processed_data(DATA_DIR, max_seq_len=MAX_SEQ_LEN, drop_last=True)

    print(f'Dataset size {len(data)}')

    for idx, name in enumerate(CHANNEL_NAMES):
        print(f'Building {name} vocab')
        vocab = get_vocab_builder(name)([y[idx] for x in data for y in x])
        vocab.save(os.path.join(VOCAB_DIR, f'{name}.vocab'))
        vocab_map[name] = vocab

    print("Building training dataset")
    X_train = build_dataset(data, vocab_map)

    print("Saving training dataset")
    torch.save(X_train, os.path.join(DATASET_DIR, 'train.pt'))

    print("Building validation dataset")
    X_val = build_dataset(load_processed_data(DATA_DIR, split='val', max_seq_len=MAX_SEQ_LEN), vocab_map)

    print("Saving validation dataset")
    torch.save(X_val, os.path.join(DATASET_DIR, 'validation.pt'))
else:
    for idx, name in enumerate(CHANNEL_NAMES):
          print(f'Loading {name} vocab')
          vocab = Vocab.load(os.path.join(VOCAB_DIR, f'{name}.vocab'))
          vocab_map[name] = vocab

    X_train = torch.load(os.path.join(DATASET_DIR, 'train.pt'))
    X_val = torch.load(os.path.join(DATASET_DIR, 'validation.pt'))

Loading note vocab
Loading chord vocab
Loading duration vocab
Loading shift vocab


In [None]:
class DataSet(torch.utils.data.Dataset):
    def __init__(self, X):
        super().__init__()

        self.X = X

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx]


In [None]:
class TrainDataSetIterator:
    def __init__(self, X, min_note, max_note):
        super().__init__()

        self.X = X
        self.min_note = min_note
        self.max_note = max_note

        self._size = len(self.X)
        self._cursor = 0

        rng = np.random.default_rng()
        self._indices = rng.integers(self._size, size=self._size)

        t = rng.random(self._size) > 0.5
        n_t = t.sum(dtype=np.int32)

        t_off = np.zeros(self._size, dtype=np.int32)
        t_off[t] = rng.normal(size=n_t, scale=12).astype(np.int32)

        self._t_off = t_off

    def __iter__(self):
        return self

    def __next__(self):
        cursor = self._cursor
        if cursor >= self._size:
            raise StopIteration()

        self._cursor = cursor + 1

        x = self.X[self._indices[cursor]]
        t_off = self._t_off[cursor]
        if t_off != 0:
            x = self.transpose(x, t_off)

        return x

    def transpose(self, x: torch.Tensor, offset):
        notes = x[..., 0]
        mask = notes >= self.min_note
        max_n = notes.max().item()
        min_n = notes[mask].min().item()

        min_t = self.min_note - min_n
        max_t = self.max_note - max_n

        if offset >= max_t:
            offset = min_t + (offset - min_t) % (max_t - min_t)
        elif offset < min_t:
            offset = max_t - 1 - (max_t - offset) % (max_t - min_t)

        x = x.clone()
        x[..., 0][mask] += offset
        return x

In [None]:
class TrainDataSet(torch.utils.data.IterableDataset):
    def __init__(self, X, min_note, max_note):
        super().__init__()

        self.X = X
        self.min_note = min_note
        self.max_note = max_note

    def __len__(self):
        return len(self.X)

    def __iter__(self):
        return TrainDataSetIterator(self.X, self.min_note, self.max_note)

    def __getitem__(self, index):
        raise NotImplementedError()

In [None]:
def collate_seq(batch):
    return torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0)


def collate(batch):
    return collate_seq(batch)

# Модели

In [None]:
class Encoder(torch.nn.Module):
    def __init__(self,
                 channel_names: Sequence[int],
                 embedding_sizes: int,
                 embedding_dims: int,
                 layers_filters: Sequence[int],
                 latent_dim: int = 128,
                 seq_len: int = 128,
                 dropout: float = 0.1):
        super().__init__()

        self.seq_len = seq_len
        self.channel_names = channel_names
        
        self.embeddings = nn.ModuleDict([(channel, nn.Embedding(num_embeddings=embedding_sizes[idx], embedding_dim=embedding_dims[idx])) for idx, channel in enumerate(channel_names)])

        modules = []

        in_channels = sum(embedding_dims)

        for filters in layers_filters:
            modules.append(
                nn.Sequential(
                    nn.Dropout(dropout),
                    nn.Conv1d(in_channels, out_channels=filters,
                                kernel_size=4, stride=2, padding=1),
                    nn.ELU())
            )

            in_channels = filters

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(in_channels * (seq_len // (2**len(layers_filters))), latent_dim)
        self.fc_var = nn.Linear(in_channels * (seq_len // (2**len(layers_filters))), latent_dim)
        
    def forward(self, x: torch.Tensor):
        emb = []
        for idx, channel in enumerate(self.channel_names):
            emb.append(self.embeddings[channel](x[..., idx]))
        
        x = torch.cat(emb, dim=-1)

        x = x.permute((0, 2, 1))

        x = self.encoder(x)

        x = torch.flatten(x, start_dim=1)

        mu = self.fc_mu(x)
        log_var = self.fc_var(x)

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)

        return eps * std + mu


In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, 
                 embedding_sizes: Sequence[int],
                 layers_filters: Sequence[int],
                 latent_dim: int = 128,
                 seq_len: int = 128, 
                 dropout: float = 0.1):
        super().__init__()

        self.embedding_sizes = embedding_sizes
        channels_num = sum(embedding_sizes)

        unflatten_len = seq_len // 2**len(layers_filters)

        in_channels = unflatten_len * layers_filters[-1]

        self.decoder_input = nn.Linear(latent_dim, in_channels)

        self.unflatten = nn.Unflatten(1, (layers_filters[-1], unflatten_len))
        
        in_channels = layers_filters[-1]

        modules = []

        for filters in reversed(layers_filters[:-1]):
            modules.append(
                nn.Sequential(
                    nn.Dropout(dropout),
                    nn.ConvTranspose1d(in_channels,
                                        filters,
                                        kernel_size=4,
                                        stride = 2,
                                        padding = 1),
                    nn.ELU())
            )

            in_channels = filters

        modules.append(
            nn.Sequential(
                nn.Dropout(dropout),
                nn.ConvTranspose1d(in_channels,
                                    channels_num,
                                    kernel_size=4,
                                    stride = 2,
                                    padding = 1))
        )

        self.decoder = nn.Sequential(*modules)

    def forward(self, x: torch.Tensor, rng: Optional[torch.Generator] = None):
        x = self.decoder_input(x)

        x = self.unflatten(x)

        x = self.decoder(x)

        x_list = list()

        for x in torch.split(x, self.embedding_sizes, dim=1):
            x_list.append(nn.functional.log_softmax(x, dim=1))

        return x_list

In [None]:
class Generator(torch.nn.Module):
    def __init__(self, 
                 *, 
                 channel_names: Sequence[str], 
                 embedding_sizes: Sequence[int],
                 embedding_dims: Sequence[int],
                 filters: Sequence[int],
                 latent_dim: int = 128,
                 seq_len: int = 128,
                 dropout=0.1):
        
        super().__init__()
        
        self.channel_names = channel_names

        self.latent_dim = latent_dim

        self.encoder = Encoder(channel_names, embedding_sizes, embedding_dims, filters, latent_dim, seq_len, dropout)
        self.decoder = Decoder(embedding_sizes, filters, latent_dim, seq_len, dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.encoder(x)
        x = self.decoder(x)

        return x

    def make_variations(self, x: torch.Tensor) -> torch.Tensor:
        x = self.forward(x)

        pred_list = []

        for channel in x:
            channel = channel.permute((0, 2, 1))
            channel = channel.argmax(dim=-1)
            pred_list.append(channel)

        return torch.stack(pred_list, dim=-1)

    def predict(self, batch_size) -> torch.Tensor:
        device = next(self.parameters()).device

        noise = torch.rand(batch_size, self.latent_dim).to(device)

        channels = self.decoder(noise)

        pred_list = []

        for channel in channels:
            channel = channel.permute((0, 2, 1))
            channel = channel.argmax(dim=-1)
            pred_list.append(channel)

        return torch.stack(pred_list, dim=-1)

# Объявление методов обучения

In [None]:
def make_midi(x: torch.Tensor, vocab_map: OrderedDict):
    from music21 import chord, duration, instrument, note, stream
    from fractions import Fraction

    midi_stream = stream.Stream()

    x = x.detach().cpu().numpy()
    next_offset = None
    for line in x:
        if np.any(line <= 1):
            continue

        p, c, d, s = (v[int(x)] for v, x in zip(vocab_map.values(), line))

        if p == '<rest>':
            element = note.Rest()
        else:
            if c == '0':
                element = note.Note(p)
            else:
                p = note.pitch.Pitch(p).midi
                note_list = [
                    note.Note(p + int(x)) for x in c.split('-')
                ]
                element = chord.Chord(note_list)

            element.storedInstrument = instrument.Piano()

        element.duration = duration.Duration(Fraction(d))

        midi_stream.append(element)
        if next_offset is not None:
            element.offset = next_offset
            next_offset = None

        if s != 'f':
            next_offset = Fraction(element.offset) + Fraction(s)

    return midi_stream


In [None]:
def train(*, 
          start_epoch: int,
          epoch_count: int, 
          device: torch.device,
          generator: Generator,
          generator_optimizer,
          train_data_loader: TrainDataSet,
          val_data_loader: DataSet,
          summary_writer: SummaryWriter):
    
    def save(e):
        if generator_optimizer is not None:
            torch.save(generator.state_dict(),
                        os.path.join(CHECKPOINT_DIR, f'generator-checkpoint-{e + 1}.model'))
            torch.save(generator_optimizer.state_dict(),
                        os.path.join(CHECKPOINT_DIR, f'generator-checkpoint-{e + 1}.optimizer'))

    batch_index = 0

    for epoch in range(start_epoch, start_epoch + epoch_count):
        torch.cuda.reset_peak_memory_stats()

        metrics = dict([(channel, 0) for channel in CHANNEL_NAMES])

        epoch_loss = 0
        train_samples = 0

        generator.train()

        for y_batch in tqdm(train_data_loader):
            y_batch = y_batch.to(device)

            generator_optimizer.zero_grad()
            pred = generator(y_batch)
            
            loss = None

            for idx, channel in enumerate(CHANNEL_NAMES):
                channel_loss = F.nll_loss(pred[idx], y_batch[..., idx])                
                loss = loss + channel_loss if loss is not None else channel_loss
                
                metrics[channel] += channel_loss.item() * y_batch.shape[0]
                summary_writer.add_scalar(f'batch_loss/{channel}', channel_loss.item(), batch_index, new_style=True)

            summary_writer.add_scalar(f'batch_loss/total', loss.item(), batch_index, new_style=True)

            epoch_loss += loss.item() * y_batch.shape[0]
            
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(generator.parameters(), 2.0)
            generator_optimizer.step()

            train_samples += y_batch.shape[0]
            batch_index += 1

        for item in metrics.items():
            channel_loss = item[1] / train_samples
            summary_writer.add_scalar(f'epoch_loss/{item[0]}', channel_loss, epoch, new_style=True)
        
        epoch_loss /= train_samples
        summary_writer.add_scalar('epoch_loss/total', epoch_loss, epoch, new_style=True)

        val_loss, val_acc = validate(epoch, device, generator, val_data_loader, summary_writer)

        print(f'cuda memory {torch.cuda.memory_stats()["allocated_bytes.all.peak"] // (1024 * 1024)} MB')
        print(f'Epoch {epoch + 1}; loss {epoch_loss}; val_loss {val_loss}; val_acc {val_acc}')

        print('save checkpoint')
        save(epoch)
        
        if (epoch + 1) % 10 == 0 or epoch + 1 == start_epoch + epoch_count:
            print('clean up unnecessary сheckpoints')
            os.makedirs(CHECKPOINT_DIR, exist_ok=True)
            for row in os.listdir(CHECKPOINT_DIR):
                try:
                    if int(row.split('-')[-1].split('.')[0]) % 5 != 0:
                        os.remove(os.path.join(CHECKPOINT_DIR, row))
                except:
                    pass

In [None]:
@torch.inference_mode()
def validate(epoch: int,
             device: torch.device,
             generator: Generator,
             data_loader: torch.utils.data.DataLoader,
             summary_writer: SummaryWriter):

    loss_metrics = dict([(channel, 0) for channel in CHANNEL_NAMES])
    acc_metrics = dict([(channel, 0) for channel in CHANNEL_NAMES])
    epoch_loss = 0.
    epoch_acc = 0.
    samples = 0.

    generator.eval()

    with torch.no_grad():
        for y_batch in tqdm(data_loader):
            y_batch = y_batch.to(device)

            pred = generator(y_batch)
            
            # calculate loss
            
            loss = 0

            for idx, channel in enumerate(CHANNEL_NAMES):
                channel_loss = F.nll_loss(pred[idx], y_batch[..., idx]).item()
                loss_metrics[channel] += channel_loss * y_batch.shape[0]
                loss += channel_loss
            
            epoch_loss += loss * y_batch.shape[0]
            
            # calculate acc
            
            pred_list = []
            
            for channel in pred:
                channel = channel.permute((0, 2, 1))
                channel = channel.argmax(dim=-1)
                pred_list.append(channel)
                
            pred = torch.stack(pred_list, dim=-1)
            
            del pred_list
            
            for idx, channel in enumerate(CHANNEL_NAMES):
                channel_acc = (pred[..., idx] == y_batch[..., idx]).float().sum().item() / y_batch.shape[1]
                acc_metrics[channel] += channel_acc 
                
            epoch_acc += (pred == y_batch).float().min(dim=-1)[0].sum().item() / y_batch.shape[1]

            samples += y_batch.shape[0]

    for item in loss_metrics.items():
        channel_loss = item[1] / samples
        summary_writer.add_scalar(f'val_loss/{item[0]}', channel_loss, epoch, new_style=True)

    epoch_loss /= samples
    summary_writer.add_scalar('val_loss/total', epoch_loss, epoch, new_style=True)

    for item in acc_metrics.items():
        channel_acc = item[1] / samples
        summary_writer.add_scalar(f'val_acc/{item[0]}', channel_acc, epoch, new_style=True)

    epoch_acc /= samples
    summary_writer.add_scalar('val_acc/total', epoch_acc, epoch, new_style=True)

    return epoch_loss, epoch_acc

# Обучение

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")

generator = Generator(
    channel_names=CHANNEL_NAMES,
    embedding_sizes=tuple(len(v) for v in vocab_map.values()),
    embedding_dims=(48, 128, 32, 32),
    filters=[384, 512, 256, 256],
    latent_dim=2048,
    dropout=0.15,
)

if START_EPOCH > 0:
    generator.load_state_dict(torch.load(
        os.path.join(CHECKPOINT_DIR, f'generator-checkpoint-{START_EPOCH}.model'),
        map_location='cpu'
    ))

generator.to(device)

generator_optimizer = torch.optim.AdamW(generator.parameters(), lr=LEARNING_RATE)
if START_EPOCH > 0:
    generator_optimizer.load_state_dict(torch.load(
        os.path.join(CHECKPOINT_DIR, f'generator-checkpoint-{START_EPOCH}.optimizer'),
        map_location='cpu'
    ))


In [None]:
batch_size = BATCH_SIZE
epoch_count = EPOCH_COUNT

train_data_loader = torch.utils.data.DataLoader(TrainDataSet(X_train, 3, len(vocab_map['note'])),
                                                batch_size, shuffle=False,
                                                collate_fn=collate,
                                                drop_last=True,
                                                pin_memory=True)

val_data_loader = torch.utils.data.DataLoader(DataSet(X_val),
                                                batch_size * 4, shuffle=False,
                                                collate_fn=collate,
                                                pin_memory=True)


run_name = f'run-{datetime.now().strftime("%Y-%m-%d %H-%M-%S")}'
summary_writer = SummaryWriter(log_dir=os.path.join('logs', run_name))

In [None]:
print("Training")
train(
    start_epoch=START_EPOCH,
    epoch_count=EPOCH_COUNT,
    device=device,
    generator=generator,
    generator_optimizer=generator_optimizer,
    train_data_loader=train_data_loader,
    val_data_loader=val_data_loader,
    summary_writer=summary_writer
)

Training


100%|████████████████████████████████████████████████████████████████| 707/707 [05:07<00:00,  2.30it/s]
100%|██████████████████████████████████████████████████████████████████| 22/22 [00:10<00:00,  2.17it/s]


cuda memory 1980 MB
Epoch 108; loss 6.91065745711158; val_loss 6.58783307787958; val_acc 0.09766090626719237
save checkpoint


100%|████████████████████████████████████████████████████████████████| 707/707 [05:06<00:00,  2.31it/s]
100%|██████████████████████████████████████████████████████████████████| 22/22 [00:10<00:00,  2.19it/s]


cuda memory 1980 MB
Epoch 109; loss 6.8941836728064985; val_loss 6.576181836845721; val_acc 0.09825010315422703
save checkpoint


100%|████████████████████████████████████████████████████████████████| 707/707 [05:08<00:00,  2.29it/s]
100%|██████████████████████████████████████████████████████████████████| 22/22 [00:10<00:00,  2.19it/s]


cuda memory 1980 MB
Epoch 110; loss 6.887736999668191; val_loss 6.565468013210077; val_acc 0.09857711638089126
save checkpoint
clean up unnecessary сheckpoints


100%|████████████████████████████████████████████████████████████████| 707/707 [05:07<00:00,  2.30it/s]
100%|██████████████████████████████████████████████████████████████████| 22/22 [00:10<00:00,  2.17it/s]


cuda memory 1980 MB
Epoch 111; loss 6.875609531429566; val_loss 6.554775637223177; val_acc 0.09869388123509995
save checkpoint


 40%|█████████████████████████▉                                      | 286/707 [02:05<03:06,  2.26it/s]

# Тестирование

In [None]:
with torch.inference_mode():
    generator.eval()
    music_batch = generator.predict(batch_size=8)

os.makedirs(GENERATED_DIR, exist_ok=True)
for idx in range(music_batch.shape[0]):
    midi = make_midi(music_batch[idx], vocab_map)
    midi.write('midi', os.path.join(GENERATED_DIR, f'midi-{idx}.midi'))

In [None]:
music_batch.shape