## Import Packages

In [1]:
from os import listdir, path
import numpy as np
import scipy, cv2, os, argparse, audio
import subprocess
import librosa
from tqdm import tqdm
import audio.audio_utils as audio
from model import *
import audio.hparams as hp 
import torch

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Set device

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Preprocessing video file

In [3]:
# Extracts frames from the video
def get_frames(file, mask):

    video_stream = cv2.VideoCapture(file)

    frames = []
    while 1:
        still_reading, frame = video_stream.read()

        if not still_reading:
            video_stream.release()
            break

        # Mask out the specified regions
        if mask == 'r':
            index = frame.shape[1]//2
            frame = frame[:, :index]
        elif mask == 'l':
            index = frame.shape[1]//2
            frame = frame[:, index:]

        frames.append(frame)

    return frames

# Function to obtain the window of images
def get_window_images(window_images):
    window = []
    for img in window_images:
        if img is None:
            raise FileNotFoundError('Missing frames!')

        img = cv2.resize(img, (hp.hparams.img_size, hp.hparams.img_size))		# 3x96x96
        window.append(img)

    x_image = np.asarray(window) / 255. 

    return x_image

# Audio Preprocessing

In [4]:
# Functon to load the wav file
def load_wav(input_file):

    wav_file  = 'tmp.wav';

    subprocess.call('ffmpeg -hide_banner -loglevel panic -threads 1 -y -i %s -async 1 -ac 1 -vn \
                    -acodec pcm_s16le -ar 16000 %s' % (input_file, wav_file), shell=True)

    wav = audio.load_wav(wav_file, 16000)

    os.remove("tmp.wav")

    return wav


# Function to extract the spectrogram from the wav file
def get_spec(wav):

    # Extract the STFT
    stft = librosa.stft(y=wav, n_fft=hp.hparams.n_fft, hop_length=hp.hparams.hop_size, win_length=hp.hparams.win_size).T
    stft = stft[:-1]
    # print("STFT: ", stft.shape)                                       # 100x257

    # Decompose STFT into magnitude and phase components
    mag = np.abs(stft)
    mag = audio.db_from_amp(mag)
    phase = audio.angle(stft)

    # Normalize the magnitude and phase components
    norm_mag = audio.normalize_mag(mag)
    norm_phase = audio.normalize_phase(phase)

    # Concatenate the magnitude and phase components
    spec = np.concatenate((norm_mag, norm_phase), axis=1)               # 100x514

    return spec


# Function to segment the spectrograms
def get_window_spec(spec_ip, idx):

    frame_num = idx
    start_idx = int((hp.hparams.spec_step_size / hp.hparams.fps) * frame_num)
    end_idx = start_idx+hp.hparams.spec_step_size

    spec_window = spec_ip[start_idx:end_idx, :]                        # 100x514

    return spec_window

# Speaker dependent audio and video generation

In [5]:
# Function to reconstruct the audio and generate the output video 
def generate_video(mag, phase, input_file, result_dir):

    denorm_mag = audio.unnormalize_mag(mag)
    denorm_phase = audio.unnormalize_phase(phase)
    recon_mag = audio.amp_from_db(denorm_mag)
    complex_arr = audio.make_complex(recon_mag, denorm_phase)
    wav = librosa.istft(complex_arr, hop_length=hp.hparams.hop_size, \
                        win_length=hp.hparams.win_size)
    print("Generated wav: ", wav.shape)


    # Create the folder to save the results
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)

    # Save the wav file
    audio_output = os.path.join(result_dir, 'pred_'+input_file.rsplit('/')[-1].split('.')[0] + '.wav')
    librosa.output.write_wav(audio_output, wav, 16000)

    # Save the video output file
    no_sound_video = os.path.join(result_dir, input_file.rsplit('/')[-1].split('.')[0] + '_nosouund.mp4')
    subprocess.call('ffmpeg -hide_banner -loglevel panic -i %s -c copy -an -strict -2 %s' % (input_file, no_sound_video), shell=True)

    video_output_mp4 = os.path.join(result_dir, 'pred_'+input_file.rsplit('/')[-1].split('.')[0] + '.mp4')
    if os.path.exists(video_output_mp4):
        os.remove(video_output_mp4)

    subprocess.call('ffmpeg -hide_banner -loglevel panic -y -i %s -i %s -strict -2 -q:v 1 %s' % 
                    (audio_output, no_sound_video, video_output_mp4), shell=True)

    os.remove(no_sound_video)

    print("Successfully generated the output video:", video_output_mp4)

# Load saved checkpoint

In [6]:
# Function to load the model
def load_model(checkpoint):

    model = Model()

    if not torch.cuda.is_available():
        checkpoint = torch.load(checkpoint, map_location='cpu')
    else:
        checkpoint = torch.load(checkpoint)

    # model.load_state_dict(checkpoint['model_state_dict'])
    if torch.cuda.device_count() > 1:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        ckpt = {}
        for key in checkpoint['model_state_dict'].keys():
            k = key.split('module.', 1)[1]
            ckpt[k] = checkpoint['model_state_dict'][key]
        model.load_state_dict(ckpt)	
    model = model.to(device)

    print("Loaded model from: ", checkpoint)

    return model.eval()

# Generate predictions

