In [2]:
import sys
sys.path.append("..")
import torch
import torch.nn.functional as F
import torchaudio
from omegaconf import OmegaConf
from ccreaim.model import operate
from ccreaim.utils import dataset, cfg_classes, audio_tools, util

In [3]:
## Maestro
data_tar = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_training_aug.tar"
data_tar_non_aug = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_training.tar"
samples_tar = "/home/ccreaim/ccreaim_data/maestro/maestro_bank_samples.tar"

load_trf_path = "/home/ccreaim/ccreaim_models/bank-classifier/maestro_final.pt"
context_length = 8
sample_rate = 44100

In [20]:
## Sounds
data_tar = "/home/ccreaim/ccreaim_data/sounds/samples_sound_bank_training_aug.tar"
data_tar_non_aug = "/home/ccreaim/ccreaim_data/sounds/samples_sound_bank_training.tar"
samples_tar = "/home/ccreaim/ccreaim_data/sounds/sounds_bank_samples.tar"

# With pitch shift
load_trf_path = "/home/ccreaim/ccreaim_models/bank-classifier/samples_final.pt"

context_length = 8
sample_rate = 44100

In [4]:
# Load transformer 

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

trf_checkpoint = torch.load(load_trf_path, map_location="cpu")
trf_state_dict = trf_checkpoint["model_state_dict"]
hyper_cfg_schema = OmegaConf.structured(cfg_classes.HyperConfig)
trf_conf = OmegaConf.create(trf_checkpoint["hyper_config"])
del trf_conf["num_seq"]
del trf_conf["seq_cat"]
del trf_conf["pre_trained_ae_path"]
del trf_conf["pre_trained_vqvae_path"]
trf_hyper_cfg = OmegaConf.merge(hyper_cfg_schema, trf_conf)
trf_hyper_cfg.transformer.dropout = 0.0 # Since we're using trf.train() for no cacheing, set this to avoid dropout in inference
get_trf = operate.get_model_init_function(trf_hyper_cfg)
trf = get_trf()
trf.load_state_dict(trf_state_dict)
trf = trf.to(device)
trf

