In [19]:
import sys
sys.path.append("..")
import torch
import torch.nn.functional as F
import torchaudio
from omegaconf import OmegaConf
from ccreaim.model import operate
from ccreaim.utils import dataset, cfg_classes, audio_tools, util

In [5]:
## Maestro
data_tar = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_training_aug.tar"
data_tar_non_aug = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_training.tar"
samples_tar = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_samples.tar"

load_trf_path = "/home/ccreaim/ccreaim_models/bank-classifier/maestro_final.pt"
context_length = 8
sample_rate = 44100

In [20]:
## Sounds
data_tar = "/home/ccreaim/ccreaim_data/sounds/samples_sound_bank_training_aug.tar"
data_tar_non_aug = "/home/ccreaim/ccreaim_data/sounds/samples_sound_bank_training.tar"
samples_tar = "/home/ccreaim/ccreaim_data/sounds/sounds_bank_samples.tar"

# With pitch shift
load_trf_path = "/home/ccreaim/ccreaim_models/bank-classifier/samples_final.pt"

context_length = 8
sample_rate = 44100

In [21]:
# Load transformer 

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

trf_checkpoint = torch.load(load_trf_path, map_location="cpu")
trf_state_dict = trf_checkpoint["model_state_dict"]
hyper_cfg_schema = OmegaConf.structured(cfg_classes.HyperConfig)
trf_conf = OmegaConf.create(trf_checkpoint["hyper_config"])
del trf_conf["num_seq"]
del trf_conf["seq_cat"]
del trf_conf["pre_trained_ae_path"]
del trf_conf["pre_trained_vqvae_path"]
trf_hyper_cfg = OmegaConf.merge(hyper_cfg_schema, trf_conf)
trf_hyper_cfg.transformer.dropout = 0.0 # Since we're using trf.train() for no cacheing, set this to avoid dropout in inference
get_trf = operate.get_model_init_function(trf_hyper_cfg)
trf = get_trf()
trf.load_state_dict(trf_state_dict)
trf = trf.to(device)
trf

