In [1]:
import sys
sys.path.append("..")
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from ccreaim.model import operate, ae
from ccreaim.utils import dataset, cfg_classes, audio_tools, util

In [299]:
data_tar = "/scratch/other/sopi/CCREAIM/datasets/maestro_bank_training_aug.tar"
data_tar_non_aug = "/scratch/other/sopi/CCREAIM/datasets/maestro_bank_training.tar"
samples_tar = "/scratch/other/sopi/CCREAIM/datasets/maestro_bank_samples.tar"

#load_trf_path = "/scratch/other/sopi/CCREAIM/logs/2023-03-10/bank-classifier_train_14-28-06/0/checkpoints/bank-classifier_seqlen-8_bs-32_lr-0.0001_seed-0_ep-1000.pt"
#load_trf_path = "/scratch/other/sopi/CCREAIM/logs/2023-03-15/bank-classifier_train_13-06-32/0/checkpoints/bank-classifier_seqlen-8_bs-4_lr-0.0001_seed-0_ep-030.pt"
#load_trf_path = "/scratch/other/sopi/CCREAIM/logs/2023-03-15/bank-classifier_train_15-26-42/0/checkpoints/bank-classifier_seqlen-8_bs-4_lr-0.0001_seed-0_ep-060.pt"
load_trf_path = "/scratch/other/sopi/CCREAIM/logs/2023-03-15/bank-classifier_train_15-26-42/1/checkpoints/bank-classifier_seqlen-8_bs-4_lr-0.0001_seed-0_final.pt"
context_length = 8
sample_rate = 44100

In [300]:
# Load transformer 

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

trf_checkpoint = torch.load(load_trf_path, map_location="cpu")
trf_state_dict = trf_checkpoint["model_state_dict"]
hyper_cfg_schema = OmegaConf.structured(cfg_classes.HyperConfig)
trf_conf = OmegaConf.create(trf_checkpoint["hyper_config"])
trf_hyper_cfg = OmegaConf.merge(hyper_cfg_schema, trf_conf)
trf_hyper_cfg.transformer.dropout = 0.0 # Since we're using trf.train() for no cacheing, set this to avoid dropout in inference
get_trf = operate.get_model_init_function(trf_hyper_cfg)
trf = get_trf()
trf.load_state_dict(trf_state_dict)
trf = trf.to(device)

In [207]:
# Prepare datasets
tmp_data_root_aug = dataset.prepare_dataset_on_tmp(data_tar)
data_aug = dataset.BankTransformerDataset(tmp_data_root_aug)

tmp_data_root_samples = dataset.prepare_dataset_on_tmp(samples_tar)
data_samples = dataset.AudioDataset(tmp_data_root_samples, sample_rate)

tmp_data_root_non_aug = dataset.prepare_dataset_on_tmp(data_tar_non_aug)
data_non_aug = dataset.BankTransformerDataset(tmp_data_root_non_aug)
#bank = dataset.Bank(tmp_data_root, context_length)

In [208]:
# list of indices with "mistakes" = 100, 250, 300, 350, 450
features, real_indices, _ = data_aug[100]

In [301]:
with torch.inference_mode():
    features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features_[:, 0:1, :], device=features_.device),
                features_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()

    pred = trf(tgt, tgt_mask=tgt_mask)
    indices = pred.argmax(-1)
    print(indices)
    print(real_indices.long())
    probabilities = F.softmax(pred, dim=-1)
    print(probabilities.max(dim=2))
    print(probabilities[:,:,4075])

torch.Size([1, 8, 612])
tensor([[4048, 1785, 7393, 1699, 4236, 1701, 1702, 1703]])
tensor([8064, 8065, 8066, 8067, 8068, 8069, 8070, 8071])
torch.return_types.max(
values=tensor([[0.0059, 0.0303, 0.0381, 0.0648, 0.0645, 0.2002, 0.0259, 0.0120]]),
indices=tensor([[4048, 1785, 7393, 1699, 4236, 1701, 1702, 1703]]))
tensor([[5.8272e-08, 6.9155e-10, 7.5450e-07, 3.0644e-03, 1.0041e-05, 1.2254e-08,
         1.0613e-10, 7.2114e-08]])


