In [1]:
import os
import json
import logging
import torch
import numpy as np
import cv2
import mediapipe as mp
from transformers import AutoModelForCausalLM, AutoTokenizer, ViTModel
from peft import PeftModel
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image
from moviepy.video.io.VideoFileClip import VideoFileClip
import yt_dlp

In [2]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

In [3]:
# Phoneme vocabulary
PHONEME_VOCAB = ['<blank>', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 
                 'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 
                 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 
                 'Y', 'Z', 'ZH']
IDX_TO_PHONEME = {i: p for i, p in enumerate(PHONEME_VOCAB)}

In [4]:
class VALLRModel(torch.nn.Module):
    def __init__(self, hidden_size=768, num_phonemes=len(PHONEME_VOCAB)):
        super(VALLRModel, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224")
        self.adapter = torch.nn.Sequential(
            torch.nn.Conv1d(hidden_size, 384, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(kernel_size=2),
            torch.nn.Conv1d(384, 192, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(kernel_size=2),
            torch.nn.Conv1d(192, 48, kernel_size=3, padding=1),
            torch.nn.ReLU(),
            torch.nn.MaxPool1d(kernel_size=2),
            torch.nn.Conv1d(48, 16, kernel_size=3, padding=1),
            torch.nn.ReLU(),
        )
        self.adapter_norm = torch.nn.LayerNorm(16)
        self.fc = torch.nn.Linear(16, num_phonemes)

    def forward(self, frames):
        logger.debug(f"Input frames shape to VALLRModel: {frames.shape}")
        if len(frames.shape) != 5:
            raise ValueError(f"Expected 5D input tensor [batch_size, num_frames, height, width, channels], got {frames.shape}")
        batch_size, num_frames, h, w, c = frames.shape
        frames = frames.permute(0, 1, 4, 2, 3)  # [1, 75, 3, 224, 224]
        logger.debug(f"Frames shape after permute: {frames.shape}")
        frames = frames.view(-1, c, h, w)  # [75, 3, 224, 224]
        logger.debug(f"Frames shape for ViT: {frames.shape}")
        outputs = self.vit(pixel_values=frames)
        sequence_output = outputs.last_hidden_state[:, 1:, :]  # [75, 196, 768]
        logger.debug(f"ViT output shape: {sequence_output.shape}")
        # Pool patch dimension (mean across patches)
        sequence_output = sequence_output.mean(dim=1)  # [75, 768]
        logger.debug(f"Pooled output shape: {sequence_output.shape}")
        # Reshape to [batch_size, hidden_size, num_frames]
        sequence_output = sequence_output.view(batch_size, num_frames, -1).transpose(1, 2)  # [1, 768, 75]
        logger.debug(f"Sequence output shape for adapter: {sequence_output.shape}")
        features = self.adapter(sequence_output)
        features = self.adapter_norm(features.transpose(1, 2)).transpose(1, 2)
        logits = self.fc(features.transpose(1, 2)).transpose(1, 2)  # [batch, num_phonemes, seq_len]
        return logits

In [5]:
def detect_face_intervals(video_file):
    face_intervals = []
    try:
        mp_face_detection = mp.solutions.face_detection
        face_detection = mp_face_detection.FaceDetection()
        cap = cv2.VideoCapture(video_file)
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_duration = 1 / fps

        face_present = False
        start_time = None
        no_face_frames = 0
        frame_index = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = face_detection.process(rgb_frame)
            if results.detections:
                if not face_present:
                    start_time = frame_index * frame_duration
                    face_present = True
                no_face_frames = 0
            else:
                if face_present:
                    no_face_frames += 1
                    if no_face_frames >= fps * 1.5:
                        end_time = frame_index * frame_duration
                        face_intervals.append({"start": start_time, "end": end_time})
                        face_present = False
            frame_index += 1

        if face_present:
            end_time = frame_index * frame_duration
            face_intervals.append({"start": start_time, "end": end_time})

        cap.release()
    except Exception as e:
        logger.error(f"Error in detect_face_intervals: {e}")
        return []
    return face_intervals


In [6]:
def preprocess_video_to_frames(video_path, target_frames=75):
    try:
        mp_face_detection = mp.solutions.face_detection
        face_detection = mp_face_detection.FaceDetection()
        cap = cv2.VideoCapture(video_path)
        fps = cap.get(cv2.CAP_PROP_FPS)
        frames = []
        frame_indices = []

        face_intervals = detect_face_intervals(video_path)
        if not face_intervals:
            logger.warning("No face intervals detected, processing all frames")
            face_intervals = [{"start": 0, "end": cap.get(cv2.CAP_PROP_FRAME_COUNT) / fps}]

        for interval in face_intervals:
            start_frame = int(interval["start"] * fps)
            end_frame = int(interval["end"] * fps)
            cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
            frame_idx = start_frame
            while frame_idx < end_frame:
                ret, frame = cap.read()
                if not ret:
                    break
                results = face_detection.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                if results.detections:
                    bbox = results.detections[0].location_data.relative_bounding_box
                    h, w = frame.shape[:2]
                    x, y = int(bbox.xmin * w), int(bbox.ymin * h)
                    width, height = int(bbox.width * w), int(bbox.height * h)
                    face = frame[max(0, y):y+height, max(0, x):x+width]
                    if face.size > 0:
                        face = cv2.resize(face, (224, 224))
                        frames.append(face)
                        frame_indices.append(frame_idx)
                frame_idx += 1
        cap.release()

        if not frames:
            logger.error("No valid face frames detected")
            return None
        frames_array = np.array(frames)
        if len(frames_array) > target_frames:
            indices = np.linspace(0, len(frames_array) - 1, target_frames, dtype=int)
            frames_array = frames_array[indices]
        elif len(frames_array) < target_frames:
            pad_length = target_frames - len(frames_array)
            frames_array = np.pad(frames_array, ((0, pad_length), (0, 0), (0, 0), (0, 0)), mode="edge")
        return frames_array[:target_frames]
    except Exception as e:
        logger.error(f"Error in preprocess_video_to_frames: {e}")
        return None


In [7]:
class InferenceDataset(Dataset):
    def __init__(self, input_data, is_npz=False, transform=None):
        if is_npz:
            self.frames = np.load(input_data)['frames']
        else:
            self.frames = preprocess_video_to_frames(input_data)
        if self.frames is None:
            raise ValueError("Failed to preprocess video or load .npz")
        self.transform = transform

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        frames = self.frames
        if self.transform:
            frames = torch.stack([self.transform(Image.fromarray(frame)) for frame in frames])
        return frames

In [8]:
def load_model_local(model_dir, model_type="llama", use_gpu=True):
    device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
    logger.info(f"Loading {model_type} model to {device}")
    
    if model_type == "llama":
        logger.debug(f"Resolved model_dir: {model_dir}")
        
        # Verify required files exist
        required_files = ['adapter_config.json', 'adapter_model.safetensors', 'tokenizer.json']
        for f in required_files:
            if not os.path.exists(os.path.join(model_dir, f)):
                raise FileNotFoundError(f"Required file {f} not found in {model_dir}")
        
        # Load tokenizer from the local directory
        tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
        
        # Load base Llama model from Hugging Face
        base_model_name = "meta-llama/Llama-3.2-1B-Instruct"
        logger.debug(f"Loading base model: {base_model_name}")
        model = AutoModelForCausalLM.from_pretrained(
            base_model_name,
            load_in_4bit=use_gpu and torch.cuda.is_available(),
            device_map={"": 0} if use_gpu and torch.cuda.is_available() else "cpu",
            token=os.environ.get("HF_TOKEN")  # Use Hugging Face token if needed
        )
        
        # Apply LoRA adapters from the local directory
        logger.debug(f"Loading LoRA adapters from: {model_dir}")
        model = PeftModel.from_pretrained(model, model_dir, local_files_only=True)
        model.to(device)
        return model, tokenizer
        return model, tokenizer
    else:
        model = VALLRModel()
        checkpoint_path = os.path.join(model_dir)  # Specific .pth file
        checkpoint = torch.load(checkpoint_path, map_location=device)
        # Strip 'module.' prefix from DataParallel
        state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()} if any(k.startswith("module.") for k in checkpoint.keys()) else checkpoint
        model.load_state_dict(state_dict)
        model.to(device)
        return model, None

In [9]:
def predict_phonemes(vallr_model, frames, device):
    vallr_model.eval()
    #frames = frames.to(device)
    with torch.no_grad():
        logits = vallr_model(frames)
        phoneme_indices = torch.argmax(logits, dim=1).cpu().numpy()
    phoneme_sequence = " ".join([IDX_TO_PHONEME[idx] for idx in phoneme_indices[0] if idx != 0])
    return phoneme_sequence

In [10]:
def predict_text(llama_model, tokenizer, phoneme_sequence, device):
    llama_model.eval()
    input_text = f"{phoneme_sequence} ->"
    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=256,
        return_attention_mask=True
    ).to(device)
    with torch.no_grad():
        outputs = llama_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            pad_token_id=tokenizer.pad_token_id,
            max_length=256,
            num_beams=5,
            no_repeat_ngram_size=2,
            early_stopping=True
        )
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split('->', 1)[-1].strip() if '->' in generated_text else generated_text

