In [1]:
!pip install soundfile
!pip install pandas
!pip install scipy
!pip install librosa
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from pathlib import Path
import pandas as pd
import numpy as np
import soundfile as sf
import scipy
from scipy import fft
import functools
from torch import nn
import librosa
import math


Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [2]:
   
def stft(filepath, 
         viseme_frame_len_in_samples, 
         audio_window_in_samples, 
         audio_bins_per_window,
         resample_to, 
         pad_len_in_secs, 
         n_mels):
    wdw = np.hanning(audio_window_in_samples)
    audio, rate = sf.read(filepath)
    audio = librosa.resample(audio, rate, resample_to)
    
    # left-pad the audio so we have the left context when starting at the initial viseme
    pad_len_in_samples = pad_len_in_secs * resample_to 
    if audio.shape[0] < pad_len_in_samples:
        audio = np.pad(audio, (0, (pad_len_in_secs * resample_to) - audio.shape[0]), constant_values=0.001)
    elif audio.shape[0] > pad_len_in_samples:
        audio = audio[:pad_len_in_samples]
    #audio = np.hstack([np.zeros((win_length*2,)), audio])
        
    padded_seq_length = (resample_to * pad_len_in_secs) // viseme_frame_len_in_samples
    actual_seq_length = audio.shape[0] // viseme_frame_len_in_samples
    
    # first, we need to take the STFT of the entire audio file 
    # with a window size equivalent to the audio_window_in_samples / audio_bins_per_window
    bin_length_in_samples = int(audio_window_in_samples / audio_bins_per_window)
    melfb = librosa.filters.mel(resample_to, bin_length_in_samples, n_mels=n_mels)

    transformed = librosa.stft(audio, 
                               n_fft=bin_length_in_samples,
                               hop_length=bin_length_in_samples)
    
    mels = np.dot(melfb, np.abs(transformed))
    
    log_mels = np.log(mels)
    
    coeffs = fft.dct(log_mels)
    
    output = np.zeros((padded_seq_length, audio_bins_per_window*n_mels))
    
    bin_hop_length = transformed.shape[1] // padded_seq_length
    
    for i in range(padded_seq_length):
        c = coeffs[:, (i*bin_hop_length):(i*bin_hop_length)+audio_bins_per_window]
        if(c.shape[1] < audio_bins_per_window):
            c = np.pad(c, [(0,0), (0, audio_bins_per_window - c.shape[1])], constant_values=0)
        output[i, :] = np.reshape(c, audio_bins_per_window*n_mels)

    return output, audio.shape[0], [1] * actual_seq_length + [0] * (padded_seq_length - actual_seq_length)



In [3]:
def preprocess_viseme(csv, pad_len_in_secs=None, resample_to=None, blendshapes=None):
    csv = pd.read_csv(csv)
    # first, drop every nth row to reduce effective framerate
    
    csv = csv.iloc[::int(59.97 / resample_to)]
    pad_len = int(pad_len_in_secs * resample_to)
    if(csv.shape[0] < pad_len):
        pad = pd.DataFrame(0, index=[i for i in range(pad_len - csv.shape[0])], columns=csv.columns)
        pad.pad(inplace=True)
        csv = pd.concat([csv, pad])
    else:
        csv = csv.iloc[:pad_len]
        #print("Visemes exceeded max length, truncate?")

    #split = csv["Timecode"].str.split(':')
    #minute = split.str[1].astype(int)
    #second = split.str[2].astype(int)
    #frame = split.str[3].astype(float)
    #minute -= minute[0]
    #ms
    #step = minute * 60 + second
    #csv["step"] = step
    #return csv.drop_duplicates(["step"])[["step", "MouthClose","MouthFunnel","MouthPucker","JawOpen"]]
    if blendshapes is None:
        return csv
    return csv[blendshapes]

