###Installing requirements

In [None]:
!pip install --upgrade transformers
!pip install --upgrade accelerate
!pip install --upgrade bitsandbytes
!pip install --upgrade requests
!pip install --upgrade pillow
!pip install --upgrade matplotlib
!pip install torch==2.6.0
!pip install torchvision
!pip install fastai

###Model loading

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [None]:
model_name = "./finetuned_model"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True
).eval()

###Initialization of functions for preprocessing

In [None]:
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode
from PIL import Image

In [None]:
# ImageNet Normalization
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

In [None]:
def build_transform(input_size=448):
    '''Pipeline for transformation'''
    return T.Compose([
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ])

def preprocess_image(image: Image.Image, image_size=448):
    '''Image transformation'''
    image = image.convert('RGB')
    transform = build_transform(image_size)
    pixel_values = transform(image).unsqueeze(0)
    return pixel_values

###Inference

In [None]:
import requests
import matplotlib.pyplot as plt

In [None]:
image_url = "https://www.ilankelman.org/stopsigns/australia.jpg"
generation_config = dict(max_new_tokens=128, do_sample=False)

# Image loading
image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')

# Image display
plt.imshow(image)
plt.axis('off')
plt.show()

# Preprocessing
pixel_values = preprocess_image(image).to(dtype=torch.bfloat16, device=model.device)

# Text input to model
additional_text = "urban life in Beijing"

# Prompt
question = (
    "<image>\n"
    "Please describe the visual content of the image, "
    "and provide a short paragraph that connects the image to "
    f"the following text concept: '{additional_text}'."
)

# Model inference
response = model.chat(tokenizer, pixel_values, question, generation_config)

print("Question:\n" + question.replace('<image>', '[IMAGE]'))
print("Answer:\n" + response)