<a href="https://colab.research.google.com/github/AndrewThanatos/Algorithms/blob/master/Simple_LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from scipy.io.wavfile import read
import matplotlib.pyplot as plt
import os
import random
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import cv2

In [None]:
# video import
from IPython.display import HTML
from base64 import b64encode

In [None]:
DOWNLOADS_PATH = 'drive/MyDrive/Downloads'
DATASET_PATH = 'drive/MyDrive/Downloads/Model_Data/Dataset/'
VAL_DATASET_PATH = 'drive/MyDrive/Downloads/Model_Data_Eval/Dataset/'
ROOT_PATH = 'drive/MyDrive/Diploma/Audio_to_Keypoints'

TEST_SPLIT = 0.1
FPS = 25                # frame per second
SPS = 16000             # samples per second
SPF = SPS // FPS         # samples per frame

IMAGE_SIZE = 224

KP_SIZE = 136

ENCODER_HIDDEN_SIZE = 1024
DECODER_HIDDEN_SIZE = 1024

SAMPLE_NUM_PER_TRAIN = 25

DEVICE = torch.device('cuda')

In [None]:
BATCH_SIZE = 1
LR = 0.001
EPOCH = 30

In [None]:
def load_video(path):
    video_cap = cv2.VideoCapture(path)
    success, image = video_cap.read()
    count = 0
    success = True
    frames = []
    while success:
        frames.append(image)
        success, image = video_cap.read()

    frames = np.array(frames)
    return frames

In [None]:
def get_samples(data_path, shuffle=True):
    result = [int(name) for name in os.listdir(data_path) if name.isnumeric()]
    if shuffle:
        np.random.shuffle(result)
    return result

In [None]:
class AudioKeypointDataset(Dataset):
    def __init__(self, samples, path, use_video=False):
        self.data = samples
        self.path = path
        self.use_video = use_video

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        full_path = self.path + str(self.data[index])
        keypoints = np.load(f'{full_path}/{self.data[index]}.npy')
        frames = keypoints.shape[0]
        keypoints = np.reshape(keypoints, (-1, KP_SIZE))
        # Prepare data
        keypoints = keypoints

        raw_audio = read(f'{full_path}/{self.data[index]}.wav')
        audio = raw_audio[1]
        # Prepare data
        audio = audio / 32767

        audio_pf = [audio[i * SPF: (i + 1) * SPF] for i in range(frames)]
        audio_np = np.array(audio_pf)
        audio_np = audio_np.reshape((frames, -1))
        result = {
            'keypoints': torch.tensor(keypoints, dtype=torch.float), 
            'audio': torch.tensor(audio_np, dtype=torch.float)
        }
        if self.use_video:
            result.update({'video': load_video(f'{full_path}/{self.data[index]}.mp4')})

        return result

In [None]:
samples = get_samples(DATASET_PATH)


test_len = round(len(samples) * (1 - TEST_SPLIT))
train_samples = samples[: test_len]
test_samples = samples[test_len: ]

In [None]:
print('Train len =', len(train_samples))
print('Test len =', len(test_samples))

In [None]:
train_data = AudioKeypointDataset(samples=train_samples, path=DATASET_PATH)
valid_data = AudioKeypointDataset(samples=test_samples, path=DATASET_PATH)

train_loader = DataLoader(train_data, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True)
valid_loader = DataLoader(valid_data, 
                          batch_size=BATCH_SIZE, 
                          shuffle=False)

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.hidden_dim = hidden_dim

        self.lstm = nn.LSTM(input_dim, hidden_dim)

    def forward(self, input_data, hidden):
        lstm_out, hidden = self.lstm(input_data, hidden)
        return lstm_out, hidden

    def initHidden(self):
        return [torch.zeros(1, 1, ENCODER_HIDDEN_SIZE, device=DEVICE), 
                torch.zeros(1, 1, ENCODER_HIDDEN_SIZE, device=DEVICE)
                ]

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.hidden_dim = hidden_dim

        self.lstm = nn.LSTM(input_dim, hidden_dim)
        self.output_fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, input_data, hidden):
        lstm_out, hidden = self.lstm(input_data, hidden)
        output = self.output_fc(lstm_out)
        return output, hidden

    def initHidden(self):
        return [
                torch.zeros(1, 1, DECODER_HIDDEN_SIZE, device=DEVICE),
                torch.zeros(1, 1, DECODER_HIDDEN_SIZE, device=DEVICE)
                ]

In [None]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, audio, target, teacher_force_ratio=0.5):
        
        hidden = encoder.initHidden()
        for i in range(audio.shape[1]):
            audio_sample = audio[0, i, :].view(1, 1, SPF)
            
            encoder_output, hidden = self.encoder(audio_sample, hidden)

        target_len = target.shape[1]
        outputs = torch.zeros(1, 0, KP_SIZE).to(DEVICE)
        kp_input = target[0, 0]
        output = kp_input

        for i in range(1, target_len):
            kp_input = target[0, i - 1] if random.random() < teacher_force_ratio else output
            output, hidden = self.decoder(kp_input.view(1, 1, KP_SIZE), hidden)
            outputs = torch.cat((outputs, output), 1)

        return outputs


