## Import Needed Libraries


In [1]:
import io
import os, sys
import requests

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from pathlib import Path

## Dalle as Image Encoder
#### Download VQVAE from DALLE
| testing usage
```python
enc = encoder
dec = decoder
```

In [None]:
from PIL import Image
from dall_e import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

In [None]:
device = torch.device('cuda:1')

In [None]:
target_image_size = 256

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess(img):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

In [None]:
enc = load_model("https://cdn.openai.com/dall-e/encoder.pkl", device)
# dec = load_model("https://cdn.openai.com/dall-e/decoder.pkl", device)

In [None]:
x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))
display_markdown('Original image:')
display(T.ToPILImage(mode='RGB')(x[0]))

In [None]:
imageVocab_len = enc.vocab_size

In [None]:
imageCodebook_len

In [None]:
def process_image_with_encoder(image):
    z = enc(image)
    z = torch.argmax(z, axis=1)
    z_ = F.one_hot(z, num_classes=imageCodebook_len).permute(0, 3, 1, 2).float()
    return z_

## VQGAN as Image Encoder

In [None]:
# from vqgan_jax.modeling_flax_vqgan import VQModel
# from transformers import VQGanForPreTraining
# from transformers import VQGanProcessor

# Load the pre-trained VQGAN model and its processor
# checkpoint = "dalle-mini/vqgan_imagenet_f16_16384"
# model = VQModel.from_pretrained(checkpoint)
# processor = VQGanProcessor.from_pretrained(checkpoint)


In [None]:
# def download_image(url):
#     resp = requests.get(url)
#     resp.raise_for_status()
#     return Image.open(io.BytesIO(resp.content))

# def preprocess_vqgan(x):
#   x = 2.*x - 1.
#   return x

# def custom_to_pil(x):
#   x = np.clip(x, -1., 1.)
#   x = (x + 1.)/2.
#   x = (255*x).astype(np.uint8)
#   x = Image.fromarray(x)
#   if not x.mode == "RGB":
#     x = x.convert("RGB")
#   return x

# def preprocess(img, target_image_size=256,):
#     s = min(img.size)
    
#     if s < target_image_size:
#         raise ValueError(f'min dim for image {s} < {target_image_size}')
        
#     r = target_image_size / s
#     s = (round(r * img.size[1]), round(r * img.size[0]))
#     img = TF.resize(img, s, interpolation=Image.LANCZOS)
#     img = TF.center_crop(img, output_size=2 * [target_image_size])
#     img = torch.unsqueeze(T.ToTensor()(img), 0)
#     return img.permute(0, 2, 3, 1)

In [None]:
# import numpy as np
# from torchvision.transforms import InterpolationMode
# def resize_image(image, size=256):
#     s = min(image.size)
#     r = size / s
#     s = (round(r * image.size[1]), round(r * image.size[0]))
#     image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
#     image = TF.center_crop(image, output_size = 2 * [size])
#     image = np.expand_dims(np.array(image), axis=0)
#     return image

In [None]:
# url='https://heibox.uni-heidelberg.de/f/7bb608381aae4539ba7a/?dl=1'
# size=256
# image = download_image(url)
# image = resize_image(image)
# image.shape

In [None]:
# display(T.ToPILImage(mode='RGB')(image[0]))

In [None]:
# _, id = model.encode(image)

In [None]:
# enc = model.encode

In [None]:
# imageCodebook_len = 16384

## MIDITOK as MIDI Encoder

In [2]:
from miditok import REMIPlus, TokenizerConfig
from miditoolkit import MidiFile

In [3]:
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": {(0, 4): 8, (4, 12): 4},
    "num_velocities": 32,
    "special_tokens": ["PAD", "BOS", "EOS", "MASK"],
    "use_chords": True,
    "use_rests": False,
    "use_tempos": True,
    "use_time_signatures": False,
    "use_programs": False,
    "num_tempos": 32,  # number of tempo bins
    "tempo_range": (40, 250),  # (min, max)
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

In [4]:
tokenizer = REMIPlus(config)

In [12]:
midi = MidiFile("../data/midi/MMD_MIDI/0/0/0/00000ec8a66b6bd2ef809b0443eeae41.mid")
tokens = tokenizer(midi)

In [21]:
token_ids = tokens.ids
print(token_ids[:20])

[4, 420, 189, 267, 284, 62, 124, 126, 191, 284, 50, 124, 126, 193, 284, 66, 124, 126, 195, 284]


In [24]:
midiVocab_len = len(tokenizer.vocab)
tokenizer.vocab