In [11]:
def end_to_end_inference(vallr_model, llama_model, tokenizer, input_data, is_npz=False, use_gpu=True):
    device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    dataset = InferenceDataset(input_data, is_npz=is_npz) #, transform=transform)
    frames = dataset[0].unsqueeze(0).to(device)
    phoneme_sequence = predict_phonemes(vallr_model, frames, device)
    text = predict_text(llama_model, tokenizer, phoneme_sequence, device)
    return phoneme_sequence, text

In [12]:
def download_youtube_video_yt_dlp(url):
    # extract video ID from the URL
    video_id = url.split("v=")[-1]
    if "&" in video_id:
        video_id = video_id.split("&")[0]

    ydl_opts = {
        "outtmpl": f"{video_id}.%(ext)s",  # Output path and filename
        "format": "best",  # Select the best single file (video + audio)
        "merge_output_format": None,  # Avoid merging, stick to single stream
    }

    try:
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            ydl.download([url])
        print("Download completed successfully!")
        return f"{video_id}.mp4"
    except Exception as e:
        print(f"An error occurred: {e}")

In [13]:
full_video = 'https://www.youtube.com/watch?v=AJsOA4Zl6Io'
video_file = download_youtube_video_yt_dlp(full_video)
video = VideoFileClip(video_file)
chunk = video.subclipped(65, 72)
chunk_filename = f"testing_vid.mp4"
chunk.write_videofile(chunk_filename, codec="libx264")
# delete the local video file
os.remove(video_file)