In [None]:
encoder = Encoder(SPF, ENCODER_HIDDEN_SIZE).to(DEVICE)
decoder = Decoder(KP_SIZE, DECODER_HIDDEN_SIZE, KP_SIZE).to(DEVICE)
model = Seq2Seq(encoder, decoder).to(DEVICE)

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.MSELoss()

In [None]:
def fit(model, dataloader):
    print('Training')
    model.train()
    training_running_loss = 0.0

    for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        keypoints = batch['keypoints'].to(DEVICE)
        audio = batch['audio'].to(DEVICE)

        frames_amount = keypoints.shape[1] - 1

        optimizer.zero_grad()
        outputs = model(audio, keypoints)

        loss = torch.div(criterion(outputs, keypoints[:, 1:]), frames_amount)
        training_running_loss += loss
        loss.backward()
        optimizer.step()

    return training_running_loss / len(dataloader)

In [None]:
def validate(model, dataloader):
    print('Validating')
    model.eval()
    valid_running_loss = 0.0

    with torch.no_grad():
        for batch_idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
            keypoints = batch['keypoints'].to(DEVICE)
            audio = batch['audio'].to(DEVICE)

            frames_amount = keypoints.shape[1] - 1

            optimizer.zero_grad()
            outputs = model(audio, keypoints)

            loss = torch.div(criterion(outputs, keypoints[:, 1:]), frames_amount)
            valid_running_loss += loss

    return valid_running_loss / len(dataloader)

In [None]:
train_loss = []
valid_loss = []

val_epoch_loss = validate(model, valid_loader)
print(f'First Val Loss: {val_epoch_loss:.2f}')
print('-------------------------')
for epoch in range(EPOCH):
    print(f"Epoch {epoch + 1} of {EPOCH}")
    train_epoch_loss = fit(model, train_loader)
    val_epoch_loss = validate(model, valid_loader)
    train_loss.append(train_loss)
    valid_loss.append(val_epoch_loss)

    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': criterion,
    }, f"{ROOT_PATH}/models/audio_to_keypoints.pth")

    print(f'Training Loss: {train_epoch_loss:.2f}')
    print(f'Val Loss: {val_epoch_loss:.2f}')

In [None]:
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(valid_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()
print('DONE TRAINING')

In [None]:
def load_model(file_name):
    full_path = f'{ROOT_PATH}/models/{file_name}'

    encoder = Encoder(SPF, ENCODER_HIDDEN_SIZE).to(DEVICE)
    decoder = Decoder(KP_SIZE, DECODER_HIDDEN_SIZE, KP_SIZE).to(DEVICE)
    model = Seq2Seq(encoder, decoder).to(DEVICE)
    # load the model checkpoint
    print('Load model ->', full_path)
    checkpoint = torch.load(full_path)
    # load model weights state_dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model

In [None]:
def save_video(np_video, name='temp'):
    frames = len(np_video)
    path = f'{DOWNLOADS_PATH}/Temp/video_samples/{name}.webm'

    fourcc = cv2.VideoWriter_fourcc(*'VP90')
    out = cv2.VideoWriter(path, fourcc, FPS, (IMAGE_SIZE, IMAGE_SIZE))
    for frame in np_video:
        out.write(frame)
    out.release()

def add_keypoints_to_video(video, keypoints, color=(0, 0, 255)):
    frames = []

    for i, frame in enumerate(video):
        for points in keypoints[i].reshape(-1, 2):
            x, y = points[0], points[1]
            frame = cv2.circle(frame, (x, y), radius=1, color=color, thickness=-1)
        frames.append(frame)

    frames = np.array(frames, dtype='uint8')
    return frames

In [None]:
model = load_model('audio_to_keypoints.pth')

In [None]:
path = 'drive/MyDrive/Downloads/Temp/video_samples/temp.webm'
webm = open(path,'rb').read()
data_url = "data:video/webm;base64," + b64encode(webm).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/webm">
</video>
""" % data_url)

In [None]:
path = 'drive/MyDrive/Downloads/Temp/video_samples/temp.mp4'
mp4 = open(path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=400 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

In [None]:
SAMPLE_NUM = 2

samples = get_samples(VAL_DATASET_PATH)
val_dataset = AudioKeypointDataset(samples=samples[SAMPLE_NUM: SAMPLE_NUM + 1], 
                                   path=VAL_DATASET_PATH, 
                                   use_video=True)
val_loader = DataLoader(val_dataset, 
                        batch_size=BATCH_SIZE, 
                        shuffle=False)

new_sample = None
for sample in val_loader:
    audio = sample['audio'].to(DEVICE)
    keypoints = sample['keypoints'].to(DEVICE)
    result = model(audio, keypoints)

    new_keypoints = result.to('cpu').detach().numpy()

    new_sample = {
        'new_keypoints': new_keypoints[0],
        'video': sample['video'].detach().numpy()[0, 1:, :, :, :],
        'basic_keypoints': keypoints[0, 1:, :]
    }
    break

print(new_sample['basic_keypoints'].shape, new_sample['new_keypoints'].shape)
frames = add_keypoints_to_video(new_sample['video'], new_sample['basic_keypoints'])
frames = add_keypoints_to_video(new_sample['video'], new_sample['new_keypoints'], color=(0, 255, 0))

save_video(frames)