In [None]:
!pip install -q transformers librosa wandb
!git clone https://github.com/Rudrabha/Wav2Lip.git 2>/dev/null || true
!mkdir -p Wav2Lip/checkpoints

# Download pretrained Wav2Lip (update URL if needed)
!gdown "1qjKBoCnOQ5IrjeirFtgDAaDM4My4A9Ci" -O Wav2Lip/checkpoints/wav2lip.pth 2>/dev/null || echo "Download wav2lip.pth manually"
!gdown "14-EvGLdMQqPPZf7g5wTMdZ4TrPBdRE3C" -O Wav2Lip/checkpoints/lipsync_expert.pth 2>/dev/null || echo "Download lipsync_expert.pth manually"

In [None]:
import sys
sys.path.insert(0, "/content")
sys.path.insert(0, "/content/Wav2Lip")

import gc
import json
import warnings
from pathlib import Path

import librosa
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import wandb
from torch.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from emotion_utils import (
    CrossModalEmotionLoss,
    DifferentiableVideoPreprocess,
    EmotionAgreementMetric,
    load_frozen_audio_encoder,
    load_frozen_video_encoder,
    extract_audio_embedding,
    extract_video_embedding,
)
from models.wav2lip import Wav2Lip as Wav2LipModel

warnings.filterwarnings("ignore")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
METADATA = "/content/processed_data/metadata.json"
WAV2LIP_CKPT = "/content/Wav2Lip/checkpoints/wav2lip.pth"
BEST_AUDIO_PATH = "/content/trained_encoders_v2/w2v2-lg-lr2e5"
BEST_VIDEO_PATH = "/content/trained_encoders_v2/tsf-lr3e5-16f-nf"
OUT_DIR = Path("/content/wav2lip_finetuned")
OUT_DIR.mkdir(parents=True, exist_ok=True)

EXCLUDE = {0, 1, 3, 5, 7}
REMAP = {2: 0, 4: 1, 6: 2}
EMOTIONS = ["happy", "angry", "disgust"]

print(f"Device: {DEVICE}")

In [None]:
IMG_SIZE = 96
MEL_STEP = 16
SR = 16000
FPS = 25

def wav_to_mel(wav_path, sr=SR):
    y, _ = librosa.load(wav_path, sr=sr)
    mel = librosa.feature.melspectrogram(
        y=y, sr=sr, n_mels=80, hop_length=200, win_length=800,
        fmin=55, fmax=7600)
    return librosa.power_to_db(mel, ref=np.max).astype(np.float32)


