## Import Needed Libraries


In [1]:
import re
import io
import os, sys
import requests

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from pathlib import Path
from torch.utils.data import DataLoader

In [None]:
device = torch.device('cuda:1')

## Dalle as Image Encoder
#### Download VQVAE from DALLE
| testing usage
```python
enc = encoder
dec = decoder
```

In [None]:
from PIL import Image
from dall_e import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

In [None]:
target_image_size = 256

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess(img):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

In [None]:
enc = load_model("https://cdn.openai.com/dall-e/encoder.pkl", device)
# dec = load_model("https://cdn.openai.com/dall-e/decoder.pkl", device)

In [None]:
x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))
display_markdown('Original image:')
display(T.ToPILImage(mode='RGB')(x[0]))

In [None]:
imageVocab_len = enc.vocab_size

In [None]:
imageCodebook_len

In [None]:
def process_image_with_encoder(image):
    z = enc(image)
    z = torch.argmax(z, axis=1)
    z_ = F.one_hot(z, num_classes=imageCodebook_len).permute(0, 3, 1, 2).float()
    return z_

## VQGAN as Image Encoder

In [None]:
# from vqgan_jax.modeling_flax_vqgan import VQModel
# from transformers import VQGanForPreTraining
# from transformers import VQGanProcessor

# Load the pre-trained VQGAN model and its processor
# checkpoint = "dalle-mini/vqgan_imagenet_f16_16384"
# model = VQModel.from_pretrained(checkpoint)
# processor = VQGanProcessor.from_pretrained(checkpoint)


In [None]:
# def download_image(url):
#     resp = requests.get(url)
#     resp.raise_for_status()
#     return Image.open(io.BytesIO(resp.content))

# def preprocess_vqgan(x):
#   x = 2.*x - 1.
#   return x

# def custom_to_pil(x):
#   x = np.clip(x, -1., 1.)
#   x = (x + 1.)/2.
#   x = (255*x).astype(np.uint8)
#   x = Image.fromarray(x)
#   if not x.mode == "RGB":
#     x = x.convert("RGB")
#   return x

# def preprocess(img, target_image_size=256,):
#     s = min(img.size)
    
#     if s < target_image_size:
#         raise ValueError(f'min dim for image {s} < {target_image_size}')
        
#     r = target_image_size / s
#     s = (round(r * img.size[1]), round(r * img.size[0]))
#     img = TF.resize(img, s, interpolation=Image.LANCZOS)
#     img = TF.center_crop(img, output_size=2 * [target_image_size])
#     img = torch.unsqueeze(T.ToTensor()(img), 0)
#     return img.permute(0, 2, 3, 1)

In [None]:
# import numpy as np
# from torchvision.transforms import InterpolationMode
# def resize_image(image, size=256):
#     s = min(image.size)
#     r = size / s
#     s = (round(r * image.size[1]), round(r * image.size[0]))
#     image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
#     image = TF.center_crop(image, output_size = 2 * [size])
#     image = np.expand_dims(np.array(image), axis=0)
#     return image

In [None]:
# url='https://heibox.uni-heidelberg.de/f/7bb608381aae4539ba7a/?dl=1'
# size=256
# image = download_image(url)
# image = resize_image(image)
# image.shape

In [None]:
# display(T.ToPILImage(mode='RGB')(image[0]))

In [None]:
# _, id = model.encode(image)

In [None]:
# enc = model.encode

In [None]:
# imageCodebook_len = 16384

## MIDITOK as MIDI Encoder

In [2]:
from miditok import REMIPlus, TokenizerConfig, REMI
from miditoolkit import MidiFile

