## Import Needed Libraries


In [None]:
import re
import io
import os
import sys
import requests
import numpy as np

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from pathlib import Path
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split

In [None]:
from PIL import Image
from IPython.display import display, display_markdown

In [None]:
device = torch.device('cuda:1')

In [None]:
torch.manual_seed(42)

## VQGAN as Image Encoder

In [None]:
from vqgan_jax.modeling_flax_vqgan import VQModel

# Load the pre-trained VQGAN model and its processor
checkpoint = "dalle-mini/vqgan_imagenet_f16_16384"
vqmodel = VQModel.from_pretrained(checkpoint)

In [None]:
def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return Image.open(io.BytesIO(resp.content))

def preprocess_vqgan(x):
  x = 2.*x - 1.
  return x

def custom_to_pil(x):
  x = np.clip(x, -1., 1.)
  x = (x + 1.)/2.
  x = (255*x).astype(np.uint8)
  x = Image.fromarray(x)
  if not x.mode == "RGB":
    x = x.convert("RGB")
  return x

def preprocess(img, target_image_size=256,):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return img.permute(0, 2, 3, 1).numpy()

In [None]:
image_vocab_len = vqmodel.config.n_embed
print("image_vocab_len:", image_vocab_len)

In [None]:
url='https://heibox.uni-heidelberg.de/f/7bb608381aae4539ba7a/?dl=1'
size=256
image = download_image(url)
image = preprocess(image, size)

In [None]:
custom_to_pil(preprocess_vqgan(image[0]))

In [None]:
quant_states, indices = vqmodel.encode(image)
indices.shape

In [None]:
rec = vqmodel.decode(quant_states)

In [None]:
custom_to_pil(preprocess_vqgan(np.asarray(rec[0])))

## Text LLM

In [None]:
# from transformers import LlamaTokenizer, LlamaForCausalLM
# import transformers
# import torch

# llm = "meta-llama/Llama-2-7b-hf"
# model = LlamaForCausalLM.from_pretrained(llm)
# tokenizer = LlamaTokenizer.from_pretrained(llm)

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
llm = "gpt2"
model = GPT2LMHeadModel.from_pretrained(llm)
llm_tokenizer = GPT2Tokenizer.from_pretrained(llm)

In [None]:
embeddings = model.lm_head.weight
# embedding_matrix = model.transformer.wte.weight
llm_feature_dim = model.config.hidden_size
llm_vocab_len = model.config.vocab_size
model.to(device)
model.eval()

In [None]:
# embeddings = embeddings.to(device)

In [None]:
print("gpt2 feature dim length:", llm_feature_dim)
print("gpt2 vocabulary length:", llm_vocab_len)
print("gpt2 embedding shape:", embeddings.shape)

## Mapper Network

map some modality to text token's feature dimension

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

In [None]:
# Create the mapper
# mapper maps vocabulary_size of target modality to feature_dimension size of llm
# mapper = TokenMapper(midi_vocab_len, llm_feature_dim, device=device)
mapper = TokenMapper(image_vocab_len, llm_feature_dim, device=device)

In [None]:
mapper

## Generate Ground Truth

In [None]:
def generate_next_token_predictions(token_sequences):
    
    # Get model predictions
    with torch.no_grad():
        outputs = model(input_ids=token_sequences, output_hidden_states=True)
    
    return outputs.hidden_states[-1]

In [None]:
def translate(batch_feature_vectors, embeddings):
    batch_size, seq_len, embedding_dim = batch_feature_vectors.shape
    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long)

    # Normalize the embedding matrix
    embedding_matrix_norm = F.normalize(embeddings, dim=1)

    closest_tokens = torch.zeros((batch_size, seq_len), dtype=torch.long).to(device)

    for i in range(batch_size):
        # Normalize the feature vectors for the i-th sample in the batch
        feature_vectors_norm = F.normalize(batch_feature_vectors[i], dim=1)

        # Compute cosine similarity for the entire sequence at once
        cosine_similarities = torch.matmul(feature_vectors_norm, embedding_matrix_norm.T)

        # Find the token with the highest similarity for each feature vector
        closest_tokens[i] = torch.argmax(cosine_similarities, dim=1)

    return closest_tokens

## Get Image Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms, datasets

In [None]:
# transform = transforms.Compose([
#     transforms.Resize((128, 128)),  # Resize to a fixed size; adjust as needed
#     transforms.ToTensor(),          # Convert images to PyTorch tensors
#     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize (mean, std) for each color channel
# ])

In [None]:
image_size = 256

def resize_and_crop(img):
    # Resize while maintaining aspect ratio and center crop
    s = min(img.size)
    r = image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [image_size])
    return img

def modified_map_pixels(img):
    # Add a batch dimension, apply map_pixels, and then remove the batch dimension
    img = img.unsqueeze(0)
    img = map_pixels(img)
    return img.squeeze(0)

transform = transforms.Compose([
            transforms.Lambda(resize_and_crop),
            transforms.ToTensor(),
            transforms.Lambda(modified_map_pixels)
        ])

In [None]:
# Replace 'path/to/lsun' with the actual path to your LSUN dataset
dataset_path = '../data/lsun'

lsun_dataset = datasets.LSUN(root=dataset_path, classes=['classroom_train'], transform=transform)