{'PAD_None': 0,
 'BOS_None': 1,
 'EOS_None': 2,
 'MASK_None': 3,
 'Bar_None': 4,
 'Pitch_21': 5,
 'Pitch_22': 6,
 'Pitch_23': 7,
 'Pitch_24': 8,
 'Pitch_25': 9,
 'Pitch_26': 10,
 'Pitch_27': 11,
 'Pitch_28': 12,
 'Pitch_29': 13,
 'Pitch_30': 14,
 'Pitch_31': 15,
 'Pitch_32': 16,
 'Pitch_33': 17,
 'Pitch_34': 18,
 'Pitch_35': 19,
 'Pitch_36': 20,
 'Pitch_37': 21,
 'Pitch_38': 22,
 'Pitch_39': 23,
 'Pitch_40': 24,
 'Pitch_41': 25,
 'Pitch_42': 26,
 'Pitch_43': 27,
 'Pitch_44': 28,
 'Pitch_45': 29,
 'Pitch_46': 30,
 'Pitch_47': 31,
 'Pitch_48': 32,
 'Pitch_49': 33,
 'Pitch_50': 34,
 'Pitch_51': 35,
 'Pitch_52': 36,
 'Pitch_53': 37,
 'Pitch_54': 38,
 'Pitch_55': 39,
 'Pitch_56': 40,
 'Pitch_57': 41,
 'Pitch_58': 42,
 'Pitch_59': 43,
 'Pitch_60': 44,
 'Pitch_61': 45,
 'Pitch_62': 46,
 'Pitch_63': 47,
 'Pitch_64': 48,
 'Pitch_65': 49,
 'Pitch_66': 50,
 'Pitch_67': 51,
 'Pitch_68': 52,
 'Pitch_69': 53,
 'Pitch_70': 54,
 'Pitch_71': 55,
 'Pitch_72': 56,
 'Pitch_73': 57,
 'Pitch_74': 58,
 'Pitc

## Text LLM

In [None]:
# from transformers import LlamaTokenizer, LlamaForCausalLM
# import transformers
# import torch

# llm = "meta-llama/Llama-2-7b-hf"
# model = LlamaForCausalLM.from_pretrained(llm)
# tokenizer = LlamaTokenizer.from_pretrained(llm)

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2"
model = GPT2LMHeadModel.from_pretrained(llm)
tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [None]:
model.config

In [None]:
model

In [None]:
# import torch.nn as nn

# embed_tokens = nn.Embedding(model.config.vocab_size, model.config.hidden_size)

In [None]:
model.lm_head.weight

In [None]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
codebook_len = model.config.hidden_size
vocab_len = model.config.vocab_size
model.to(device)
model.eval()

In [None]:
print(embeddings.shape)
print(embeddings)

In [None]:
# embeddings = embeddings.to(device)

In [None]:
print("gpt2 codebook length:", codebook_len)
print("gpt2 vocabulary length:", vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

In [None]:
def forward_llm_with_embeddings(embeddings):
    """
    Forward pass through GPT-2 for sequential token prediction logits from embeddings.

    :param embeddings: Embeddings of the sequence, shape [batch_size, seq_len, embedding_dim].
    :return: Tensor of logits for token predictions, shape [batch_size, seq_len, vocab_size].
    """
    batch_size, seq_len, _ = embeddings.size()
    vocab_size = model.config.vocab_size
    predicted_logits = torch.zeros((batch_size, seq_len, vocab_size), device=embeddings.device)
    
    gpt2_model.eval()

    embeddings = embeddings.detach()
    
    for i in range(seq_len):
        # Use embeddings up to the i-th position to predict the next token
        input_embeddings = embeddings[:, :i+1, :]

        # Forward pass through GPT-2
        with torch.no_grad():
            outputs = gpt2_model(inputs_embeds=input_embeddings)
        logits = outputs.logits

        # Get the logits for the next position (i+1)
        predicted_logits[:, i, :] = logits[:, -1, :]  # Last token in the sequence

    predicted_logits.requires_grad = True

    return predicted_logits


## Changing Image tokens to Text tokens

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [None]:
# Create the mapper
# mapper maps a 8192 to a 768
mapper = TokenMapper(imageCodebook_len, codebook_len, device=device)

In [None]:
mapper

## Generate Ground Truth

In [None]:
def find_closest_indices_cosine(mapped_feature_vector, batch_size=10):
    # mapped_fv_reshaped has shape (-1, 768)
    mapped_fv_reshaped = mapped_feature_vector.view(-1, mapped_feature_vector.shape[-1])
    
    closest_indices = []
    for i in range(0, mapped_fv_reshaped.size(0), batch_size):
        # Process in smaller batches
        batch_fv = mapped_fv_reshaped[i:i+batch_size]

        # Compute cosine similarity for the batch
        distances_batch = F.cosine_similarity(batch_fv.unsqueeze(1), gpt2_embeddings.unsqueeze(0), dim=2)

        # Find the index of the maximum similarity for each vector in the batch
        closest_indices_batch = torch.argmax(distances_batch, dim=1)
        closest_indices.append(closest_indices_batch)

    # Concatenate results from all batches
    closest_indices = torch.cat(closest_indices, dim=0)

    # Reshape to the original batch and sequence dimension
    closest_indices_reshaped = closest_indices.view(mapped_feature_vector.shape[0], mapped_feature_vector.shape[1]).to(device)
    
    return closest_indices_reshaped


In [None]:
def find_closest_gpt2_token(batch_feature_vectors):
    """
    Find the GPT-2 token whose embedding is closest to the given feature vector.

    :param feature_vector: The feature vector (from the mapper). Shape: (embedding_dim,)
    :param embedding_matrix: GPT-2's embedding matrix. Shape: (vocab_size, embedding_dim)
    :return: The ID of the closest token.
    """
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape
    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long)
    
    # Normalize the feature vector and the embedding matrix for cosine similarity
    embedding_matrix_norm = F.normalize(gpt2_embeddings, dim=1)

    for i in range(batch_size):
        for j in range(seq_len):
            # Normalize the feature vector
            feature_vector_norm = F.normalize(batch_feature_vectors[i, j].unsqueeze(0), dim=1)

            # Compute cosine similarity
            cosine_similarities = torch.matmul(feature_vector_norm, embedding_matrix_norm.T).squeeze(0)

            # Find the token with the highest similarity
            closest_token_id = torch.argmax(cosine_similarities).item()
            closest_tokens[i, j] = closest_token_id

    return closest_tokens


In [None]:
def generate_next_token_predictions(token_sequences):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=token_sequences, output_hidden_states=True)
    
    return outputs.hidden_states[-1]
        
    # return predictions
    # return logits, predictions

