## Import Needed Libraries


In [1]:
import re
import io
import os
import sys
import math
import requests
import numpy as np
 
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from pathlib import Path
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split

In [2]:
device = torch.device('cuda:0')

In [3]:
torch.manual_seed(41648)

<torch._C.Generator at 0x7f7df7aae6f0>

## MIDITOK as MIDI Encoder

In [4]:
from miditok import REMIPlus, TokenizerConfig, REMI
from miditoolkit import MidiFile

In [5]:
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = True
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = True
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200)  # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
    "pitch_range": PITCH_RANGE,
    "beat_res": BEAT_RES,
    "num_velocities": NUM_VELOCITIES,
    "special_tokens": SPECIAL_TOKENS,
    "use_chords": USE_CHORDS,
    "use_rests": USE_RESTS,
    "use_tempos": USE_TEMPOS,
    "use_time_signatures": USE_TIME_SIGNATURE,
    "use_programs": USE_PROGRAMS,
    "num_tempos": NUM_TEMPOS,
    "tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

In [6]:
midi_tokenizer = REMI(config)

In [7]:
midi = MidiFile("../data/midi/Maestro/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi")
tokens = midi_tokenizer(midi)

In [8]:
midi_vocab_len = len(midi_tokenizer.vocab)
print(f"midi has {midi_vocab_len} vocabularies")

midi has 344 vocabularies


In [9]:
# one_hot_midi_tokens = F.one_hot(torch.Tensor(token_ids).long(), num_classes=midi_vocab_len)
# print(one_hot_midi_tokens.shape) # [seq_len, num_classes]

## Text LLM

In [10]:
# from transformers import LlamaTokenizer, LlamaForCausalLM
# import transformers
# import torch

# llm = "meta-llama/Llama-2-7b-hf"
# model = LlamaForCausalLM.from_pretrained(llm)
# tokenizer = LlamaTokenizer.from_pretrained(llm)

In [11]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2"
model = GPT2LMHeadModel.from_pretrained(llm)
llm_tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [12]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [13]:
for param in model.parameters():
    param.requires_grad = False

In [14]:
# embeddings = embeddings.to(device)

In [15]:
print("gpt2 feature dim length:", llm_feature_dim)
print("gpt2 vocabulary length:", llm_vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

gpt2 feature dim length: 768
gpt2 vocabulary length: 50257
gpt2 embedding shape: torch.Size([50257, 768])


## Mapper Network

map some modality to text token's feature dimension

In [16]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [17]:
# ONLY FOR BPE
# midi_vocab_len = 10000

In [18]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim, bias=False)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [19]:
# Create the mapper
# mapper maps vocabulary_size of target modality to feature_dimension size of llm
# mapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)
mapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)

In [20]:
mapper

TokenMapper(
  (mapper): Linear(in_features=344, out_features=768, bias=False)
)

In [21]:
reverseMapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)

In [22]:
reverseMapper

TokenMapper(
  (mapper): Linear(in_features=344, out_features=768, bias=False)
)

## Prompt Network
give some prompts for training

In [23]:
prompt_len = 0

