## Download VQVAE from DALLE
| testing usage
```python
enc = encoder
dec = decoder
```

In [None]:
import io
import os, sys
import requests
import PIL

import torch
import torch.nn.functional as F
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from dall_e import map_pixels, unmap_pixels, load_model
from IPython.display import display, display_markdown

In [None]:
target_image_size = 256

def download_image(url):
    resp = requests.get(url)
    resp.raise_for_status()
    return PIL.Image.open(io.BytesIO(resp.content))

def preprocess(img):
    s = min(img.size)
    
    if s < target_image_size:
        raise ValueError(f'min dim for image {s} < {target_image_size}')
        
    r = target_image_size / s
    s = (round(r * img.size[1]), round(r * img.size[0]))
    img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
    img = TF.center_crop(img, output_size=2 * [target_image_size])
    img = torch.unsqueeze(T.ToTensor()(img), 0)
    return map_pixels(img)

In [None]:
device = torch.device('cuda:0')

enc = load_model("https://cdn.openai.com/dall-e/encoder.pkl", device)
# dec = load_model("https://cdn.openai.com/dall-e/decoder.pkl", device)

In [None]:
x = preprocess(download_image('https://assets.bwbx.io/images/users/iqjWHBFdfxIU/iKIWgaiJUtss/v2/1000x-1.jpg'))
display_markdown('Original image:')
display(T.ToPILImage(mode='RGB')(x[0]))

In [None]:
x.shape

In [None]:
imageCodebook_len = enc.vocab_size
imageCodebook_len

In [None]:
x = x.to(device)

z_logits = enc(x)
z = torch.argmax(z_logits, axis=1)

z_ = F.one_hot(z, num_classes=imageCodebook_len).permute(0, 3, 1, 2).float()

x_stats = dec(z_).float()
x_rec = unmap_pixels(torch.sigmoid(x_stats[:, :3]))
x_rec = T.ToPILImage(mode='RGB')(x_rec[0])

display_markdown('Reconstructed image:')
display(x_rec)

## LLM model

In [None]:
import transformers

from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer

In [None]:
checkpoint = "gpt2-large"
# checkpoint = "princeton-nlp/Sheared-LLaMA-1.3B"

In [None]:
llm_model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

In [None]:
llm_tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# llm_tokenizer = GPT2Tokenizer.from_pretrained(checkpoint)
textVocab_len = len(llm_tokenizer)
textVocab_len

## Image Tokens as Text

In [None]:
# shape of image_token = (1, 32, 32)
image_token = z

In [None]:
# shape of image_token = (32, 32)
image_token = image_token.squeeze()

In [None]:
# shape of image_token = (1024)
image_token = image_token.view(-1)

In [None]:
image_token.shape

In [None]:
image_token

In [None]:
print(llm_tokenizer.decode(image_token.tolist()))

In [None]:
tokens = llm_tokenizer('hi', return_tensors="pt")

In [None]:
tokens

In [None]:
outs = llm_model.generate(
    tokens['input_ids'].to(dev),
    max_length=20
)

In [None]:
outs

In [None]:
outs1 = llm_tokenizer.convert_ids_to_tokens(outs[0])

In [None]:
print(llm_tokenizer.convert_tokens_to_string(outs1))

In [None]:
def convert_id_to_string(ids, tokenizer=llm_tokenizer):
    out = tokenizer.convert_ids_to_tokens(ids)
    out = tokenizer.convert_tokens_to_string(out)
    return out

## Changing Image tokens to Text tokens

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd

In [None]:
class TokenMapper(nn.Module):
    def __init__(self, input_dim, output_dim, device="cpu"):
        super().__init__()
        self.mapper = nn.Linear(input_dim, output_dim)
        self.mapper.to(device)

    def forward(self, one_hot_token):
        return self.mapper(one_hot_token)

## Text LLM

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

gpt2_embeddings = gpt2_model.get_input_embeddings().weight

gpt2Codebook_len = gpt2_model.config.n_embd
gpt2_model.to(device)

In [None]:
gpt2Codebook_len

In [None]:
gpt2_embeddings.shape

## Straight Through Gradient

In [None]:
import pickle

def find_closest_tokens(mapped_vectors, text_token_embeddings):
    distances = F.cosine_similarity(mapped_vectors.unsqueeze(1), text_token_embeddings.unsqueeze(0), dim=2)
    closest_token_indices = torch.argmax(distances, dim=1)
    return closest_token_indices


# Constants
image_token_dim = imageCodebook_len  # DALL-E image token dimension
text_token_dim = gpt2Codebook_len   # Example dimension of text token (like GPT-2)

# Create the mapper
mapper = TokenMapper(image_token_dim, text_token_dim, device=device)