In [None]:
def find_closest_token_logits(batch_feature_vectors, embeddings):
        """
        Find the LLM token who has the highest dot product to the given feature vector
        This acts as the action of our REINFORCE algorithm
        
        return prob: Probability of each token getting chosen => shape:(batch_size, seq_len, llm_vocabulary_size)
        return closest_tokens: The token with the highest probability => shape:(batch_size, seq_len) 
        """
        dot_product = torch.matmul(batch_feature_vectors, embeddings.T)
        probs = F.softmax(dot_product, dim=-1)
        closest_tokens = torch.argmax(probs, dim=-1)
        
        return closest_tokens

In [None]:
def get_ground_truth(mapped_feature_vector, embeddings):
    
    ground_truth = find_closest_token_logits(mapped_feature_vector, embeddings)

    return ground_truth


In [None]:
def translate(batch_feature_vectors, embeddings):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape
    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long)

    # Normalize the embedding matrix
    embedding_matrix_norm = F.normalize(embeddings, dim=1)

    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)

    for i in range(batch_size):
        # Normalize the feature vectors for the i-th sample in the batch
        feature_vectors_norm = F.normalize(batch_feature_vectors[i], dim=1)

        # Compute cosine similarity for the entire sequence at once
        cosine_similarities = torch.matmul(feature_vectors_norm, embedding_matrix_norm.T)

        # Find the token with the highest similarity for each feature vector
        closest_tokens[i] = torch.argmax(cosine_similarities, dim=1)

    return closest_tokens

## Get Image Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms, datasets

from torch.utils.data import DataLoader

In [None]:
# transform = transforms.Compose([
#     transforms.Resize((128, 128)),  # Resize to a fixed size; adjust as needed
#     transforms.ToTensor(),          # Convert images to PyTorch tensors
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize (mean, std) for each color channel
# ])

In [None]:
image_size = 128

def resize_and_crop(img):
    # Resize while maintaining aspect ratio and center crop
    s = min(img.size)
    r = image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [image_size])
    return img

def modified_map_pixels(img):
    # Add a batch dimension, apply map_pixels, and then remove the batch dimension
    img = img.unsqueeze(0)
    img = map_pixels(img)
    return img.squeeze(0)

transform = transforms.Compose([
            transforms.Lambda(resize_and_crop),
            transforms.ToTensor(),
            transforms.Lambda(modified_map_pixels)
        ])

In [None]:
# Replace 'path/to/lsun' with the actual path to your LSUN dataset
dataset_path = './data/lsun'

lsun_dataset = datasets.LSUN(root=dataset_path, classes=['bedroom_train'], transform=transform)

