In [None]:
import random
import os
import numpy as np
from datasets import load_dataset
import torch
import note_seq
import sys
sys.path.append("..")
from source.tokenizer import Tokenizer
from source.transformer import Transformer
from source.noteseqhelpers import token_sequence_to_note_sequence

In [None]:
dataset_id = "TristanBehrens/js-fakes-4bars"
model_path = "../models/transformer_variational_20230930-1235"
assert os.path.exists(model_path), "Model path does not exist."

In [None]:
# Load the dataset.
split_dataset = load_dataset(dataset_id)

def random_sample():
    # Select a random sample.
    random_index = random.randint(0, len(split_dataset["test"]) - 1)
    return split_dataset["test"][random_index]["text"]

random_sample()

In [None]:
# Load the tokenizer.
tokenizer = Tokenizer.from_config_file(os.path.join(model_path, "tokenizer.json"))
print(tokenizer.vocabulary)

In [None]:
# Load the model.
model = Transformer.load(os.path.join(model_path, "ckpt.pt"))

In [None]:
# Sample from a normal distribution.
bottleneck_shape = model.get_bottleneck_shape()
bottleneck_z = torch.randn(1, *bottleneck_shape)
print(f"bottleneck shape: {bottleneck_shape}, numbers {np.prod(bottleneck_shape)}")

# Create the start sequence.
start_sequence = "PIECE_START"
start_sequence_indices = tokenizer.encode_sequence(start_sequence)
print(f"Start sequence: {start_sequence}")
print(f"Start sequence indices: {start_sequence_indices}")

result_ids = model.generate(
    idx=start_sequence_indices,
    max_new_tokens=512,
    end_token_id=tokenizer.encode_token("TRACK_END"),
    bottleneck_condition=bottleneck_z, 
    temperature=0.2,
    top_k=None
)[0]
print(f"Result ids: {result_ids}")

result_sequence = tokenizer.decode_sequence(result_ids, join=True)
print(f"Result sequence: {result_sequence}")

note_sequence = token_sequence_to_note_sequence(result_sequence)
note_seq.plot_sequence(note_sequence)
note_seq.play_sequence(note_sequence, synth=note_seq.fluidsynth)