# Example usage
def process_image_with_dalle_encoder(image):
    z_logits = enc(image)
    z = torch.argmax(z_logits, axis=1)
    z_ = F.one_hot(z, num_classes=image_token_dim).permute(0, 3, 1, 2).float()
    return z_



## Generate Ground Truth

In [None]:
def find_closest_indices(mapped_feature_vector):
    # mapped_feature_vector has shape (64, 256, 768)
    # gpt2_embeddings has shape (50257, 768
    # Reshape a to (-1, 768) to treat all vectors in a individually
    mapped_fv_reshaped = mapped_feature_vector.view(-1, mapped_feature_vector.shape[-1])

    # Compute cosine similarity for each vector in a against all vectors in dict
    distances = F.cosine_similarity(mapped_fv_reshaped.unsqueeze(1), gpt2_embeddings.unsqueeze(0), dim=2)

    # Find the index of the maximum similarity for each vector
    closest_indices = torch.argmax(distances, dim=1)

    # Reshape to get back to the original batch and sequence dimension: (64, 256)
    closest_indices_reshaped = closest_indices.view(mapped_feature_vector.shape[0], mapped_feature_vector.shape[1])
    
    return closest_indices_reshaped


In [None]:
def generate_next_token_predictions(token_sequences):
    # Initialize container for predictions
    predictions = torch.zeros(token_sequences.size(), dtype=torch.long)

    for i in range(token_sequences.size(1)):  # Iterate over sequence length (256)
        # Input tokens up to the current step
        input_tokens = token_sequences[:, :i+1]

        # Get model predictions
        with torch.no_grad():
            outputs = gpt2_model(input_ids=input_tokens)
            logits = outputs.logits

        # Get the predicted next token (at the current step)
        next_token = torch.argmax(logits[:, -1, :], dim=-1)
        predictions[:, i] = next_token

    return predictions

In [None]:
def get_gpt2_ground_truth(mapped_feature_vector):
    
    closest_indices = find_closest_indices(mapped_feature_vector)
    ground_truth = generate_next_token_predictions(closest_indices)
    
    return ground_truth


## Get Image Dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torchvision
from torchvision import transforms, datasets

from torch.utils.data import DataLoader

In [None]:
transform = transforms.Compose([
    transforms.Resize((128, 128)),  # Resize to a fixed size; adjust as needed
    transforms.ToTensor(),          # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize (mean, std) for each color channel
])

In [None]:
# Replace 'path/to/lsun' with the actual path to your LSUN dataset
dataset_path = './data'

lsun_dataset = datasets.LSUN(root=dataset_path, classes=['classroom_train'], transform=transform)

