# UnQWorkFlow - AI Video Generation Engine

This notebook provides an automated pipeline for generating AI videos using the wan2.1 text-to-video model. 

## Features
- Smart caching of the large model in Google Drive
- Automatic code synchronization with GitHub
- Idempotent design (can be run multiple times without unnecessary re-downloads)
- Efficient dependency management

**IMPORTANT**: To use this notebook, you need at least 15GB of free space in your Google Drive.


## Cell 1: Initial Setup & Configuration

In [None]:
import os
import json
import sys
import shutil
from pathlib import Path
from datetime import datetime
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger('unqworkflow')

# --- CONFIGURATION ---
GDRIVE_MOUNT_PATH = '/content/drive'
GDRIVE_PROJECT_PATH = f'{GDRIVE_MOUNT_PATH}/MyDrive/UnQWorkFlow'
REPO_PATH = f'{GDRIVE_PROJECT_PATH}/code'
MODEL_PATH = f'{GDRIVE_PROJECT_PATH}/models/wan2.1'
OUTPUT_PATH = f'{GDRIVE_PROJECT_PATH}/outputs'
GITHUB_REPO_URL = 'https://github.com/Sandeepgaddam5432/unq-content-flow.git'
# --- END CONFIGURATION ---

# Print configuration for verification
logger.info("Starting UnQWorkFlow Video Generator")
logger.info(f"Google Drive Mount Path: {GDRIVE_MOUNT_PATH}")
logger.info(f"Project Path: {GDRIVE_PROJECT_PATH}")
logger.info(f"Repository Path: {REPO_PATH}")
logger.info(f"Model Path: {MODEL_PATH}")
logger.info(f"Output Path: {OUTPUT_PATH}")

# Create a timestamp to log the session
session_start_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Session started at: {session_start_time}")

## Cell 2: Mount Google Drive & Create Directory Structure

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount(GDRIVE_MOUNT_PATH)
logger.info("Google Drive mounted successfully.")

# Function to create directories if they don't exist
def ensure_dir_exists(dir_path):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)
        logger.info(f"Created directory: {dir_path}")
    else:
        logger.info(f"Directory already exists: {dir_path}")

# Create required directories
ensure_dir_exists(GDRIVE_PROJECT_PATH)
ensure_dir_exists(REPO_PATH)
ensure_dir_exists(MODEL_PATH)
ensure_dir_exists(OUTPUT_PATH)

# Create additional directory for logs
LOGS_PATH = f"{GDRIVE_PROJECT_PATH}/logs"
ensure_dir_exists(LOGS_PATH)

# Create a session log file
session_id = datetime.now().strftime('%Y%m%d_%H%M%S')
session_log_file = f"{LOGS_PATH}/session_{session_id}.log"

# Write initial session info
with open(session_log_file, 'w') as f:
    f.write(f"UnQWorkFlow Session - {session_start_time}\n")
    f.write(f"Google Drive mounted at: {GDRIVE_MOUNT_PATH}\n")

logger.info(f"Session log created at: {session_log_file}")

## Cell 3: Sync GitHub Repository (Smart Clone/Pull Logic)

In [None]:
import subprocess

def run_command(cmd, cwd=None):
    """Run a shell command and return its output"""
    result = subprocess.run(
        cmd, 
        cwd=cwd, 
        shell=True, 
        check=True, 
        capture_output=True, 
        text=True
    )
    return result.stdout.strip()

# Check if repo exists and sync accordingly
if os.path.exists(os.path.join(REPO_PATH, '.git')):
    # Repository exists, pull latest changes
    logger.info(f"Repository exists at {REPO_PATH}. Pulling latest changes...")
    try:
        output = run_command("git pull", cwd=REPO_PATH)
        logger.info(f"Git pull output: {output}")
    except subprocess.CalledProcessError as e:
        logger.error(f"Git pull failed: {e}")
        # Handle potential conflicts
        logger.warning("Attempting to reset and pull...")
        run_command("git fetch origin", cwd=REPO_PATH)
        run_command("git reset --hard origin/main", cwd=REPO_PATH)
        logger.info("Repository reset to main branch head")