CachedDecoderOnly(
  (positional_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (transformer_decoder): CachedTransformerEncoder(
    (layers): ModuleList(
      (0): CachedTransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=612, out_features=612, bias=True)
        )
        (linear1): Linear(in_features=612, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=612, bias=True)
        (norm1): LayerNorm((612,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((612,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.0, inplace=False)
        (dropout2): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (trf_out_to_tokens): Linear(in_features=612, out_features=8192, bias=True)
)

In [5]:
# Prepare datasets
tmp_data_root_aug = dataset.prepare_dataset_on_tmp(data_tar)
data_aug = dataset.BankTransformerDataset(tmp_data_root_aug)
tmp_data_root_non_aug = dataset.prepare_dataset_on_tmp(data_tar_non_aug)
data_non_aug = dataset.BankTransformerDataset(tmp_data_root_non_aug)
# bank = dataset.Bank(tmp_data_root, context_length)

/tmp/maestro_bank_training_aug 2464
/tmp/maestro_bank_training 1024


In [6]:
tmp_data_root_samples = dataset.prepare_dataset_on_tmp(samples_tar)
data_samples = dataset.AudioDataset(tmp_data_root_samples, sample_rate)
print(len(data_samples))

/tmp/maestro_bank_samples 8192
8192


In [7]:
# list of indices with "mistakes" = 100, 250, 300, 350, 450
features, real_indices, _ = data_aug[100]

In [8]:
with torch.inference_mode():
    features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features_[:, 0:1, :], device=features_.device),
                features_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()

    pred, _ = trf(tgt, tgt_mask=tgt_mask)
    indices = pred.argmax(-1)
    print(indices)
    print(real_indices.long())
    probabilities = F.softmax(pred, dim=-1)
    print(probabilities.max(dim=2))
    print(probabilities[:,:,4075])

torch.Size([1, 8, 612])
attn dims: (batch_size,num_heads,target seqlen,source seqlen)
attn weights: torch.Size([1, 4, 8, 8]) tensor([[[[1.0000e+00, 4.0150e-10, 2.0487e-08, 3.0131e-08, 1.8721e-09,
           2.7595e-10, 9.6897e-12, 2.3302e-12],
          [1.6910e-20, 2.2257e-16, 5.8079e-13, 1.0433e-09, 3.2868e-06,
           2.5859e-03, 5.1461e-02, 9.4595e-01],
          [3.8133e-22, 1.3642e-18, 1.1217e-14, 8.0637e-11, 9.7538e-07,
           3.6868e-03, 9.8617e-02, 8.9770e-01],
          [9.2088e-27, 3.8826e-23, 2.5508e-19, 8.0091e-15, 1.6193e-09,
           1.7409e-04, 4.6521e-02, 9.5330e-01],
          [8.0020e-29, 3.5801e-26, 4.7207e-23, 1.0773e-18, 1.0393e-12,
           2.0416e-06, 9.3570e-03, 9.9064e-01],
          [3.5927e-26, 1.1284e-25, 1.3520e-23, 3.5185e-20, 1.1086e-14,
           6.9440e-08, 2.1147e-03, 9.9789e-01],
          [3.2736e-23, 8.3860e-24, 1.4862e-22, 3.7783e-20, 1.8652e-15,
           6.6461e-09, 5.8207e-04, 9.9942e-01],
          [1.1534e-20, 1.7319e-21, 1.5291e

In [9]:
def play_prediction(dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    audio_tools.play_audio(sample, sample_rate)
        
    
play_prediction(data_samples, indices.int().squeeze()[1:], sample_rate)

tensor([8071, 8071, 8071, 8071, 8071, 8071, 8071], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_107_sample7.wav
/tmp/maestro_bank_samples/ind_08071_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_co

In [10]:
def save_prediction(path, dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    torchaudio.save(path, sample, sample_rate)

In [11]:
# THIS DOESN'T WORK WITH THE AUGMENTED DATASET
sample_id = 450
features1, real_indices1, _ = data_aug[sample_id]
features2, real_indices2, _ = data_aug[sample_id+1]

In [12]:
# Transformer has seen 
# [80, 81, 82, 83, 84, 85, 86, 87]
# and 
# [88, 89, 90, 91, 92, 93, 94, 95]
# But let's show it 
# [84, 85, 86, 87, 88, 89, 90, 91]

# Result: The transformer can't even predict 88 from 85,86,87
# so it hasn't really even learned anything :D 
with torch.inference_mode():
    features3 = torch.cat(
                  (
                      features1[context_length//2:context_length, :],
                      features2[0:context_length//2, :],
                  ),
                  dim=0,
        )
    print(features3.size())
    features3_ = features3.unsqueeze(0)
    features3_ = features3_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features3_[:, 0:1, :], device=features3_.device),
                features3_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()
    pred, _ = trf(tgt, tgt_mask=tgt_mask)
    print(pred.shape)
    indices = pred.argmax(-1)
    print(real_indices1)
    print(real_indices2)
    print(indices)

torch.Size([8, 612])
torch.Size([1, 8, 612])
attn dims: (batch_size,num_heads,target seqlen,source seqlen)
attn weights: torch.Size([1, 4, 8, 8]) tensor([[[[1.0000e+00, 2.1559e-09, 7.5646e-09, 2.9375e-08, 6.0035e-09,
           1.0816e-10, 1.7474e-11, 1.0235e-12],
          [6.8400e-20, 6.4645e-16, 9.0902e-13, 5.6503e-09, 3.0817e-05,
           9.9954e-04, 1.6446e-02, 9.8252e-01],
          [2.0597e-24, 5.6248e-20, 2.5702e-16, 9.2769e-12, 8.1130e-07,
           2.0529e-04, 1.1692e-02, 9.8810e-01],
          [5.7273e-27, 9.2032e-24, 2.9878e-20, 3.5265e-15, 2.5310e-09,
           8.3577e-06, 3.9968e-03, 9.9599e-01],
          [7.3617e-29, 2.4800e-27, 1.5033e-24, 1.8884e-19, 4.5383e-13,
           2.2502e-08, 3.1849e-04, 9.9968e-01],
          [2.0529e-29, 3.7110e-28, 3.1356e-26, 1.2454e-21, 1.5127e-15,
           3.0476e-10, 3.7886e-05, 9.9996e-01],
          [3.7252e-24, 4.0090e-25, 2.2897e-24, 9.3248e-21, 7.2233e-16,
           5.9800e-11, 1.7455e-05, 9.9998e-01],
          [5.4133e-23

In [13]:
# This hacky function works only with the non-augmented dataset, can be used for POC generation
def get_features_from_index(ind, dataset, context_length):
    context, indices, a = dataset[ind//context_length]
    print(ind, indices, ind % context_length)
    assert ind == indices[ind % context_length]
    return context[ind % context_length]
    

In [14]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [15]:
def create_feature_vec_from_clip(x, Fs, frame_size, frame_step, deltas):
    F, f_names = ShortTermFeatures.feature_extraction(
        x, Fs, frame_size * Fs, frame_step * Fs, deltas=deltas
    )
    features = torch.tensor(F).view(-1)
    return features

In [16]:
# Returns both a scipy.io wavfile output and torchaudio.load output
def return_clip_from_path(path, starting_point_sec, clip_length_sec):
    torch_data, samp_rate = torchaudio.load(path)
    [Fs, x] = audioBasicIO.read_audio_file(path)
    
    if len(x.shape) != 1:
        x = numpy.average(x, axis=1).astype(x.dtype)
    if len(torch_data.shape) != 1:
        torch_data = torch.mean(torch_data, dim=0)

    torch_clip = torch_data[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    clip = x[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    return clip, torch_clip.unsqueeze(0), samp_rate

In [21]:
from pyAudioAnalysis import ShortTermFeatures, audioBasicIO
import numpy
frame_size = 0.2
frame_step = 0.1
deltas = True
latent_dim = 612 # number of features returned per 1 second clip by ShortTermFeatures

# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Schubert7-9_MID--AUDIO_11_R2_2018_wav.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital17-19_MID--AUDIO_17_R1_2018_wav--2.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Recital1-3_MID--AUDIO_01_R1_2018_wav--1.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2017/MIDI-Unprocessed_083_PIANO083_MID--AUDIO-split_07-09-17_Piano-e_2_-06_wav--4.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Chamber2_MID--AUDIO_09_R3_2018_wav--1.wav"
# external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/samples/99-12-2009.wav"

# test_fn = "ethereal-teleport"
# test_fn = "across"
# test_fn = "quartet"
test_fn = "piano_man"
# test_fn = "train-upon-us"
# test_fn = "waves-of-hawaii"
external_sample_path = "/home/ccreaim/ccreaim_data/sounds/test_files/{}.wav".format(test_fn)

# Generation
top_p = 0
top_k = 5
temperature = 1.0

external = True
start_sec = 0

sample_id = 1385

n = context_length // 2
torch_sample = None
with torch.inference_mode():
    # Pick half of context length, generate rest
    if external:
        features_ = torch.zeros(1, n, latent_dim)
        indices = torch.zeros(n)
        for i in range(n):
            clip, torch_clip, samp_rate = return_clip_from_path(external_sample_path, start_sec+i, 1)
            print(clip.shape, torch_clip.shape, samp_rate)
            feature_vec = create_feature_vec_from_clip(clip, samp_rate, frame_size, frame_step, deltas)
            print(feature_vec.shape)
            features_[:,i,:] = feature_vec
            if torch_sample == None:
                torch_sample = torch_clip
            else:
                torch_sample = torch.cat((torch_sample, torch_clip),dim=1)
    else:
        features, indices, _ = data_aug[sample_id]
        features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    
    features_half = features_[:, :n, :]
    tgt = torch.cat(
            (
                torch.zeros_like(features_half[:, 0:1, :], device=features_.device),
                features_half[:, :-1, :],
            ),
            dim=1,
        )
    print("size", tgt.size())
    
    # No cacheing
    trf = trf.train()
    indices = indices[0:(n-1)]
    print(indices)
    for i in range(n):
        trf_out, attn = trf(tgt)
        print("hello", trf_out.shape)
        trf_out_filtered = top_k_top_p_filtering(trf_out[:,-1,:].squeeze()/temperature, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(trf_out_filtered)
        print("max_prob:", max(probabilities))
        emb_ind = torch.multinomial(probabilities,1)
        print(emb_ind)
        # print("optional_emb_ind:", trf_out[:,-1,:].argmax(-1))
        #emb_ind = trf_out[:,-1,:].argmax(-1)
        indices = torch.cat((indices, emb_ind), dim=0)
        next_feature = get_features_from_index(emb_ind.item(), data_non_aug, context_length)
        print("features:", next_feature.shape)
        tgt = torch.cat(
                (
                    tgt,
                    next_feature.unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        print("new_size", tgt.size())
        probabilities = F.softmax(trf_out, dim=-1)
        print(probabilities.max(dim=2)[0][-1][-1])
        print(probabilities.max(dim=2)[1][-1][-1])
    print(indices)

(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
(48000,) torch.Size([1, 48000]) 48000
torch.Size([612])
size torch.Size([1, 4, 612])
tensor([0., 0., 0.])
torch.Size([1, 4, 612])
attn dims: (batch_size,num_heads,target seqlen,source seqlen)
attn weights: torch.Size([1, 4, 4, 4]) tensor([[[[1.0000e+00, 1.0677e-09, 7.8223e-08, 3.9288e-08],
          [1.4013e-45, 1.0000e+00, 1.3560e-39, 6.9397e-37],
          [4.8400e-12, 3.9257e-06, 4.4395e-05, 9.9995e-01],
          [1.2410e-12, 3.7085e-08, 7.5053e-06, 9.9999e-01]],

         [[1.6216e-05, 1.7817e-08, 2.7275e-04, 9.9971e-01],
          [5.4272e-12, 1.4682e-06, 8.3255e-08, 1.0000e+00],
          [2.3638e-15, 7.8922e-20, 9.4022e-08, 1.0000e+00],
          [7.3439e-16, 2.6818e-18, 5.6383e-08, 1.0000e+00]],

         [[1.0000e+00, 0.0000e+00, 2.3053e-37, 2.1500e-40],
          [1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

  probabilities = F.softmax(trf_out_filtered)


In [18]:
play_prediction(data_samples, indices[3:].int(), sample_rate, torch_sample)

tensor([7500, 7501, 7502, 7503], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_07500_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample4.wav
/tmp/maestro_bank_samples/ind_07501_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample5.wav
/tmp/maestro_bank_samples/ind_07502_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample6.wav
/tmp/maestro_bank_samples/ind_07503_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample7.wav


In [19]:
import torchaudio
import os
import json
save = True
#save = False
save_path_root = "{}/outputs".format(os.getcwd())
gen = 1
if save:
    if not os.path.exists(save_path_root + "/" + test_fn):
        os.makedirs(save_path_root + "/" + test_fn)
    save_prediction("{}/{}/{}-gen{}-sec{}.wav".format(save_path_root, test_fn, test_fn, gen, start_sec), data_samples, indices[3:].int(), sample_rate, torch_sample)
    torch.save(attn, '{}/{}/{}-gen{}-sec{}-attn_weights.pt'.format(save_path_root, test_fn, test_fn, gen, start_sec))
    with open('{}/{}/{}-gen{}-sec{}-attn_weights.json'.format(save_path_root, test_fn, test_fn, gen, start_sec), 'w') as f:
        json.dump(attn.cpu().detach().numpy().tolist(), f)

tensor([7500, 7501, 7502, 7503], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_07500_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample4.wav
/tmp/maestro_bank_samples/ind_07501_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample5.wav
/tmp/maestro_bank_samples/ind_07502_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample6.wav
/tmp/maestro_bank_samples/ind_07503_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_36_sample7.wav


torch.Size([1, 4, 7, 7])