In [3]:
import os
import json 

os.environ["HF_HOME"] = "/scratch/mmm9912/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/scratch/mmm9912/hf_cache/datasets"
os.environ["XDG_CACHE_HOME"] = "/scratch/mmm9912/hf_cache"

# Verify that the cache paths are set
for var in ["HF_HOME", "HF_DATASETS_CACHE", "XDG_CACHE_HOME"]:
    print(f"{var}: {os.environ.get(var)}")

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

HF_HOME: /scratch/mmm9912/hf_cache
HF_DATASETS_CACHE: /scratch/mmm9912/hf_cache/datasets
XDG_CACHE_HOME: /scratch/mmm9912/hf_cache


In [4]:
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    load_in_8bit=True,          # Enable 8-bit quantization for inference
    device_map="sequential",    # Load layers sequentially
    offload_folder="/scratch/mmm9912/qwen25vl_offload"  # Offload weights to CPU as needed
)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 5/5 [00:43<00:00,  8.73s/it]


In [7]:
# Also lower the expected visual resolution
min_pixels = 128 * 28 * 28   # Drastically reduced from defaults
max_pixels = 256 * 28 * 28

processor = AutoProcessor.from_pretrained(
    "Qwen/Qwen2.5-VL-7B-Instruct",
    min_pixels=min_pixels,
    max_pixels=max_pixels
)


In [84]:
def answer(model, processor, inputs):
    with torch.no_grad():
        generated_ids = model.generate(
            **inputs,
            max_new_tokens=400,  # Reduced from 1000
            temperature=0.1,      # Added for determinism
            top_p=0.9,
            do_sample=False,      # Use greedy decoding
            num_return_sequences=1
        )
        generated_ids_trimmed = [
            out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
        ]
        
        # Decode the generated tokens into text
        output_text = processor.batch_decode(
            generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )

        return output_text[0]


In [110]:
def format_conversation(db_messages):
    """
    Given a list of messages from the database, format them into the proper conversation structure.
    
    Each message is represented as a dict with a role and a list of content blocks.
    The conversation MUST begin with a system message that instructs the assistant to output
    a single valid JSON object with exactly these three keys:
      - "direct_answer_to_frontend": a string that will be sent to the user.
      - "function_call_to_video_processor": an object (with "name" and "arguments") to trigger frame extraction, or null.
      - "function_call_to_extracted_frames_processor": an object (with "name" and "arguments") to trigger expert inference, or null.
    
    Allowed expert inference options:
      - For frequency inconsistencies (SPSL/frequency), use model_name "spsl".
      - For spatial inconsistencies (UCF/spatial), use model_name "ucf".
      - For general inconsistencies (naïve/xception), use model_name "xception".
    
    IMPORTANT:
      - NEVER use your own reasoning to produce a final answer unless no expert analysis is required.
      - Use the video solely for contextual description.
      - If the user's request requires expert analysis (e.g. "What does SPSL say about this video?"), 
        you MUST output a JSON function call (with the appropriate inference option) instead of a final answer.
      - Your output MUST be ONLY a valid JSON object (with no extra text, markdown formatting, code fences, or commentary).
    
    EXACT EXAMPLES (output exactly as shown):
    
    *For frame extraction:*
    {
      "direct_answer_to_frontend": "",
      "function_call_to_video_processor": {
         "name": "extract_frames_from_video",
         "arguments": {
            "video_path": "<path_to_video>",
            "num_frames": 4
         }
      },
      "function_call_to_extracted_frames_processor": null
    }
    
    *For expert inference (e.g. SPSL/frequency):*
    {
      "direct_answer_to_frontend": "",
      "function_call_to_video_processor": null,
      "function_call_to_extracted_frames_processor": {
         "name": "run_inference_on_images_with_old_preprocess",
         "arguments": {
            "model_name": "spsl",
            "image_paths": ["<path1>", "<path2>", "..."],
            "cuda": true,
            "manual_seed": 42
         }
      }
    }
    
    *For a final answer (when no expert analysis is required):*
    {
      "direct_answer_to_frontend": "<your final answer>",
      "function_call_to_video_processor": null,
      "function_call_to_extracted_frames_processor": null
    }
    
    Follow these instructions exactly.
    """
    
    system_message = {
        "role": "system",
        "content": [{
            "type": "text",
            "text": (
                "You MUST respond with EXACTLY one JSON object containing: "
                "'direct_answer_to_frontend', 'function_call_to_video_processor', "
                "and 'function_call_to_extracted_frames_processor'. "
                "NEVER include markdown, additional text, or explanations. "
                "Follow these examples rigidly:\n\n"
                "User asks about models: "
                '{"direct_answer_to_frontend": "", '
                '"function_call_to_video_processor": null, '
                '"function_call_to_extracted_frames_processor": {'
                '"name": "run_inference...", "arguments": {"model_name": "spsl"}}}\n\n'
                "User says 'Hello': "
                '{"direct_answer_to_frontend": "Hello! How can I assist?", '
                '"function_call...": null}'
            )
        }]
    }
    
    formatted = [system_message]
    sorted_messages = sorted(db_messages, key=lambda msg: msg["id"])
    
    for msg in sorted_messages:
        entry = {"role": msg.get("role", "user"), "content": []}
        if msg.get("media_type") in ["image", "video"] and msg.get("media_url"):
            entry["content"].append({
                "type": msg["media_type"],
                "image" if msg["media_type"] == "image" else "video": f"file://{msg['media_url']}"
            })
        text = msg.get("content", "").strip()
        if text:
            entry["content"].append({"type": "text", "text": text})
        formatted.append(entry)
    
    return formatted


