In [1]:
import ollama 
import os
from tqdm import tqdm
import json
import signal
import argparse
import wandb
import pandas as pd

import sys

from  utilities import *

In [2]:
from PIL import Image
import matplotlib.pyplot as plt

In [3]:
# # login to wandb and set up a run
# wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mramytrm[0m ([33mvocalwell-vigir[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
# sys.argv = [
#     'notebook',  
#     '--modelname', 'llava:7b',
#     '--prompt', 'Is this an offensive meme? Please answer with YES or NO. DO NOT mention the reason:',
#     '--data', 'hateful_memes',
#     '--data_path','/home1/pupil/goowfd/CVPR_2025/hateful_memes/',
#     "--data_samples", 'train.jsonl',
#     "--subset", 'train',
#     '--results_dir', '/home1/pupil/goowfd/CVPR_2025/hateful_memes/results/baselineExp',
#     '--timeout', '20',
#     '--model_unloading'
# ]


# sys.argv = [
#     'notebook',  
#     '--modelname', 'llava:7b',
#     '--prompt', 'Classify the satellite image into one of the following categories: SeaLake, PermanentCrop, River, Residential, Pasture, Industrial, Highway, HerbaceousVegetation, Forest, or AnnualCrop. Provide only the class name as your answer.',
#     '--data', 'eurosat',
#     '--data_path','/home1/pupil/rmf3mc/eurosat/2750/',
#     "--data_samples", '/home1/pupil/rmf3mc/test_access/ViGIR_CVPR_LLM/data_split/split_zhou_EuroSAT.json',
#     "--subset", 'train',
#     '--results_dir', '/home1/pupil/rmf3mc/eurosat/results/baselineExp',
#     '--timeout', '20',
#     '--model_unloading'
# ]





In [5]:
parser = argparse.ArgumentParser(description="A script to run V-LLMs on different image classification datasets")

In [6]:
parser.add_argument("--modelname", type=str, required=True, help="The name of the V-LLM model")
parser.add_argument("--prompt", type=str, required=True, help="The prompt that you want to give to the V-LLM")
parser.add_argument("--data", type=str, required=True, help="Dataset name")
parser.add_argument("--data_path", type=str, required=True, help="Path to the image data dir")
parser.add_argument("--data_samples", type=str, required=True, help="Name of the samples to run on")
parser.add_argument("--subset", type=str, required=True, help="train, test or validation set")
parser.add_argument("--results_dir", type=str, required=True, help="Folder name to save results")
parser.add_argument("--timeout", type=int, default=40, help="time out duration to skip one sample")
parser.add_argument("--model_unloading", action="store_true", help="Enables unloading mode. Every 100 sampels it unloades the model from the GPU to avoid carshing.")


args = parser.parse_args()

In [7]:
print("Parsed arguments:")
for arg, value in vars(args).items():
    print(f"{arg}: {value}")

Parsed arguments:
modelname: llava:7b
prompt: Classify the satellite image into one of the following categories: SeaLake, PermanentCrop, River, Residential, Pasture, Industrial, Highway, HerbaceousVegetation, Forest, or AnnualCrop. Provide only the class name as your answer.
data: eurosat
data_path: /home1/pupil/rmf3mc/eurosat/2750/
data_samples: /home1/pupil/rmf3mc/test_access/ViGIR_CVPR_LLM/data_split/split_zhou_EuroSAT.json
subset: train
results_dir: /home1/pupil/rmf3mc/eurosat/results/baselineExp
timeout: 20
model_unloading: True


In [8]:
# Set up the run 
run = wandb.init(
    entity="ramytrm",
    project="CVPR-2025",
    name="run_test_" + args.data+"-"+args.modelname+"-"+args.subset
)

