In [1]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
import cv2 as cv

In [2]:
class LllmWrapper:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
        self.model =  LlavaNextForConditionalGeneration.from_pretrained(
                        "llava-hf/llava-v1.6-mistral-7b-hf",
                        torch_dtype=torch.float16,
                        low_cpu_mem_usage=True
                )
        self.model.to(self.device)

    def analyser(self, image):
        image = cv.imread(image)
        prompt = """[INST] <image>\n Analyse the image given and provide me ratings on the following actions detected in the image:
            People Count (Integer)
            Looking At ("at-system" | "left" | "right" | "up" | "down")
            Communication Device Present (Boolean)(Communication device refers to any other device like phone wired headphone etc which can be used for cheating). 
            The output should be in the given JSON Fromat only. 
            {
                "people_count": 5,
                "direction-looking": "at-system",
                "communication_device_present": false
            } [/INST]"""

        inputs = self.processor(prompt, image, return_tensors="pt").to(self.device)
        output = self.model.generate(**inputs, max_new_tokens=100)
        return self.processor.decode(output[0], skip_special_tokens=True)
                

In [3]:
wrapper = LllmWrapper()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
output = wrapper.analyser('/kaggle/input/llmwrapperdata/at-screen-4.jpg')

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


In [7]:
print(output.split('json')[1].split('```')[0])
{
    "people_count": 1,
    "direction-looking": "at-system",
    "communication_device_present": false
}


{
    "people_count": 1,
    "direction-looking": "at-system",
    "communication_device_present": false
}