CachedDecoderOnly(
  (positional_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (transformer_decoder): CachedTransformerEncoder(
    (layers): ModuleList(
      (0): CachedTransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=612, out_features=612, bias=True)
        )
        (linear1): Linear(in_features=612, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=612, bias=True)
        (norm1): LayerNorm((612,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((612,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (trf_out_to_tokens): Linear(in_features=612, out_features=8192, bias=True)
)

In [22]:
# Prepare datasets
tmp_data_root_aug = dataset.prepare_dataset_on_tmp(data_tar)
data_aug = dataset.BankTransformerDataset(tmp_data_root_aug)
tmp_data_root_non_aug = dataset.prepare_dataset_on_tmp(data_tar_non_aug)
data_non_aug = dataset.BankTransformerDataset(tmp_data_root_non_aug)
# bank = dataset.Bank(tmp_data_root, context_length)

/tmp/samples_sound_bank_training_aug 11264
/tmp/samples_sound_bank_training 1024


In [23]:
tmp_data_root_samples = dataset.prepare_dataset_on_tmp(samples_tar)
data_samples = dataset.AudioDataset(tmp_data_root_samples, sample_rate)
print(len(data_samples))

/tmp/sounds_bank_samples 8192
8192


In [24]:
# list of indices with "mistakes" = 100, 250, 300, 350, 450
features, real_indices, _ = data_aug[100]

In [29]:
with torch.inference_mode():
    features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features_[:, 0:1, :], device=features_.device),
                features_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()

    pred, _ = trf(tgt, tgt_mask=tgt_mask)
    indices = pred.argmax(-1)
    print(indices)
    print(real_indices.long())
    probabilities = F.softmax(pred, dim=-1)
    print(probabilities.max(dim=2))
    print(probabilities[:,:,4075])

torch.Size([1, 8, 612])
weights, tensor([[[[1.0000e+00, 1.0971e-08, 2.0641e-07, 3.6103e-07, 1.2944e-07,
           1.0638e-08, 1.0996e-09, 2.5016e-10],
          [3.4107e-23, 1.7018e-18, 6.0317e-15, 2.3865e-12, 1.0967e-07,
           5.3606e-04, 1.4581e-01, 8.5365e-01],
          [1.5156e-23, 4.8620e-19, 2.2752e-15, 8.3772e-13, 5.6568e-08,
           4.3791e-04, 1.1395e-01, 8.8561e-01],
          [5.2161e-26, 3.9955e-21, 2.0433e-17, 1.9500e-14, 3.2632e-09,
           8.4272e-05, 5.8298e-02, 9.4162e-01],
          [4.6168e-30, 3.3179e-24, 1.3783e-20, 9.1729e-18, 8.0493e-12,
           2.4435e-06, 1.1732e-02, 9.8827e-01],
          [2.5733e-33, 7.8725e-27, 1.9420e-23, 8.2369e-21, 2.6598e-14,
           8.5370e-08, 2.4468e-03, 9.9755e-01],
          [4.9019e-33, 2.7800e-27, 4.9796e-24, 1.7590e-21, 5.5623e-15,
           2.3074e-08, 1.3871e-03, 9.9861e-01],
          [1.8002e-30, 6.8980e-26, 6.4151e-23, 1.5309e-20, 1.7296e-14,
           2.9368e-08, 1.3245e-03, 9.9868e-01]],

         [[1.

In [30]:
def play_prediction(dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    audio_tools.play_audio(sample, sample_rate)
        
    
play_prediction(data_samples, indices.int().squeeze()[1:], sample_rate)

tensor([882, 887, 879, 879, 879, 879, 879], dtype=torch.int32)
/tmp/sounds_bank_samples/ind_00882_320529__vumseplutten1709__cargoshipstarfleet_context_14_sample2.wav
/tmp/sounds_bank_samples/ind_00887_320529__vumseplutten1709__cargoshipstarfleet_context_14_sample7.wav
/tmp/sounds_bank_samples/ind_00879_320529__vumseplutten1709__cargoshipstarfleet_context_13_sample7.wav
/tmp/sounds_bank_samples/ind_00879_320529__vumseplutten1709__cargoshipstarfleet_context_13_sample7.wav
/tmp/sounds_bank_samples/ind_00879_320529__vumseplutten1709__cargoshipstarfleet_context_13_sample7.wav
/tmp/sounds_bank_samples/ind_00879_320529__vumseplutten1709__cargoshipstarfleet_context_13_sample7.wav
/tmp/sounds_bank_samples/ind_00879_320529__vumseplutten1709__cargoshipstarfleet_context_13_sample7.wav


In [10]:
def save_prediction(path, dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    torchaudio.save(path, sample, sample_rate)

In [31]:
# THIS DOESN'T WORK WITH THE AUGMENTED DATASET
sample_id = 450
features1, real_indices1, _ = data_aug[sample_id]
features2, real_indices2, _ = data_aug[sample_id+1]

In [33]:
# Transformer has seen 
# [80, 81, 82, 83, 84, 85, 86, 87]
# and 
# [88, 89, 90, 91, 92, 93, 94, 95]
# But let's show it 
# [84, 85, 86, 87, 88, 89, 90, 91]

# Result: The transformer can't even predict 88 from 85,86,87
# so it hasn't really even learned anything :D 
with torch.inference_mode():
    features3 = torch.cat(
                  (
                      features1[context_length//2:context_length, :],
                      features2[0:context_length//2, :],
                  ),
                  dim=0,
        )
    print(features3.size())
    features3_ = features3.unsqueeze(0)
    features3_ = features3_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features3_[:, 0:1, :], device=features3_.device),
                features3_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()
    pred, _ = trf(tgt, tgt_mask=tgt_mask)
    print(pred.shape)
    indices = pred.argmax(-1)
    print(real_indices1)
    print(real_indices2)
    print(indices)

torch.Size([8, 612])
torch.Size([1, 8, 612])
weights, tensor([[[[1.0000e+00, 7.1476e-09, 5.8938e-08, 4.3771e-08, 7.2142e-08,
           5.0035e-09, 2.8530e-10, 4.2078e-11],
          [1.6723e-24, 9.9274e-20, 1.1912e-16, 1.4293e-12, 7.3347e-09,
           1.0239e-04, 5.4336e-05, 9.9984e-01],
          [1.4205e-26, 4.0161e-21, 8.6949e-18, 2.3008e-13, 1.4750e-09,
           3.9182e-05, 4.1185e-05, 9.9992e-01],
          [7.0431e-30, 3.5720e-23, 4.9337e-20, 2.5901e-15, 2.7409e-11,
           4.0939e-06, 2.3070e-05, 9.9997e-01],
          [2.5353e-31, 7.1826e-25, 5.5132e-22, 4.0406e-17, 8.0705e-13,
           4.7941e-07, 1.8236e-05, 9.9998e-01],
          [1.2966e-33, 6.2874e-27, 1.8987e-24, 1.7281e-19, 5.3311e-15,
           3.5647e-08, 1.0212e-05, 9.9999e-01],
          [7.4292e-33, 2.3445e-26, 1.5447e-24, 4.9065e-20, 8.8928e-16,
           1.0399e-08, 3.0209e-05, 9.9997e-01],
          [1.8106e-33, 1.7053e-27, 2.1442e-25, 1.3954e-20, 4.9500e-16,
           7.1874e-09, 9.0304e-06, 9.9999e

In [34]:
# This hacky function works only with the non-augmented dataset, can be used for POC generation
def get_features_from_index(ind, dataset, context_length):
    context, indices, a = dataset[ind//context_length]
    assert ind == indices[ind % context_length]
    return context[ind % context_length]
    

In [35]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [36]:
def create_feature_vec_from_clip(x, Fs, frame_size, frame_step, deltas):
    F, f_names = ShortTermFeatures.feature_extraction(
        x, Fs, frame_size * Fs, frame_step * Fs, deltas=deltas
    )
    features = torch.tensor(F).view(-1)
    return features

In [37]:
# Returns both a scipy.io wavfile output and torchaudio.load output
def return_clip_from_path(path, starting_point_sec, clip_length_sec):
    torch_data, samp_rate = torchaudio.load(path)
    [Fs, x] = audioBasicIO.read_audio_file(path)
    
    if len(x.shape) != 1:
        x = numpy.average(x, axis=1).astype(x.dtype)
    if len(torch_data.shape) != 1:
        torch_data = torch.mean(torch_data, dim=0)

    torch_clip = torch_data[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    clip = x[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    return clip, torch_clip.unsqueeze(0), samp_rate

In [41]:
from pyAudioAnalysis import ShortTermFeatures, audioBasicIO
import numpy
frame_size = 0.2
frame_step = 0.1
deltas = True
latent_dim = 612 # number of features returned per 1 second clip by ShortTermFeatures

# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Schubert7-9_MID--AUDIO_11_R2_2018_wav.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital17-19_MID--AUDIO_17_R1_2018_wav--2.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital1-3_MID--AUDIO_01_R1_2018_wav--1.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2017/MIDI-Unprocessed_083_PIANO083_MID--AUDIO-split_07-09-17_Piano-e_2_-06_wav--4.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/samples/99-12-2009.wav"

# test_fn = "ethereal-teleport"
# test_fn = "across"
# test_fn = "quartet"
# test_fn = "piano_man"
# test_fn = "train-upon-us"
test_fn = "waves-of-hawaii"
external_sample_path = "/home/ccreaim/ccreaim_data/sounds/test_files/{}.wav".format(test_fn)

# Generation
top_p = 0
top_k = 5
temperature = 1.0

external = True
start_sec = 0

sample_id = 1385

n = context_length // 2
torch_sample = None
with torch.inference_mode():
    # Pick half of context length, generate rest
    if external:
        features_ = torch.zeros(1, n, latent_dim)
        indices = torch.zeros(n)
        for i in range(n):
            clip, torch_clip, samp_rate = return_clip_from_path(external_sample_path, start_sec+i, 1)
            print(clip.shape, torch_clip.shape, samp_rate)
            feature_vec = create_feature_vec_from_clip(clip, samp_rate, frame_size, frame_step, deltas)
            print(feature_vec.shape)
            features_[:,i,:] = feature_vec
            if torch_sample == None:
                torch_sample = torch_clip
            else:
                torch_sample = torch.cat((torch_sample, torch_clip),dim=1)
    else:
        features, indices, _ = data_aug[sample_id]
        features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    
    features_half = features_[:, :n, :]
    tgt = torch.cat(
            (
                torch.zeros_like(features_half[:, 0:1, :], device=features_.device),
                features_half[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())
    
    # No cacheing
    trf = trf.train()
    indices = indices[0:(n-1)]
    print(indices)
    for i in range(n):
        print(tgt.shape)
        trf_out, attn = trf(tgt)
        print("hello", trf_out.shape)
        trf_out_filtered = top_k_top_p_filtering(trf_out[:,-1,:].squeeze()/temperature, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(trf_out_filtered)
        print("max_prob:", max(probabilities))
        emb_ind = torch.multinomial(probabilities,1)
        print(emb_ind)
        # print("optional_emb_ind:", trf_out[:,-1,:].argmax(-1))
        #emb_ind = trf_out[:,-1,:].argmax(-1)
        indices = torch.cat((indices, emb_ind), dim=0)
        next_feature = get_features_from_index(emb_ind.item(), data_non_aug, context_length)
        print("features:", next_feature.shape)
        tgt = torch.cat(
                (
                    tgt,
                    next_feature.unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        probabilities = F.softmax(trf_out, dim=-1)
        print(probabilities.max(dim=2)[0][-1][-1])
        print(probabilities.max(dim=2)[1][-1][-1])
    print(indices)

  sampling_rate, signal = wavfile.read(input_file) # from scipy.io


(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
torch.Size([1, 4, 612])
tensor([0., 0., 0.])
torch.Size([1, 4, 612])
weights, tensor([[[[9.9999e-01, 1.9165e-07, 2.4804e-06, 5.8560e-06],
          [1.7269e-07, 9.9528e-07, 3.2019e-04, 9.9968e-01],
          [2.5501e-07, 7.2617e-07, 2.8986e-04, 9.9971e-01],
          [7.1740e-08, 4.8607e-07, 1.8195e-04, 9.9982e-01]],

         [[1.0000e+00, 2.6505e-08, 5.5112e-08, 6.9265e-08],
          [1.0000e+00, 6.6255e-20, 1.1092e-18, 1.9372e-17],
          [1.0000e+00, 9.4039e-25, 1.0788e-23, 7.7890e-23],
          [1.0000e+00, 4.1493e-26, 3.8557e-25, 1.6077e-24]],

         [[9.9122e-01, 1.4157e-04, 1.5806e-03, 7.0531e-03],
          [6.0601e-11, 3.1544e-06, 4.9449e-04, 9.9950e-01],
          [1.5944e-09, 2.7129e-06, 4.8015e-04, 9.9952e-01],
          [5.6230e-08, 2.3238e-06,

AttributeError: 'tuple' object has no attribute 'shape'

In [18]:
play_prediction(data_samples, indices[3:].int(), sample_rate, torch_sample)

tensor([4523, 4524, 4525, 4526], dtype=torch.int32)
/tmp/sounds_bank_samples/ind_04523_recording 2021-11-22 13-59-28_context_83_sample3.wav
/tmp/sounds_bank_samples/ind_04524_recording 2021-11-22 13-59-28_context_83_sample4.wav
/tmp/sounds_bank_samples/ind_04525_recording 2021-11-22 13-59-28_context_83_sample5.wav
/tmp/sounds_bank_samples/ind_04526_recording 2021-11-22 13-59-28_context_83_sample6.wav


In [19]:
import torchaudio
save = True
#save = False
save_path_root = "bank_outputs/maestro"
if save:
    save_prediction("{}/{}/{}-gen3-sec{}.wav".format(save_path_root, test_fn, test_fn, start_sec), data_samples, indices[3:].int(), sample_rate, torch_sample)

tensor([4523, 4524, 4525, 4526], dtype=torch.int32)
/tmp/sounds_bank_samples/ind_04523_recording 2021-11-22 13-59-28_context_83_sample3.wav
/tmp/sounds_bank_samples/ind_04524_recording 2021-11-22 13-59-28_context_83_sample4.wav
/tmp/sounds_bank_samples/ind_04525_recording 2021-11-22 13-59-28_context_83_sample5.wav
/tmp/sounds_bank_samples/ind_04526_recording 2021-11-22 13-59-28_context_83_sample6.wav


RuntimeError: Error saving audio file: failed to open file bank_outputs/maestro/waves-of-hawaii/waves-of-hawaii-gen3-sec0.wav