In [None]:
batch_size = 5  # Adjust based on your memory availability and requirements
lsun_loader = DataLoader(lsun_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
print('dataset size:',len(lsun_loader))

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

## REINFORCE Loss Function

In [None]:
def Reinforce_Loss(logits, targets, loss, gamma=1.0):
    """
    Calculate the REINFORCE loss for sequence prediction.

    :param logits: Logits from the model, shape [batch_size, seq_len, vocab_size].
    :param targets: Ground truth sequence, shape [batch_size, seq_len].
    :param rewards: Reward for each step in the sequence, shape [batch_size, seq_len].
    :param gamma: Discount factor for future rewards.
    :return: The REINFORCE loss (to be maximized).
    """
    
    batch_size, seq_len, _ = logits.shape

    # return loss / seq_len
    log_probs = F.log_softmax(logits, dim=2)
    log_probs_targets = log_probs.gather(2, targets.unsqueeze(2)).squeeze(2)

    # Create a discount matrix
    discounts = gamma ** torch.arange(seq_len).float().unsqueeze(0).to(log_probs.device)
    discount_matrix = torch.tril(discounts.repeat(seq_len, 1).T).T


    # Calculate discounted rewards
    discounted_loss = loss.unsqueeze(1) * discount_matrix
    cumulative_loss = discounted_loss.sum(dim=2)
    
    # Calculate loss
    # total_loss = -torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len
    total_loss = torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len

    return total_loss

## Train Model

In [None]:
# Hyper Parameters
learning_rate = 1e-4
epochs = 1
gamma = 0.95

In [None]:
experiment = "test"
exp_type = "image"
experiment_name = f"{exp_type}/{experiment}/model={llm}_lr={learning_rate}"

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(f'runs/{experiment_name}')

In [None]:
optimizer = optim.Adam(mapper.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
rl_criterion = nn.CrossEntropyLoss(reduction='none')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

In [None]:
def CrossEntropySG_Loss(mapped_feature_vector, targets, reduction='mean'):
    """
    Custom cross-entropy loss with straight-through estimator.
    :return: Loss value.
    """
    batch_size, seq_len, embedding_dim = mapped_feature_vector.shape
    
    # Closest tokens have shape [batch_size, seq_len]
    # closest_tokens = get_llm_ground_truth(mapped_feature_vector)

    closest_embeddings = embeddings[targets]
    closest_embeddings = closest_embeddings.reshape(batch_size, seq_len, embedding_dim)


    
    # STE_LOGITS have shape [batch_size, seq_len, embedding_dim]
    ste_logits = (closest_embeddings - mapped_feature_vector.detach()) + mapped_feature_vector

    predictions = forward_with_embeddings(ste_logits)
    predictions = predictions.reshape(batch_size*seq_len, -1)
    
    # Calculate cross-entropy loss
    loss = F.cross_entropy(predictions, targets, reduction=reduction)

    return loss

In [None]:
def process_image_with_encoder_(image):
    _, z = enc(image)
    z = np.asarray(z)
    z = torch.from_numpy(z).to(device)
    z_ = F.one_hot(z.long(), num_classes=imageCodebook_len).float()
    return z_

In [None]:
ce_loss.shape

In [None]:
for epoch in range(epochs):
    for i, (images, _) in enumerate(lsun_loader):
    
        optimizer.zero_grad()

        # for vqgan
        # images = images.permute(0, 2, 3, 1)
        
        # Process each image through DALL-E encoder to get image tokens
        image_token_logits = enc(images.to(device))
        ground_truth_tokens = torch.argmax(image_token_logits, dim=1)
        one_hot_image_tokens = F.one_hot(ground_truth_tokens, num_classes=imageCodebook_len).permute(0, 3, 1, 2).float()

        ground_truth_tokens = ground_truth_tokens.reshape( -1)
        flattened_tokens = one_hot_image_tokens.reshape(one_hot_image_tokens.size(0), -1, imageCodebook_len)

        # Map tokens and get ground truth from LLM
        mapped_feature_vector = mapper(flattened_tokens)

        translated_text_tokens = translate(mapped_feature_vector, embeddings)
        
        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions(translated_text_tokens)

        # Calculate Logits with mapper function
        logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
        logits_ = logits.reshape(-1, imageCodebook_len)
        
        # RL Loss
        # prediction_logits = prediction_logits.reshape(batch_size, -1, llm.vocab_len)
        ce_loss = rl_criterion(logits_, ground_truth_tokens)
        ground_truth_tokens = ground_truth_tokens.reshape(batch_size, -1)
        ce_loss = ce_loss.reshape(batch_size, -1)

        loss = Reinforce_Loss(logits, ground_truth_tokens, ce_loss)
        
        # Backward pass and update
        loss.backward()
        optimizer.step()

        # Log the losses
        # writer.add_scalars(
        #     "Training Metrics",
        #     {
        #         "loss": loss.item(),
        #         "cross_entropy": ce_loss[:,0].mean().item(),
        #     },
        #     epoch * len(lsun_loader) + i
        # )
            
        if i % 50 == 0:
            print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item()}")

    scheduler.step()
    print(f"Epoch {epoch+1}/{args.epoch} completed.")
    
Path(f"models/{exp_type}/{experiment}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"models/{experiment_name}")
writer.close()

In [None]:
from pathlib import Path
Path(f"models/{exp_type}/{experiment}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"models/{experiment_name}")