class Wav2LipDataset(Dataset):
    def __init__(self, metadata_path, split, T=5):
        with open(metadata_path) as f:
            data = json.load(f)
        self.samples = [s for s in data
                        if s["split"] == split and s["emotion_idx"] not in EXCLUDE]
        self.T = T

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]

        wav, sr = torchaudio.load(s["audio_path"])
        audio_1d = wav.squeeze(0)

        mel = wav_to_mel(s["audio_path"])

        frames = np.load(s["frames_path"]).astype(np.float32) / 255.0
        n_frames = frames.shape[0]

        start = np.random.randint(0, max(1, n_frames - self.T))
        face_window = frames[start:start + self.T]
        if face_window.shape[0] < self.T:
            pad = np.repeat(face_window[-1:], self.T - face_window.shape[0], axis=0)
            face_window = np.concatenate([face_window, pad], axis=0)

        mel_start = int(start / FPS * SR / 200)
        mel_end = mel_start + MEL_STEP * self.T
        mel_window = mel[:, mel_start:mel_end]
        if mel_window.shape[1] < MEL_STEP * self.T:
            mel_window = np.pad(mel_window, ((0, 0), (0, MEL_STEP * self.T - mel_window.shape[1])))

        gt = torch.from_numpy(face_window).permute(0, 3, 1, 2)
        H, W = gt.shape[2], gt.shape[3]
        if H != IMG_SIZE or W != IMG_SIZE:
            gt = F.interpolate(gt, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False)

        masked = gt.clone()
        masked[:, :, IMG_SIZE // 2:, :] = 0.0

        ref_idx = np.random.randint(0, n_frames)
        ref = torch.from_numpy(frames[ref_idx]).permute(2, 0, 1).unsqueeze(0).expand(self.T, -1, -1, -1)
        if ref.shape[2] != IMG_SIZE or ref.shape[3] != IMG_SIZE:
            ref = F.interpolate(ref, size=(IMG_SIZE, IMG_SIZE), mode="bilinear", align_corners=False)

        face_input = torch.cat([ref, masked], dim=1)

        mel_chunks = []
        for t in range(self.T):
            m = mel_window[:, t * MEL_STEP:(t + 1) * MEL_STEP]
            mel_chunks.append(torch.from_numpy(m).unsqueeze(0))
        mel_tensor = torch.stack(mel_chunks, dim=0)

        return {
            "mel": mel_tensor,
            "face_input": face_input,
            "gt": gt,
            "audio": audio_1d,
            "emotion": REMAP[s["emotion_idx"]],
        }


def collate_wav2lip(batch):
    return {
        "mel": torch.stack([b["mel"] for b in batch]),
        "face_input": torch.stack([b["face_input"] for b in batch]),
        "gt": torch.stack([b["gt"] for b in batch]),
        "audio": [b["audio"] for b in batch],
        "emotion": torch.tensor([b["emotion"] for b in batch]),
    }

In [None]:
def load_wav2lip(ckpt_path, device):
    model = Wav2LipModel()
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
    state = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
    model.load_state_dict(state)
    return model.to(device)

wav2lip = load_wav2lip(WAV2LIP_CKPT, DEVICE)
print(f"Wav2Lip loaded: {sum(p.numel() for p in wav2lip.parameters()) / 1e6:.1f}M params")

audio_enc, audio_proc = load_frozen_audio_encoder(BEST_AUDIO_PATH, DEVICE)
video_enc = load_frozen_video_encoder(BEST_VIDEO_PATH, DEVICE)
video_preprocess = DifferentiableVideoPreprocess(224).to(DEVICE)
print("Frozen emotion encoders loaded.")

In [None]:
wandb.login()

CONFIGS = [
    {"name": "wav2lip-baseline", "lambda_emo": 0.0},
    {"name": "wav2lip-emo-001",  "lambda_emo": 0.01},
    {"name": "wav2lip-emo-005",  "lambda_emo": 0.05},
    {"name": "wav2lip-emo-01",   "lambda_emo": 0.1},
]

LR = 1e-4
EPOCHS = 20
BATCH_SIZE = 4
PATIENCE = 5
T_FRAMES = 5

In [None]:
def train_one_epoch(model, loader, optimizer, scaler, emotion_loss_fn, lambda_emo):
    model.train()
    total_recon, total_emo, total_loss = 0.0, 0.0, 0.0

    for batch in tqdm(loader, leave=False):
        mel = batch["mel"].to(DEVICE)
        face_in = batch["face_input"].to(DEVICE)
        gt = batch["gt"].to(DEVICE)
        B, T = mel.shape[0], mel.shape[1]

        optimizer.zero_grad(set_to_none=True)

        all_gen = []
        recon = 0.0
        with autocast("cuda", enabled=DEVICE == "cuda"):
            for t in range(T):
                gen = model(mel[:, t], face_in[:, t])
                recon += F.l1_loss(gen, gt[:, t])
                all_gen.append(gen)
            recon = recon / T

            emo = torch.tensor(0.0, device=DEVICE)
            if lambda_emo > 0:
                gen_video = torch.stack(all_gen, dim=1)
                audio_emb = extract_audio_embedding(
                    audio_enc, audio_proc, batch["audio"], device=DEVICE)
                video_emb = extract_video_embedding(
                    video_enc, video_preprocess, gen_video, device=DEVICE)
                emo = emotion_loss_fn(audio_emb.detach(), video_emb)

            loss = recon + lambda_emo * emo

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()

        total_recon += recon.item()
        total_emo += emo.item()
        total_loss += loss.item()

    n = len(loader)
    return {"recon": total_recon / n, "emotion": total_emo / n, "total": total_loss / n}


@torch.no_grad()
def evaluate(model, loader, emotion_loss_fn, lambda_emo):
    model.eval()
    total_recon, total_emo, total_loss = 0.0, 0.0, 0.0
    metric = EmotionAgreementMetric()

    for batch in tqdm(loader, leave=False):
        mel = batch["mel"].to(DEVICE)
        face_in = batch["face_input"].to(DEVICE)
        gt = batch["gt"].to(DEVICE)
        B, T = mel.shape[0], mel.shape[1]

        all_gen = []
        recon = 0.0
        with autocast("cuda", enabled=DEVICE == "cuda"):
            for t in range(T):
                gen = model(mel[:, t], face_in[:, t])
                recon += F.l1_loss(gen, gt[:, t])
                all_gen.append(gen)
            recon = recon / T

            emo = torch.tensor(0.0, device=DEVICE)
            if lambda_emo > 0:
                gen_video = torch.stack(all_gen, dim=1)
                audio_emb = extract_audio_embedding(
                    audio_enc, audio_proc, batch["audio"], device=DEVICE)
                video_emb = extract_video_embedding(
                    video_enc, video_preprocess, gen_video, device=DEVICE)
                emo = emotion_loss_fn(audio_emb, video_emb)
                metric.update(audio_emb, video_emb)

            loss = recon + lambda_emo * emo

        total_recon += recon.item()
        total_emo += emo.item()
        total_loss += loss.item()

    n = len(loader)
    result = {"recon": total_recon / n, "emotion": total_emo / n, "total": total_loss / n}
    if lambda_emo > 0:
        result.update(metric.compute())
    return result

In [None]:
train_ds = Wav2LipDataset(METADATA, "train", T=T_FRAMES)
val_ds = Wav2LipDataset(METADATA, "val", T=T_FRAMES)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=0, collate_fn=collate_wav2lip)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=0, collate_fn=collate_wav2lip)
print(f"Train: {len(train_ds)}, Val: {len(val_ds)}")