In [313]:
def play_prediction(dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    audio_tools.play_audio(sample, sample_rate)
        
    
play_prediction(data_samples, indices.int().squeeze()[1:], sample_rate)

tensor([   0,    0, 4355, 4356, 7421, 1414], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_00000_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_0_sample0.wav
/tmp/maestro_bank_samples/ind_00000_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_0_sample0.wav
/tmp/maestro_bank_samples/ind_04355_MIDI-Unprocessed_02_R1_2006_01-04_ORIG_MID--AUDIO_02_R1_2006_02_Track02_wav_context_166_sample3.wav
/tmp/maestro_bank_samples/ind_04356_MIDI-Unprocessed_02_R1_2006_01-04_ORIG_MID--AUDIO_02_R1_2006_02_Track02_wav_context_166_sample4.wav
/tmp/maestro_bank_samples/ind_07421_MIDI-Unprocessed_06_R1_2006_01-04_ORIG_MID--AUDIO_06_R1_2006_04_Track04_wav_context_26_sample5.wav
/tmp/maestro_bank_samples/ind_01414_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_176_sample6.wav


In [312]:
def save_prediction(path, dataset, indices, sample_rate, sample=None):
    print(indices)
    for ind in indices:
        if sample == None:
            sample, name = dataset[ind.item()]
        else:
            next_sample, name = dataset[ind.item()]
            sample = torch.cat((sample, next_sample),dim=1)
        print(name)
    torchaudio.save(path, sample, sample_rate)

In [90]:
# THIS DOESN'T WORK WITH THE AUGMENTED DATASET
sample_id = 450
features1, real_indices1, _ = data_aug[sample_id]
features2, real_indices2, _ = data_aug[sample_id+1]

In [91]:
# Transformer has seen 
# [80, 81, 82, 83, 84, 85, 86, 87]
# and 
# [88, 89, 90, 91, 92, 93, 94, 95]
# But let's show it 
# [84, 85, 86, 87, 88, 89, 90, 91]

# Result: The transformer can't even predict 88 from 85,86,87
# so it hasn't really even learned anything :D 
with torch.inference_mode():
    features3 = torch.cat(
                  (
                      features1[context_length//2:context_length, :],
                      features2[0:context_length//2, :],
                  ),
                  dim=0,
        )
    print(features3.size())
    features3_ = features3.unsqueeze(0)
    features3_ = features3_.to(device)
    tgt = torch.cat(
            (
                torch.zeros_like(features3_[:, 0:1, :], device=features3_.device),
                features3_[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())

    tgt_mask = util.get_tgt_mask(tgt.size(1))
    tgt_mask = tgt_mask.to(tgt.device)

    trf = trf.train()

    pred = trf(tgt, tgt_mask=tgt_mask)

    indices = pred.argmax(-1)
    print(real_indices1)
    print(real_indices2)
    print(indices)

torch.Size([8, 612])
torch.Size([1, 8, 612])
tensor([4112., 4113., 4114., 4115., 4116., 4117., 4118., 4119.])
tensor([4112., 4113., 4114., 4115., 4116., 4117., 4118., 4119.])
tensor([[7960, 7689, 4306, 7859, 7444, 4109, 4198, 4207]])


In [94]:
# This hacky function works only with the non-augmented dataset, can be used for POC generation
def get_features_from_index(ind, dataset, context_length):
    context, indices, _ = dataset[ind//context_length]
    assert ind == indices[ind % context_length]
    return context[ind % context_length]
    

In [132]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [304]:
def create_feature_vec_from_clip(x, Fs, frame_size, frame_step, deltas):
    F, f_names = ShortTermFeatures.feature_extraction(
        x, Fs, frame_size * Fs, frame_step * Fs, deltas=deltas
    )
    features = torch.tensor(F).view(-1)
    return features

In [319]:
# Returns both a scipy.io wavfile output and torchaudio.load output
def return_clip_from_path(path, starting_point_sec, clip_length_sec):
    torch_data, samp_rate = torchaudio.load(path)
    [Fs, x] = audioBasicIO.read_audio_file(path)
    
    if len(x.shape) != 1:
        x = numpy.average(x, axis=1).astype(x.dtype)
    if len(torch_data.shape) != 1:
        torch_data = torch.mean(torch_data, dim=0)

    torch_clip = torch_data[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    clip = x[samp_rate*starting_point_sec : samp_rate*(starting_point_sec + clip_length_sec)]
    return clip, torch_clip.unsqueeze(0), samp_rate

In [345]:
from pyAudioAnalysis import ShortTermFeatures, audioBasicIO
import numpy
frame_size = 0.2
frame_step = 0.1
deltas = True
latent_dim = 612 # number of features returned per 1 second clip by ShortTermFeatures

external_sample_path = "/scratch/other/sopi/CCREAIM/datasets/maestro-v3.0.0/2018/MIDI-Unprocessed_Schubert7-9_MID--AUDIO_11_R2_2018_wav.wav"

# Generation
top_p = 0
top_k = 5
temperature = 1.0

external = True
start_sec = 190

sample_id = 1385

n = context_length // 2
torch_sample = None
with torch.inference_mode():
    # Pick half of context length, generate rest
    if external:
        features_ = torch.zeros(1, n, latent_dim)
        indices = torch.zeros(n)
        for i in range(n):
            clip, torch_clip, samp_rate = return_clip_from_path(external_sample_path, start_sec+i, 1)
            feature_vec = create_feature_vec_from_clip(clip, samp_rate, frame_size, frame_step, deltas)
            features_[:,i,:] = feature_vec
            if torch_sample == None:
                torch_sample = torch_clip
            else:
                torch_sample = torch.cat((torch_sample, torch_clip),dim=1)
    else:
        features, indices, _ = data_aug[sample_id]
        features_ = features.unsqueeze(0)
    features_ = features_.to(device)
    
    features_half = features_[:, :n, :]
    tgt = torch.cat(
            (
                torch.zeros_like(features_half[:, 0:1, :], device=features_.device),
                features_half[:, :-1, :],
            ),
            dim=1,
        )
    print(tgt.size())
    
    # No cacheing
    trf = trf.train()
    indices = indices[0:(n-1)]
    print(indices)
    for i in range(n):
        trf_out = trf(tgt)
        trf_out_filtered = top_k_top_p_filtering(trf_out[:,-1,:].squeeze()/temperature, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(trf_out_filtered)
        emb_ind = torch.multinomial(probabilities,1)
        
        #emb_ind = trf_out[:,-1,:].argmax(-1)
        indices = torch.cat((indices, emb_ind), dim=0)
        next_feature = get_features_from_index(emb_ind.item(), data_non_aug, context_length)
        tgt = torch.cat(
                (
                    tgt,
                    next_feature.unsqueeze(0).unsqueeze(0)
                ),
                dim=1
            )
        probabilities = F.softmax(trf_out, dim=-1)
        print(probabilities.max(dim=2)[0][-1][-1])
        print(probabilities.max(dim=2)[1][-1][-1])
    print(indices)

torch.Size([1, 4, 612])
tensor([0., 0., 0.])
tensor(0.0402)
tensor(8003)
tensor(0.8507)
tensor(1516)
tensor(0.3625)
tensor(1517)
tensor(0.1356)
tensor(1518)
tensor([   0.,    0.,    0., 1515., 1516., 1517., 1518.])


  probabilities = F.softmax(trf_out_filtered)


In [346]:
play_prediction(data_samples, indices[3:].int(), sample_rate, torch_sample)

tensor([1515, 1516, 1517, 1518], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_01515_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample3.wav
/tmp/maestro_bank_samples/ind_01516_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample4.wav
/tmp/maestro_bank_samples/ind_01517_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample5.wav
/tmp/maestro_bank_samples/ind_01518_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample6.wav


In [347]:
import torchaudio
save = True
#save = False
if save:
    save_prediction("bank_samples/sample6.wav", data_samples, indices[3:].int(), sample_rate, torch_sample)

tensor([1515, 1516, 1517, 1518], dtype=torch.int32)
/tmp/maestro_bank_samples/ind_01515_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample3.wav
/tmp/maestro_bank_samples/ind_01516_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample4.wav
/tmp/maestro_bank_samples/ind_01517_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample5.wav
/tmp/maestro_bank_samples/ind_01518_MIDI-Unprocessed_01_R1_2006_01-09_ORIG_MID--AUDIO_01_R1_2006_01_Track01_wav_context_189_sample6.wav
