In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from lightning_trainer import UnetDACLighting
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import WandbLogger, TensorBoardLogger

from audio_dataset import DictTorchPartedDataset, PinDictTorchPartedDataset

from unet_dac import UnetDAC
import lightning as L

In [2]:
from config import NUM_MICS, ANGLE_RES


L_v = 96
K = 256
# INPUT_LEN = 64
# VIRTUAL_BATCH_SIZE = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UnetDAC(L=L_v, K=K, M=NUM_MICS).to(device)

lr = 1e-3
train_bs = 64
validation_bs = train_bs
model_name = f"unet_doa_batch{train_bs}_lr{lr:.0e}"

logger = TensorBoardLogger("tb_logs", name=model_name)

trainer = L.Trainer(max_epochs=100,
                    callbacks=[EarlyStopping(monitor="val_loss", mode="min", patience=3)],
                    default_root_dir=model_name,
                    log_every_n_steps=9,
                    logger=logger)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
train_dataset = PinDictTorchPartedDataset('data_batches', 'train06r076' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)
validation_dataset = PinDictTorchPartedDataset('data_batches', 'validation06r076' , ['samples', 'target'], real_batch_size=64, virtual_batch_size=1, device=device)

train_dataloader = DataLoader(train_dataset, batch_size=train_bs, shuffle=True, num_workers=4, persistent_workers=True, prefetch_factor=16)
valiadtion_dataloader = DataLoader(validation_dataset, batch_size=validation_bs, shuffle=True, num_workers=4, persistent_workers=True, prefetch_factor=16)

model_name = f"unet_doa_batch{train_bs}_lr{lr:.0e}"

criterion = nn.CrossEntropyLoss()
model_lighting = UnetDACLighting(model, criterion, lr)
# wandb_logger = WandbLogger(log_model="all", project='AudioDOA', name='bs=64,sig0.6 clean. 0.76 with reverb')

trainer.fit(model_lighting, train_dataloaders=train_dataloader, val_dataloaders=valiadtion_dataloader)
# trainer.test(model_lighting, dataloaders=test_dataloader)

KeyboardInterrupt: 

In [3]:
test_dataset = PinDictTorchPartedDataset('data_batches', 'test10r0235revrad' , ['samples', 'ref_stft', 'target', 'mixed_signals', 'perceived_signals'], real_batch_size=30, virtual_batch_size=1, device=device)
test_dataloader = DataLoader(test_dataset, batch_size=30, shuffle=False)

trainer.test(ckpt_path=model_name, dataloaders=test_dataloader)

TypeError: `Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run