In [1]:
import json
import pytorch_lightning as pl
from einops import rearrange
import torch
import torchaudio
from torch.utils.data import DataLoader
from safetensors.torch import load_file

from stable_audio_tools import create_model_from_config
from stable_audio_tools.data.dataset import VideoFeatDataset, collation_fn
from stable_audio_tools.training.training_wrapper import DiffusionCondTrainingWrapper
from stable_audio_tools.inference.generation import generate_diffusion_cond



model_config_file = '/home/chengxin/chengxin/stable-v2a/stable_audio_tools/configs/model_config.json'
with open(model_config_file) as f:
    model_config = json.load(f)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = create_model_from_config(model_config)
# model.load_state_dict(load_file('../stable-v2a/dataset/StableAudio/model.safetensors'), strict=False)



In [3]:
info_dirs = ['/home/chengxin/chengxin/stable-v2a/dataset/video_feature/test/AudioSet']
audio_dirs = ['/home/chengxin/chengxin/AudioSet/generated_audios/test/10']
ds_config = {
    'info_dirs' : info_dirs,
    'audio_dirs' : audio_dirs,
    'exts':'wav',
    'sample_rate':44100, 
    'force_channels':"stereo"
}
dl_config = {
    'batch_size':10, 
    'shuffle':True,
    'num_workers':4, 
    'persistent_workers':True, 
    'pin_memory':True, 
    'drop_last':False, 
}


dataset = VideoFeatDataset(**ds_config)
dataloader = DataLoader(dataset=dataset,  collate_fn=collation_fn, **dl_config)

Found 10 files


In [4]:
device = "cuda:0"
model.to(device)

conditioning = [dataset[i][1] for i in range(5)]


for _,y in dataloader:
    conditioning = y
    break

conditioning