all_results = []

for cfg in CONFIGS:
    name = cfg["name"]
    lambda_emo = cfg["lambda_emo"]
    print(f"\n{'='*60}\n{name} (lambda_emo={lambda_emo})\n{'='*60}")

    wandb.init(project="uncanny-valley-wav2lip", name=name,
               config={**cfg, "lr": LR, "epochs": EPOCHS}, reinit=True)

    model = load_wav2lip(WAV2LIP_CKPT, DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    scaler = GradScaler(enabled=DEVICE == "cuda")
    emotion_loss_fn = CrossModalEmotionLoss(weight=1.0)

    best_val, patience_cnt = float("inf"), 0
    save_path = OUT_DIR / name

    for epoch in range(EPOCHS):
        t = train_one_epoch(model, train_loader, optimizer, scaler,
                            emotion_loss_fn, lambda_emo)
        v = evaluate(model, val_loader, emotion_loss_fn, lambda_emo)

        wandb.log({
            "epoch": epoch + 1,
            "train/recon": t["recon"], "train/emotion": t["emotion"], "train/total": t["total"],
            "val/recon": v["recon"], "val/emotion": v["emotion"], "val/total": v["total"],
            **{f"val/{k}": v[k] for k in ["avg_cosine_sim", "agreement_rate"] if k in v},
        })

        print(f"  [{epoch+1:2d}/{EPOCHS}] "
              f"t_loss={t['total']:.4f} v_loss={v['total']:.4f} v_recon={v['recon']:.4f}"
              + (f" cos_sim={v.get('avg_cosine_sim', 0):.3f}" if lambda_emo > 0 else ""))

        if v["total"] < best_val:
            best_val = v["total"]
            save_path.mkdir(parents=True, exist_ok=True)
            torch.save(model.state_dict(), save_path / "wav2lip.pth")
            patience_cnt = 0
        else:
            patience_cnt += 1
            if patience_cnt >= PATIENCE:
                print(f"  Early stopping at epoch {epoch+1}")
                break

    wandb.finish()
    del model, optimizer, scaler
    torch.cuda.empty_cache()
    gc.collect()
    all_results.append({"name": name, "lambda_emo": lambda_emo, "best_val": best_val})
    print(f"  Best val loss: {best_val:.4f} -> {save_path}")

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(all_results).sort_values("best_val")
print(df.to_string(index=False))

fig, ax = plt.subplots(figsize=(8, 4))
ax.bar(df["name"], df["best_val"], color="steelblue")
ax.set_ylabel("Best Val Loss")
ax.set_title("Wav2Lip Fine-tuning: Î»_emo Ablation")
plt.xticks(rotation=30, ha="right")
plt.tight_layout()
plt.show()

In [None]:
best_name = df.iloc[0]["name"]
best_model = load_wav2lip(WAV2LIP_CKPT, DEVICE)
best_model.load_state_dict(torch.load(OUT_DIR / best_name / "wav2lip.pth", map_location=DEVICE, weights_only=True))
best_model.eval()
print(f"Loaded best model: {best_name}")

metric = EmotionAgreementMetric()
all_recon = []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating best"):
        mel = batch["mel"].to(DEVICE)
        face_in = batch["face_input"].to(DEVICE)
        gt = batch["gt"].to(DEVICE)
        T = mel.shape[1]

        all_gen = []
        for t in range(T):
            gen = best_model(mel[:, t], face_in[:, t])
            all_gen.append(gen)
            all_recon.append(F.l1_loss(gen, gt[:, t]).item())

        gen_video = torch.stack(all_gen, dim=1)
        audio_emb = extract_audio_embedding(audio_enc, audio_proc, batch["audio"], device=DEVICE)
        video_emb = extract_video_embedding(video_enc, video_preprocess, gen_video, device=DEVICE)
        metric.update(audio_emb, video_emb)

agreement = metric.compute()
print(f"\nBest model evaluation:")
print(f"  Avg L1 recon:     {np.mean(all_recon):.4f}")
print(f"  Avg cosine sim:   {agreement['avg_cosine_sim']:.4f}")
print(f"  Agreement rate:   {agreement['agreement_rate']:.4f}")
print(f"  Std cosine sim:   {agreement['std_cosine_sim']:.4f}")

del best_model
torch.cuda.empty_cache()