In [1]:
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
import lightning as L
import sys
from lightning.pytorch.callbacks import LearningRateMonitor

# sys.path.append('../lightning_scripts/')
import importlib
import yaml
import torch

In [2]:
## init config. Will be yaml eventually, but start as dict 
config_path = "model_configs/pilot_ssl_word_resnet50.yaml"
config = yaml.load(open(config_path, 'r'), Loader=yaml.FullLoader)
config['num_workers'] = 4 
config['hparas']['batch_size'] = 32

In [3]:
config

{'data': {'root': '/mnt/ceph/users/jfeather/data/training_datasets_audio/JSIN_all_v3/subsets/'},
 'audio_rep': {'name': 'cochleagram_1', 'on_gpu': True},
 'audio_transforms': {'low_snr': -10, 'high_snr': 10, 'dbspl': 60},
 'val_metric': {'word_task': 'val_signal/word_int_acc'},
 'hparas': {'epochs': 10,
  'batch_size': 32,
  'optimizer': 'LARS',
  'lr': 0.2,
  'num_warmup_steps_or_ratio': 0.1,
  'lambda_ssl': 1,
  'valid_step': 5000,
  'ssl_task': 'word',
  'ssl_loss_str': 'mmcr',
  'ssl_loss': 'MMCR',
  'ssl_loss_kwargs': {'lmbda': 0.0}},
 'model': {'arch_name': 'SSLAudioModel',
  'arch_kwargs': {'projector_dims': [512, 512],
   'proj_out_dim': 2048,
   'n_classes': 794,
   'supervised': False}},
 'num_workers': 4}

In [4]:
### make sure dataloader workds 
import lightning_scripts.jsinV3DataLoader_precombined_batched as jsin_batched 
import robustness.audio_functions.audio_transforms as at 

importlib.reload(jsin_batched)

jsinV3_precombined_paired = jsin_batched.jsinV3_precombined_paired

transforms = at.AudioCompose([
                at.AudioToTensor(),
                at.CombineWithRandomDBSNR(low_snr=config['audio_transforms']['low_snr'],
                                          high_snr=config['audio_transforms']['high_snr']),
                at.DBSPLNormalizeForegroundAndBackground(dbspl=config['audio_transforms']['dbspl']),
                at.UnsqueezeAudio(dim=0) # dim=0 here so batches of audio from dataloader will be (Batch, 1, Time)
            ])

train_dset = jsinV3_precombined_paired(root=config['data']['root'], train=True, transform=transforms, batch_size=64)
loader = torch.utils.data.DataLoader(
            train_dset,
            batch_size=1,
            num_workers=0, 
            pin_memory=True,
            # persistent_workers=True,
            shuffle=False,
        )


In [13]:
batched_eg = next(iter(loader))

In [18]:
import IPython.display as ipd

In [22]:
ipd.Audio(batched_eg[1].squeeze(), rate=20_000)

In [6]:
total = 100 
for batch in loader:
    if total == 0:
        break
    total -= 1 

### Compare to unbatched version

In [7]:
from robustness.audio_functions.jsinV3DataLoader_precombined import jsinV3_precombined_paired as jsinV3_precombined_paired_og
train_dset = jsinV3_precombined_paired_og(root=config['data']['root'], train=True, transform=transforms,)
loader = torch.utils.data.DataLoader(
            train_dset,
            batch_size=64,
            num_workers=0, 
            pin_memory=True,
            # persistent_workers=True,
            shuffle=False,
        )


In [8]:
# %%timeit

total = 100 
for batch in loader:
    if total == 0:
        break
    total -= 1 

In [12]:
sig11, sig12, sig21, sig22, label1, label2 = batch

In [14]:
label1, label2

(tensor([781, 704, 145, 589, 604, 720, 586, 165, 493, 264, 497, 487, 188, 361,
         518,  22,  87, 196, 710, 572, 716, 643, 749, 433, 782, 787, 456, 479,
         387, 415, 727,  13, 471, 687,  20, 379, 200, 601,   2, 271, 730, 232,
         504, 694, 458, 649,  95, 513, 709, 269, 729, 769, 261, 435, 558, 652,
         202, 439, 548, 171,   3, 477,  89, 571]),
 tensor([287, 559, 755, 167, 732, 356, 325, 655, 474, 589, 397,  51, 264, 115,
         481, 389, 217,  57, 419, 727, 769, 300, 475, 194, 502, 535, 240, 736,
         323, 194, 316, 694, 729,  87, 121, 448, 533, 689, 107, 634, 691, 543,
          95, 302, 771, 770, 483, 429, 721,  38,   8, 717, 324,  56, 181, 658,
         363, 232, 712, 556, 494, 264, 232, 415]))

In [10]:
unbatched_eg = next(iter(loader))

In [4]:
import lightning_scripts.lightning_ssl as lightning 
importlib.reload(lightning)

LitAudioSSL = lightning.LitAudioSSL

In [5]:
config['num_gpus'] = 1 

In [6]:
model = LitAudioSSL(config)

In [7]:
lr_monitor = LearningRateMonitor(logging_interval='step')

In [8]:
trainer = L.Trainer(
                    callbacks=[lr_monitor],
                    # limit_train_batches=5,
                    limit_val_batches=2,
                    max_epochs=5,
                    #  strategy='ddp_notebook',
                    #  reload_dataloaders_every_n_epochs=-1,
                    devices=1)


/mnt/home/igriffith/ceph/conda_envs/cochdnn_ssl_pl/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /mnt/home/igriffith/ceph/conda_envs/cochdnn_ssl_pl/l ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [9]:
trainer.fit(model)

/mnt/home/igriffith/ceph/conda_envs/cochdnn_ssl_pl/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /mnt/ceph/users/igriffith/projects/cochdnn/lightning_logs/version_4053982/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                       | Params | Mode 
------------------------------------------------------------------
0 | transforms | AudioCompose               | 0      | train
1 | audio_rep  | AudioToAudioRepresentation | 0      | train
2 | model      | ModelWithFrontEnd          | 26.4 M | train
3 | ssl_loss   | MMCR                       | 0      | train
------------------------------------------------------------------
26.4 M    Trainable params
0         Non-trainable params
26.4 M    Total params
105.762   Total estimated model params size (MB)
169       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...



KeyboardInterrupt