In [7]:
# Function to obtain the predictions
def predict(input_file, num_frames, mask, checkpoint, result_dir):

    # Check the input video
    video_formats = ['mp4', 'avi', 'mkv']
    if input_file.rsplit('/')[-1].rsplit('.', 1)[1] not in video_formats:
        print("Oops! Invalid input. Please try again by providing the appropriate video input.")
        exit(0)

    # Extract the frames from the given input video
    faces = get_frames(input_file, mask)
    total_frames = len(faces)

    print(f'No of frames: {len(faces)}')
    if len(faces) < num_frames: 
        print("No of frames is less than {}!".format(num_frames))
        return
    print("Total no of frames = ", total_frames)

    # Obtain a window for frames
    id_windows = [range(i, i + num_frames) for i in range(0, total_frames, 
                num_frames - hp.hparams.overlap) if (i + num_frames <= total_frames)]
    print("ID windows: ", id_windows)

    all_images = [[faces[i] for i in window] for window in id_windows]
    print("All images: ", len(all_images))

    inp_wav = load_wav(input_file)
    spec_ip = get_spec(inp_wav)
    print("Noisy spec inp: ", spec_ip.shape)


    # Load the model
    model = load_model(checkpoint)


    for i, window_images in enumerate(tqdm(all_images)):

        images = get_window_images(window_images)

        if(images.shape[0] != num_frames):
            continue
        image_batch = np.expand_dims(images, axis=0)			#1x25x15x48x96

        # Get the corresponding input noisy melspectrograms
        idx = id_windows[i][0]
        spec_window = get_window_spec(spec_ip, idx)

        if(spec_window.shape[0] != hp.hparams.spec_step_size):
            continue
        spec_batch = np.expand_dims(np.array(spec_window), axis=0)

        x_mag = torch.FloatTensor(spec_batch)[..., :257].to(device)
        x_phase = torch.FloatTensor(spec_batch)[..., 257:].to(device)
        x_image = torch.FloatTensor(image_batch).to(device)

        # Predict the spectrograms for the corresponding window
        with torch.no_grad():
            pred_mag, pred_phase = model(x_mag, x_phase, x_image)


        pred_mag = pred_mag.cpu().numpy()
        pred_mag = np.squeeze(pred_mag, axis=0).T

        pred_phase = pred_phase.cpu().numpy()
        pred_phase = np.squeeze(pred_phase, axis=0).T


        # Concatenate the melspectrogram windows to generate the complete spectrogram	
        if i == 0:
            generated_mag = pred_mag[:, :80]
            generated_phase = pred_phase[:, :80]
        else:
            generated_mag = np.concatenate((generated_mag, pred_mag[:, :80]), axis=1)
            generated_phase = np.concatenate((generated_phase, pred_phase[:, :80]), axis=1)


    print("Output mag: ", generated_mag.shape)
    print("Output phase: ", generated_phase.shape)

    # Reconstruct the audio and generate the output video
    generate_video(generated_mag, generated_phase, input_file, result_dir)

# Run Code

In [8]:
# parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# args = parser.parse_args()

checkpoint = 'model_vox.pt'
input_file = '2.mp4'
mask = None # 'l' or 'r' is used to specify which speaker to mask
result_dir = 'results'
sampling_rate = 16000 # audio sampling rate
fps = 25 # video fps

img_size = 96
num_frames = 25

# parser.add_argument('--input', type=str, required=True, help='Filepath of noisy video')
# parser.add_argument('--checkpoint', type=str, required=True, help='Name of trained checkpoint to load weights from')

# parser.add_argument('--mask', default=None, required=False, help='Type either "l" or "r" which specifies to mask the left (l) or right (r) speaker')

# parser.add_argument('--result_dir', default='results', required=False, help='Name of the directory to save the results')
# parser.add_argument('--sampling_rate', type=int, required=False, default=16000)
# parser.add_argument('--fps', type=float, default=25., required=False, help='FPS of input video, ignore if image')

# args = parser.parse_args()
# args.img_size = 96
# args.num_frames = 25

predict(input_file, num_frames, mask, checkpoint, result_dir)

No of frames: 80
Total no of frames =  80
ID windows:  [range(0, 25), range(20, 45), range(40, 65)]
All images:  3
Noisy spec inp:  (262, 514)
Loaded model from:  {'step': 6000, 'model_state_dict': OrderedDict([('module.audio_encoder.0.conv_block.0.weight', tensor([[[ 0.1083,  0.1187,  0.1591],
         [-0.0836, -0.1334, -0.0745],
         [-0.1119, -0.1132, -0.0562],
         ...,
         [ 0.0211,  0.0460,  0.0265],
         [ 0.0209, -0.0172, -0.0058],
         [-0.0117,  0.0007,  0.0029]],

        [[ 0.0875, -0.0874,  0.0534],
         [-0.0358, -0.0909,  0.0048],
         [ 0.0256, -0.1184, -0.0259],
         ...,
         [ 0.0250,  0.0077, -0.0404],
         [ 0.0230,  0.0263,  0.0236],
         [-0.0012, -0.0188,  0.0113]],

        [[ 0.0109, -0.0295,  0.0060],
         [ 0.0128, -0.0384, -0.0318],
         [ 0.1256,  0.0948,  0.0689],
         ...,
         [ 0.0119, -0.0601,  0.0304],
         [ 0.0023, -0.0595,  0.0475],
         [ 0.0519, -0.0120,  0.0285]],

        ..

100%|██████████| 3/3 [00:00<00:00, 11.99it/s]


Output mag:  (257, 240)
Output phase:  (257, 240)
Generated wav:  (38240,)
Successfully generated the output video: results/pred_2.mp4