{'fps': tensor([[22.],
         [22.],
         [22.],
         [22.],
         [22.],
         [22.],
         [22.],
         [22.],
         [22.],
         [22.]]),
 'duration': tensor([[10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180],
         [10.0180]]),
 'frame_num': tensor([[221.],
         [221.],
         [221.],
         [221.],
         [221.],
         [221.],
         [221.],
         [221.],
         [221.],
         [221.]]),
 'feature': tensor([[[ 0.0864, -0.4263, -0.1164,  ..., -0.0546, -0.1753, -0.4968],
          [ 0.1497, -0.2717, -0.0367,  ..., -0.0872, -0.0952, -0.4287],
          [ 0.1720, -0.3508, -0.0410,  ..., -0.1179, -0.0631, -0.5029],
          ...,
          [ 0.4270,  0.1989,  0.0712,  ..., -0.0251, -0.0352, -0.3728],
          [ 0.4373,  0.1680,  0.0981,  ..., -0.0112, -0.0586, -0.3411],
          [ 0.4382,  0.1683,  0.0976,  ..., -0

In [5]:

sample_size = model_config["sample_size"]
sample_rate = model_config["sample_rate"]
audio, cond = dataset[0]
# conditioning = [cond]

output = generate_diffusion_cond(
    model,
    steps=100,
    cfg_scale=7,
    conditioning=conditioning,
    sample_size=sample_size,
    sigma_min=0.3,
    sigma_max=500,
    sampler_type="dpmpp-3m-sde",
    device=device
)


# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

# Peak normalize, clip, convert to int16, and save to file
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
torchaudio.save("output.wav", output, sample_rate)

2682432704


  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([2, 1025, 1536]) torch.Size([20, 221, 768]) None





RuntimeError: k must have shape (batch_size, seqlen_k, num_heads_k, head_size_og)

# 2. Split Steps

In [11]:
trainer = pl.Trainer(
    devices=1,
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)


trainer.fit(training_wrapper, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /data-81-01/chengxin/stable-v2a/lightning_logs


/home/chengxin/chengxin/anaconda3/envs/stableaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [2]:
from stable_audio_tools.models.diffusion import DiTWrapper, ConditionedDiffusionModelWrapper
from stable_audio_tools.models.conditioners import create_multi_conditioner_from_conditioning_config
from stable_audio_tools.models.autoencoders import create_autoencoder_from_config
from stable_audio_tools.models.pretransforms import AutoencoderPretransform

io_channels = model_config["model"].get('io_channels', None)
sample_rate = model_config.get('sample_rate', None)
assert io_channels is not None, "Must specify io_channels in model config"
assert sample_rate is not None, "Must specify sample_rate in config"



diffusion_config = {
    'cross_attention_cond_ids': ['feature'],
    'type': 'dit',
    'config': {
        'io_channels': 64,
        'embed_dim': 1536,
        'depth': 24,
        'num_heads': 24,
        'cond_token_dim': 768,
        'global_cond_dim': 1536,
        'project_cond_tokens': False,
        'transformer_type': 'continuous_transformer'
        }
}
diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
global_cond_ids = diffusion_config.get('global_cond_ids', [])
input_concat_ids = diffusion_config.get('input_concat_ids', [])
prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])

diffusion_model_config = diffusion_config['config']
diffusion_model = DiTWrapper(**diffusion_model_config)


In [4]:
conditioning_config = {
    'configs': [
        {
            'id': 'duration',
            'type': 'number',
            'config': {'min_val': 0, 'max_val': 512}
         },
         {
             'id': 'feature', 
             'type': 'video_feature', 
             'config': {}
        }
    ],
    'cond_dim': 768,
    'default_keys': {'feature': 'video_path'}
}
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)

In [5]:

pretransform_config = {
    'type': 'autoencoder',
    'iterate_batch': True,
    'config': {
        'encoder': {
            'type': 'oobleck',
            'requires_grad': False,
            'config': {
                'in_channels': 2,
                'channels': 128,
                'c_mults': [1, 2, 4, 8, 16],
                'strides': [2, 4, 4, 8, 8],
                'latent_dim': 128,
                'use_snake': True
            }
        },
        'decoder': {
            'type': 'oobleck',
            'config': {
                'out_channels': 2,
                'channels': 128,
                'c_mults': [1, 2, 4, 8, 16],
                'strides': [2, 4, 4, 8, 8],
                'latent_dim': 64,
                'use_snake': True,
                'final_tanh': False
            }
        },
        'bottleneck': {'type': 'vae'},
        'latent_dim': 64,
        'downsampling_ratio': 2048,
        'io_channels': 2
        }
    }

autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
autoencoder = create_autoencoder_from_config(autoencoder_config)


pretransform = AutoencoderPretransform(
    autoencoder, 
    scale=pretransform_config.get("scale", 1.0),
    model_half=pretransform_config.get("model_half", False), 
    iterate_batch=pretransform_config.get("iterate_batch", False), 
    chunked=pretransform_config.get("chunked", False)
)
pretransform.enable_grad = pretransform_config.get('enable_grad', False)
pretransform.eval().requires_grad_(pretransform.enable_grad)


min_input_length = pretransform.downsampling_ratio
min_input_length *= diffusion_model.model.patch_size



In [None]:
extra_kwargs = {"diffusion_objective":diffusion_objective}
model = ConditionedDiffusionModelWrapper(
        diffusion_model,
        conditioner,
        min_input_length=min_input_length,
        sample_rate=sample_rate,
        cross_attn_cond_ids=cross_attention_ids,
        global_cond_ids=global_cond_ids,
        input_concat_ids=input_concat_ids,
        prepend_cond_ids=prepend_cond_ids,
        pretransform=pretransform,
        io_channels=io_channels,
        **extra_kwargs
    )

# 3. Train

In [6]:
training_config = model_config.get('training', None)
training_wrapper = DiffusionCondTrainingWrapper(
            model, 
            lr=training_config.get("learning_rate", None),
            mask_padding=training_config.get("mask_padding", False),
            mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
            use_ema = training_config.get("use_ema", True),
            log_loss_info=training_config.get("log_loss_info", False),
            optimizer_configs=training_config.get("optimizer_configs", None),
            pre_encoded=training_config.get("pre_encoded", False),
            cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
            timestep_sampler = training_config.get("timestep_sampler", "uniform")
        )

In [7]:
trainer = pl.Trainer(
    devices=1,
    accelerator="gpu",
    num_nodes = 1,
    max_epochs=2,
)


trainer.fit(training_wrapper, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


You are using a CUDA device ('NVIDIA H100 PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name          | Type                             | Params
-------------------------------------------------------------------
0 | diffusion     | ConditionedDiffusionModelWrapper | 1.2 B 
1 | diffusion_ema | EMA                              | 1.1 B 
2 | losses        | MultiLoss                        | 0     
-------------------------------------------------------------------
1.1 B     Trainable params
1.2 B     Non-trainable params
2.3 B     Total params
9,079.871 Total estimated model params size (MB)
/home/chengxin/chengxin/anaconda3/envs/stableaudio/lib/python3.8/site-pa

Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s] torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]) None
torch.Size([10, 216, 1536]) torch.Size([10, 221, 768]

/home/chengxin/chengxin/anaconda3/envs/stableaudio/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
