In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from lightning_trainer import UnetDACLighting
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

from audio_dataset import DictTorchPartedDataset, PinDictTorchPartedDataset

from unet_dac import UnetDAC
import lightning as L

In [2]:
from config import NUM_MICS, ANGLE_RES


L_v = 96
K = 256
# INPUT_LEN = 64
# VIRTUAL_BATCH_SIZE = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnetDAC(L=L_v, K=K, M=NUM_MICS).to(device)

lr = 1e-3
train_bs = 64
validation_bs = train_bs
model_name = f"unet_doa_batch{train_bs}_lr{lr:.0e}"

logger = TensorBoardLogger("tb_logs", name=model_name)

trainer = L.Trainer(max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
                    default_root_dir=model_name,
                    log_every_n_steps=9,
                    logger=logger)

criterion = nn.CrossEntropyLoss()

model_lighting = UnetDACLighting.load_from_checkpoint('tb_logs/unet_doa_batch64_lr1e-03/version_3/checkpoints/epoch=32-step=3102.ckpt', model=model, loss_fn=criterion)
# model_lighting = UnetDACLighting(model, criterion, lr)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
train_dataset = PinDictTorchPartedDataset('data_batches', 'train06r076' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)
validation_dataset = PinDictTorchPartedDataset('data_batches', 'validation06r076' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)

train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, persistent_workers=True, prefetch_factor=16)
valiadtion_dataloader = DataLoader(validation_dataset, batch_size=validation_bs, shuffle=True, num_workers=4, persistent_workers=True, prefetch_factor=16)

In [3]:
# model_lighting = UnetDACLighting(model, criterion, lr)


# wandb_logger = WandbLogger(log_model="all", project='AudioDOA', name='bs=64,sig0.6 clean. 0.76 with reverb')

trainer.fit(model_lighting, train_dataloaders=train_dataloader, val_dataloaders=valiadtion_dataloader)
# trainer.test(model_lighting, dataloaders=test_dataloader)

NameError: name 'train_dataloader' is not defined

In [4]:
test_dataset = PinDictTorchPartedDataset('data_batches', 'test10r0235revrad' , ['samples', 'ref_stft', 'target', 'mixed_signals', 'perceived_signals'], real_batch_size=60, virtual_batch_size=1, device=device)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False)

trainer.test(model_lighting, dataloaders=test_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\agadi\miniconda3\envs\audio_env\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing: |          | 0/? [00:00<?, ?it/s]

In [4]:
all_t_data = torch.load('data_batches/train06r076_42.pt')
all_t_data_small = {k: v[:2] for k,v in all_t_data.items()}
probs = model_lighting.model(all_t_data_small['samples'].cuda())
print(probs.shape)
print(probs.device)
print(all_t_data_small['samples'].device)

all_t_data_small['probs'] = probs.detach().cpu()
torch.save(all_t_data_small, 'samples_test1605_v2.pt')
print('\n'.join([f"{k}: {v.shape}" for k,v in all_t_data_small.items()]))

torch.Size([2, 13, 256, 96])
cuda:0
cpu
samples: torch.Size([2, 14, 256, 96])
ref_stft: torch.Size([2, 257, 96])
target: torch.Size([2, 256, 96])
perceived_signals: torch.Size([2, 2, 12160])
doas: torch.Size([2, 2])
probs: torch.Size([2, 13, 256, 96])
