## Tutorial on training a HTS-AT model for audio classification on the ESC-50 Dataset

Referece: 

[HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection, ICASSP 2022](https://arxiv.org/abs/2202.00874)

Following the HTS-AT's paper, in this tutorial, we would show how to use the HST-AT in the training of the ESC-50 Dataset.

The [ESC-50 dataset](https://github.com/karolpiczak/ESC-50) is a labeled collection of 2000 environmental audio recordings suitable for benchmarking methods of environmental sound classification. The dataset consists of 5-second-long recordings organized into 50 semantical classes (with 40 examples per class) loosely arranged into 5 major categories

Before running this tutorial, please make sure that you install the below packages by following steps:

1. download [the codebase](https://github.com/RetroCirce/HTS-Audio-Transformer), and put this tutorial notebook inside the codebase folder.

2. In the github code folder:

    > pip install -r requirements.txt

3. We do not include the installation of PyTorch in the requirment, since different machines require different vereions of CUDA and Toolkits. So make sure you install the PyTorch from [the official guidance](https://pytorch.org/).

4. Install the 'SOX' and the 'ffmpeg', we recommend that you run this code in Linux inside the Conda environment. In that, you can install them by:

    > sudo apt install sox
    
    > conda install -c conda-forge ffmpeg


In [1]:
# import basic packages
import os
import numpy as np
import wget
import sys
import gdown
import zipfile
import librosa
# in the notebook, we only can use one GPU
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
# Build the workspace and download the needed files

def create_path(path):
    if not os.path.exists(path):
        os.mkdir(path)

workspace = "./workspace_ADS_v2"
dataset_path = os.path.join(workspace, "mfg_robot")
checkpoint_path = os.path.join(workspace, "ckpt")
mfg_raw_path = os.path.join(dataset_path, 'raw')


create_path(workspace)
create_path(dataset_path)
create_path(checkpoint_path)
create_path(mfg_raw_path)


# # download the esc-50 dataset
# 
# if not os.path.exists(os.path.join(dataset_path, 'ESC-50-master.zip')):
#     print("-------------Downloading ESC-50 Dataset-------------")
#     wget.download('https://github.com/karoldvl/ESC-50/archive/master.zip', out=dataset_path)
#     with zipfile.ZipFile(os.path.join(dataset_path, 'ESC-50-master.zip'), 'r') as zip_ref:
#         zip_ref.extractall(esc_raw_path)
#     print("-------------Success-------------")
# 
# if not os.path.exists(os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt')):
#     gdown.download(id='1OK8a5XuMVLyeVKF117L8pfxeZYdfSDZv', output=os.path.join(checkpoint_path,'htsat_audioset_pretrain.ckpt'))
# 



In [3]:
# Process Manufacturing Dataset – Resampling Audio Files

audio_path = os.path.join(mfg_raw_path, 'MFG-master', 'audio')
resample_path = os.path.join(dataset_path, 'resample')
savedata_path = os.path.join(dataset_path, 'mfg-data.npy')
create_path(resample_path)

audio_list = os.listdir(audio_path)

print("-------------Resample ESC-50-------------")
for f in audio_list:
    full_f = os.path.join(audio_path, f)
    resample_f = os.path.join(resample_path, f)
    if not os.path.exists(resample_f):
        os.system('sox -V1 ' + full_f + ' -r 32000 ' + resample_f)
print("-------------Resample Success-------------")


-------------Resample ESC-50-------------
-------------Resample Success-------------


In [4]:
import os
import numpy as np
import librosa

# Paths
meta_path = os.path.join(mfg_raw_path, 'MFG-master\meta\mfg.csv')  # Adjust this path if needed
meta = np.loadtxt(meta_path, delimiter=',', dtype='str', skiprows=1)

print("-------------Build Dataset-------------")
output_dict = [[] for _ in range(5)]  # Assuming 5 folds, still okay if we only use fold 1

for label in meta:
    name = label[0]
    fold = int(label[1])
    target = int(label[2])
    
    #y, sr = librosa.load(os.path.join(resample_path, name), sr=None)
    
    # Preserve the orinal multi-channel structure for the sensor data
    y, sr = librosa.load(os.path.join(resample_path, name), sr=None, mono=False)

    output_dict[fold - 1].append({
        "name": name,
        "target": target,
        "waveform": y
    })

np.save(savedata_path, np.array(output_dict, dtype=object))

print("-------------Build Dataset Success-------------")


-------------Build Dataset-------------
-------------Build Dataset Success-------------


In [5]:
# Load the model package
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
import warnings

from utils import create_folder, dump_config, process_idc
import mfg_config as config
from sed_model import SEDWrapper, Ensemble_SEDWrapper
from data_generator_mfg import MFG_Dataset
from model.htsat_mfg_v_2_0 import HTSAT_Swin_Transformer



In [6]:
# Data Preparation
# New data preparation class
class data_prep(pl.LightningDataModule):
    def __init__(self, dataset, config, device_num):
        super().__init__()
        self.dataset = dataset  # Store only a reference
        self.config = config
        self.device_num = device_num
        self.train_dataset = None  # Placeholder, will be initialized later
        self.eval_dataset = None

    def setup(self, stage=None):
        """This method is called inside Lightning, and it ensures datasets are created properly."""
        if stage == "fit" or stage is None:
            self.train_dataset = MFG_Dataset(
                dataset=self.dataset,
                config=self.config,
                eval_mode=False
            )
            self.eval_dataset = MFG_Dataset(
                dataset=self.dataset,
                config=self.config,
                eval_mode=True
            )

    def train_dataloader(self):
        train_sampler = DistributedSampler(self.train_dataset, shuffle=False) if self.device_num > 1 else None
        return DataLoader(
            dataset=self.train_dataset,
            num_workers=self.config.num_workers,
            batch_size=self.config.batch_size // max(1, self.device_num),
            shuffle=False,
            sampler=train_sampler
        )

    def val_dataloader(self):
        eval_sampler = DistributedSampler(self.eval_dataset, shuffle=False) if self.device_num > 1 else None
        return DataLoader(
            dataset=self.eval_dataset,
            num_workers=self.config.num_workers,
            batch_size=self.config.batch_size // max(1, self.device_num),
            shuffle=False,
            sampler=eval_sampler
        )

    def test_dataloader(self):
        test_sampler = DistributedSampler(self.eval_dataset, shuffle=False) if self.device_num > 1 else None
        return DataLoader(
            dataset=self.eval_dataset,
            num_workers=self.config.num_workers,
            batch_size=self.config.batch_size // max(1, self.device_num),
            shuffle=False,
            sampler=test_sampler
        )

    def on_fit_start(self):
        """Removes unpicklable attributes before multiprocessing starts"""
        for attr in ["trainer", "prepare_data", "setup", "teardown"]:
            if hasattr(self, attr):
                delattr(self, attr)


In [7]:
# Set the workspace
device_num = torch.cuda.device_count()
print("device_num:", device_num)
print("each batch size:", config.batch_size // device_num)

full_dataset = np.load(os.path.join(config.dataset_path, "mfg-data.npy"), allow_pickle = True)

# set exp folder
exp_dir = os.path.join(config.workspace, "results", config.exp_name)
checkpoint_dir = os.path.join(config.workspace, "results", config.exp_name, "checkpoint")
if not config.debug:
    create_folder(os.path.join(config.workspace, "results"))
    create_folder(exp_dir)
    create_folder(checkpoint_dir)
    dump_config(config, os.path.join(exp_dir, config.exp_name), False)

print("Using MFG Dataset")
dataset = MFG_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = False
)
eval_dataset = MFG_Dataset(
    dataset = full_dataset,
    config = config,
    eval_mode = True
)

audioset_data = data_prep(dataset, eval_dataset, device_num)
checkpoint_callback = ModelCheckpoint(
    monitor = "acc",
    filename='l-{epoch:d}-{acc:.3f}',
    save_top_k = 20,
    mode = "max"
)




device_num: 1
each batch size: 64
Using MFG Dataset


In [8]:
# Set the Trainer
trainer = pl.Trainer(
    deterministic=False,
    default_root_dir=checkpoint_dir,
    gpus=device_num, 
    val_check_interval=1.0,
    max_epochs=config.max_epoch,
    auto_lr_find=True,    
    sync_batchnorm=True,
    callbacks=[checkpoint_callback],
    accelerator="ddp" if device_num > 1 else None,
    num_sanity_val_steps=0,
    resume_from_checkpoint=None, 
    replace_sampler_ddp=False,
    gradient_clip_val=1.0
)

# Create the HTSAT model with updated channel input (e.g., 3 or 6 channels)
sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=3,  # Change to 3 or 6 depending on your sensor data channels
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config=config,
    depths=config.htsat_depth,
    embed_dim=config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)

model = SEDWrapper(
    sed_model=sed_model, 
    config=config,
    dataset=dataset
)

if config.resume_checkpoint is not None:
    print("Load Checkpoint from ", config.resume_checkpoint)
    ckpt = torch.load(config.resume_checkpoint, map_location="cpu")
    
    key = "sed_model.patch_embed.proj.weight"
    if key in ckpt["state_dict"]:
        weight = ckpt["state_dict"][key]
        # Adapt the patch embedding weights to match the current in_chans setting
        if weight.shape[1] != sed_model.in_chans:
            # Assume the checkpoint was trained with a single channel (in_chans==1)
            if weight.shape[1] == 1:
                adapted_weight = weight.repeat(1, sed_model.in_chans, 1, 1) / sed_model.in_chans
                ckpt["state_dict"][key] = adapted_weight
            else:
                raise ValueError("Unexpected number of channels in checkpoint weight: {}".format(weight.shape[1]))
    
    # Remove keys that might conflict with the current model architecture
    ckpt["state_dict"].pop("sed_model.head.weight", None)
    ckpt["state_dict"].pop("sed_model.head.bias", None)
    ckpt["state_dict"].pop("sed_model.tscam_conv.weight", None)
    ckpt["state_dict"].pop("sed_model.tscam_conv.bias", None)
    
    model.load_state_dict(ckpt["state_dict"], strict=False)



GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Load Checkpoint from  ./workspace_ADS_v2/ckpt/htsat_audioset_pretrain.ckpt


  ckpt = torch.load(config.resume_checkpoint, map_location="cpu")


In [12]:
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader, DistributedSampler

def dynamic_pad_collate(batch):
    # Convert each sample's waveform to a tensor (if needed)
    for sample in batch:
        if not isinstance(sample["waveform"], torch.Tensor):
            sample["waveform"] = torch.tensor(sample["waveform"], dtype=torch.float)
        # Ensure the waveform has at least two dimensions: [channels, time]
        if sample["waveform"].dim() == 1:
            sample["waveform"] = sample["waveform"].unsqueeze(0)
    # Determine the maximum time dimension (axis=1) among samples
    max_length = max(sample["waveform"].shape[1] for sample in batch)
    padded_waveforms, targets, names, lengths = [], [], [], []
    for sample in batch:
        wav = sample["waveform"]  # expected shape: [channels, time]
        current_length = wav.shape[1]
        if current_length < max_length:
            pad_amount = int(max_length - current_length)
            # Pad along the time dimension
            wav = F.pad(wav, (0, pad_amount), mode="constant", value=0)
        padded_waveforms.append(wav)
        targets.append(sample["target"])
        names.append(sample["audio_name"])
        lengths.append(current_length)
    return {
        "waveform": torch.stack(padded_waveforms),  # shape: [B, channels, max_length]
        "target": torch.tensor(targets),
        "audio_name": names,
        "real_len": torch.tensor(lengths)
    }

# Instantiate DataModule and ensure setup()
audioset_data = data_prep(full_dataset, config, device_num)
audioset_data.setup("fit")
if hasattr(audioset_data, "trainer"):
    del audioset_data.trainer

# Override DataLoaders to use dynamic_pad_collate
audioset_data.train_dataloader = lambda: DataLoader(
    audioset_data.train_dataset,
    batch_size=config.batch_size // max(1, device_num),
    sampler=DistributedSampler(audioset_data.train_dataset) if device_num > 1 else None,
    num_workers=0,  # For debugging; increase as needed later
    collate_fn=dynamic_pad_collate
)
audioset_data.val_dataloader = lambda: DataLoader(
    audioset_data.eval_dataset,
    batch_size=config.batch_size // max(1, device_num),
    sampler=DistributedSampler(audioset_data.eval_dataset) if device_num > 1 else None,
    num_workers=0,
    collate_fn=dynamic_pad_collate
)

# Build model with 3-channel input (since your data now has 3 channels)
sed_model = HTSAT_Swin_Transformer(
    spec_size=config.htsat_spec_size,
    patch_size=config.htsat_patch_size,
    in_chans=1,  # Use 3-channel input
    num_classes=config.classes_num,
    window_size=config.htsat_window_size,
    config=config,
    depths=config.htsat_depth,
    embed_dim=config.htsat_dim,
    patch_stride=config.htsat_stride,
    num_heads=config.htsat_num_head
)
model = SEDWrapper(sed_model=sed_model, config=config, dataset=audioset_data.train_dataset)

# Trainer setup
trainer = pl.Trainer(
    default_root_dir=checkpoint_dir,
    gpus=device_num,
    max_epochs=config.max_epoch,
    callbacks=[checkpoint_callback],
    accelerator="ddp" if device_num > 1 else None,
    num_sanity_val_steps=0,
)

print("Starting training…")
trainer.fit(model, audioset_data)


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                   | Params
-----------------------------------------------------
0 | sed_model | HTSAT_Swin_Transformer | 31.3 M
-----------------------------------------------------
30.2 M    Trainable params
1.1 M     Non-trainable params
31.3 M    Total params
125.301   Total estimated model params size (MB)


Starting training…


Training: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

cuda:0 {'acc': 1.0}


Validating: 0it [00:00, ?it/s]

cuda:0 {'acc': 1.0}


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Now Let us Check the Result

Find the path of your saved checkpoint and paste it in the below variable.
Then you are able to follow the below code for checking the prediction result of any sample you like.

In [78]:
# infer the single data to check the result
# get a model you saved
model_path = r"C:\Users\Louis\PycharmProjects\HTS-AT(Conda)\HTS-Audio-Transformer\workspace\results\exp_htsat_esc_50\checkpoint\lightning_logs\version_1\checkpoints\l-epoch=4-acc=0.815.ckpt"

# get the groundtruth
meta = np.loadtxt(meta_path , delimiter=',', dtype='str', skiprows=1)
gd = {}
for label in meta:
    name = label[0]
    target = label[2]
    gd[name] = target

class Audio_Classification:
    def __init__(self, model_path, config):
        super().__init__()

        self.device = torch.device('cuda')
        self.sed_model = HTSAT_Swin_Transformer(
            spec_size=config.htsat_spec_size,
            patch_size=config.htsat_patch_size,
            in_chans=1,
            num_classes=config.classes_num,
            window_size=config.htsat_window_size,
            config = config,
            depths = config.htsat_depth,
            embed_dim = config.htsat_dim,
            patch_stride=config.htsat_stride,
            num_heads=config.htsat_num_head
        )
        ckpt = torch.load(model_path, map_location="cpu")
        temp_ckpt = {}
        for key in ckpt["state_dict"]:
            temp_ckpt[key[10:]] = ckpt['state_dict'][key]
        self.sed_model.load_state_dict(temp_ckpt)
        self.sed_model.to(self.device)
        self.sed_model.eval()


    def predict(self, audiofile):

        if audiofile:
            waveform, sr = librosa.load(audiofile, sr=32000)

            with torch.no_grad():
                x = torch.from_numpy(waveform).float().to(self.device)
                output_dict = self.sed_model(x[None, :], None, True)
                pred = output_dict['clipwise_output']
                pred_post = pred[0].detach().cpu().numpy()
                pred_label = np.argmax(pred_post)
                pred_prob = np.max(pred_post)
            return pred_label, pred_prob


In [ ]:
batch = next(iter(audioset_data.train_dataloader()))
x = batch["waveform"].to(device)  # shape should be (batch_size, 3, clip_samples)
out = sed_model(x)                # no channel‐mismatch error
print("Output keys:", out.keys())
print("Clipwise output shape:", out["clipwise_output"].shape)


In [39]:
# Inference
Audiocls = Audio_Classification(model_path, config)

# pick any audio you like in the ESC-50 testing set (cross-validation)
pred_label, pred_prob = Audiocls.predict("./workspace/esc-50/raw/ESC-50-master/audio/1-7456-A-13.wav")

print('Audiocls predict output: ', pred_label, pred_prob, gd["1-7456-A-13.wav"])

  ckpt = torch.load(model_path, map_location="cpu")


Audiocls predict output:  13 8.718129 13
