In [1]:
import librosa
import librosa.display
from numba import core
import matplotlib.pyplot as plt
import numpy as np

In [2]:
def padd_with_zeros(twod_numpy, w, h):
    old_w = twod_numpy.shape[0]
    old_h = twod_numpy.shape[1]
    h = max(old_h, h)
    w = max(old_w, w)
    lx = (w - old_w) // 2
    rx = w - lx - old_w
    ly = (h - old_h) // 2
    ry = h - ly - old_h
    return np.pad(twod_numpy, pad_width=((lx, rx), (ly, ry)), mode='constant')
def process_input_file(filename):

    waveform, sample_rate = librosa.load(filename)

    n_fft = 512  # Number of FFT points (window size)
    hop_length = 1024  # Hop length (frame shift)

    # Compute the STFT
    stft = librosa.stft(waveform, n_fft=n_fft, hop_length=hop_length)

    padded = padd_with_zeros(stft, 260, 90)
    return  padded


# filename='./data/LydoK7hXKbs-audio/audio0-10-23.00.wav'
filename='./data/a_oqcg0hvpo-audio/audio0-55-27.00.wav'
#define the beginning time of the signal

print(librosa.__version__)

data_point = process_input_file(filename)

0.10.0.post2


In [3]:
import os
import random
import torch
video_files=[
    'a_oqcg0hvpo',
    'Gz99TTxmvls',
    'ubz5lz_l7IY',
    'LydoK7hXKbs',
]

tensors = {}

def get_prediction(video_name, time_stamp):
    global tensors
    local_name = f'{video_name}{time_stamp[:3]}'
    if local_name not in tensors:
        t = torch.load(f'./resnet_predictions/{video_name}/{local_name}.pth')
        tensors[local_name] = t
    else:
        t = tensors[local_name]
    return t[f'frame{time_stamp}']


DATA_PER_FILE = 50

X = []
Y = []

for video_name in video_files:
    subdir = os.listdir('./data/' + video_name + '-audio')
    files = random.choices(subdir, k=DATA_PER_FILE)
    for file in files:
        time_stamp = file[5:file.find('.')]
        label = video_name + '-' + time_stamp
        audio_data = process_input_file(f'./data/{video_name}-audio/{file}')
        classification_res = get_prediction(video_name, time_stamp)

        X.append(np.abs(audio_data))
        Y.append(classification_res)
    print(video_name)


a_oqcg0hvpo
Gz99TTxmvls
ubz5lz_l7IY
LydoK7hXKbs


In [4]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
X_train, X_test, y_train, y_test = train_test_split(X, Y, train_size=0.8, shuffle=True)
BATCH_SIZE = 40
train_loader = DataLoader(list(zip(X_train, y_train)), shuffle=True, batch_size=40)

In [5]:
from sklearn.preprocessing import normalize

import torch
import torch.nn as nn
import scipy.io.wavfile as wavfile
from scipy import signal