In [4]:
# data_dir should be structured as follows:
# - speaker_id_1/
# - speaker_id_1/sample_id1.wav
# - speaker_id_1/sample_id1.csv
# - speaker_id_1/sample_id2.wav
# - speaker_id_1/sample_id2.wav
# - speaker_id_2/sample_id1.wav
# ..
class VisemeDataset(Dataset):
    def __init__(self, data_dir, audio_transform, viseme_transform):
        self.viseme_transform= viseme_transform
        self.audio_transform = audio_transform
        self.audio_files = []
        self.visemes = []
        for file in list(Path(data_dir).rglob("*.wav")):
            self.audio_files.append(file)
            self.visemes.append(str(file).replace("wav", "csv"))

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        fft, num_samples, mask = self.audio_transform(self.audio_files[idx])
        fft = torch.tensor(fft, dtype=torch.float)
        viseme_filename = self.visemes[idx]
        visemes = torch.tensor(self.viseme_transform(viseme_filename).values.astype(np.float32))       
        assert(visemes.shape[0] == fft.shape[0])
        return fft, torch.tensor(mask), visemes, viseme_filename

In [5]:
# the blendshapes that will used as input labels (and predictions)
blendshapes = ["MouthClose", "MouthFunnel", "MouthPucker", "JawOpen"]

# source framerate for raw viseme label input
framerate=59.97 

# actual framerate to use for viseme labels. 
# raw labels will be resampled/transformed (either averaged or simply dropped).
target_framerate = framerate / 2 

# the sample rate that audio will be resampled to
resample_to = 22050

# the duration of each viseme frame, in seconds
viseme_frame_len_in_seconds = 1 / target_framerate

# STFT hop length should equal to length of a single viseme frame, in samples
# round up to ensure consistency with viseme frame length
viseme_frame_len_in_samples = math.ceil(resample_to * viseme_frame_len_in_seconds)

# all audio will be padded to the following size
pad_len_in_secs = 4

# each resampled viseme label will correspond to an audio window of size X (in samples)
# where the viseme is centered on sample (X-1)/2, with X/2 samples to the left and X/2 samples to the right
# let's use 1 second
audio_window_in_secs = 1
audio_bins_per_viseme = 63
audio_window_in_samples = (audio_window_in_secs * resample_to)

num_mels=39

seq_length = math.ceil((pad_len_in_secs * resample_to) / viseme_frame_len_in_samples)
seq_length

120

In [6]:
batch_size = 10

process_audio = functools.partial(stft, 
                                    viseme_frame_len_in_samples=viseme_frame_len_in_samples, # this refers to the size of the viseme/audio window,
    audio_window_in_samples=audio_window_in_samples,
    audio_bins_per_window=audio_bins_per_viseme,
    resample_to=resample_to, 
    pad_len_in_secs=pad_len_in_secs,
    n_mels=num_mels)


process_viseme = functools.partial(preprocess_viseme, 
                                   pad_len_in_secs=pad_len_in_secs, 
                                   blendshapes=blendshapes, 
                                   resample_to=target_framerate)

training_data = VisemeDataset("./data/training/speaker_1/", 
                              process_audio, \
                              process_viseme)
test_data = VisemeDataset("./data/test/speaker_1/", 
                              process_audio, \
                              process_viseme)

train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=True)



In [112]:
class SeparableConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=False):
        super(SeparableConv1d, self).__init__()
        self.depthwise = nn.Conv1d(in_channels, in_channels, kernel_size=kernel_size, 
                               groups=in_channels, bias=bias, padding=1)
        self.pointwise = nn.Conv1d(in_channels, out_channels, 
                               kernel_size=1, bias=bias)
    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out
    
class Conv1dModel(nn.Module):
    def __init__(self, seq_length=633, n_ffts=2457, num_viseme=4, ks=16):
        super(Conv1dModel, self).__init__()
        self.conv = nn.Sequential(
            SeparableConv1d(seq_length,seq_length,ks),
            nn.ReLU(),
            #SeparableConv1d(2431,2431,ks),
            #nn.ReLU(),
        )
        self.attention = nn.MultiheadAttention(2444, 1, batch_first=True)
        #self.linear_relu_stack2 = nn.Sequential(
        #    SeparableConv1d(seq_length,seq_length,ks),
        #    #nn.Linear(seq_length, seq_length),
        #    nn.ReLU(),
        #)
        self.linear_out = nn.Linear(2444, num_viseme)

    def forward(self, x):
        o1 = self.conv(x)
        attn_output, attn_output_weights = self.attention(o1, o1, o1)
        #o1 = torch.sum(attn_output, 2, keepdim=True)
        #o1 = self.linear_relu_stack2(o1)
        #print(o1.shape)
        o1 = self.linear_out(o1)
        return o1
        

