In [1]:
import os
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from tqdm.auto import tqdm
from einops import rearrange

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

sys.path.append('../..')
from fast_nystrom_attention import LlavaNextForConditionalGenerationFNA

In [2]:
def seed_everything(seed: int = 42):
    """Set random seed for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(123)

In [3]:
MODEL_ID = "llava-hf/llava-v1.6-vicuna-7b-hf"
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")
fna_config = {
    'fna_layers': range(16, 32),
    'num_sample': 512,
    'sampling_strategy': 'fps',
    'sampling_features': 'q',
    'resample_fps': False, 
}

processor = LlavaNextProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = LlavaNextForConditionalGenerationFNA.from_pretrained(
    MODEL_ID, 
    fna_config=fna_config, 
    torch_dtype=DTYPE, 
    device_map=DEVICE
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
import time

image = Image.open('forest.jpg')

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Write a long story about this image."},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, prompt, return_tensors="pt").to(DEVICE)

# autoregressively complete prompt
start_time = time.time()
output = model.generate(**inputs, max_new_tokens=500, do_sample=False)
end_time = time.time()

generation_time = end_time - start_time
num_generated_tokens = len(output[0]) - len(inputs.input_ids[0])
tokens_per_second = num_generated_tokens / generation_time

print(f"Generated {num_generated_tokens} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")

print(processor.decode(output[0], skip_special_tokens=True))

Generated 500 tokens in 13.65s (36.63 tokens/s)
USER: 
Write a long story about this image. ASSISTANT: In the heart of a mystical forest, where the trees stood tall and the air was thick with the scent of ancient secrets, there was a bridge that spanned the deepest part of the forest. It was a bridge that connected the known world to the unknown, a bridge that had stood for centuries, bearing witness to the passage of time and the whispers of the forest.

The bridge was not just a physical structure; it was a bridge of the mind, a bridge of the heart, a bridge of the spirit. It was a place where the wild creatures came to drink from the stream that flowed beneath it, and where the wise old trees whispered their secrets to those who dared to listen.

The forest was alive with stories, and the bridge was the heart of it all. It was a place where the spirits of the forest roamed, where the ghosts of the past haunted the present, and where the dreams of the future whispered in the wind.

O