In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from lightning_trainer import UnetDACLighting
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

from audio_dataset import DictTorchPartedDataset, PinDictTorchPartedDataset

from unet_dac import UnetDAC
import lightning as L

In [2]:
from config import NUM_MICS, ANGLE_RES


L_v = 96
K = 256
# INPUT_LEN = 64
# VIRTUAL_BATCH_SIZE = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnetDAC(L=L_v, K=K, M=NUM_MICS).to(device)

lr = 1e-3
train_bs = 64
validation_bs = train_bs
model_name = f"unet_doa_batch{train_bs}_lr{lr:.0e}_v4"

logger = TensorBoardLogger("tb_logs", name=model_name)

criterion = nn.CrossEntropyLoss()

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Training #

In [3]:
model_lighting = UnetDACLighting(model, criterion, lr)

trainer = L.Trainer(max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=5)],
                    default_root_dir=model_name,
                    log_every_n_steps=9,
                    logger=logger, auto_lr_find=True)

In [4]:
train_dataset = PinDictTorchPartedDataset('data_batches', 'train06r076v3' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)
validation_dataset = PinDictTorchPartedDataset('data_batches', 'validation06r076v3' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)

train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, persistent_workers=True, prefetch_factor=16)
valiadtion_dataloader = DataLoader(validation_dataset, batch_size=validation_bs, shuffle=False, num_workers=1, persistent_workers=True, prefetch_factor=16)

In [5]:
# model_lighting = UnetDACLighting(model, criterion, lr)
# model_lighting = UnetDACLighting.load_from_checkpoint('tb_logs/unet_doa_batch64_lr1e-03_v4/version_1/checkpoints/epoch=42-step=4042.ckpt', model=model, loss_fn=criterion)
# model_lighting = UnetDACLighting.load_from_checkpoint('tb_logs/unet_doa_batch64_lr1e-03_v4/version_12/checkpoints/epoch=4-step=470.ckpt', model=model, loss_fn=criterion)
model_lighting = UnetDACLighting.load_from_checkpoint('tb_logs/unet_doa_batch64_lr1e-03_v4/version_13/checkpoints/epoch=5-step=564.ckpt', model=model, loss_fn=criterion)

model_lighting.lr = 1e-4

# wandb_logger = WandbLogger(log_model="all", project='AudioDOA', name='bs=64,sig0.6 clean. 0.76 with reverb')

trainer.fit(model_lighting, train_dataloaders=train_dataloader, val_dataloaders=valiadtion_dataloader)
# trainer.test(model_lighting, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\agadi\miniconda3\envs\audio_env\lib\site-packages\lightning\pytorch\core\optimizer.py:257: Found unsupported keys in the lr scheduler dict: {'verbose', 'patience', 'mode', 'factor'}. HINT: remove them from the output of `configure_optimizers`.

  | Name    | Type             | Params
---------------------------------------------
0 | model   | UnetDAC          | 1.9 M 
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.772     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\agadi\miniconda3\envs\audio_env\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

# Testing #

In [7]:
model_lighting = UnetDACLighting.load_from_checkpoint('tb_logs/unet_doa_batch64_lr1e-03_v4/version_12/checkpoints/epoch=4-step=470.ckpt', model=model, loss_fn=criterion)

