In [18]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
from PIL import Image
import torchvision
import torchvision.transforms as transforms
import numpy as np

In [19]:
IMAGE_HEIGHT = 2160
IMAGE_WIDTH = 3840
COLORS = 3

In [20]:
model_name = "TheBloke/Mistral-7B-Instruct-v0.2-GPTQ"

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [21]:
encoded = tokenizer.encode("Hello world, how are you? Is this a lot?", return_tensors="pt")

In [22]:
encoded

tensor([[    1, 22557,  1526, 28725,   910,   460,   368, 28804,  1691,   456,
           264,  2055, 28804]])

In [23]:
total_image_size = IMAGE_HEIGHT * IMAGE_WIDTH * COLORS
difference = total_image_size - len(encoded.squeeze())
if difference < 0: raise Exception("NOT IMPLEMENTED")

left_half = (difference) // 2
right_half = left_half
if left_half != (difference - left_half):
  print(f"difference: {difference-left_half}, half: {left_half}")
  right_half = difference - left_half

normalized_encoding = encoded.to(torch.float64) / (255 * total_image_size)
padded = F.pad(normalized_encoding, [left_half, right_half])
reshaped_padded = padded.reshape((COLORS, IMAGE_HEIGHT, IMAGE_WIDTH))

difference: 12441594, half: 12441593


In [24]:
(normalized_encoding * (255 * total_image_size)).to(torch.int32)

tensor([[    1, 22557,  1526, 28724,   910,   459,   368, 28804,  1691,   456,
           264,  2055, 28804]], dtype=torch.int32)

In [25]:
tokenizer.decode((normalized_encoding * (255 * total_image_size)).to(torch.int32).squeeze().tolist())

'<s> Hello worldy how not you? Is this a lot?'

# Image Concatenation

In [26]:
transform = transforms.Compose([transforms.PILToTensor()])

image = transform(Image.open("BackgroundPicture1.jpg"))

In [27]:
image.shape

torch.Size([3, 2160, 3840])

In [28]:
(image + reshaped_padded).numpy().astype(np.uint8).shape

(3, 2160, 3840)

In [29]:
# result = Image.fromarray((image + reshaped_padded).numpy().astype(np.uint8).reshape(IMAGE_WIDTH, IMAGE_HEIGHT, COLORS))
# reshaped_padded[reshaped_padded != 0] = 200
result = transforms.ToPILImage()((image + reshaped_padded.to(torch.uint8)))
result.save("resultTest.jpg")

In [30]:
# TODO: Converting to uint8 causes all values to zero out. How do we avoid this?
reshaped_padded[reshaped_padded != 0].to(torch.uint8)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)

In [31]:
test = (transforms.PILToTensor()(result) - image)
test[test != 0]

tensor([], dtype=torch.uint8)