[youtube] Extracting URL: https://www.youtube.com/watch?v=AJsOA4Zl6Io
[youtube] AJsOA4Zl6Io: Downloading webpage
[youtube] AJsOA4Zl6Io: Downloading tv client config
[youtube] AJsOA4Zl6Io: Downloading player b2858d36-main
[youtube] AJsOA4Zl6Io: Downloading tv player API JSON
[youtube] AJsOA4Zl6Io: Downloading ios player API JSON
[youtube] AJsOA4Zl6Io: Downloading m3u8 information
[info] AJsOA4Zl6Io: Downloading 1 format(s): 18
[download] Destination: AJsOA4Zl6Io.mp4
[download] 100% of    5.41MiB in 00:00:00 at 8.40MiB/s   
Download completed successfully!
{'video_found': True, 'audio_found': True, 'metadata': {'major_brand': 'mp42', 'minor_version': '0', 'compatible_brands': 'isommp42', 'creation_time': '2025-03-04T20:57:47.000000Z', 'encoder': 'Google'}, 'inputs': [{'streams': [{'input_number': 0, 'stream_number': 0, 'stream_type': 'video', 'language': None, 'default': True, 'size': [640, 360], 'bitrate': 270, 'fps': 29.97002997002997, 'codec_name': 'h264', 'profile': '(Main)', 'metada

                                                        

MoviePy - Done.
MoviePy - Writing video testing_vid.mp4



                                                                          

MoviePy - Done !
MoviePy - video ready testing_vid.mp4




In [14]:
input_file = os.getcwd() + "/testing_vid.mp4" 
use_gpu = torch.cuda.is_available()

In [15]:
vallr_model_dir = os.getcwd() + "/vallr_models/model.pth"
vallr_model, _ = load_model_local(vallr_model_dir, model_type="vallr", use_gpu=use_gpu)

Loading vallr model to cpu
Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
llama_model_dir = os.getcwd() + "/vallr_models/checkpoint-29750/"
llama_model, tokenizer = load_model_local(llama_model_dir, model_type="llama", use_gpu=use_gpu)

Loading llama model to cpu
Resolved model_dir: /Users/emmettstorts/Documents/slip-ml/inference/vallr_models/checkpoint-29750/
Loading base model: meta-llama/Llama-3.2-1B-Instruct
Loading LoRA adapters from: /Users/emmettstorts/Documents/slip-ml/inference/vallr_models/checkpoint-29750/


In [17]:
dataset = InferenceDataset(input_file, is_npz=False)
frames = torch.tensor(dataset[0]).unsqueeze(0)
#frames = frames.permute(0, 1, 4, 2, 3) 
phoneme_sequence = predict_phonemes(vallr_model, frames, 'cpu')

I0000 00:00:1747348395.837760 11921918 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M1 Pro
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1747348395.840836 11922590 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
I0000 00:00:1747348395.846643 11921918 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M1 Pro
W0000 00:00:1747348395.847907 11922606 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
Input frames shape to VALLRModel: torch.Size([1, 75, 224, 224, 3])
Frames shape after permute: torch.Size([1, 75, 3, 224, 224])
Frames shape for ViT: torch.Size([75, 3, 224, 224])
ViT output shape: torch.Size([75, 196, 768])
Pooled output shape: torch.Size([75, 768])
Sequence output shape for adapter: torch.Size([1, 768, 75])


In [18]:
text = predict_text(llama_model, tokenizer, phoneme_sequence, 'cpu')

In [19]:
text.split('->', 1)[-1].strip() if '->' in text else text

"unmoved weddings without doing what you'd want to, travel to the plazma, make more problems out of doing that. what u'd wanted to do, weddings, unmove, move, can't convalve yourself, made more, more. that's it. wow, wow. no, wait, t. w. h. d. unwedded, married, divorced, nobody, no. nobody. wait. welcome, welcome. yes, yes. you, m.m. love, wedding, unmarried,"