In [6]:
test_dataset = PinDictTorchPartedDataset('data_batches', 'test2r0168v4' , ['samples', 'ref_stft', 'target', 'mixed_signals', 'perceived_signals', 'radii', 'reverbs'], real_batch_size=60, virtual_batch_size=1, device=device)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [7]:
trainer.test(model_lighting, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\agadi\miniconda3\envs\audio_env\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'mix_rad1.0_rev0.16_0_0_epoch': 0.119950070977211,
  'mix_rad1.0_rev0.16_0_1_epoch': inf,
  'mix_rad1.0_rev0.16_1_0_epoch': 0.2224656492471695,
  'mix_rad1.0_rev0.16_1_1_epoch': inf,
  'sep_rad1.0_rev0.16_0_0_epoch': -7.002167224884033,
  'sep_rad1.0_rev0.16_0_1_epoch': inf,
  'sep_rad1.0_rev0.16_1_0_epoch': -4.04954719543457,
  'sep_rad1.0_rev0.16_1_1_epoch': inf,
  'mix_rad1.0_rev0.36_0_0_epoch': -0.1807350218296051,
  'mix_rad1.0_rev0.36_0_1_epoch': inf,
  'mix_rad1.0_rev0.36_1_0_epoch': 0.4670673906803131,
  'mix_rad1.0_rev0.36_1_1_epoch': inf,
  'sep_rad1.0_rev0.36_0_0_epoch': -3.1569104194641113,
  'sep_rad1.0_rev0.36_0_1_epoch': inf,
  'sep_rad1.0_rev0.36_1_0_epoch': -0.8480923771858215,
  'sep_rad1.0_rev0.36_1_1_epoch': inf,
  'mix_rad2.0_rev0.16_0_0_epoch': 0.16430231928825378,
  'mix_rad2.0_rev0.16_0_1_epoch': inf,
  'mix_rad2.0_rev0.16_1_0_epoch': 0.1810220181941986,
  'mix_rad2.0_rev0.16_1_1_epoch': inf,
  'sep_rad2.0_rev0.16_0_0_epoch': -4.702014446258545,
  'sep_rad2.0_

In [None]:
# all_t_data = torch.load('data_batches/train06r076_42.pt')
# all_t_data_small = {k: v[:2] for k,v in all_t_data.items()}
# probs = model_lighting.model(all_t_data_small['samples'].cuda())
# print(probs.shape)
# print(probs.device)
# print(all_t_data_small['samples'].device)

# all_t_data_small['probs'] = probs.detach().cpu()
# torch.save(all_t_data_small, 'samples_test1605_v2.pt')
# print('\n'.join([f"{k}: {v.shape}" for k,v in all_t_data_small.items()]))

In [None]:
# from metrics import SeparatedSource

# for i, batch in enumerate(test_dataset):
#     samples, ref_stft, target, mixed_signals, perceived_signals = batch
#     samples = samples.to('cuda', dtype=torch.float)
#     print(samples.shape)
#     probs = model(samples)
#     ref_spec = ref_stft.detach().cpu().numpy()
#     samp_probs = probs.detach().cpu().numpy()
#     sep_src = SeparatedSource(ref_spec[1:], samp_probs)
#     sep_src.save(f"sep_{i}.wav")

In [None]:
from metrics import SeparatedSource
import sounddevice as sd
import torch.nn.functional as F
i = iter(test_dataloader)
_ = next(i)
samples, ref_stft, target, mixed_signals, perceived_signals = next(i)
samples = samples.to('cuda', dtype=torch.float)  # (B,S,V)
target = target.to('cuda', dtype=torch.long)
ref_stft = ref_stft.cuda()
mixed_signals = mixed_signals.cuda()
perceived_signals = perceived_signals.cuda()

print(samples.shape)
print(ref_stft.shape)
print(target.shape)
print(mixed_signals.shape)
print(perceived_signals.shape)

idx = 0
target_onehot = F.one_hot(target[idx] // ANGLE_RES, 13).permute(2, 0, 1)
speaker_sources_pred = SeparatedSource.speakers(ref_stft[idx,1:,:], target_onehot)
first_speaker = speaker_sources_pred[0].signal().cpu()
second_speaker = speaker_sources_pred[1].signal().cpu()
import scipy
# print(first_speaker.shape)
scipy.io.wavfile.write('test2_gt_first_speaker.wav', 16000, first_speaker.numpy())
scipy.io.wavfile.write('test2_gt_second_speaker.wav', 16000, second_speaker.numpy())
scipy.io.wavfile.write('test2_mixed.wav', 16000, mixed_signals[idx][3].detach().cpu().numpy())
print((target[idx] // ANGLE_RES).unique())

In [None]:
sd.play(speaker_sources_pred[1].signal().cpu(), 16000)