[34m[1mwandb[0m: Currently logged in as: [33mramytrm[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
if args.data == 'hateful_memes':
    data ={}
    data_samples_path=args.data_samples
    images_dir=os.path.join(args.data_path,'img')
    
    with open(data_samples_path, 'r') as file:
        for line in file: 
            record = json.loads(line)
            if record['label']==1:
                data[os.path.join(images_dir,f'{record["img"][4:]}')]={"label":record['label'],"class":'hateful_meme'}
            else:
                data[os.path.join(images_dir,f'{record["img"][4:]}')]={"label":record['label'],"class":'not_hateful_meme'}
                


            
elif args.data == 'eurosat':
    data ={}
    data_samples_path=args.data_samples
    print(data_samples_path)
    images_dir=os.path.join(args.data_path)
    print(images_dir)
    
    with open(data_samples_path, 'r') as file:
        raw_data = json.load(file)
    
    raw_data=raw_data[args.subset]
    for sample in raw_data:
        #print(sample)
        data[os.path.join(images_dir,sample[0])]={"label":sample[1],"class":sample[2]}


/home1/pupil/rmf3mc/test_access/ViGIR_CVPR_LLM/data_split/split_zhou_EuroSAT.json
/home1/pupil/rmf3mc/eurosat/2750/


In [10]:
model_name = args.modelname
results_dir=os.path.join(args.results_dir)
os.makedirs(results_dir, exist_ok=True)

In [11]:
results_file_name=os.path.join(results_dir,f"{args.data}-{model_name}-{args.subset}.json")
raw_image_info=os.path.join(results_dir,f"{args.data}-{model_name}-{args.subset}-raw_info.json")

In [12]:
ollama.pull(model_name)

timeout_duration = args.timeout

options= {  # new
            "seed": 123,
            "temperature": 0,
            "num_ctx": 2048, # must be set, otherwise slightly random output
        }

model_labels = {}
prompt = args.prompt
count = 0

In [13]:
for key,info in tqdm(data.items()):
    print(type(key))
    count = count + 1
    image_path = key

    #disp_img(image_path)
    
    signal.signal(signal.SIGALRM, timeout_handler)
    signal.alarm(timeout_duration)  

    try:
        if args.model_unloading and count % 99 == 0:
            response = ollama.generate(model=model_name, prompt=prompt, images=[image_path], options=options, keep_alive=0)
        else:
            response = ollama.generate(model=model_name, prompt=prompt, images=[image_path], options=options)
    
    except TimeoutException:
        print(f"Prompt for {image_path} took longer than {timeout_duration} seconds. Moving to the next one.")
        label = None
        
    finally:
        signal.alarm(0)  

    model_labels[image_path] = response['response']

    
    

  0%|                                          | 0/13500 [00:00<?, ?it/s]

<class 'str'>


  0%|                                | 1/13500 [00:01<6:18:51,  1.68s/it]

<class 'str'>


  0%|                                | 2/13500 [00:02<3:38:18,  1.03it/s]

<class 'str'>


  0%|                                | 3/13500 [00:02<2:36:03,  1.44it/s]

<class 'str'>


  0%|                                | 4/13500 [00:02<2:13:28,  1.69it/s]

<class 'str'>


  0%|                                | 5/13500 [00:03<2:00:59,  1.86it/s]

<class 'str'>


  0%|                                | 6/13500 [00:03<1:47:20,  2.10it/s]

<class 'str'>


  0%|                                | 7/13500 [00:04<1:47:52,  2.08it/s]

<class 'str'>


  0%|                                | 8/13500 [00:04<1:39:06,  2.27it/s]

<class 'str'>


  0%|                                | 9/13500 [00:05<1:38:40,  2.28it/s]

<class 'str'>


  0%|                               | 10/13500 [00:05<1:39:17,  2.26it/s]

<class 'str'>


  0%|                               | 11/13500 [00:05<1:33:01,  2.42it/s]

<class 'str'>


  0%|                               | 12/13500 [00:06<2:00:29,  1.87it/s]


<class 'str'>


KeyboardInterrupt: 

In [15]:
with open(results_file_name, 'w') as fp:
    json.dump(model_labels, fp, indent=4)
    
with open(raw_image_info, 'w') as fp:
    json.dump(data, fp, indent=4)

In [20]:
#model_labels_df = pd.DataFrame(model_labels, index=[0])
#data_df = pd.DataFrame(data, index=[0])

model_labels_df = pd.DataFrame(model_labels.items(), columns=["File Path", "Response"])
model_labels_df["Image Name"] = model_labels_df["File Path"].apply(lambda x: x.split('/')[-1])
model_labels_wandb = wandb.Table(data=model_labels_df)

data_df = pd.DataFrame.from_dict(data, orient="index").reset_index()
data_df.columns = ["File Path", "Label", "Class"]
data_df_wandb = wandb.Table(data=data_df)

In [21]:
# Log JSON data as a table or dictionary to WandB
#model_labels_df = pd.DataFrame(model_labels, index=[0])
#data_df = pd.DataFrame(data_df, index=[0])
wandb.log({"results": model_labels_wandb})
wandb.log({"raw_images": data_df_wandb})

# Optionally, if you want to save the JSON as an artifact
artifact = wandb.Artifact("json_file", type="dataset")
artifact.add_file(results_file_name)
wandb.log_artifact(artifact)

artifact = wandb.Artifact("json_file", type="dataset")
artifact.add_file(raw_image_info)
wandb.log_artifact(artifact)

<Artifact json_file>