In [1]:
import pyroomacoustics as pra

import os
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

from einops import rearrange

from src.dataset import SignalDataset, TRUNetDataset
from src.loss import loss_tot
from models.model1d_stream_med_1dphm import TRUNet 

import matplotlib.pyplot as plt

In [2]:
DATA_DIR = os.path.join("..", "data", "data_thchs30")
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR = os.path.join(DATA_DIR, "test")
NOISE_DIR = os.path.join("..", "data", "test-noise", "noise", "white")

In [3]:
N_FFTS = 512
HOP_LENGTH = 256

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"It's {DEVICE} time!!!")

It's cuda time!!!


In [4]:
train_dataset = TRUNetDataset(TRAIN_DIR, sr=16_000, noise_dir=NOISE_DIR, snr=(-5, 25), return_noise=False, return_rir=False, max_seq_len=16_000 * 10, partition=100)
test_dataset = TRUNetDataset(TEST_DIR, sr=16_000, noise_dir=NOISE_DIR, snr=(-5, 25), return_noise=False, return_rir=False, max_seq_len=16_000 * 10, partition=100)

In [5]:
# sig, gt, noise_, rir_ = train_dataset[0]

In [6]:
# gt.shape, noise_.shape, rir_.shape

In [7]:
# type(rir_)

In [8]:
def vorbis_window(winlen, device="cuda"):
    sq = torch.sin(torch.pi/2*(torch.sin(torch.pi/winlen*(torch.arange(winlen)-0.5))**2)).float()
    return sq

In [9]:
def pad_sequence(batch):
    if not batch:
        return torch.zeros(0), torch.zeros(0)

    input_signal, target_signal, noise, rir = zip(*batch)
    # input_signal, target_signal, _, _ = zip(*batch)
    # input_signal = torch.cat(input_signal, dim=0).to(DEVICE)
    # target_signal = torch.cat(target_signal, dim=0).to(DEVICE)
    # noise = torch.cat(noise, dim=0).to(DEVICE)
    # rir = torch.cat(rir, dim=0).to(DEVICE)
        
    max_len_s = max(s.shape[-1] for s in input_signal)
    max_len_n = max(n.shape[-1] for n in noise)
    max_len_r = max(r.shape[-1] for r in rir)
    
    padded_input = torch.zeros(len(input_signal), max_len_s)
    padded_target = torch.zeros(len(target_signal), max_len_s)
    padded_noise = torch.zeros(len(noise), max_len_n, )
    padded_rir = torch.zeros(len(rir), max_len_r, )
    for i, s in enumerate(input_signal):
        padded_input[i, :s.shape[-1]] = s
        padded_target[i, :s.shape[-1]] = target_signal[i]
    for i, n in enumerate(padded_noise):
        padded_noise[i, :n.shape[-1]] = n
    for i, r in enumerate(padded_rir):
        padded_rir[i, :r.shape[-1]] = r

    return padded_input.to(DEVICE), padded_target.to(DEVICE), padded_noise.to(DEVICE), padded_rir.to(DEVICE)

def collate_fn(batch):
    
    # padded_input, padded_target, padded_noise, padded_rir = pad_sequence(batch)
    # padded_input, padded_target = pad_sequence(batch)
    padded_input, padded_target, padded_noise, padded_rir = zip(*batch)
    padded_input = torch.stack(padded_input, dim=0)# .to(DEVICE)
    padded_target = torch.stack(padded_target, dim=0)# .to(DEVICE)
    # padded_noise = torch.stack(padded_noise, dim=0)# .to(DEVICE)
    # padded_rir = torch.stack(padded_rir, dim=0)# .to(DEVICE)
    # batch_size, t = padded_input.shape
    
    padded_input = padded_input.unfold(-1, 16_000 * 2, 16_000)
    padded_target = padded_target.unfold(-1, 16_000 * 2, 16_000)
    # padded_noise = padded_noise.unfold(-1, 16_000 * 2, 16_000)
    # padded_rir = padded_rir.unfold(-1, 16_000 * 2, 16_000)
    
    window = vorbis_window(512)# .to(DEVICE)
    
    # input_spec = []
    # for s in padded_input:
    #     input_spec.append(torch.stft(
    #         s,
    #         n_fft=N_FFTS,
    #         hop_length=HOP_LENGTH,
    #         # onesided=True,
    #         win_length=512,
    #         window=window,
    #         return_complex=True,
    #         normalized=True,
    #         center=True
    #     ))
    # print(padded_input[-1].shape, input_spec[-1].shape)
    padded_input = padded_input.reshape(-1, padded_input.shape[-1])
    input_spec = torch.stft(
            padded_input,
            n_fft=N_FFTS,
            hop_length=HOP_LENGTH,
            # onesided=True,
            win_length=512,
            window=window,
            return_complex=True,
            normalized=True,
            center=False
        ) # torch.stack(input_spec)
    # print(input_spec.shape, padded_target.shape, padded_noise.shape, padded_rir.shape)
    # input_spec = input_spec.reshape(-1, input_spec.shape[-2], input_spec.shape[-1])
    padded_target = padded_target.reshape(-1, padded_target.shape[-1])
    # padded_noise = padded_noise.reshape(-1, padded_noise.shape[-1])
    # padded_rir = padded_rir.reshape(-1, padded_rir.shape[-1])

    return input_spec, padded_target#, padded_noise, padded_rir