In [None]:
batch_size = 5  # Adjust based on your memory availability and requirements
lsun_loader = DataLoader(lsun_dataset, batch_size=batch_size, shuffle=TrueS)
print('dataset size:',len(lsun_loader))

## Reinforce Loss

In [None]:
def Reinforce_Loss(logits, translated, loss, gamma=0.9):
    """
    Calculate the REINFORCE loss for sequence prediction.

    :param logits: Logits from the model, shape [batch_size, seq_len, vocab_size].
    :param targets: Ground truth sequence, shape [batch_size, seq_len].
    :param rewards: Reward for each step in the sequence, shape [batch_size, seq_len].
    :param gamma: Discount factor for future rewards.
    :return: The REINFORCE loss (to be maximized).
    """
    batch_size, seq_len, _ = logits.shape

    # shape = [batch_size, seq_len, llm_vocab_len]
    log_probs = F.log_softmax(logits, dim=-1)
    log_probs_targets = log_probs.gather(2, translated.unsqueeze(2)).squeeze(2)
    
    # Create a discount matrix
    discount_matrix = torch.zeros((seq_len, seq_len)).to(device)

    # Fill the matrix according to the given pattern
    for i in range(seq_len):
        for j in range(i, seq_len):
            discount_matrix[i, j] = gamma ** (j - i)

    normalize_factor = discount_matrix.sum(dim=1)
    
    # Calculate discounted rewards
    discounted_loss = loss.unsqueeze(1) * discount_matrix
    cumulative_loss = discounted_loss.sum(dim=-1) / normalize_factor
    
    # Calculate loss
    total_loss = torch.sum(log_probs_targets * cumulative_loss.detach()) / batch_size / seq_len
    # total_loss = -torch.sum(log_probs_targets * cumulative_loss) / batch_size / seq_len

    return total_loss

## Train Model

In [None]:
# Hyper Parameters
learning_rate = 1e-5
epochs = 1
gamma = 0.95

In [None]:
experiment = "base_test"
exp_type = "image"
name = "vqgan"
experiment_name = f"{exp_type}/{experiment}/{name}/model={llm}_lr={learning_rate}"

In [None]:
from torch.utils.tensorboard import SummaryWriter

# Create a SummaryWriter instance (logs will be saved in 'runs' folder)
writer = SummaryWriter(log_dir = f'../runs/{experiment_name}')

In [None]:
optimizer = optim.Adam(mapper.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
rl_criterion = nn.CrossEntropyLoss(reduction='none')
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

In [None]:
for epoch in range(epochs):
    mapper.train()
    # mapper.eval()
    for i, (images, _) in enumerate(lsun_loader):
    
        optimizer.zero_grad()

        # for vqgan
        images = images.permute(0, 2, 3, 1)
        _, ground_truth_tokens = vqmodel.encode(images)
        ground_truth_tokens = torch.from_dlpack(ground_truth_tokens)
        ground_truth_tokens = ground_truth_tokens.to(torch.int64).to(device)
        one_hot_tokens = F.one_hot(ground_truth_tokens, num_classes=image_vocab_len).float()

        # Logits are to be compared with the next ground truth tokens
        ground_truth_tokens = ground_truth_tokens[:,1:]

        # Map tokens and get ground truth from LLM
        mapped_feature_vector = mapper(one_hot_tokens)

        translated_text_tokens = translate(mapped_feature_vector, embeddings)
        # Calculate Representation of Last Layer in LLM
        final_layer_fv = generate_next_token_predictions(translated_text_tokens)

        # Calculate Logits with mapper function
        logits = torch.matmul(final_layer_fv, mapper.mapper.weight)
        logits = logits[:,:-1]
        logits_ = logits.reshape(-1, image_vocab_len)
        ground_truth_tokens = ground_truth_tokens.reshape(-1)
        
        loss = criterion(logits_, ground_truth_tokens)
        loss.backward()
        optimizer.step()

        if 'base' in experiment:
            writer.add_scalar("training/cross_entropy", loss.item(), epoch*len(lsun_loader)+i)
        # RL Loss
        if 'rl' in experiment:
            
            action_logits = torch.matmul(mapped_feature_vector, embeddings.T)
            with torch.no_grad():
                ce_loss = rl_criterion(logits_, ground_truth_tokens)
            ground_truth_tokens = ground_truth_tokens.reshape(batch_size, -1)
            ce_loss = ce_loss.reshape(batch_size, -1)
    
            loss = Reinforce_Loss(action_logits, translated_text_tokens, ce_loss)
    
            loss.backward()
            optimizer.step()

            # Log the losses
            writer.add_scalars(
                "training",
                {
                    "loss": loss.item(),
                    "cross_entropy": ce_loss.mean().item(),
                },
                epoch * len(lsun_loader) + i
            )
            
        if i % 50 == 0:
            print(f"Epoch {epoch+1}, Batch {i}, Loss: {loss.item()}")

    scheduler.step()
    print(f"Epoch {epoch+1}/{epochs} completed.")
writer.close()

In [None]:
Path(f"../models/{experiment_name}").mkdir(parents=True, exist_ok=True)
torch.save(mapper.state_dict(), f"../models/{experiment_name}/model.pt")

In [None]:
writer.close()