In [None]:
import os
import cv2
import numpy as np
import subprocess
from tqdm import tqdm

# Define paths - replace these with your actual paths
DATA_ROOT = "E:\\training data"
OUTPUT_DIR = "E:\\play_6.0\\processed_data"

def extract_frames_and_audio(video_path, frames_dir, audio_path):
    """Extract frames and audio from video file"""
    os.makedirs(frames_dir, exist_ok=True)
    
    # Extract frames
    video = cv2.VideoCapture(video_path)
    frame_idx = 0
    while True:
        ret, frame = video.read()
        if not ret:
            break
        cv2.imwrite(os.path.join(frames_dir, f'{frame_idx}.jpg'), frame)
        frame_idx += 1
    video.release()
    
    # Extract audio
    command = f'ffmpeg -y -i "{video_path}" -strict -2 "{audio_path}"'
    subprocess.call(command, shell=True)

def process_dataset(data_root, output_dir):
    """Process all videos in the dataset"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Create filelists directory
    filelists_dir = os.path.join(output_dir, 'filelists')
    os.makedirs(filelists_dir, exist_ok=True)
    
    train_filelist = []
    val_filelist = []
    
    # Process each subject folder
    subject_dirs = [d for d in os.listdir(data_root) if os.path.isdir(os.path.join(data_root, d)) and d.endswith('_processed')]
    
    for subject_dir in tqdm(subject_dirs):
        subject_path = os.path.join(data_root, subject_dir)
        
        # Get all video files
        video_files = [f for f in os.listdir(subject_path) if f.endswith('.mpg')]
        
        # Process each video
        for video_file in tqdm(video_files, desc=f'Processing {subject_dir}'):
            video_name = os.path.splitext(video_file)[0]
            video_path = os.path.join(subject_path, video_file)
            
            # Create output directories
            video_output_dir = os.path.join(output_dir, subject_dir, video_name)
            frames_dir = os.path.join(video_output_dir)
            audio_path = os.path.join(video_output_dir, 'audio.wav')
            
            # Extract frames and audio
            extract_frames_and_audio(video_path, frames_dir, audio_path)
            
            # Copy alignment file if it exists
            align_file = os.path.join(subject_path, 'align', f'{video_name}.align')
            if os.path.exists(align_file):
                os.makedirs(os.path.join(video_output_dir, 'align'), exist_ok=True)
                with open(align_file, 'r') as src, open(os.path.join(video_output_dir, 'align', f'{video_name}.align'), 'w') as dst:
                    dst.write(src.read())
            
            # Add to filelist (80% train, 20% val)
            if np.random.rand() < 0.8:
                train_filelist.append(os.path.join(subject_dir, video_name))
            else:
                val_filelist.append(os.path.join(subject_dir, video_name))
    
    # Write filelists
    with open(os.path.join(filelists_dir, 'train.txt'), 'w') as f:
        f.write('\n'.join(train_filelist))
    
    with open(os.path.join(filelists_dir, 'val.txt'), 'w') as f:
        f.write('\n'.join(val_filelist))

# Run the preprocessing
process_dataset(DATA_ROOT, OUTPUT_DIR)
print("Preprocessing completed!")


Processing s10_processed: 100%|██████████| 1000/1000 [11:51<00:00,  1.41it/s]
Processing s11_processed: 100%|██████████| 1000/1000 [10:17<00:00,  1.62it/s]
Processing s12_processed: 100%|██████████| 1000/1000 [10:04<00:00,  1.65it/s]
Processing s13_processed: 100%|██████████| 1000/1000 [11:15<00:00,  1.48it/s]
Processing s14_processed: 100%|██████████| 1000/1000 [11:53<00:00,  1.40it/s]
Processing s15_processed: 100%|██████████| 1000/1000 [10:32<00:00,  1.58it/s]
Processing s16_processed: 100%|██████████| 1000/1000 [09:44<00:00,  1.71it/s]
Processing s17_processed: 100%|██████████| 1000/1000 [11:55<00:00,  1.40it/s]
Processing s18_processed: 100%|██████████| 1000/1000 [09:48<00:00,  1.70it/s]
Processing s19_processed: 100%|██████████| 1000/1000 [09:39<00:00,  1.73it/s]
Processing s1_processed: 100%|██████████| 1000/1000 [11:55<00:00,  1.40it/s]
Processing s20_processed: 100%|██████████| 1000/1000 [11:43<00:00,  1.42it/s]
 36%|███▋      | 12/33 [2:10:41<3:53:02, 665.83s/it]

In [None]:
import os
import torch
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
import argparse
import sys
sys.path.append('E:\\play_6.0\\Wav2Lip')

# Import Wav2Lip modules
from hparams import hparams
from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
from models import Wav2Lip_disc_qual as Discriminator

# Define paths
DATA_ROOT = "E:\\play_6.0\\processed_data"
CHECKPOINT_DIR = "E:\\play_6.0\\checkpoints"
SYNCNET_PATH = "E:\\play_6.0\\Wav2Lip\\checkpoints\\syncnet.pth"

# Set improved hyperparameters
hparams.set_hparam('batch_size', 8)  # Smaller batch size for better quality
hparams.set_hparam('syncnet_wt', 0.03)  # Start with higher sync weight
hparams.set_hparam('disc_wt', 0.1)  # Increase discriminator weight for better visual quality
hparams.set_hparam('img_size', 128)  # Increase image size for better resolution
hparams.set_hparam('initial_learning_rate', 5e-5)  # Lower learning rate for more stable training

# Run the training command
command = f'python E:\\play_6.0\\Wav2Lip\\wav2lip_train.py --data_root {DATA_ROOT} --checkpoint_dir {CHECKPOINT_DIR} --syncnet_checkpoint_path {SYNCNET_PATH}'
os.system(command)


In [None]:
import os
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list

from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
from models import Wav2Lip_disc_qual as Discriminator
import audio

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model with the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed dataset", required=True)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True)
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None)
parser.add_argument('--disc_checkpoint_path', help='Resume discriminator from this checkpoint', default=None)
args = parser.parse_args()

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

# Rest of the code from wav2lip_train.py with modifications for GAN training

# Add discriminator loss
def train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step
 
    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, running_disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            model.train()
            disc_model.train()
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            g = model(indiv_mels, x)

            # Sync loss
            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            # L1 loss
            l1loss = recon_loss(g, gt)

            # Perceptual loss (GAN)
            if hparams.disc_wt > 0.:
                perceptual_loss = disc_model.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            # Combined loss
            loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt - hparams.disc_wt) * l1loss
            if hparams.disc_wt > 0.:
                loss += hparams.disc_wt * perceptual_loss

            loss.backward()
            optimizer.step()

            # Train discriminator
            if hparams.disc_wt > 0.:
                disc_optimizer.zero_grad()
                
                # Real samples - should be classified as real (1)
                real_pred = disc_model(gt)
                real_loss = F.binary_cross_entropy(real_pred, torch.ones((len(real_pred), 1)).to(device))
                
                # Fake samples - should be classified as fake (0)
                fake_pred = disc_model(g.detach())
                fake_loss = F.binary_cross_entropy(fake_pred, torch.zeros((len(fake_pred), 1)).to(device))
                
                disc_loss = (real_loss + fake_loss) / 2
                disc_loss.backward()
                disc_optimizer.step()
            else:
                disc_loss = 0.

            # Rest of the training loop...

# Main execution
if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = Wav2Lip().to(device)
    disc_model = Discriminator().to(device)
    
    print('Generator: total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('Discriminator: total trainable params {}'.format(sum(p.numel() for p in disc_model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate)
    disc_optimizer = optim.Adam([p for p in disc_model.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
        
    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc_model, disc_optimizer, reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)

    # Train!
    train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)


In [None]:
import os
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list

from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
from models import Wav2Lip_disc_qual as Discriminator
import audio

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model with the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed dataset", required=True)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True)
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None)
parser.add_argument('--disc_checkpoint_path', help='Resume discriminator from this checkpoint', default=None)
args = parser.parse_args()

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

# Rest of the code from wav2lip_train.py with modifications for GAN training

# Add discriminator loss
def train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step
 
    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, running_disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            model.train()
            disc_model.train()
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            g = model(indiv_mels, x)

            # Sync loss
            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            # L1 loss
            l1loss = recon_loss(g, gt)

            # Perceptual loss (GAN)
            if hparams.disc_wt > 0.:
                perceptual_loss = disc_model.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            # Combined loss
            loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt - hparams.disc_wt) * l1loss
            if hparams.disc_wt > 0.:
                loss += hparams.disc_wt * perceptual_loss

            loss.backward()
            optimizer.step()

            # Train discriminator
            if hparams.disc_wt > 0.:
                disc_optimizer.zero_grad()
                
                # Real samples - should be classified as real (1)
                real_pred = disc_model(gt)
                real_loss = F.binary_cross_entropy(real_pred, torch.ones((len(real_pred), 1)).to(device))
                
                # Fake samples - should be classified as fake (0)
                fake_pred = disc_model(g.detach())
                fake_loss = F.binary_cross_entropy(fake_pred, torch.zeros((len(fake_pred), 1)).to(device))
                
                disc_loss = (real_loss + fake_loss) / 2
                disc_loss.backward()
                disc_optimizer.step()
            else:
                disc_loss = 0.

            # Rest of the training loop...

# Main execution
if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = Wav2Lip().to(device)
    disc_model = Discriminator().to(device)
    
    print('Generator: total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('Discriminator: total trainable params {}'.format(sum(p.numel() for p in disc_model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate)
    disc_optimizer = optim.Adam([p for p in disc_model.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
        
    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc_model, disc_optimizer, reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)

    # Train!
    train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)


In [None]:
import os
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list

from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
from models import Wav2Lip_disc_qual as Discriminator
import audio

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model with the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed dataset", required=True)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True)
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None)
parser.add_argument('--disc_checkpoint_path', help='Resume discriminator from this checkpoint', default=None)
args = parser.parse_args()

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

# Rest of the code from wav2lip_train.py with modifications for GAN training

# Add discriminator loss
def train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step
 
    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, running_disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            model.train()
            disc_model.train()
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            g = model(indiv_mels, x)

            # Sync loss
            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            # L1 loss
            l1loss = recon_loss(g, gt)

            # Perceptual loss (GAN)
            if hparams.disc_wt > 0.:
                perceptual_loss = disc_model.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            # Combined loss
            loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt - hparams.disc_wt) * l1loss
            if hparams.disc_wt > 0.:
                loss += hparams.disc_wt * perceptual_loss

            loss.backward()
            optimizer.step()

            # Train discriminator
            if hparams.disc_wt > 0.:
                disc_optimizer.zero_grad()
                
                # Real samples - should be classified as real (1)
                real_pred = disc_model(gt)
                real_loss = F.binary_cross_entropy(real_pred, torch.ones((len(real_pred), 1)).to(device))
                
                # Fake samples - should be classified as fake (0)
                fake_pred = disc_model(g.detach())
                fake_loss = F.binary_cross_entropy(fake_pred, torch.zeros((len(fake_pred), 1)).to(device))
                
                disc_loss = (real_loss + fake_loss) / 2
                disc_loss.backward()
                disc_optimizer.step()
            else:
                disc_loss = 0.

            # Rest of the training loop...

# Main execution
if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = Wav2Lip().to(device)
    disc_model = Discriminator().to(device)
    
    print('Generator: total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('Discriminator: total trainable params {}'.format(sum(p.numel() for p in disc_model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate)
    disc_optimizer = optim.Adam([p for p in disc_model.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
        
    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc_model, disc_optimizer, reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)

    # Train!
    train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)


In [None]:
import os
import torch
from torch import nn
from torch import optim
import torch.backends.cudnn as cudnn
from torch.utils import data as data_utils
import numpy as np
from glob import glob
import os, random, cv2, argparse
from hparams import hparams, get_image_list

from models import SyncNet_color as SyncNet
from models import Wav2Lip as Wav2Lip
from models import Wav2Lip_disc_qual as Discriminator
import audio

parser = argparse.ArgumentParser(description='Code to train the Wav2Lip model with the visual quality discriminator')
parser.add_argument("--data_root", help="Root folder of the preprocessed dataset", required=True)
parser.add_argument('--checkpoint_dir', help='Save checkpoints to this directory', required=True)
parser.add_argument('--syncnet_checkpoint_path', help='Load the pre-trained Expert discriminator', required=True)
parser.add_argument('--checkpoint_path', help='Resume from this checkpoint', default=None)
parser.add_argument('--disc_checkpoint_path', help='Resume discriminator from this checkpoint', default=None)
args = parser.parse_args()

global_step = 0
global_epoch = 0
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))

syncnet_T = 5
syncnet_mel_step_size = 16

# Rest of the code from wav2lip_train.py with modifications for GAN training

# Add discriminator loss
def train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
          checkpoint_dir=None, checkpoint_interval=None, nepochs=None):
    global global_step, global_epoch
    resumed_step = global_step
 
    while global_epoch < nepochs:
        print('Starting Epoch: {}'.format(global_epoch))
        running_sync_loss, running_l1_loss, running_disc_loss, running_perceptual_loss = 0., 0., 0., 0.
        prog_bar = tqdm(enumerate(train_data_loader))
        for step, (x, indiv_mels, mel, gt) in prog_bar:
            model.train()
            disc_model.train()
            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            # Move data to CUDA device
            x = x.to(device)
            mel = mel.to(device)
            indiv_mels = indiv_mels.to(device)
            gt = gt.to(device)

            g = model(indiv_mels, x)

            # Sync loss
            if hparams.syncnet_wt > 0.:
                sync_loss = get_sync_loss(mel, g)
            else:
                sync_loss = 0.

            # L1 loss
            l1loss = recon_loss(g, gt)

            # Perceptual loss (GAN)
            if hparams.disc_wt > 0.:
                perceptual_loss = disc_model.perceptual_forward(g)
            else:
                perceptual_loss = 0.

            # Combined loss
            loss = hparams.syncnet_wt * sync_loss + (1 - hparams.syncnet_wt - hparams.disc_wt) * l1loss
            if hparams.disc_wt > 0.:
                loss += hparams.disc_wt * perceptual_loss

            loss.backward()
            optimizer.step()

            # Train discriminator
            if hparams.disc_wt > 0.:
                disc_optimizer.zero_grad()
                
                # Real samples - should be classified as real (1)
                real_pred = disc_model(gt)
                real_loss = F.binary_cross_entropy(real_pred, torch.ones((len(real_pred), 1)).to(device))
                
                # Fake samples - should be classified as fake (0)
                fake_pred = disc_model(g.detach())
                fake_loss = F.binary_cross_entropy(fake_pred, torch.zeros((len(fake_pred), 1)).to(device))
                
                disc_loss = (real_loss + fake_loss) / 2
                disc_loss.backward()
                disc_optimizer.step()
            else:
                disc_loss = 0.

            # Rest of the training loop...

# Main execution
if __name__ == "__main__":
    checkpoint_dir = args.checkpoint_dir
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Dataset and Dataloader setup
    train_dataset = Dataset('train')
    test_dataset = Dataset('val')

    train_data_loader = data_utils.DataLoader(
        train_dataset, batch_size=hparams.batch_size, shuffle=True,
        num_workers=hparams.num_workers)

    test_data_loader = data_utils.DataLoader(
        test_dataset, batch_size=hparams.batch_size,
        num_workers=4)

    device = torch.device("cuda" if use_cuda else "cpu")

    # Model
    model = Wav2Lip().to(device)
    disc_model = Discriminator().to(device)
    
    print('Generator: total trainable params {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    print('Discriminator: total trainable params {}'.format(sum(p.numel() for p in disc_model.parameters() if p.requires_grad)))

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=hparams.initial_learning_rate)
    disc_optimizer = optim.Adam([p for p in disc_model.parameters() if p.requires_grad],
                           lr=hparams.disc_initial_learning_rate)

    if args.checkpoint_path is not None:
        load_checkpoint(args.checkpoint_path, model, optimizer, reset_optimizer=False)
        
    if args.disc_checkpoint_path is not None:
        load_checkpoint(args.disc_checkpoint_path, disc_model, disc_optimizer, reset_optimizer=False, overwrite_global_states=False)
        
    load_checkpoint(args.syncnet_checkpoint_path, syncnet, None, reset_optimizer=True, overwrite_global_states=False)

    # Train!
    train(device, model, disc_model, train_data_loader, test_data_loader, optimizer, disc_optimizer,
              checkpoint_dir=checkpoint_dir,
              checkpoint_interval=hparams.checkpoint_interval,
              nepochs=hparams.nepochs)


In [None]:
import cv2
import numpy as np
import torch
import os
import argparse
from tqdm import tqdm
from models import Wav2Lip
import face_detection
import audio

parser = argparse.ArgumentParser(description='Inference with improved post-processing')
parser.add_argument('--checkpoint_path', help='Path to the Wav2Lip model checkpoint', required=True)
parser.add_argument('--face', help='Path to video/image that contains faces to use', required=True)
parser.add_argument('--audio', help='Path to audio file to use', required=True)
parser.add_argument('--outfile', help='Path to save the output file', required=True)
parser.add_argument('--smooth_factor', help='Smoothing factor for blending', default=0.8, type=float)
parser.add_argument('--enhance_face', help='Apply face enhancement', action='store_true')
args = parser.parse_args()

def get_smoothened_boxes(boxes, T):
    """Smooth detection boxes across frames"""
    for i in range(len(boxes)):
        if i > 0 and i < len(boxes) - 1:
            boxes[i] = 0.5 * boxes[i] + 0.25 * boxes[i-1] + 0.25 * boxes[i+1]
    return boxes

def face_detect(images):
    """Detect faces in images"""
    detector = face_detection.FaceAlignment(face_detection.LandmarksType._2D, 
                                          flip_input=False, device='cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 16
    while batch_size > 0:
        try:
            predictions = []
            for i in tqdm(range(0, len(images), batch_size)):
                batch = images[i:i + batch_size]
                results = detector.get_detections_for_batch(np.array(batch))
                predictions.extend(results)
            break
        except RuntimeError:
            batch_size //= 2
    
    results = []
    pady1, pady2, padx1, padx2 = [0, 10, 0, 0]
    for rect, image in zip(predictions, images):
        if rect is None:
            continue
        
        y1 = max(0, rect[1] - pady1)
        y2 = min(image.shape[0], rect[3] + pady2)
        x1 = max(0, rect[0] - padx1)
        x2 = min(image.shape[1], rect[2] + padx2)
        
        results.append([x1, y1, x2, y2])
    
    boxes = get_smoothened_boxes(results, T=5)
    return boxes

def enhance_face(face):
    """Apply simple face enhancement"""
    # Apply subtle sharpening
    kernel = np.array([[-1, -1, -1], 
                       [-1,  9, -1],
                       [-1, -1, -1]])
    sharpened = cv2.filter2D(face, -1, kernel)
    
    # Blend with original
    enhanced = cv2.addWeighted(face, 0.7, sharpened, 0.3, 0)
    
    # Subtle color correction
    lab = cv2.cvtColor(enhanced, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl = clahe.apply(l)
    merged = cv2.merge([cl, a, b])
    enhanced = cv2.cvtColor(merged, cv2.COLOR_LAB2BGR)
    
    return enhanced

def main():
    # Load model
    model = Wav2Lip()
    checkpoint = torch.load(args.checkpoint_path)
    s = checkpoint["state_dict"]
    new_s = {}
    for k, v in s.items():
        new_s[k.replace('module.', '')] = v
    model.load_state_dict(new_s)
    model = model.to('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()
    
    # Load video and audio
    video_stream = cv2.VideoCapture(args.face)
    fps = video_stream.get(cv2.CAP_PROP_FPS)
    
    # Read video frames
    full_frames = []
    while True:
        ret, frame = video_stream.read()
        if not ret:
            break
        full_frames.append(frame)
    
    # Detect faces
    face_boxes = face_detect(full_frames)
    
    # Process audio
    wav = audio.load_wav(args.audio, 16000)
    mel = audio.melspectrogram(wav)
    
    # Process each frame
    output_frames = []
    for i, (frame, bbox) in enumerate(tqdm(zip(full_frames, face_boxes), total=len(full_frames))):
        x1, y1, x2, y2 = bbox
        face = frame[y1:y2, x1:x2]
        
        # Prepare face for model
        face = cv2.resize(face, (96, 96))
        face = np.transpose(face, (2, 0, 1))
        face = face / 255.0
        face = torch.FloatTensor(face).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Get corresponding audio segment
        frame_idx = i
        start_idx = int(80. * (frame_idx / float(fps)))
        end_idx = start_idx + 16
        if end_idx > len(mel):
            break
        mel_segment = torch.FloatTensor(mel[start_idx:end_idx]).unsqueeze(0).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Generate lip-synced face
        with torch.no_grad():
            pred = model(mel_segment, face)
        
        # Convert back to numpy
        pred = pred.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255.
        pred = pred.astype(np.uint8)
        
        # Apply face enhancement if requested
        if args.enhance_face:
            pred = enhance_face(pred)
        
        # Resize back to original size
        pred = cv2.resize(pred, (x2 - x1, y2 - y1))
        
        # Create a mask for smooth blending
        mask = np.zeros((y2 - y1, x2 - x1), dtype=np.float32)
        mask[int(mask.shape[0]*0.3):] = 1.0  # Only blend the lower part of the face (mouth region)
        mask = cv2.GaussianBlur(mask, (15, 15), 5)
        mask = np.expand_dims(mask, -1)
        
        # Blend the generated face with the original using the mask
        blended_face = args.smooth_factor * pred + (1 - args.smooth_factor) * frame[y1:y2, x1:x2]
        blended_face = blended_face.astype(np.uint8)
        
        # Copy the blended face back to the original frame
        output_frame = frame.copy()
        output_frame[y1:y2, x1:x2] = blended_face
        
        output_frames.append(output_frame)
    
    # Write output video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(args.outfile, fourcc, fps, (full_frames[0].shape[1], full_frames[0].shape[0]))
    for frame in output_frames:
        out.write(frame)
    out.release()

if __name__ == "__main__":
    main()


@echo off
echo Starting Wav2Lip+GAN training pipeline

set DATA_ROOT=E:\training data
set OUTPUT_DIR=E:\play_6.0\processed_data
set CHECKPOINT_DIR=E:\play_6.0\checkpoints
set SYNCNET_PATH=E:\play_6.0\Wav2Lip\checkpoints\syncnet.pth

echo Step 1: Preprocessing data...
python preprocess.py --data_root "%DATA_ROOT%" --output_dir "%OUTPUT_DIR%"

echo Step 2: Training Wav2Lip model...
python train_wav2lip.py --data_root "%OUTPUT_DIR%" --checkpoint_dir "%CHECKPOINT_DIR%" --syncnet_checkpoint_path "%SYNCNET_PATH%"

echo Step 3: Training Wav2Lip+GAN model...
python train_wav2lip_gan.py --data_root "%OUTPUT_DIR%" --checkpoint_dir "%CHECKPOINT_DIR%\gan" --syncnet_checkpoint_path "%SYNCNET_PATH%" --checkpoint_path "%CHECKPOINT_DIR%\checkpoint_latest.pth"

echo Training complete! Models saved to %CHECKPOINT_DIR%