In [10]:
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, drop_last=True, collate_fn=collate_fn, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, drop_last=False, collate_fn=collate_fn, num_workers=0)

In [11]:
fgru_hidden_size = 64
tgru_hidden_size = 128

count = 10


def train(model, train_loader, optimizer, device="cuda"):
    total_train_loss = 0
    
    # ind = 0
    
    for input_spec, gt_signal in tqdm(train_loader, desc="Train model "):
        optimizer.zero_grad()
        h_f = torch.randn(2, input_spec.shape[0] * input_spec.shape[-1], fgru_hidden_size, device=device)
        h_t = torch.randn(1, 16, tgru_hidden_size, device=device)

        input_spec = input_spec.to(device)
        gt_signal = gt_signal.to(device)
        # noise = noise.to(device)
        # rir = rir.to(device)
        # print("to_model")
        output_d, output_n, output_r, _, _ = model(input_spec.abs(), input_spec.real, input_spec.imag, h_f, h_t)
        # print(output_d.shape)
        # output_direct = torch.polar(output_abs, input_frame.angle())

        window = vorbis_window(512).to(device)
        out_wave = torch.istft(output_d, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
                               window=window,
                               # onesided=True,
                               return_complex=False,
                               normalized=True,
                               center=False)#, length=gt_signal.shape[-1])
        # out_noise = torch.istft(output_n, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
        #                        window=window,
        #                        return_complex=False,
        #                        normalized=True,
        #                        center=True, length=noise.shape[-1])
        # out_rir = torch.istft(output_r, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
        #                        window=window,
        #                        return_complex=False,
        #                        normalized=True,
        #                        center=True, length=rir.shape[-1])
        # print(out_wave.shape, gt_signal.shape)
        # print(out_wave.shape)
        # print(out_wave.shape, gt_signal.shape)
        loss = loss_tot(out_wave, gt_signal.squeeze(0))#, out_noise.squeeze(0), noise.squeeze(0), out_rir.squeeze(0), rir.squeeze(0)).squeeze(0)
            
        loss.backward()
        optimizer.step()
        total_train_loss += loss.detach().item()
        # ind += 1
        # if ind >= count:
        #     break
    
    return model, optimizer, total_train_loss / len(train_loader)
            
def evaluate(model, test_loader, device="cuda"):
    total_test_loss = 0
    
    with torch.no_grad():
        # ind = 0
        for input_spec, gt_signal in tqdm(test_loader, desc="Test model "):
            optimizer.zero_grad()
            h_f = torch.randn(2, input_spec.shape[0] * input_spec.shape[-1], fgru_hidden_size, device=device)
            h_t = torch.randn(1, 16, tgru_hidden_size, device=device)

            input_spec = input_spec.to(device)
            gt_signal = gt_signal.to(device)
            # noise = noise.to(device)
            # rir = rir.to(device)

            output_d, output_n, output_r, _, _ = model(input_spec.abs(), input_spec.real, input_spec.imag, h_f, h_t)

            # output_direct = torch.polar(output_abs, input_frame.angle())

            window = vorbis_window(512).to(device)
            out_wave = torch.istft(output_d, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
                                   window=window,
                                   return_complex=False,
                                   normalized=True,
                                   center=True, length=gt_signal.shape[-1])
            # out_noise = torch.istft(output_n, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
            #                        window=window,
            #                        return_complex=False,
            #                        normalized=True,
            #                        center=True, length=noise.shape[-1])
            # out_rir = torch.istft(output_r, n_fft=N_FFTS, hop_length=HOP_LENGTH, win_length=512,
            #                        window=window,
            #                        return_complex=False,
            #                        normalized=True,
            #                        center=True, length=rir.shape[-1])
            
            loss = loss_tot(out_wave, gt_signal.squeeze(0))# , out_noise.squeeze(0), noise.squeeze(0), out_rir.squeeze(0), rir.squeeze(0))

            total_test_loss += loss.detach().item()
            # ind += 1
            # if ind >= count:
            #     break
        
    return total_test_loss / len(test_loader)
    

def learning_loop(model, train_loader, val_loader, optimizer, epoch=10, device='cuda'):
    
    model = model.to(device)
    
    train_losses = []
    val_losses = []
    
    for epoch in range(epoch):
        
        model.train()
        
        model, optimizer, total_train_loss = train(model, train_loader, optimizer)
        
        train_losses.append(total_train_loss)
        
        model.eval()
        val_loss = evaluate(model, val_loader)
        
        val_losses.append(val_loss)
        
        print(f'Epoch {epoch}, Training Loss: {total_train_loss:.4f}, Validation Loss: {val_loss:.4f}')
        
    return model, train_losses, val_losses


In [12]:
from torch.optim import Adam

trunet = TRUNet()

optimizer = Adam(trunet.parameters(), lr=1e-4)

In [None]:
trunet, train_l, val_l = learning_loop(trunet, train_dataloader, test_dataloader, optimizer, device=DEVICE)

Train model :  36%|███▌      | 36/100 [02:32<08:32,  8.00s/it]

In [None]:
plt.plot(train_l, label='Train')

In [None]:
plt.plot(val_l, label='Validation')