In [1]:
import os
import sys
import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.patches as mpatches
from mpl_toolkits.mplot3d import Axes3D
from PIL import Image
from tqdm.auto import tqdm
from einops import rearrange

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

sys.path.append('../..')
from fast_nystrom_attention import LlavaNextForConditionalGenerationFNA

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def seed_everything(seed: int = 42):
    """Set random seed for reproducibility."""
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(123)

In [3]:
MODEL_ID = "llava-hf/llava-v1.6-vicuna-7b-hf"
DTYPE = torch.float32
DEVICE = torch.device("cuda:0")
fna_config = {
    'fna_layers': range(12, 32),
    'num_sample': 256,
    'resample_fps': False,
}

processor = LlavaNextProcessor.from_pretrained(MODEL_ID, use_fast=False)
model = LlavaNextForConditionalGenerationFNA.from_pretrained(
    MODEL_ID, 
    fna_config=fna_config, 
    torch_dtype=DTYPE, 
    device_map=DEVICE
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.42it/s]


In [5]:
import time

image = Image.open('forest.jpg')

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Write a story about this image."},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, prompt, return_tensors="pt").to(DEVICE)

# autoregressively complete prompt
start_time = time.time()
output = model.generate(**inputs, max_new_tokens=500, do_sample=False)
end_time = time.time()

generation_time = end_time - start_time
num_generated_tokens = len(output[0]) - len(inputs.input_ids[0])
tokens_per_second = num_generated_tokens / generation_time

print(f"Generated {num_generated_tokens} tokens in {generation_time:.2f}s ({tokens_per_second:.2f} tokens/s)")

print(processor.decode(output[0], skip_special_tokens=True))

Generated 500 tokens in 14.27s (35.03 tokens/s)
USER: 
Write a story about this image. ASSISTANT: The image shows a serene scene of a forest path. The path is lined with trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the ground is covered in fallen leaves. The path is flanked by trees, and the g