In [113]:
class BiLSTMModel(nn.Module):
    def __init__(self, seq_length=633, n_ffts=2457, embed_dim=512, num_viseme=4):
        super(BiLSTMModel, self).__init__()
        self.embed_dim = embed_dim
        self.lstm = torch.nn.LSTM(n_ffts, embed_dim, 3, bidirectional=False)
        #self.attention = nn.MultiheadAttention(embed_dim, 1, batch_first=True)
        #self.attention = nn.Linear
        
        self.linear_out = nn.Sequential(
            #nn.ReLU(),
            nn.Linear(embed_dim, num_viseme),
          #  nn.Linear(embed_dim, num_viseme),
            #nn.ReLU(),
            #nn.Linear(num_viseme, num_viseme),

            )
    def forward(self, x):
        out_f, _ = self.lstm(x)
        #out_f = out_f[:,:,:self.embed_dim] + out_f[:,:,self.embed_dim:]
        #attn_output, attn_output_weights = self.attention(out_f, out_f, out_f)
        #product = out_f * attn_output        
        return self.linear_out(out_f)
        #print(out_f)
        #print(attn_output)
        #a_s = torch.sum(attn_output, 1)
        #return self.linear_out(attn_output, a_s)
        
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

model = Conv1dModel(seq_length=seq_length-1, n_ffts=2457).to(device)
#model = BiLSTMModel(seq_length=seq_length,embed_dim=512, n_ffts=2457).to(device)

Using cuda device


In [114]:
learning_rate = 0.0001
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
num_steps = 100000
print_loss_every = 50
eval_every = 500