In [111]:
import cv2
import uuid

def extract_frames_from_video(video_path, num_frames=4, output_dir="/scratch/mmm9912/Capstone/FRONT_END_STORAGE/images/"):
    """
    Extracts up to `num_frames` equally spaced frames from the provided video.
    If the video has fewer than num_frames, extracts all available frames.
    Saves each extracted frame as a PNG with a random unique filename in output_dir.
    
    Parameters:
        video_path (str): Full path to the video file.
        num_frames (int): Desired number of frames to extract.
        output_dir (str): Directory in which to save the frames.
    
    Returns:
        list: A list of file paths for the extracted frame images.
    """
    os.makedirs(output_dir, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise Exception("Cannot open video: " + video_path)
        
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    # Determine indices: if fewer frames than desired, use them all.
    if total_frames <= num_frames:
        indices = list(range(total_frames))
    else:
        indices = [int(i * total_frames / num_frames) for i in range(num_frames)]
    
    extracted_paths = []
    current_frame = 0
    ret = True
    while ret:
        ret, frame = cap.read()
        if not ret:
            break
        if current_frame in indices:
            filename = uuid.uuid4().hex + ".png"
            file_path = os.path.join(output_dir, filename)
            cv2.imwrite(file_path, frame)
            extracted_paths.append(file_path)
        current_frame += 1
    cap.release()
    return extracted_paths


In [112]:
def response_formatter(inference_results):
    """
    Given a list of inference_results (each a tuple of (overlay_path, confidence, prediction_message, red_percentage)),
    produce a list of message dictionaries. For each result, a message containing text and the overlay image is produced,
    followed by a final message that summarizes the overall analysis.
    
    Returns:
        list: A list of message dictionaries.
    """
    messages = []
    # For each inference result (assumed to be one per extracted frame)
    for overlay_path, confidence, prediction_message, red_percentage in inference_results:
        text_message = (
            f"Frame analysis: {prediction_message} "
            f"(Confidence: {confidence:.2f}, Red activation: {red_percentage:.2f}%)."
        )
        messages.append({
            "role": "assistant",
            "content": [
                {"type": "text", "text": text_message},
                {"type": "image", "image": f"file://{overlay_path}"}
            ]
        })
    
    # Final conclusive message
    final_conclusion = (
        "Based on the analyses of the extracted frames, the expert model indicates a high likelihood "
        "of deepfake manipulation. Please consider further forensic evaluation."
    )
    messages.append({
        "role": "assistant",
        "content": [
            {"type": "text", "text": final_conclusion}
        ]
    })
    return messages


In [113]:
def run_inference_on_images_with_old_preprocess(model_name, image_paths, cuda, manual_seed):
    """
    Simulated function call to run inference on images with old preprocess.
    In production, this function would run the deepfake detection model.
    
    Parameters:
        model_name (str): One of "spsl", "ucf", or "xception".
        image_paths (list): List of image paths.
        cuda (bool): Whether to use CUDA.
        manual_seed (int): Seed for reproducibility.
    
    Returns:
        list: A list of 4 tuples, each tuple containing:
            - overlay_path (str): Path to the Grad-CAM overlay image.
            - confidence (float): Softmax probability that the image is forged.
            - prediction_message (str): Verdict message from the model.
            - red_percentage (float): Percentage of red pixels in the Grad-CAM heatmap.
    """
    overlay_path = "/scratch/mmm9912/Capstone/FRONT_END_STORAGE/images/ca4227e5f59643179b25ba59c0483b9b.png"
    confidence = 0.75
    prediction_message = f"{model_name.upper()} model detected forgery."
    red_percentage = 10.0
    # Duplicate the dummy result 4 times.
    return [(overlay_path, confidence, prediction_message, red_percentage) for _ in range(4)]


In [114]:
def parse_json_response(response_text):
    # Attempt strict extraction of a JSON object from the response text.
    json_match = re.search(r"(\{.*\})", response_text, re.DOTALL)
    if not json_match:
        raise ValueError("No JSON found in response")
    
    json_str = json_match.group(1)
    try:
        parsed = json.loads(json_str)
        # Validate that all required keys are present.
        assert all(k in parsed for k in [
            "direct_answer_to_frontend",
            "function_call_to_video_processor",
            "function_call_to_extracted_frames_processor"
        ])
        return parsed
    except Exception as e:
        raise ValueError(f"Invalid JSON structure: {str(e)}") from e


# When an invalid JSON is detected, use a retry prompt:
def retry_with_correction(model, processor, conversation_messages):
    retry_prompt = (
        "Your previous response was invalid. You MUST respond with ONLY JSON matching "
        "the required format. Example valid response:\n"
        '{"direct_answer_to_frontend": "", '
        '"function_call_to_video_processor": null, '
        '"function_call_to_extracted_frames_processor": null}'
    )
    correction_message = {
        "role": "system",
        "content": [{
            "type": "text",
            "text": retry_prompt
        }]
    }
    # Append the correction prompt to the conversation history.
    updated_messages = conversation_messages + [correction_message]
    text = processor.apply_chat_template(updated_messages, tokenize=False, add_generation_prompt=True)
    inputs = processor(
        text=[text],
        return_tensors="pt",
        padding=True
    ).to(model.device)
    return answer(model, processor, inputs)


###########################################
# 4. Enhanced Keyword Detection            #
###########################################
def enforce_inference_call(last_user_text, response_json):
    # Define expert keyword groups mapping to inference models.
    expert_keywords = {
        "spsl": ["spsl", "frequency", "spectral"],
        "ucf": ["ucf", "spatial", "artifact"],
        "xception": ["xception", "naïve", "general"]
    }
    frames_processor_call = response_json.get("function_call_to_extracted_frames_processor")
    # If expert keywords are present but no inference function call exists, force one.
    for model_name, keywords in expert_keywords.items():
        if any(kw in last_user_text.lower() for kw in keywords) and frames_processor_call is None:
            return {
                "direct_answer_to_frontend": "",
                "function_call_to_video_processor": None,
                "function_call_to_extracted_frames_processor": {
                    "name": "run_inference_on_images_with_old_preprocess",
                    "arguments": {"model_name": model_name}
                }
            }
    return response_json


###########################################
# 5. Conversation History Cleaning         #
###########################################
def clean_conversation_history(messages):
    """Remove error messages and non-essential system messages."""
    return [msg for msg in messages if not (
        msg.get("role") == "system" and 
        "SAMPLE RESPONSE FORMAT:" in msg.get("content", [{}])[0].get("text", "")
    )]


In [115]:
import os
import sys
import time
import pprint
import json
from supabase_wrapper import get_all_messages, create_conversation, insert_message

node_dir = "/scratch/mmm9912/condaDONOTDELETE/qwen25vl/bin"
os.environ["PATH"] = node_dir + os.pathsep + os.environ["PATH"]
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# If __file__ is not defined (e.g. in a notebook), use the current working directory.
try:
    base_dir = os.path.dirname(__file__)
except NameError:
    base_dir = os.getcwd()


In [116]:
import re

# Get the initial list of messages and determine the current highest message id.
initial_data = get_all_messages()
messages = initial_data.get("messages", [])

highest_id = max((message["id"] for message in messages), default=0)  # Ensure it starts from 0 if no messages exist

print("Initial highest message id:", highest_id)


Initial highest message id: 250


In [None]:
while True:
    data = get_all_messages()
    messages = data.get("messages", [])
    
    if not messages:
        print("No messages found in the database. Waiting...")
        time.sleep(1)
        continue

    current_highest = max(message["id"] for message in messages)
    if current_highest == highest_id:
        time.sleep(1)
        continue

    new_message = next((msg for msg in messages if msg["id"] == current_highest), None)
    if new_message is None:
        print("Error: Highest id message not found. Waiting...")
        time.sleep(1)
        continue

    # Only trigger a cycle if the new message is from a user.
    if new_message.get("role", "").lower() != "user":
        print("New message is not from a user. Skipping cycle.")
        highest_id = current_highest
        time.sleep(1)
        continue

    conversation_id = new_message["conversation_id"]
    conversation_messages = [msg for msg in messages if msg["conversation_id"] == conversation_id]
    conversation_messages_sorted = clean_conversation_history(
        sorted(conversation_messages, key=lambda msg: msg["id"])
    )
    print(f"New message detected in conversation {conversation_id}.")

    # Format the conversation (this prepends the strict system prompt).
    formatted_messages = format_conversation(conversation_messages_sorted)
    
    # Inject a sample response format message immediately after the user's last input.
    if conversation_messages_sorted[-1].get("role", "").lower() == "user":
        sample_response_message = {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": (
                        "SAMPLE RESPONSE FORMAT:\n"
                        "If you receive a video, and ONLY if you receive a video, first perform frame extraction:\n"
                        '{"direct_answer_to_frontend": "", "function_call_to_video_processor": {"name": "extract_frames_from_video", "arguments": {"video_path": "<path_to_video>", "num_frames": 4}}, "function_call_to_extracted_frames_processor": null}\n'
                        "Then send for expert inference (e.g. SPSL/frequency):\n"
                        '{"direct_answer_to_frontend": "", "function_call_to_video_processor": null, "function_call_to_extracted_frames_processor": {"name": "run_inference_on_images_with_old_preprocess", "arguments": {"model_name": "spsl", "image_paths": ["<path1>", "<path2>", "..."], "cuda": true, "manual_seed": 42}}}\n'
                        "For a final answer or remaining conversational. ALWAYS remain conversational when your response is directed at the user:\n"
                        '{"direct_answer_to_frontend": "<your final answer>", "function_call_to_video_processor": null, "function_call_to_extracted_frames_processor": null}\n'
                        "Follow these instructions exactly, while always writing grammatically correct sentences when your response is directed to the user. The user can never see the internal workings of your processes. Now tackle the last user query within the JSON format."
                    )
                }
            ]
        }
        formatted_messages.append(sample_response_message)
        
    print("Formatted messages:", formatted_messages)
    
    # Build the conversation prompt.
    text = processor.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)
    
    image_inputs, video_inputs = process_vision_info(formatted_messages)
    
    inputs = processor(
        text=[text],
        images=image_inputs if image_inputs else None,
        videos=video_inputs if video_inputs else None,
        return_tensors="pt",
        padding=True
    ).to(model.device)
    
    raw_response_text = answer(model, processor, inputs)
    print("Raw VLM response:")
    print(raw_response_text)
    
    # Attempt to parse the JSON response, with a retry on error.
    try:
        response_json = parse_json_response(raw_response_text)
    except Exception as e:
        print("Error parsing JSON from VLM response:", e)
        # Retry using the correction prompt.
        raw_response_text = retry_with_correction(model, processor, formatted_messages)
        try:
            response_json = parse_json_response(raw_response_text)
        except Exception as inner_e:
            print("Retry failed:", inner_e)
            insert_message(
                conversation_id=conversation_id,
                role="assistant",
                content="An error occurred while processing the response.",
                media_url=None,
                media_type=None
            )
            highest_id = current_highest
            time.sleep(1)
            continue
    
    direct_answer = response_json.get("direct_answer_to_frontend", "")
    video_processor_call = response_json.get("function_call_to_video_processor")
    frames_processor_call = response_json.get("function_call_to_extracted_frames_processor")
    
    # Safely extract the latest user text for keyword checking.
    last_user_text = ""
    for msg in reversed(conversation_messages_sorted):
        if msg.get("role", "").lower() == "user":
            content = msg.get("content", "")
            if isinstance(content, list):
                for block in content:
                    if isinstance(block, dict) and block.get("type") == "text":
                        last_user_text += block.get("text", "")
                    elif isinstance(block, str):
                        last_user_text += block
            elif isinstance(content, str):
                last_user_text += content
            if last_user_text:
                break

    # If expert keywords are detected but no inference call is provided,
    # force the appropriate inference function call.
    expert_keywords = ["spsl", "frequency", "ucf", "spatial", "xception", "naïve"]
    if any(kw in last_user_text.lower() for kw in expert_keywords) and frames_processor_call is None:
        print("Expert analysis query detected without inference call; forcing inference call.")
        response_json = enforce_inference_call(last_user_text, response_json)
        frames_processor_call = response_json.get("function_call_to_extracted_frames_processor")
    
    # Chain 1: Handle video processing call (frame extraction).
    if video_processor_call is not None:
        func_name = video_processor_call.get("name")
        func_args = video_processor_call.get("arguments", {})
        if func_name == "extract_frames_from_video":
            try:
                extracted_frames = extract_frames_from_video(**func_args)
                function_result_message = {
                    "role": "function",
                    "name": func_name,
                    "content": json.dumps({"extracted_frames": extracted_frames})
                }
                formatted_messages.append(function_result_message)
                text = processor.apply_chat_template(formatted_messages, tokenize=False, add_generation_prompt=True)
                inputs = processor(
                    text=[text],
                    images=image_inputs if image_inputs else None,
                    videos=video_inputs if video_inputs else None,
                    return_tensors="pt",
                    padding=True
                ).to(model.device)
                raw_response_text = answer(model, processor, inputs)
                print("VLM response after frame extraction:")
                print(raw_response_text)
                try:
                    response_json = parse_json_response(raw_response_text)
                except Exception as e:
                    raise ValueError("Error after frame extraction: " + str(e))
                direct_answer = response_json.get("direct_answer_to_frontend", "")
                frames_processor_call = response_json.get("function_call_to_extracted_frames_processor")
            except Exception as e:
                print("Error executing extract_frames_from_video:", e)
    
    # Chain 2: Handle inference call (expert analysis using extracted frames).
    if frames_processor_call is not None:
        func_name = frames_processor_call.get("name")
        func_args = frames_processor_call.get("arguments", {})
        if func_name == "run_inference_on_images_with_old_preprocess":
            try:
                inference_results = run_inference_on_images_with_old_preprocess(**func_args)
                formatted_responses = response_formatter(inference_results)
                for resp in formatted_responses:
                    insert_message(
                        conversation_id=conversation_id,
                        role="assistant",
                        content=json.dumps(resp),
                        media_url=None,
                        media_type=None
                    )
                print("Inserted formatted responses from inference results.")
                highest_id = current_highest
                time.sleep(1)
                continue
            except Exception as e:
                print("Error executing run_inference_on_images_with_old_preprocess:", e)
    
    # If no function call is triggered and a direct answer is provided,
    # insert ONLY the string inside "direct_answer_to_frontend" into the database.
    if direct_answer and (video_processor_call is None and frames_processor_call is None):
        insert_message(
            conversation_id=conversation_id,
            role="assistant",
            content=direct_answer,
            media_url=None,
            media_type=None
        )
        print("Inserted direct answer to frontend.")
    
    print("Cycle complete.\n")
    highest_id = current_highest
    time.sleep(1)
