In [1]:
from pathlib import Path
from tqdm import tqdm
import polars as pl

In [2]:
CACHE_DIR = 'cache'
USE_CHAT_HISTORY = False
PRINT_QA = False
TUTORIALS = [1, 2, 3, 4]
MODEL_NAME = "dragonfly"

In [3]:
if MODEL_NAME == "dragonfly":
    from models.dragonfly import Dragonfly
    model = Dragonfly(use_history=USE_CHAT_HISTORY, cache_dir=CACHE_DIR)
elif MODEL_NAME == "cogagent":
    from models.cogagent import CogAgent
    model = CogAgent(use_history=USE_CHAT_HISTORY, cache_dir=CACHE_DIR)
else:
    from models.internvl import InternVL
    assert MODEL_NAME in InternVL.VARIANTS.keys(), f"Keys are {InternVL.VARIANTS.keys()}"
    model = InternVL(model_name=MODEL_NAME, use_history=USE_CHAT_HISTORY, cache_dir=CACHE_DIR)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Initialize Vision Encoder


In [4]:
final_foldername = 'history' if USE_CHAT_HISTORY else 'no_history'
res_folder = Path('results') / MODEL_NAME.lower() / final_foldername
if not res_folder.is_dir():
    res_folder.mkdir(parents=True)

initial_prompt = "You are a gamer. You are playing a game tutorial. I will provide you some screenshots of the tutorial. Answer the questions related to the screenshot. Be concise and direct."

In [5]:
def get_img_path(basepath: Path, frame: str) -> str:
    filepath = basepath / f'{frame}.png'
    if not filepath.is_file():
        filepath = basepath / f'{frame}:.png'
    assert filepath.is_file(), f'File {str(filepath)} not found!'

    return filepath

In [6]:
for revision in ['8832ec1f3d1e27aebefa9228dbbc57474edd94cb', 'last']: 
    postfix_export = '' if not USE_CHAT_HISTORY else '_with_history'
    filename = f'{MODEL_NAME.lower()}_{revision[:7]}{postfix_export}'
    BASEPATH = Path(f'data/frames/{revision}')

    original = pl.read_csv(BASEPATH / "frame_labels.csv").drop_nulls()
    original = (
        original.group_by("frame")
        .agg(pl.col("*"))
        .with_columns(
            pl.exclude("frame")
            .map_elements(
                lambda x: [f"{n}) {q}" for q, n in zip(x, range(1, len(x) + 1))],
                return_dtype=pl.List(pl.Utf8),
            )
            .list.join("\n")
        )
    )

    responses = []
    for tutorial in TUTORIALS:
        model.clean_history()
        filtered = original.filter(pl.col("frame").str.starts_with(str(tutorial))).sort(
            "frame"
        )
        for i, (frame_label, questions, expectation) in tqdm(
            enumerate(filtered.iter_rows()), total=filtered.height
        ):
            img_path = get_img_path(BASEPATH, frame_label)

            question = (
                f"{initial_prompt}\n{questions}"
                if i == 0 or not USE_CHAT_HISTORY
                else questions
            )
            resp = model.generate(img_path, question)

            if PRINT_QA:
                print("@" * 10)
                print(f"Frame: {frame_label}")
                print(f"Question: {questions}")
                print(f"Expectation: {expectation}")
                print(f"Response: {resp}")
                print("@" * 10)

            resp = {
                "frame": frame_label,
                "question": questions,
                "expectation": expectation,
                "reply": resp,
            }

            responses.append(resp)
            if not USE_CHAT_HISTORY:
                model.clean_history()

    pl.DataFrame(responses).write_csv(res_folder / f"{filename}.csv")

100%|██████████| 9/9 [00:22<00:00,  2.51s/it]
 83%|████████▎ | 5/6 [00:09<00:02,  2.40s/it]

100%|██████████| 6/6 [00:14<00:00,  2.45s/it]
100%|██████████| 6/6 [00:09<00:00,  1.53s/it]
100%|██████████| 7/7 [00:17<00:00,  2.51s/it]
100%|██████████| 10/10 [00:28<00:00,  2.84s/it]
100%|██████████| 8/8 [00:20<00:00,  2.53s/it]
100%|██████████| 5/5 [00:14<00:00,  2.90s/it]
100%|██████████| 11/11 [00:27<00:00,  2.53s/it]