else:
    # Repository doesn't exist, clone it
    logger.info(f"Repository doesn't exist at {REPO_PATH}. Cloning from {GITHUB_REPO_URL}...")
    
    # First, ensure the directory is empty or doesn't exist
    if os.path.exists(REPO_PATH):
        # If it exists but is not a git repo, clean it
        shutil.rmtree(REPO_PATH)
        os.makedirs(REPO_PATH)
    
    try:
        output = run_command(f"git clone {GITHUB_REPO_URL} {REPO_PATH}")
        logger.info(f"Git clone output: {output}")
        logger.info(f"Repository cloned successfully to {REPO_PATH}")
    except subprocess.CalledProcessError as e:
        logger.error(f"Git clone failed: {e}")
        raise

# Verify the repo is properly synced
repo_files = os.listdir(REPO_PATH)
logger.info(f"Repository contains {len(repo_files)} files/directories")
if 'requirements.txt' in repo_files:
    logger.info("Found requirements.txt in the repository")
else:
    logger.warning("requirements.txt not found in the repository!")

## Cell 4: Install Dependencies (Efficiently)

In [None]:
# Function to install dependencies from requirements.txt
def install_requirements(requirements_path):
    if os.path.exists(requirements_path):
        logger.info(f"Installing dependencies from {requirements_path}...")
        try:
            # Install dependencies
            !pip install -q -r {requirements_path}
            logger.info("Dependencies installed successfully")
            return True
        except Exception as e:
            logger.error(f"Failed to install dependencies: {e}")
            return False
    else:
        logger.error(f"Requirements file not found at {requirements_path}")
        return False

# Path to requirements.txt in the repo
requirements_path = os.path.join(REPO_PATH, 'requirements.txt')

# Install dependencies
install_success = install_requirements(requirements_path)

if not install_success:
    logger.warning("Installing fallback dependencies...")
    
    # Install core dependencies directly if requirements.txt fails
    !pip install -q torch torchvision torchaudio transformers diffusers accelerate
    !pip install -q opencv-python moviepy ffmpeg-python
    !pip install -q numpy scipy pillow tqdm requests
    
    logger.info("Fallback dependencies installed")

## Cell 5: Smart Model Caching (The Key Optimization)

In [None]:
import time
import requests
from tqdm.notebook import tqdm

# Define the key file that indicates the model is fully downloaded
MODEL_KEY_FILE = os.path.join(MODEL_PATH, 'model.safetensors')

# Function to check if model is already downloaded
def check_model_downloaded():
    if os.path.exists(MODEL_KEY_FILE):
        file_size_mb = os.path.getsize(MODEL_KEY_FILE) / (1024 * 1024)  # Size in MB
        logger.info(f"Model found in Google Drive. Size: {file_size_mb:.2f} MB")
        return True
    return False

# Function to download the model using Hugging Face or another source
def download_model():
    try:
        start_time = time.time()
        logger.info("Starting model download. This might take a while...")
        
        # For wan2.1 model, we'll use the Hugging Face transformers library
        # This is a placeholder - you should replace with the actual model download code
        import torch
        from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
        
        # Note: The actual model ID may differ - replace with correct one
        model_id = "wanonly/wan2.1"  # Example - verify the correct model ID
        
        # Download the model directly to the MODEL_PATH
        pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.enable_xformers_memory_efficient_attention()
        
        # Save the model to our MODEL_PATH
        pipe.save_pretrained(MODEL_PATH)
        
        end_time = time.time()
        logger.info(f"Model downloaded and saved successfully to {MODEL_PATH}")
        logger.info(f"Download took {(end_time - start_time) / 60:.2f} minutes")
        
        return True
    except Exception as e:
        logger.error(f"Model download failed: {e}")
        return False

# Main logic for model management
if check_model_downloaded():
    logger.info("Model already exists in Google Drive. Skipping download.")
