In [1]:
import os

In [2]:
import argparse
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.distributed import init_process_group, destroy_process_group

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from model.waveform_model import WaveMAE
from model.spectrogram_model import SpectrogramMAE

In [3]:
parser = argparse.ArgumentParser()
parser.add_argument("--num_processes", default=4, type=str)

# dataset configuration
parser.add_argument("--num_workers", default=4, type=int)
parser.add_argument("--batch_size", default=512, type=int)
parser.add_argument("--pin_memory", default=True, type=bool)

# model configuration
parser.add_argument("--model_type", default="waveform", choices=["waveform", "spectrogram"], type=str)
parser.add_argument("--embed_dim", default=768, type=int)
parser.add_argument("--num_heads", default=16, type=int)
parser.add_argument("--middle_channel", default=512, type=int)
parser.add_argument("--depth", default=12, type=int)
parser.add_argument("--masking_mode", default="random", choices=["random", "uniform"], type=str)
parser.add_argument("--masked_ratio", default=0.8, type=float)

# training configuration
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--lr", default=1e-5, type=float)

args = parser.parse_args()

_StoreAction(option_strings=['--num_processes'], dest='num_processes', nargs=None, const=None, default=4, type=<class 'str'>, choices=None, required=False, help=None, metavar=None)

In [None]:
def ddp_setup(rank: int, world_size: int):
   """
   Args:
       rank: Unique identifier of each process
       world_size: Total number of processes
   """
   os.environ["MASTER_ADDR"] = "localhost"
   os.environ["MASTER_PORT"] = "12355"
   torch.cuda.set_device(rank)
   init_process_group(backend="nccl", rank=rank, world_size=world_size)

In [None]:
# dataset
training_set = load_dataset("agkphysics/AudioSet", "unbalanced")
training_set_sampler = DistributedSampler(training_set)
train_loader = DataLoader(training_set, sampler=training_set_sampler, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True) # add collate function if needed

In [None]:
# model
if args.model_type == "waveform":
    model = WaveMAE(middle_channel=args.middle_channel, embed_dim=args.embed_dim, num_heads=args.num_heads, 
                    depth=args.depth, masking_mode=args.masking_mode, masked_ratio=args.masked_ratio)
elif args.model_type == "spectrogram":
    model = SpectrogramMAE(embed_dim=args.embed_dim, num_heads=args.num_heads, depth=args.depth,
                           masking_mode=args.masking_mode, mask_ratio=args.mask_ratio)
    
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(model, device_ids=[])