batch = iter(train_dataloader)
train_features, train_mask , train_labels, _ = next(batch, (None,None,None,None))
for t in range(num_steps):
        
    if train_features is None:
        batch = iter(train_dataloader)
        train_features, train_mask , train_labels, _ = next(batch)
        
    train_mask = torch.unsqueeze(train_mask, 2)
    x = train_features.to(device)      
    y = train_labels.to(device) #* train_mask.to(device)

    preds = model(x) #* train_mask.to(device)
    
    
    loss = torch.nn.functional.huber_loss(preds, y)
    for i in range(preds.shape[1] - 1):
        loss += torch.nn.functional.cosine_embedding_loss(preds[:,i,:], preds[:,i+1,:], (torch.ones(preds.shape[0])).to(device))
    
    if t % print_loss_every == 0:
        print(f"Step {t} Loss: {loss.item()}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if t > 0 and t % eval_every == 0:
        accum_loss = 0
        for test_features, test_mask, test_labels, _ in iter(test_dataloader):
            x = test_features.to(device)
            #x = torch.transpose(x, 1, 2)
            #x = torch.unsqueeze(x, dim=3)
            y = test_labels.to(device)
            preds = model(x)
            accum_loss += torch.nn.functional.mse_loss(preds, y).item()
        print(f"Test loss {accum_loss}")
    
#pred_probab = nn.Softmax(dim=1)(logits)
#y_pred = pred_probab.argmax(1)
#print(f"Predicted class: {y_pred}")

Step 0 Loss: 54.84494400024414
Step 50 Loss: 8.735651016235352
Step 100 Loss: 7.009223461151123
Step 150 Loss: 3.5436196327209473
Step 200 Loss: 3.000743865966797
Step 250 Loss: 4.123295307159424
Step 300 Loss: 5.30915641784668
Step 350 Loss: 9.18289852142334
Step 400 Loss: 2.0758512020111084
Step 450 Loss: 1.3796993494033813
Step 500 Loss: 2.407243251800537
Test loss 79.16479682922363
Step 550 Loss: 1.1363482475280762
Step 600 Loss: 1.4173367023468018
Step 650 Loss: 2.304938316345215
Step 700 Loss: 2.226378917694092
Step 750 Loss: 1.9173606634140015
Step 800 Loss: 1.9451627731323242
Step 850 Loss: 1.873793363571167
Step 900 Loss: 2.0740981101989746
Step 950 Loss: 1.684860348701477
Step 1000 Loss: 1.2349035739898682
Test loss 25.090959668159485
Step 1050 Loss: 0.8629779815673828
Step 1100 Loss: 0.6628715991973877
Step 1150 Loss: 0.6930238008499146
Step 1200 Loss: 0.5193652510643005
Step 1250 Loss: 0.5513647794723511
Step 1300 Loss: 1.027346134185791
Step 1350 Loss: 1.188471794128418
St

KeyboardInterrupt: 

In [115]:
test_features, test_mask, test_labels, test_files = next(iter(test_dataloader))
test_features, test_mask, test_labels, test_files = next(iter(train_dataloader))
##test_features = torch.transpose(test_features, 1, 2)

export_y_batch = model(test_features.to(device)) 
export_y = export_y_batch[0,:,:]
print(export_y.shape)
header = "Timecode,BlendShapeCount,eyeBlinkRight,eyeLookDownRight,eyeLookInRight,eyeLookOutRight,eyeLookUpRight,eyeSquintRight,eyeWideRight,eyeBlinkLeft,eyeLookDownLeft,eyeLookInLeft,eyeLookOutLeft,eyeLookUpLeft,eyeSquintLeft,eyeWideLeft,jawForward,jawRight,jawLeft,jawOpen,mouthClose,mouthFunnel,mouthPucker,mouthRight,mouthLeft,mouthSmileRight,mouthSmileLeft,mouthFrownRight,mouthFrownLeft,mouthDimpleRight,mouthDimpleLeft,mouthStretchRight,mouthStretchLeft,mouthRollLower,mouthRollUpper,mouthShrugLower,mouthShrugUpper,mouthPressRight,mouthPressLeft,mouthLowerDownRight,mouthLowerDownLeft,mouthUpperUpRight,mouthUpperUpLeft,browDownRight,browDownLeft,browInnerUp,browOuterUpRight,browOuterUpLeft,cheekPuff,cheekSquintRight,cheekSquintLeft,noseSneerRight,noseSneerLeft,tongueOut,HeadYaw,HeadPitch,HeadRoll,LeftEyeYaw,LeftEyePitch,LeftEyeRoll,RightEyeYaw,RightEyePitch,RightEyeRoll".split(',')
selected_output_indices = [header.index(x[0].lower() + x[1:]) for x in blendshapes]
num_visemes = len(blendshapes)
with open("output.csv", "w") as outfile:
    outfile.write(",".join(header) + "\n")
    timer_ms = 0
    for t in range(export_y.shape[0]):
        output = [str(0)] * len(header)
        second = str(int(timer_ms // 1000)).zfill(2)
        frame = (timer_ms % 1000) * target_framerate / 1000
        output[0] = f"00:00:{second}:{frame}"
        for viseme in range(num_visemes): 
            output[selected_output_indices[viseme]] = str(export_y[t,viseme].item())
        timer_ms += (1 / target_framerate) * 1000
        outfile.write(",".join(output) + "\n")

torch.Size([119, 4])


In [None]:
export_y.shape

In [38]:
test_files[0]

'data/training/speaker_1/20210826_2/75.csv'

In [19]:
new_header = "Timecode,BlendShapeCount,EyeBlinkLeft,EyeLookDownLeft,EyeLookInLeft,EyeLookOutLeft,EyeLookUpLeft,EyeSquintLeft,EyeWideLeft,EyeBlinkRight,EyeLookDownRight,EyeLookInRight,EyeLookOutRight,EyeLookUpRight,EyeSquintRight,EyeWideRight,JawForward,JawRight,JawLeft,JawOpen,MouthClose,MouthFunnel,MouthPucker,MouthRight,MouthLeft,MouthSmileLeft,MouthSmileRight,MouthFrownLeft,MouthFrownRight,MouthDimpleLeft,MouthDimpleRight,MouthStretchLeft,MouthStretchRight,MouthRollLower,MouthRollUpper,MouthShrugLower,MouthShrugUpper,MouthPressLeft,MouthPressRight,MouthLowerDownLeft,MouthLowerDownRight,MouthUpperUpLeft,MouthUpperUpRight,BrowDownLeft,BrowDownRight,BrowInnerUp,BrowOuterUpLeft,BrowOuterUpRight,CheekPuff,CheekSquintLeft,CheekSquintRight,NoseSneerLeft,NoseSneerRight,TongueOut,HeadYaw,HeadPitch,HeadRoll,LeftEyeYaw,LeftEyePitch,LeftEyeRoll,RightEyeYaw,RightEyePitch,RightEyeRoll"
remap = {h:(h[0].lower() + h[1:]) if h not in ["Timecode","BlendShapeCount","HeadYaw","HeadPitch","HeadRoll","LeftEyeYaw","LeftEyePitch","LeftEyeRoll","RightEyeYaw","RightEyePitch","RightEyeRoll"]  else h for h in new_header.split(",") }
for oh in remap.values():
    if oh not in header:
        print(oh)

print(remap)
def new_to_old(csv_df):
    print(csv_df)
    return csv_df.rename(columns=remap)
df = preprocess_viseme("data/training/speaker_1/20210824_1/61.csv", pad_len_in_secs=pad_len_in_secs, 
                                   resample_to=target_framerate)
#print(df)
new_to_old(df)[header].to_csv("original.csv", index=False)
df

{'Timecode': 'Timecode', 'BlendShapeCount': 'BlendShapeCount', 'EyeBlinkLeft': 'eyeBlinkLeft', 'EyeLookDownLeft': 'eyeLookDownLeft', 'EyeLookInLeft': 'eyeLookInLeft', 'EyeLookOutLeft': 'eyeLookOutLeft', 'EyeLookUpLeft': 'eyeLookUpLeft', 'EyeSquintLeft': 'eyeSquintLeft', 'EyeWideLeft': 'eyeWideLeft', 'EyeBlinkRight': 'eyeBlinkRight', 'EyeLookDownRight': 'eyeLookDownRight', 'EyeLookInRight': 'eyeLookInRight', 'EyeLookOutRight': 'eyeLookOutRight', 'EyeLookUpRight': 'eyeLookUpRight', 'EyeSquintRight': 'eyeSquintRight', 'EyeWideRight': 'eyeWideRight', 'JawForward': 'jawForward', 'JawRight': 'jawRight', 'JawLeft': 'jawLeft', 'JawOpen': 'jawOpen', 'MouthClose': 'mouthClose', 'MouthFunnel': 'mouthFunnel', 'MouthPucker': 'mouthPucker', 'MouthRight': 'mouthRight', 'MouthLeft': 'mouthLeft', 'MouthSmileLeft': 'mouthSmileLeft', 'MouthSmileRight': 'mouthSmileRight', 'MouthFrownLeft': 'mouthFrownLeft', 'MouthFrownRight': 'mouthFrownRight', 'MouthDimpleLeft': 'mouthDimpleLeft', 'MouthDimpleRight': 'mo

Unnamed: 0,Timecode,BlendShapeCount,EyeBlinkLeft,EyeLookDownLeft,EyeLookInLeft,EyeLookOutLeft,EyeLookUpLeft,EyeSquintLeft,EyeWideLeft,EyeBlinkRight,...,TongueOut,HeadYaw,HeadPitch,HeadRoll,LeftEyeYaw,LeftEyePitch,LeftEyeRoll,RightEyeYaw,RightEyePitch,RightEyeRoll
0,173:26:49:24.936,61,0.131980,0.355326,0.0,0.015279,0.0,0.018205,0.0,0.132631,...,0.000379,0.084126,0.022471,-1.538283,-0.074095,0.217107,-0.016328,-0.009114,0.217065,-0.002010
2,173:26:49:26.937,61,0.131097,0.352878,0.0,0.012558,0.0,0.018077,0.0,0.131835,...,0.000418,0.081196,0.022907,-1.539148,-0.072515,0.215608,-0.015867,-0.007494,0.215567,-0.001641
4,173:26:49:28.937,61,0.125842,0.351979,0.0,0.012761,0.0,0.018064,0.0,0.126697,...,0.000165,0.084483,0.026060,-1.537370,-0.072600,0.215060,-0.015843,-0.007616,0.215018,-0.001663
6,173:26:49:30.938,61,0.124457,0.351926,0.0,0.010018,0.0,0.018040,0.0,0.125371,...,0.000134,0.086911,0.026403,-1.538250,-0.070950,0.215025,-0.015481,-0.005979,0.214983,-0.001306
8,173:26:49:32.939,61,0.134261,0.355852,0.0,0.008626,0.0,0.017907,0.0,0.134898,...,0.000091,0.083985,0.023684,-1.539549,-0.070115,0.217422,-0.015475,-0.005145,0.217380,-0.001136
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24,0,0,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
25,0,0,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
26,0,0,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
27,0,0,0.000000,0.000000,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