class AutoEncoder(nn.Module):
    def __init__(self, input_shape):
        super().__init__()

        self.encoder = nn.Sequential( # 260 x 90
                nn.Conv2d(1, 16, stride=(1, 1), kernel_size=(3, 3), padding='same'),
                nn.LeakyReLU(0.01),
                # nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), return_indices=True),
                nn.Conv2d(16, 32, stride=(1, 1), kernel_size=(3, 3), padding='same'),
                nn.LeakyReLU(0.01),
                nn.Dropout(0.2),
                # nn.MaxPool2d(kernel_size=(2, 2), stride=(1, 1), return_indices=True),
                # nn.Conv2d(32, 32, stride=(1, 1), kernel_size=(3, 3), padding='same'),
                # nn.LeakyReLU(0.01),
                nn.Dropout(0.2),
                nn.Flatten(),
                nn.Linear(748800, 1000)
        )
        self.classification = nn.Sequential(
                nn.Softmax(1)
        )
        self.decoder = nn.Sequential(
                nn.Linear(1000, 748800),
                nn.Unflatten(1, (64, 260, 90)),
                nn.LeakyReLU(0.01),
                # nn.ConvTranspose2d(64, 64, stride=(1, 1), kernel_size=(3, 3), padding=1),
                # nn.MaxUnpool2d(kernel_size=(3, 3), stride=(1, 1), padding=1),
                nn.ConvTranspose2d(32, 16, stride=(1, 1), kernel_size=(3, 3), padding=1),
                # nn.MaxUnpool2d(kernel_size=(3, 3), stride=(1, 1), padding=1),
                nn.LeakyReLU(0.01),
                nn.ConvTranspose2d(16, 1, stride=(1, 1), kernel_size=(3, 3), padding=1),
                nn.LeakyReLU(0.01),
                nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        classified = self.classification(encoded)
        decoded = self.decoder(encoded)
        return decoded, classified


In [6]:
import torch.optim as optim
from torchvision.transforms import ToTensor
from tqdm import tqdm


# Initialize the Autoencoder
autoencoder = AutoEncoder((260, 90))

# Define the loss function
reconstruction_loss_fn = nn.MSELoss()
classication_loss_fn = nn.CrossEntropyLoss()
# Define the optimizer
optimizer = optim.Adam(autoencoder.parameters(), lr=0.01)


# Training loop
num_epochs = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder.to(device)

import time

print("hello")
start = time.time()

for epoch in range(num_epochs):
    total_loss = 0.0

    # Iterate over the training dataset
    for X_batch, y_batch in tqdm(train_loader):
        print(X_batch.shape)
        X_batch = X_batch.view(BATCH_SIZE, -1, 260, 90).to(device)
        print(X_batch.shape)
        y_batch = y_batch.to(device)

        # Forward pass
        X_pred, y_pred = autoencoder(X_batch)

        # Compute the loss
        constr_loss = reconstruction_loss_fn(X_pred, X_batch)
        classi_loss = classication_loss_fn(y_pred, y_batch)
        loss = constr_loss + classi_loss
        total_loss += loss.item()

        # Backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Print average loss for the epoch
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {average_loss}")

    end = time.time()
    dur = end - start
    if dur > 3600:
        print(f"Elapsed {epoch} number of epochs.")
        num_epochs = epoch
        break

weight_file_name = f'model_{num_epochs}_{len(train_loader)}'
torch.save(autoencoder.state_dict(), weight_file_name)

# # Example usage after training
# test_dataset = MNIST(root='data', train=False, transform=ToTensor(), download=True)
# test_loader = DataLoader(test_dataset, batch_size=10, shuffle=True)

# with torch.no_grad():
#     for batch in test_loader:
#         images, _ = batch
#         images = images.view(-1, 784).to(device)
#         reconstructed = autoencoder(images)

#         # Perform any further processing or visualization with the reconstructed images

hello


  0%|          | 0/40 [00:00<?, ?it/s]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


  2%|▎         | 1/40 [09:39<6:16:53, 579.83s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


  5%|▌         | 2/40 [25:57<8:35:18, 813.64s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


  8%|▊         | 3/40 [46:02<10:11:58, 992.40s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 10%|█         | 4/40 [59:27<9:11:08, 918.57s/it] 

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 12%|█▎        | 5/40 [1:09:34<7:50:20, 806.29s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 15%|█▌        | 6/40 [1:19:33<6:56:54, 735.71s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 18%|█▊        | 7/40 [1:29:37<6:20:57, 692.65s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 20%|██        | 8/40 [1:39:09<5:48:55, 654.24s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 22%|██▎       | 9/40 [1:48:32<5:23:19, 625.79s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 25%|██▌       | 10/40 [1:58:43<5:10:37, 621.25s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 28%|██▊       | 11/40 [2:08:42<4:56:57, 614.40s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 30%|███       | 12/40 [2:19:01<4:47:18, 615.67s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 32%|███▎      | 13/40 [2:28:46<4:32:53, 606.43s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 35%|███▌      | 14/40 [2:40:14<4:33:30, 631.17s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 38%|███▊      | 15/40 [2:49:41<4:14:51, 611.64s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 40%|████      | 16/40 [3:00:06<4:06:14, 615.62s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 42%|████▎     | 17/40 [3:10:56<4:00:00, 626.10s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 45%|████▌     | 18/40 [3:21:49<3:52:33, 634.23s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 48%|████▊     | 19/40 [3:32:05<3:40:02, 628.70s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 50%|█████     | 20/40 [3:42:41<3:30:19, 631.00s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 52%|█████▎    | 21/40 [3:52:50<3:17:43, 624.39s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 55%|█████▌    | 22/40 [4:03:05<3:06:25, 621.43s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 57%|█████▊    | 23/40 [4:13:18<2:55:20, 618.88s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 60%|██████    | 24/40 [4:24:16<2:48:09, 630.62s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 62%|██████▎   | 25/40 [4:34:12<2:35:05, 620.36s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 65%|██████▌   | 26/40 [4:43:47<2:21:34, 606.77s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 68%|██████▊   | 27/40 [4:53:58<2:11:43, 607.94s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 70%|███████   | 28/40 [5:04:17<2:02:16, 611.35s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 72%|███████▎  | 29/40 [5:14:12<1:51:11, 606.46s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 75%|███████▌  | 30/40 [5:24:12<1:40:44, 604.47s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 78%|███████▊  | 31/40 [5:34:35<1:31:28, 609.86s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 80%|████████  | 32/40 [5:46:13<1:24:51, 636.38s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 82%|████████▎ | 33/40 [5:58:24<1:17:33, 664.83s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 85%|████████▌ | 34/40 [6:10:18<1:07:56, 679.45s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 88%|████████▊ | 35/40 [6:22:22<57:45, 693.03s/it]  

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 90%|█████████ | 36/40 [6:33:55<46:11, 692.95s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 92%|█████████▎| 37/40 [6:43:57<33:16, 665.48s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 95%|█████████▌| 38/40 [6:55:19<22:21, 670.71s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


 98%|█████████▊| 39/40 [7:06:23<11:08, 668.63s/it]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


100%|██████████| 40/40 [7:16:48<00:00, 655.21s/it]


Epoch 1/1000, Average Loss: 7.882465565204621
Elapsed 0 number of epochs.


  0%|          | 0/40 [00:00<?, ?it/s]

torch.Size([40, 260, 90])
torch.Size([40, 1, 260, 90])


  0%|          | 0/40 [04:35<?, ?it/s]


KeyboardInterrupt: 

In [7]:
weight_file_name = f'model_{num_epochs}_{len(train_loader)}'
torch.save(autoencoder.state_dict(), weight_file_name)

In [None]:
import torchaudio.transforms as transforms
import torchaudio
import scipy
with torch.no_grad():
    # Pass the test data through the model
    test_data = train_dataset[0]
    test_data = test_data[np.newaxis, ...]
    test_data = torch.from_numpy(test_data)
    reconstructed_stft, classified = autoencoder.forward(test_data)
    
    test_output_stft = reconstructed_stft[:, 1: -2, 1: -2]

    n_fft = 512  # Number of FFT points (window size)
    hop_length = 1024  # Hop length (frame shift)

    # Create the STFT and iSTFT transforms
    transform_stft = transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length)
    transform_istft = transforms.GriffinLim(n_fft=n_fft, hop_length=hop_length)

    # Reconstruct audio from the output STFT batch
    # test_reconstructed_waveform= transform_istft(test_output_stft)
    t, x = scipy.signal.istft(test_output_stft)
    a, b = scipy.io.wavfile.read(audio_filenames[0])
    print(b.shape)
    x = x.transpose()
    print(x.shape)
    # Save the test examples
    scipy.io.wavfile.write('tmp.wav', 6000, x)