else:
    logger.info("Model not found in Google Drive. Starting download...")
    success = download_model()
    
    if success:
        logger.info("Model download completed successfully.")
    else:
        logger.error("Failed to download the model. Please check logs and try again.")
        # You might want to halt execution here if the model is essential
        # raise Exception("Model download failed, cannot continue.")

## Cell 6: The Main Video Generation Function

In [None]:
def generate_video(prompt: str, duration_seconds: int, output_filename: str) -> dict:
    """
    Generates a video using the wan2.1 model.

    Args:
        prompt (str): The text prompt for the video.
        duration_seconds (int): The desired duration of the video.
        output_filename (str): The name for the output video file (e.g., 'space_video_1.mp4').

    Returns:
        dict: A dictionary containing the status and the full path to the generated video.
    """
    try:
        logger.info(f"Starting video generation for prompt: {prompt}")
        start_time = time.time()
        
        # Ensure the output directory exists
        ensure_dir_exists(OUTPUT_PATH)
        
        # Prepare output path
        final_video_path = os.path.join(OUTPUT_PATH, output_filename)
        
        # 1. Load the wan2.1 model from MODEL_PATH
        logger.info("Loading the model...")
        import torch
        from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
        
        # Calculate number of frames based on duration (assuming 30fps)
        fps = 30
        num_frames = duration_seconds * fps
        
        # Load the pipeline
        pipe = DiffusionPipeline.from_pretrained(
            MODEL_PATH, 
            torch_dtype=torch.float16
        )
        
        # Move the model to GPU if available
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Using device: {device}")
        pipe = pipe.to(device)
        
        # Set up the scheduler for better quality
        pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
        
        # Optimize memory usage if possible
        if hasattr(pipe, 'enable_xformers_memory_efficient_attention'):
            pipe.enable_xformers_memory_efficient_attention()
        
        # 2. Execute the text-to-video generation process
        logger.info("Generating video frames...")
        
        # Generate the video frames (note: actual API might differ)
        result = pipe(
            prompt=prompt,
            num_frames=num_frames,
            guidance_scale=7.5,  # You may tune this parameter
            num_inference_steps=50,  # You may tune this parameter
            height=320,  # Adjust based on model capabilities
            width=576    # Adjust based on model capabilities
        ).frames
        
        # 3. Save the final video to the OUTPUT_PATH
        logger.info(f"Processing and saving video to {final_video_path}...")
        
        # Convert frames to video using moviepy or opencv
        import cv2
        import numpy as np
        
        # Create a VideoWriter object
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(final_video_path, fourcc, fps, (576, 320))
        
        # Write all frames to the video file
        for frame in result:
            # Convert PIL Image to OpenCV format
            frame_np = np.array(frame)
            # Convert RGB to BGR (OpenCV uses BGR)
            frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
            out.write(frame_bgr)
        
        # Release the VideoWriter
        out.release()
        
        end_time = time.time()
        generation_time = (end_time - start_time) / 60  # time in minutes
        
        logger.info(f"Video generated successfully in {generation_time:.2f} minutes: {final_video_path}")
        
        # Record the generation in the session log
        with open(session_log_file, 'a') as f:
            f.write(f"\nVideo generated at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"Prompt: {prompt}\n")
            f.write(f"Duration: {duration_seconds} seconds\n")
            f.write(f"Output file: {output_filename}\n")
            f.write(f"Generation time: {generation_time:.2f} minutes\n")
        
        return {
            "status": "success",
            "video_path": final_video_path,
            "generation_time_minutes": generation_time,
            "prompt": prompt,
            "duration": duration_seconds
        }
        
    except Exception as e:
        logger.error(f"Error during video generation: {e}")
        import traceback
        traceback.print_exc()
        
        return {
            "status": "error",
            "message": str(e),
            "prompt": prompt
        }

# Example of how to call the function (uncomment to test)
# result = generate_video(
#     prompt="A cinematic shot of a rocket launching into space, with flames and smoke billowing from the engines",
#     duration_seconds=10,
#     output_filename="rocket_launch_demo.mp4"
# )
# print(json.dumps(result, indent=2))  # Print result as a formatted JSON string

## Cell 7: API Integration Example

This cell demonstrates how the notebook could be integrated with the UnQWorkFlow application. It sets up a simple API endpoint that the main application can call to generate videos.

In [None]:
from IPython.display import HTML, display
import ipywidgets as widgets
import threading

# Create a simple UI for testing the video generation
prompt_input = widgets.Textarea(
    value='A cinematic shot of a rocket launching into space',
    placeholder='Enter your video prompt here',
    description='Prompt:',
    disabled=False,
    layout=widgets.Layout(width='90%', height='80px')
)

duration_input = widgets.IntSlider(
    value=5,
    min=3,
    max=30,
    step=1,
    description='Duration (s):',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True
)

filename_input = widgets.Text(
    value='generated_video.mp4',
    placeholder='output.mp4',
    description='Filename:',
    disabled=False
)

output = widgets.Output()
button = widgets.Button(description="Generate Video")
progress = widgets.FloatProgress(value=0, min=0, max=100, description='Processing:')

def on_button_click(b):
    output.clear_output()
    with output:
        progress.value = 0
        display(progress)
        
        # Start a thread to show progress
        def update_progress():
            import time
            import random
            # Simulate progress
            while progress.value < 99:
                time.sleep(0.5)
                # Randomly increment progress to simulate processing
                if progress.value < 30:
                    progress.value += random.uniform(0.5, 2)
                elif progress.value < 70:
                    progress.value += random.uniform(0.1, 0.8)
                else:
                    progress.value += random.uniform(0.05, 0.2)
        
        progress_thread = threading.Thread(target=update_progress)
        progress_thread.daemon = True
        progress_thread.start()
        
        # Call the actual generation function
        print(f"Generating video for prompt: {prompt_input.value}")
        print(f"Duration: {duration_input.value} seconds")
        print(f"Filename: {filename_input.value}")
        print("\nThis may take several minutes. Please wait...")
        
        result = generate_video(
            prompt=prompt_input.value,
            duration_seconds=duration_input.value,
            output_filename=filename_input.value
        )
        
        progress.value = 100
        
        if result["status"] == "success":
            print(f"\n✅ Video generated successfully!")
            print(f"📂 Path: {result['video_path']}")
            print(f"⏱️ Generation time: {result['generation_time_minutes']:.2f} minutes")
            
            # Display the video if possible
            try:
                video_path = result['video_path'].replace('/content/', '/content/drive/')
                display(HTML(f"""
                <div style="border:2px solid #ddd; padding:10px; border-radius:10px; margin-top:20px;">
                  <h3 style="margin-top:0">Generated Video Preview</h3>
                  <video width="100%" height="auto" controls>
                    <source src="{video_path}" type="video/mp4">
                    Your browser does not support the video tag.
                  </video>
                </div>
                """))
            except Exception as e:
                print(f"Could not display video preview: {e}")
                
        else:
            print(f"\n❌ Error: {result['message']}")

button.on_click(on_button_click)

# Display the UI
display(widgets.HTML("<h2>UnQWorkFlow Video Generator</h2>"))
display(prompt_input)
display(widgets.HBox([duration_input, filename_input]))
display(button)
display(output)

## Cell 8: Session Summary and Cleanup

This cell provides a summary of the current session and handles any necessary cleanup.

In [None]:
# Calculate the session duration
session_end_time = datetime.now()
session_start_time_dt = datetime.strptime(session_start_time, '%Y-%m-%d %H:%M:%S')
session_duration = session_end_time - session_start_time_dt

logger.info("\nSession Summary:")
logger.info(f"Session started at: {session_start_time}")
logger.info(f"Session ended at: {session_end_time.strftime('%Y-%m-%d %H:%M:%S')}")
logger.info(f"Session duration: {session_duration}")

# Write session summary to the log file
with open(session_log_file, 'a') as f:
    f.write(f"\nSession ended at: {session_end_time.strftime('%Y-%m-%d %H:%M:%S')}\n")
    f.write(f"Session duration: {session_duration}\n")

# Display session info
print("\nSession Information:")
print(f"Session ID: {session_id}")
print(f"Duration: {session_duration}")
print(f"Log file: {session_log_file}")

# Optional cleanup to free memory
import gc
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("\nUnQWorkFlow video generation session completed.")
print("The notebook is now ready for the next generation task.")

## Step 7: Launch the Backend Server & Cloudflare Tunnel

This section sets up a Flask web server and exposes it to the internet using a Cloudflare Tunnel. This allows the frontend application to communicate directly with this notebook in real-time.

In [None]:
# Cell: Download Cloudflared
!wget https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64 -O cloudflared
!chmod +x cloudflared
print("🚀 Cloudflared client downloaded.")

In [None]:
# Cell: Create and Run Flask API Server
from flask import Flask, request, jsonify
import threading
import subprocess
import re
import time
from flask_cors import CORS

# --- 1. Define the Flask App ---
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

@app.route('/api/generate-video', methods=['POST'])
def handle_video_generation():
    print("Received request at /api/generate-video")
    data = request.get_json()
    if not data or 'prompt' not in data:
        return jsonify({"status": "error", "message": "Missing 'prompt' in request"}), 400

    prompt = data.get('prompt')
    duration = data.get('duration', 15) # Default duration 15s
    job_id = f"job_{int(time.time())}" # Create a simple job ID
    output_filename = f"{job_id}.mp4"

    # Call our existing video generation function
    result = generate_video(prompt=prompt, duration_seconds=duration, output_filename=output_filename)

    # Make the output file path accessible via a relative URL
    if result["status"] == "success":
        # Add a URL that can be accessed through the Flask server
        result["video_url"] = f"/api/videos/{output_filename}"
    
    return jsonify(result)

@app.route('/api/videos/<filename>', methods=['GET'])
def serve_video(filename):
    # Basic validation to prevent directory traversal attacks
    if '../' in filename or '/' in filename:
        return jsonify({"error": "Invalid filename"}), 400
        
    video_path = os.path.join(OUTPUT_PATH, filename)
    if not os.path.exists(video_path):
        return jsonify({"error": "Video not found"}), 404
        
    # In a real production environment, you'd use send_file here
    # For the Colab environment, we'll return the full path for now
    return jsonify({"video_path": video_path})

@app.route('/api/health', methods=['GET'])
def health_check():
    return jsonify({
        "status": "online",
        "timestamp": datetime.now().isoformat(),
        "model": "wan2.1",
        "device": "cuda" if torch.cuda.is_available() else "cpu"
    })

# --- 2. Function to Run Flask App ---
def run_flask():
    # Running on port 5000, accessible only within Colab
    app.run(port=5000, host='0.0.0.0')

# --- 3. Function to Run Cloudflare Tunnel ---
def run_cloudflared():
    # Start the cloudflared tunnel
    process = subprocess.Popen(
        ['./cloudflared', 'tunnel', '--url', 'http://localhost:5000'],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    
    # Find the public URL in the output
    for line in iter(process.stderr.readline, ''):
        if '.trycloudflare.com' in line:
            public_url = re.search(r'(https?://[a-zA-Z0-9-]+\.trycloudflare\.com)', line).group(0)
            print("=====================================================================================")
            print(f"🚀 Your Public Backend URL is LIVE: {public_url}")
            print("COPY THIS URL and PASTE it into the UnQWorkFlow website.")
            print("=====================================================================================")
            break

# --- 4. Start Both in Threads ---
print("Starting Flask server in the background...")
flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()

print("Starting Cloudflare Tunnel...")
time.sleep(2) # Give Flask a moment to start
cloudflared_thread = threading.Thread(target=run_cloudflared)
cloudflared_thread.daemon = True
cloudflared_thread.start()

# Keep the main thread alive to see the output
print("Waiting for Cloudflare tunnel to be established...")
cloudflared_thread.join()

# This line will only be reached if the cloudflared thread terminates
print("\n⚠️ The Cloudflare tunnel has terminated. Please restart the notebook if you need to reconnect.")