## Setup Environment

In [None]:
import torch

### install packages

In [None]:
# about 0.5~1min
!pip install tensorrt==8.6.1 librosa tqdm filetype imageio opencv_python_headless scikit-image cython cuda-python imageio-ffmpeg colored polygraphy numpy==2.0.1

In [None]:
!pip install flask-ngrok
!pip install pyngrok

In [None]:
!apt update
!apt install ffmpeg -y

In [None]:
!apt install -y libcudnn8

### restart runtime

In [None]:
# !!!
# You need to restart the runtime to ensure that the newly installed environment takes effect
import os
os.kill(os.getpid(), 9)

### Environment Check

In [None]:
import numpy as np
import torch
import tensorrt as trt
print(np.__version__)
print(torch.__version__)                            #Ensuring all the packages are downloaded and imported
print(trt.__version__)

## Downloading Models

In [None]:
import os
%cd /content
if not os.path.isdir("ditto-talkinghead"):
    !git clone https://github.com/antgroup/ditto-talkinghead.git
else:
    print("ditto-talkinghead already cloned.")

%cd ditto-talkinghead
!git pull
!ls

In [None]:
# about 1~2min
!git lfs install
if not os.path.isdir("checkpoints"):
    !git clone https://huggingface.co/digital-avatar/ditto-talkinghead checkpoints

%cd checkpoints
!git pull
!ls

%cd ..
!ls

### check GPU architecture

In [None]:
# about 1~2min
import os
import torch

def cvt_custom_trt():
    from scripts.cvt_onnx_to_trt import main as cvt_trt
    onnx_dir = "./checkpoints/ditto_onnx"
    trt_dir = "./checkpoints/ditto_trt_custom"
    assert os.path.isdir(onnx_dir)
    os.makedirs(trt_dir, exist_ok=True)
    grid_sample_plugin_file = os.path.join(onnx_dir, "libgrid_sample_3d_plugin.so")
    cvt_trt(onnx_dir, trt_dir, grid_sample_plugin_file)
    return trt_dir


def download_Non_Ampere_trt():
    !pip install --upgrade --no-cache-dir gdown
    !gdown https://drive.google.com/drive/folders/1-1qnqy0D9ICgRh8iNY_22j9ieNRC0-zf?usp=sharing -O ./checkpoints/ditto_trt --folder
    trt_dir = "./checkpoints/ditto_trt"
    return trt_dir


if torch.cuda.get_device_capability()[0] < 8:
    # data_root = cvt_custom_trt()    # cvt
    # The conversion is slow, so you can download pre-converted files.
    data_root = download_Non_Ampere_trt()
else:
    data_root = "./checkpoints/ditto_trt_Ampere_Plus"

## Inference

In [None]:
from stream_pipeline_offline.py import StreamSDK
from inference import run
# data_root = "./checkpoints/ditto_trt_custom"   # model dir
cfg_pkl = "./checkpoints/ditto_cfg/v0.4_hubert_cfg_trt.pkl"     # cfg pkl
print(data_root)
print(cfg_pkl)
SDK = StreamSDK(cfg_pkl, data_root)

### Ngrok Tunelling to connect to Frontend

In [None]:
from flask import Flask, request, send_file
from pyngrok import ngrok
import os
import numpy as np
import librosa

# NGROK SETUP
ngrok.set_auth_token('')                                #Enter you Ngrok Auth Token
app = Flask(__name__)
public_url = ngrok.connect(5000)
print(f"🔗 Public URL: {public_url}")

def get_seq_len_from_audio(audio_path, fps=25, target_sr=16000):
    audio, sr = librosa.load(audio_path, sr=target_sr)
    duration_sec = len(audio) / sr
    seq_len = int(duration_sec * fps)
    return seq_len

@app.route('/generate', methods=['POST'])
def generate():
    image = request.files.get('image')
    audio = request.files.get('audio')
    emotion_index = request.form.get('emotion')

    if not image or not audio or emotion_index is None:
        return "Missing image, audio, or emotion", 400

    try:
        emotion_index = int(emotion_index)
        if not (0 <= emotion_index <= 7):
            return "Emotion index must be between 0 and 7", 400
    except ValueError:
        return "Emotion index must be an integer", 400

    image_path = "/content/image.png"
    audio_path = "/content/audio.wav"
    output_path = "/content/result.mp4"
    image.save(image_path)
    audio.save(audio_path)

    try:
        # 1. Get seq_len
        seq_len = get_seq_len_from_audio(audio_path)                        #Dynamically getting number of frames, based on audio input

        # 2. Create emotion array
        emo_arr = np.zeros((seq_len, 8), dtype=np.float32)                  
        emo_arr[:, emotion_index] = 1.0

        # 3. Prepare kwargs
        setup_kwargs = {"emo": emo_arr}
        more_kwargs = {
            "setup_kwargs": setup_kwargs,
            "run_kwargs": {}
        }

        # 4. Run model
        SDK = StreamSDK(cfg_pkl, data_root)
        run(SDK, audio_path, image_path, output_path, more_kwargs=more_kwargs)
        del SDK                                     #Freeing Up VRAM

    except Exception as e:
        return f"Inference failed: {e}", 500

    if not os.path.exists(output_path):
        return "Video not found", 500

    return send_file(output_path, mimetype="video/mp4", as_attachment=True)

app.run(port=5000)