In [24]:
class Prompt(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.model = nn.Linear(input_dim, output_dim, bias=False)
        self.model.to(device)

    def forward(self, one_hot_token):
        return self.model(one_hot_token)

In [25]:
if prompt_len!=0:
    prompt = Prompt(prompt_len, llm_feature_dim, device=device)
    prompt_inputs = F.one_hot(torch.arange(prompt_len), num_classes=prompt_len).float().to(device)

## Generate Ground Truth

In [26]:
def generate_next_token_predictions(token_sequences):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=token_sequences, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [27]:
def generate_next_token_predictions_withfv(token_fv):
    
    # Get model predictions
    outputs = model(inputs_embeds=token_fv, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [28]:
def translate(batch_feature_vectors, embeddings, temperature=1.0):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape

    # Normalize the embedding matrix
    # embedding_matrix_norm = F.normalize(embeddings, dim=1)

    # batch_feature_vector_norm = F.normalize(batch_feature_vectors, dim=2)
    # cosine_similarities = torch.matmul(batch_feature_vector_norm, embedding_matrix_norm.T)
    cosine_similarities = torch.matmul(batch_feature_vectors, embeddings.T)
    sfmx = torch.softmax(cosine_similarities/temperature, dim=2)
    closest_tokens = torch.argmax(sfmx, dim=2)
    
    mm = torch.matmul(sfmx, embeddings)
    # closest_tokens1 = torch.zeros((batch_size, seq_len), dtype=torch.float).to(device)
    # mm1 = torch.zeros((batch_size, seq_len, embeddings.size(1)), dtype=torch.float).to(device)
    
    # for i in range(batch_size):
    #     # Normalize the feature vectors for the i-th sample in the batch
    #     feature_vectors_norm = F.normalize(batch_feature_vectors[i], dim=1)

    #     # Compute cosine similarity for the entire sequence at once
    #     cosine_similarities = torch.matmul(feature_vectors_norm, embedding_matrix_norm.T)

    #     # Find the token with the highest similarity for each feature vector
    #     closest_tokens1[i] = torch.argmax(cosine_similarities, dim=1)

    #     mm1[i] = torch.matmul(torch.softmax(cosine_similarities / temperature, dim=1), embeddings)


    return mm, cosine_similarities, closest_tokens

# Get Midi Dataset

In [29]:
from miditok.pytorch_data.datasets import DatasetTok

In [30]:
# dataset_path = Path("../data/midi/MMD_MIDI")
# tokens_path = Path("../data/midi/MMD_MIDI_no_bpe")
# pattern = re.compile(r"/\._")

# # Use glob to find all .mid files and filter out the undesired ones
# midi_files = [file for file in dataset_path.glob("**/*.mid") if not pattern.search(str(file))]

In [31]:
midi_paths = list(Path('../data/midi/Maestro').glob('**/*.mid')) + list(Path('../data/midi/Maestro').glob('**/*.midi'))

In [32]:
midi_paths[0]

PosixPath('../data/midi/Maestro/2009/MIDI-Unprocessed_15_R1_2009_03-06_ORIG_MID--AUDIO_15_R1_2009_15_R1_2009_06_WAV.midi')

In [33]:
# midi_tokenizer.tokenize_midi_dataset(midi_paths, tokens_path)

In [34]:
# tokens_path = Path('../data/midi/Maestro_tokens_no_bpe')
# tokens_bpe_path = Path('../data/midi/Maestro_tokens_bpe')
# tokens_bpe_path.mkdir(exist_ok=True, parents=True)
# midi_tokenizer.learn_bpe(
#     vocab_size=10000,
#     tokens_paths=list(tokens_path.glob("**/*.json")),
#     start_from_empty_voc=False,
# )

In [35]:
# midi_tokenizer.save_params("tokenizer_bpe.conf")
# midi_tokenizer.apply_bpe_to_dataset(
#     tokens_path,
#     tokens_bpe_path,
# )

In [36]:
# tokens_paths

In [37]:
tokens_paths = list(Path('../data/midi/Maestro_tokens_no_bpe').glob("**/*.json"))
# tokens_paths = list(Path('../data/midi/Maestro_tokens_bpe').glob("**/*.json"))
midi_dataset = DatasetTok(
    tokens_paths, max_seq_len=128, min_seq_len=128
)

Loading data: ../data/midi/Maestro_tokens_no_bpe/2009: 100%|███████████████████████| 1276/1276 [00:03<00:00, 323.94it/s]


In [38]:
train_size = int(len(midi_dataset)*0.8)
val_size = len(midi_dataset) - train_size

train_dataset, validation_dataset = random_split(midi_dataset, [train_size, val_size])

In [39]:
batch_size = 10
midi_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

## REINFORCE Loss Function

In [40]:
def Reinforce_Loss(logits, translated, loss, gamma=0.9, alpha=1, temperature=1):
    """
    Calculate the REINFORCE loss for sequence prediction.

    :param logits: Logits from the model, shape [batch_size, seq_len, vocab_size].
    :param targets: Ground truth sequence, shape [batch_size, seq_len].
    :param rewards: Reward for each step in the sequence, shape [batch_size, seq_len].
    :param gamma: Discount factor for future rewards.
    :return: The REINFORCE loss (to be maximized).
    """
    batch_size, seq_len, _ = logits.shape
    translated = translated.to(torch.int64)
    # shape = [batch_size, seq_len, llm_vocab_len]
    log_probs = F.log_softmax(logits/temperature, dim=-1)
    log_probs_targets = log_probs.gather(2, translated.unsqueeze(2)).squeeze(2)
    
    # Create a discount matrix
    discount_matrix = torch.zeros((seq_len, seq_len)).to(device)

    # Fill the matrix according to the given pattern
    for i in range(seq_len):
        for j in range(i, seq_len):
            discount_matrix[i, j] = gamma ** (j - i)

    normalize_factor = discount_matrix.sum(dim=1)
    
    # Calculate discounted rewards
    discounted_loss = loss.unsqueeze(1) * discount_matrix
    cumulative_loss = discounted_loss.sum(dim=-1) / normalize_factor / alpha
    
    # Calculate loss
    total_loss = torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len

    return total_loss

## Train Model

In [41]:
# Hyper Parameters
learning_rate = 1e-5
epochs = 1
gamma = 0.1
temperature = 0.01
alpha = 1

In [42]:
experiment = "testing_nobpe_dual"
algo = "rl"
exp_type = "midi"
name = "remi"
experiment_name = f"{exp_type}/{algo}/{experiment}/{name}/{llm}/lr={learning_rate},gamma={gamma},temp={temperature},promptlen={prompt_len}"

In [43]:
experiment_name

'midi/testing_nobpe_dual/remi/gpt2/lr=1e-05,gamma=0.1,temp=0.01,promptlen=0'

In [44]:
from torch.utils.tensorboard import SummaryWriter

# # Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(log_dir = f'../runs/{experiment_name}')

In [45]:
# optimizer = optim.Adam(mapper.parameters(), lr=learning_rate)
# multioptimizer = optim.Adam(list(mapper.parameters()) + list(prompt.parameters()), lr=learning_rate)
optimizer = optim.Adam(list(mapper.parameters()) + list(reverseMapper.parameters()), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
rl_criterion = nn.CrossEntropyLoss(reduction='none')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [46]:
writer.add_hparams(
    {
        "lr": learning_rate,
        "data_type": exp_type,
        "algo": algo,
        "gamma": gamma,
        "temperature": temperature,
        "scale_rate": alpha,
        "prompt_length": prompt_len,
    },
    {},
    run_name=experiment_name
)

In [47]:
if 'base' in algo:
    print("Calculating with BASE LOSS\n")
elif 'rl' in algo:
    print("Calculating with REINFORCE LOSS\n")
    
for epoch in range(epochs):
    mapper.train()
    # mapper.eval()
    for i, midi in enumerate(midi_loader):
    
        optimizer.zero_grad()
        # multioptimizer.zero_grad()
        
        # midi_one_hot shape -> [batch_size, seq_len, feature_dim]
        ground_truth_tokens = midi["input_ids"].to(device)
        one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=midi_vocab_len).float().to(device)
        batch_len = one_hot_tokens.size(0)
        
        # break
        # Logits are to be compared with the next ground truth tokens
        ground_truth_tokens = ground_truth_tokens[:,1:]
        inputs_feature_vector = mapper(one_hot_tokens)
        
        # Add prompt to input
        # prompt_feature_vector = prompt(prompt_inputs)
        # prompt_feature_vector = prompt_feature_vector.unsqueeze(0).repeat(batch_len, 1, 1)
        # inputs_feature_vector = torch.cat((prompt_feature_vector, mapped_feature_vector), dim=1)

        # Map tokens and get ground truth from LLM
        translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
        # translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)

        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
        # final_layer_fv = generate_next_token_predictions(translated_text_tokens.long()).to(device)

        # Calculate Logits with mapper function
        # final_layer_fv = F.normalize(final_layer_fv, dim=-1)
        # mapper_embeds = F.normalize(mapper.mapper.weight, dim=0)
        logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
        # logits = logits[:,prompt_len:-1]
        logits = logits[:,:-1]
        logits_ = logits.reshape(-1, midi_vocab_len)
        ground_truth_tokens = ground_truth_tokens.reshape(-1)        
        ce_loss = criterion(logits_, ground_truth_tokens)
        ce_loss.backward()
        optimizer.step()
        if 'base' in algo:
            # ce_loss.backward()
            # optimizer.step()
            writer.add_scalar("training/cross_entropy_base", ce_loss.item(), epoch*len(midi_loader)+i)
            if i%50==0:
                print(f"Epoch {epoch+1}, Batch {i}, CE Loss: {ce_loss.mean().item()}")
        # RL Loss
        if 'rl' in algo:
            optimizer.zero_grad()
            # action_logits = torch.matmul(mapped_feature_vector, embeddings.T.detach())
            translated_feature_vector, translate_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)

            with torch.no_grad():
                final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
                logits = torch.matmul(final_layer_fv, reverseMapper.mapper.weight)
                logits = logits[:,prompt_len:-1]
                logits_ = logits.reshape(-1, midi_vocab_len)
                ce_loss = rl_criterion(logits_, ground_truth_tokens)
                ce_loss = ce_loss.reshape(-1, logits.size(1))
                
            rl_loss = Reinforce_Loss(translate_logits[:,1:], translated_text_tokens[:,1:].detach(), ce_loss, alpha=alpha, gamma=gamma, temperature=temperature)
            
            rl_loss.backward()
            optimizer.step()
            # writer.add_scalar("training_rl", rl_loss.item(), epoch*len(midi_loader)+i)
            # Log the losses
            writer.add_scalars(
                "training",
                {
                    "rl_loss": rl_loss.item(),
                    "cross_entropy_rl": ce_loss.mean().item(),
                },
                epoch * len(midi_loader) + i
            )
            
            if i % 50 == 0:
                print(f"Epoch {epoch+1}, Batch {i}, CE Loss: {ce_loss.mean().item()}, RL Loss: {rl_loss.item()}")

    scheduler.step()
    print(f"Epoch {epoch+1}/{epochs} completed.")
writer.close()

Calculating with REINFORCE LOSS



RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [55]:
Path(f"../models/{experiment_name}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"../models/{experiment_name}/model.pt")

In [56]:
writer.close()

## Compression Testing

In [None]:
mapper.load_state_dict(torch.load(f"../models/{experiment_name}/model.pt"))
mapper.eval()

criterion = nn.CrossEntropyLoss()

In [None]:
def get_compression(total_loss, tokens, compress_bits, file_size):
    token_len = len(tokens.ids)
    return (total_loss / token_len) / math.log2(midi_vocab_len)

In [None]:
total_loss = 0
for i, midi in enumerate(val_loader):

    # midi_one_hot shape -> [batch_size, seq_len, feature_dim]
    ground_truth_tokens = midi["input_ids"].to(device)
    one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=midi_vocab_len).float().to(device)
    batch_len = one_hot_tokens.size(0)
    
    # break
    # Logits are to be compared with the next ground truth tokens
    ground_truth_tokens = ground_truth_tokens[:,1:]
    inputs_feature_vector = mapper(one_hot_tokens)
    
    # Add prompt to input
    # prompt_feature_vector = prompt(prompt_inputs)
    # prompt_feature_vector = prompt_feature_vector.unsqueeze(0).repeat(batch_len, 1, 1)
    # inputs_feature_vector = torch.cat((prompt_feature_vector, mapped_feature_vector), dim=1)

    # Map tokens and get ground truth from LLM
    translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
    # translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)

    # Calculate Representation of Last Layer in LLM
    final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
    # final_layer_fv = generate_next_token_predictions(translated_text_tokens.long()).to(device)
          
    # Calculate Logits with mapper function
    # final_layer_fv = F.normalize(final_layer_fv, dim=-1)
    # mapper_embeds = F.normalize(mapper.mapper.weight, dim=0)
    logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
    # logits = logits[:,prompt_len:-1]
    logits = logits[:,:-1]
    logits_ = logits.reshape(-1, midi_vocab_len)
    ground_truth_tokens = ground_truth_tokens.reshape(-1)

    loss = criterion(logits_, ground_truth_tokens)

    total_loss += loss.item()

print("testing loss avg:", total_loss / val_size)

In [None]:
midi_path = "../data/midi/Maestro/2009/MIDI-Unprocessed_15_R1_2009_03-06_ORIG_MID--AUDIO_15_R1_2009_15_R1_2009_06_WAV.midi"

In [None]:
midi_tokenizer

In [None]:
midi = MidiFile(midi_path)
tokens = midi_tokenizer(midi)

In [None]:
compress_bits = total_loss / math.log(2)

In [None]:
tokens

In [None]:
len(tokens.ids)

In [None]:
length = 128

In [None]:
# midi_one_hot shape -> [batch_size, seq_len, feature_dim]
total_loss = 0
for i in range(0, len(tokens.ids), length//2):
    if i + length >= len(tokens.ids):
        break
        ground_truth_tokens = torch.tensor(tokens.ids[i:]).to(device)
    else:
        ground_truth_tokens = torch.tensor(tokens.ids[i:i+length]).to(device)
    ground_truth_tokens = ground_truth_tokens.unsqueeze(dim=0)
    
    if i == 0:  
        one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=midi_vocab_len).float().to(device)
    else:
        one_hot_tokens = F.one_hot(ground_truth_tokens[:,length//2:], num_classes=midi_vocab_len).float().to(device)
        one_hot_tokens = torch.cat((logits[:,length//2:], one_hot_tokens), dim=1)
    
    # break
    # Logits are to be compared with the next ground truth tokens
    ground_truth_tokens = ground_truth_tokens[:,1:]

    
    inputs_feature_vector = mapper(one_hot_tokens)

    # if i!=0:
    #     inputs_feature_vector = torch.cat((logits[:,length//2], inputs_feature_vector), dim=1)

    # print(inputs_feature_vector.shape)
    # input('stop')
    
    # Add prompt to input
    # prompt_feature_vector = prompt(prompt_inputs)
    # prompt_feature_vector = prompt_feature_vector.unsqueeze(0).repeat(batch_len, 1, 1)
    # inputs_feature_vector = torch.cat((prompt_feature_vector, inputs_feature_vector), dim=1)

    with torch.no_grad():
        # Map tokens and get ground truth from LLM
        translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
        # translated_feature_vector, translated_logits, translated_text_tokens = translate(inputs_feature_vector, embeddings.detach(), temperature=temperature)
        
        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions_withfv(translated_feature_vector)
        # final_layer_fv = generate_next_token_predictions(translated_text_tokens.long()).to(device)
              
        # Calculate Logits with mapper function
        # final_layer_fv = F.normalize(final_layer_fv, dim=-1)
        # mapper_embeds = F.normalize(mapper.mapper.weight, dim=0)
        logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
        # logits = logits[:,prompt_len:-1]
        logits_ = logits[:,:-1].reshape(-1, midi_vocab_len)
        ground_truth_tokens = ground_truth_tokens.reshape(-1)

    if i==0:
        loss = criterion(logits_, ground_truth_tokens)
        print(f"loss from tokens 0 to {i+length} = {loss.item()}")
        total_loss += loss.item()*length
    else:
        loss = criterion(logits_[length//2:], ground_truth_tokens[length//2:])
        print(f"loss from tokens {i+length//2} to {i+length} = {loss.item()}")
        total_loss += loss.item()*length/2

print("testing loss:", total_loss/len(tokens.ids))

In [None]:
get_compression(total_loss, tokens, compress_bits, file_size)

In [None]:
def get_midi_file_size(file_path):
    try:
        # Open the file in binary mode and read its contents
        with open(file_path, 'rb') as file:
            file_contents = file.read()
            # Return the size of the file
            return len(file_contents)
    except FileNotFoundError:
        return "File not found."

# Example usage
file_size = get_midi_file_size(midi_path)
# print(f"The MIDI file size is: {file_size} bytes")