In [None]:
batch_size = 64  # Adjust based on your memory availability and requirements
lsun_loader = DataLoader(lsun_dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()

## Train Model

In [None]:
optimizer = optim.Adam(mapper.parameters(), lr=1e-5)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

epochs = 2

# def train_on_lsun(dataloader, epochs=2):
for epoch in range(epochs):
    for i, (images, _) in enumerate(lsun_loader):
        optimizer.zero_grad()

        # Process each image through DALL-E encoder to get image tokens
        one_hot_image_tokens = process_image_with_dalle_encoder(images.to(device))

        # Flatten the tokens for processing with the mapper
        flattened_tokens = one_hot_image_tokens.reshape(one_hot_image_tokens.size(0), -1, image_token_dim)

        # Initialize container for ground truth tokens
        ground_truth_tokens = torch.tensor([], dtype=torch.long, device=device)

        # Map tokens and get ground truth from GPT-2
        mapped_feature_vector = mapper(token)
        ground_truth_token = get_gpt2_ground_truth(mapped_feature_vector).to(device)

        # Calculate loss (e.g., CrossEntropyLoss)
        # Note: Adjust the loss function as per your requirement and data format
        loss = F.cross_entropy(mapped_feature_vector, ground_truth_tokens)

        # Backward pass and update
        loss.backward()
        optimizer.step()

        if i % 10 == 0:  # Print loss every 10 batches
            print(f"Epoch {epoch+1}, Batch {i+1}, Loss: {loss.item()}")

    print(f"Epoch {epoch+1}/{epochs} completed.")


In [None]:
flattened_tokens.shape

In [None]:
token.shape

In [None]:
mapped_feature_vector.shape

In [None]:
ground_truth_token.shape

In [None]:
gpt2_embeddings.shape

In [None]:
temp = F.cosine_similarity(mapped_feature_vector.unsqueeze(1).cpu(), gpt2_embeddings.unsqueeze(0).cpu(), dim=2)

In [None]:
temp = torch.argmax(temp, dim=1)
temp

In [None]:
temp.shape

In [None]:
images[0].shape

In [None]:
imshow(images[4].cpu())

In [None]:
train_on_lsun(lsun_loader, epochs=2)

In [None]:
def init_weights(m):
    if isinstance(m, torch.nn.Linear):  # For fully connected layers
        torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

In [None]:
class Image2Text(torch.nn.Module):

    def __init__(self, image_encoder, shift, llm, llm_tokenizer, imageToken_size, imageVocab_size=8192, textVocab_size=50257, device="cpu"):
        super(Image2Text, self).__init__()

        self.image_encoder = image_encoder.to(device)
        self.shift = shift.to(device)
        self.norm = nn.BatchNorm1d(imageVocab_size).to(device)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        self.llm = llm.to(device)
        self.llm_tokenizer = llm_tokenizer
        self.imageVocab_size = imageVocab_size
        self.textVocab_size = textVocab_size
        self.device = device
        
        for params in self.image_encoder.parameters():
            params.requires_grad = False
            
        for params in self.llm.parameters():
            params.requires_grad = False

        self.shift.apply(init_weights)
    
    def forward(self, x):
        x = self.image_encoder(x)  
        x = x.permute(0,2,3,1)
        x = x.reshape(x.shape[0], -1, x.shape[-1])
        
        x = x.permute(0, 2, 1)  # Change dims to (N, F, L)
        x = self.norm(x)
        x = x.permute(0, 2, 1)  # Change dims back to (N, L, F)
        
        x = self.shift(x)
        x = x.permute(0,2,1)

        return x

    def getLabel(self, x):
        logits = self.image_encoder(x)
        img_tokens = torch.argmax(logits, axis=1)
        img_tokens = img_tokens.reshape(img_tokens.shape[0], -1)

        gpt2_out = self.llm(img_tokens)
        gpt2_logits = gpt2_out.logits
        gpt2_ids = torch.argmax(gpt2_logits, axis=2)

        return gpt2_ids

In [None]:
shift = Shift(imageCodebook_len, textVocab_len)
model = Image2Text(enc, shift, llm_model, llm_tokenizer, imageToken_size, imageCodebook_len, textVocab_len, device=device)
num_epochs = 2  # or whatever number you choose

optimizer = optim.Adam(model.parameters(), lr=5e-7)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

model.train()

for epoch in range(num_epochs):
    for i, (inputs, _) in enumerate(trainloader):  # Assuming you've named your DataLoader 'dataloader'
        # Move data to GPU if available
        inputs = inputs.to(device)
        labels = model.getLabel(inputs)
        
        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        with torch.autograd.detect_anomaly():
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
           

        # for name, param in model.named_parameters():
        #     if param.requires_grad:
        #         print(f"{name} gradients:")
        #         print(param.grad)

        # input()

        # Print loss every epoch (optional)
        if i%5==0:
            print(f"Epoch [{epoch+1}/{num_epochs} Step:{i}], Loss: {loss.item():.4f}")

        # input()
        
    scheduler.step()
    break


In [None]:
imshow(inputs.cpu())

In [None]:
loss

In [None]:
outputs.shape

In [None]:
x = model.image_encoder(inputs.cuda())

In [None]:
x.shape

In [None]:
x = x.permute(0,2,3,1)
x = x.reshape(x.shape[0], -1, x.shape[-1])

In [None]:
x.shape

In [None]:
n = model.norm(x)

In [None]:
y1 = model.shift.linear1(model.norm(x))

In [None]:
     for name, param in model.named_parameters():
            if param.requires_grad:
                print(f"{name} gradients:")
                print(param)

In [None]:
# Set the model to evaluation mode
model.eval()

# Metrics
test_loss = 0.0
matrix_loss = []

# Disable gradient computation
with torch.no_grad():
    for i, (inputs, _) in enumerate(testloader):
        # Move data to the same device as the model
        inputs = inputs.to(device)
        labels = model.getLabel(inputs)

        # Forward pass
        outputs = model(inputs)

        # Compute loss
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        matrix_loss.append(loss.item())
        
        # if i%50==0:
        #     print(f"Step:{i}, Loss: {loss.item():.4f}")


# Compute average test loss and accuracy
avg_test_loss = test_loss / len(testloader)
print(f'average test loss: {avg_test_loss}')

In [None]:
def moving_average(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

window_size = 5
smoothed_values = moving_average(matrix_loss, window_size)

# Plotting
plt.figure(figsize=(12,6))
plt.plot(matrix_loss, label="Original Loss")
plt.plot(np.arange(window_size-1, len(matrix_loss)), smoothed_values, label=f"Smoothed Loss (window size={window_size})", linewidth=2)
plt.legend()
plt.show()


## What does the image say

In [None]:
dataiter = iter(testloader)
image, label = next(dataiter)

imshow(image)

In [None]:
logits = model(image.cuda())

In [None]:
logits.shape

In [None]:
img_tokens = torch.argmax(logits, axis=1)

In [None]:
for i in range(img_tokens.shape[0]):
    print(convert_id_to_string(img_tokens[i], tokenizer=llm_tokenizer))