In [20]:
PITCH_RANGE = (21, 109)
BEAT_RES = {(0, 1): 8, (1, 2): 4, (2, 4): 2, (4, 8): 1}
NUM_VELOCITIES = 24
SPECIAL_TOKENS = ["PAD", "MASK", "BOS", "EOS"]
USE_CHORDS = True
USE_RESTS = False
USE_TEMPOS = True
USE_TIME_SIGNATURE = False
USE_PROGRAMS = True
NUM_TEMPOS = 32
TEMPO_RANGE = (50, 200)  # (min_tempo, max_tempo)
TOKENIZER_PARAMS = {
    "pitch_range": PITCH_RANGE,
    "beat_res": BEAT_RES,
    "num_velocities": NUM_VELOCITIES,
    "special_tokens": SPECIAL_TOKENS,
    "use_chords": USE_CHORDS,
    "use_rests": USE_RESTS,
    "use_tempos": USE_TEMPOS,
    "use_time_signatures": USE_TIME_SIGNATURE,
    "use_programs": USE_PROGRAMS,
    "num_tempos": NUM_TEMPOS,
    "tempo_range": TEMPO_RANGE,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

In [21]:
midi_tokenizer = REMI(config)

In [None]:
midi = MidiFile("../data/midi/Maestro/2004/MIDI-Unprocessed_SMF_02_R1_2004_01-05_ORIG_MID--AUDIO_02_R1_2004_05_Track05_wav.midi")
tokens = midi_tokenizer(midi)

In [None]:
token_ids = tokens[0].ids
print(token_ids[:20])

In [22]:
midi_vocab_len = len(midi_tokenizer.vocab)
print(f"midi has {midi_vocab_len} vocabularies")

midi has 344 vocabularies


In [None]:
one_hot_midi_tokens = F.one_hot(torch.Tensor(token_ids).long(), num_classes=midi_vocab_len)
print(one_hot_midi_tokens.shape) # [seq_len, num_classes]

## Text LLM

In [None]:
# from transformers import LlamaTokenizer, LlamaForCausalLM
# import transformers
# import torch

# llm = "meta-llama/Llama-2-7b-hf"
# model = LlamaForCausalLM.from_pretrained(llm)
# tokenizer = LlamaTokenizer.from_pretrained(llm)

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2"
model = GPT2LMHeadModel.from_pretrained(llm)
llm_tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [None]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
model.to(device)
model.eval()

In [None]:
# embeddings = embeddings.to(device)

In [None]:
print("gpt2 feature dim length:", llm_feature_dim)
print("gpt2 vocabulary length:", llm_vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

In [None]:
def forward_llm_with_embeddings(embeddings):
    """
    Forward pass through GPT-2 for sequential token prediction logits from embeddings.

    :param embeddings: Embeddings of the sequence, shape [batch_size, seq_len, embedding_dim].
    :return: Tensor of logits for token predictions, shape [batch_size, seq_len, vocab_size].
    """
    batch_size, seq_len, _ = embeddings.size()
    vocab_size = model.config.vocab_size
    predicted_logits = torch.zeros((batch_size, seq_len, vocab_size), device=embeddings.device)
    
    gpt2_model.eval()

    embeddings = embeddings.detach()
    
    for i in range(seq_len):
        # Use embeddings up to the i-th position to predict the next token
        input_embeddings = embeddings[:, :i+1, :]

        # Forward pass through GPT-2
        with torch.no_grad():
            outputs = gpt2_model(inputs_embeds=input_embeddings)
        logits = outputs.logits

        # Get the logits for the next position (i+1)
        predicted_logits[:, i, :] = logits[:, -1, :]  # Last token in the sequence

    predicted_logits.requires_grad = True

    return predicted_logits


## Mapper Network

map some modality to text token's feature dimension

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [None]:
# Create the mapper
# mapper maps vocabulary_size of target modality to feature_dimension size of llm
mapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)

In [None]:
mapper

## Generate Ground Truth

In [None]:
def generate_next_token_predictions(token_sequences):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=token_sequences, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [None]:
def find_closest_token_logits(batch_feature_vectors, embeddings):
        """
        Find the LLM token who has the highest dot product to the given feature vector
        This acts as the action of our REINFORCE algorithm
        
        return prob: Probability of each token getting chosen => shape:(batch_size, seq_len, llm_vocabulary_size)
        return closest_tokens: The token with the highest probability => shape:(batch_size, seq_len) 
        """
        dot_product = torch.matmul(batch_feature_vectors, embeddings.T)
        probs = F.softmax(dot_product, dim=-1)
        closest_tokens = torch.argmax(probs, dim=-1)
        
        return closest_tokens

In [None]:
def get_ground_truth(mapped_feature_vector, embeddings):
    
    ground_truth = find_closest_token_logits(mapped_feature_vector, embeddings)

    return ground_truth


In [None]:
def translate(batch_feature_vectors, embeddings):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape
    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long)

    # Normalize the embedding matrix
    embedding_matrix_norm = F.normalize(embeddings, dim=1)

    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)

    for i in range(batch_size):
        # Normalize the feature vectors for the i-th sample in the batch
        feature_vectors_norm = F.normalize(batch_feature_vectors[i], dim=1)

        # Compute cosine similarity for the entire sequence at once
        cosine_similarities = torch.matmul(feature_vectors_norm, embedding_matrix_norm.T)

        # Find the token with the highest similarity for each feature vector
        closest_tokens[i] = torch.argmax(cosine_similarities, dim=1)

    return closest_tokens

## Get Image Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms, datasets

In [None]:
# transform = transforms.Compose([
#     transforms.Resize((128, 128)),  # Resize to a fixed size; adjust as needed
#     transforms.ToTensor(),          # Convert images to PyTorch tensors
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize (mean, std) for each color channel
# ])

In [None]:
image_size = 128

def resize_and_crop(img):
    # Resize while maintaining aspect ratio and center crop
    s = min(img.size)
    r = image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [image_size])
    return img

def modified_map_pixels(img):
    # Add a batch dimension, apply map_pixels, and then remove the batch dimension
    img = img.unsqueeze(0)
    img = map_pixels(img)
    return img.squeeze(0)

transform = transforms.Compose([
            transforms.Lambda(resize_and_crop),
            transforms.ToTensor(),
            transforms.Lambda(modified_map_pixels)
        ])

In [None]:
# Replace 'path/to/lsun' with the actual path to your LSUN dataset
dataset_path = '../data/lsun'

lsun_dataset = datasets.LSUN(root=dataset_path, classes=['bedroom_train'], transform=transform)

In [None]:
batch_size = 5  # Adjust based on your memory availability and requirements
lsun_loader = DataLoader(lsun_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print('dataset size:',len(lsun_loader))

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

# Get Midi Dataset

In [6]:
from miditok.pytorch_data.datasets import DatasetTok

In [7]:
dataset_path = Path("../data/midi/MMD_MIDI")
tokens_path = Path("../data/midi/MMD_MIDI_no_bpe")
pattern = re.compile(r"/\._")

# Use glob to find all .mid files and filter out the undesired ones
midi_files = [file for file in dataset_path.glob("**/*.mid") if not pattern.search(str(file))]

In [None]:
# for f in midi_files:
#     print(f)
#     midi_tokenizer.tokenize_midi_dataset(f, tokens_path, logging=False)

In [None]:
midi_tokenizer.tokenize_midi_dataset(midi_files, tokens_path)

Tokenizing MIDIs (midi/MMD_MIDI_no_bpe):   0%|                                  | 238/433527 [00:26<11:59:31, 10.04it/s]

In [14]:
tokens_paths

[PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444cffbf1886ef55bc479a54af2527bd.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/4444a23c1c638de34eff324cd6eae1eb.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444ccacdca976c78f542128b2a187e92.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/4440004d0c862d4126b664d77f2e3037.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/4445a9554c7b8717b79f2c8e1a36475f.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444bc98d2f607d96933e9e9d8332fe4f.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/4444aae7668e19689dbda4f3d45cb21a.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444ef7504e626e7877c91512fc839809.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444d3cc75ec4fb0adfa0163c55d31d3c.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/4447b5bd432461c15bd07abef223c9e2.json'),
 PosixPath('../data/midi/MMD_MIDI_no_bpe/4/4/4/444b62c6d69a1971e691cababca9a72e.json'),
 PosixPath('../data/midi/MMD_MID

In [18]:
tokens_paths = list(tokens_path.glob("**/*.json"))

midi_dataset = DatasetTok(tokens_paths, min_seq_len=256, max_seq_len=256, one_token_stream=False)

Loading data: ../data/midi/MMD_MIDI_no_bpe/4/4/4: 100%|███████████████████████| 184163/184163 [01:53<00:00, 1624.70it/s]


In [19]:
midi_dataset

No data loaded

In [None]:
len(midi_files)

In [None]:
augment_midi_dataset(
    midi_paths,
    pitch_offsets=[-12, 12],
    velocity_offsets=[-4, 5],
    duration_offsets=[-0.5, 1],
    out_path=midi_aug_path,
    Path("to", "new", "location", "augmented"),
)
tokenizer.tokenize_midi_dataset(        # 2 velocity and 1 duration values
    midi_paths,
    Path("path", "to", "tokens"),
    midi_valid,
)

## REINFORCE Loss Function

In [None]:
def Reinforce_Loss(logits, targets, loss, gamma=1.0):
    """
    Calculate the REINFORCE loss for sequence prediction.

    :param logits: Logits from the model, shape [batch_size, seq_len, vocab_size].
    :param targets: Ground truth sequence, shape [batch_size, seq_len].
    :param rewards: Reward for each step in the sequence, shape [batch_size, seq_len].
    :param gamma: Discount factor for future rewards.
    :return: The REINFORCE loss (to be maximized).
    """
    
    batch_size, seq_len, _ = logits.shape

    # return loss / seq_len
    log_probs = F.log_softmax(logits, dim=2)
    log_probs_targets = log_probs.gather(2, targets.unsqueeze(2)).squeeze(2)

    # Create a discount matrix
    discounts = gamma ** torch.arange(seq_len).float().unsqueeze(0).to(log_probs.device)
    discount_matrix = torch.tril(discounts.repeat(seq_len, 1).T).T


    # Calculate discounted rewards
    discounted_loss = loss.unsqueeze(1) * discount_matrix
    cumulative_loss = discounted_loss.sum(dim=2)
    
    # Calculate loss
    # total_loss = -torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len
    total_loss = torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len

    return total_loss

## Train Model

In [None]:
# Hyper Parameters
learning_rate = 1e-4
epochs = 1
gamma = 0.95

In [None]:
experiment = "test"
exp_type = "image"
experiment_name = f"{exp_type}/{experiment}/model={llm}_lr={learning_rate}"

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(f'runs/{experiment_name}')

In [None]:
optimizer = optim.Adam(mapper.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
rl_criterion = nn.CrossEntropyLoss(reduction='none')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

In [None]:
def CrossEntropySG_Loss(mapped_feature_vector, targets, reduction='mean'):
    """
    Custom cross-entropy loss with straight-through estimator.
    :return: Loss value.
    """
    batch_size, seq_len, embedding_dim = mapped_feature_vector.shape
    
    # Closest tokens have shape [batch_size, seq_len]
    # closest_tokens = get_llm_ground_truth(mapped_feature_vector)

    closest_embeddings = embeddings[targets]
    closest_embeddings = closest_embeddings.reshape(batch_size, seq_len, embedding_dim)


    
    # STE_LOGITS have shape [batch_size, seq_len, embedding_dim]
    ste_logits = (closest_embeddings - mapped_feature_vector.detach()) + mapped_feature_vector

    predictions = forward_with_embeddings(ste_logits)
    predictions = predictions.reshape(batch_size*seq_len, -1)
    
    # Calculate cross-entropy loss
    loss = F.cross_entropy(predictions, targets, reduction=reduction)

    return loss

In [None]:
def process_image_with_encoder_(image):
    _, z = enc(image)
    z = np.asarray(z)
    z = torch.from_numpy(z).to(device)
    z_ = F.one_hot(z.long(), num_classes=imageCodebook_len).float()
    return z_

In [None]:
ce_loss.shape

In [None]:
for epoch in range(epochs):
    for i, (images, _) in enumerate(lsun_loader):
    
        optimizer.zero_grad()

        # for vqgan
        # images = images.permute(0, 2, 3, 1)
        
        # Process each image through DALL-E encoder to get image tokens
        image_token_logits = enc(images.to(device))
        ground_truth_tokens = torch.argmax(image_token_logits, dim=1)
        one_hot_image_tokens = F.one_hot(ground_truth_tokens, num_classes=imageCodebook_len).permute(0, 3, 1, 2).float()

        ground_truth_tokens = ground_truth_tokens.reshape( -1)
        flattened_tokens = one_hot_image_tokens.reshape(one_hot_image_tokens.size(0), -1, imageCodebook_len)

        # Map tokens and get ground truth from LLM
        mapped_feature_vector = mapper(flattened_tokens)

        translated_text_tokens = translate(mapped_feature_vector, embeddings)
        
        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions(translated_text_tokens)

        # Calculate Logits with mapper function
        logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
        logits_ = logits.reshape(-1, imageCodebook_len)
        
        # RL Loss
        # prediction_logits = prediction_logits.reshape(batch_size, -1, llm.vocab_len)
        ce_loss = rl_criterion(logits_, ground_truth_tokens)
        ground_truth_tokens = ground_truth_tokens.reshape(batch_size, -1)
        ce_loss = ce_loss.reshape(batch_size, -1)

        loss = Reinforce_Loss(logits, ground_truth_tokens, ce_loss)
        
        # Backward pass and update
        loss.backward()
        optimizer.step()

        # Log the losses
        # writer.add_scalars(
        #     "Training Metrics",
        #     {
        #         "loss": loss.item(),
        #         "cross_entropy": ce_loss[:,0].mean().item(),
        #     },
        #     epoch * len(lsun_loader) + i
        # )
            
        if i % 50 == 0:
            print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item()}")

    scheduler.step()
    print(f"Epoch {epoch+1}/{args.epoch} completed.")
    
Path(f"models/{exp_type}/{experiment}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"models/{experiment_name}")
writer.close()

In [None]:
from pathlib import Path
Path(f"models/{exp_type}/{experiment}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"models/{experiment_name}")