In [1]:
import sys
import os
# sys.path.append(os.path.dirname(__file__))
import torch

# Other Files

In [2]:
# -*- coding: utf-8 -*-
# @Time    : 6/10/21 5:04 PM
# @Author  : Yuan Gong
# @Affiliation  : Massachusetts Institute of Technology
# @Email   : yuangong@mit.edu
# @File    : ast_models.py

import torch
import torch.nn as nn
from torch.cuda.amp import autocast
import os
import wget
os.environ['TORCH_HOME'] = '../pretrained_models'
import timm
from timm.models.layers import to_2tuple,trunc_normal_

# override the timm package to relax the input shape constraint.
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()

        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class ASTModel(nn.Module):
    """
    The AST model.
    :param label_dim: the label dimension, i.e., the number of total classes, it is 527 for AudioSet, 50 for ESC-50, and 35 for speechcommands v2-35
    :param fstride: the stride of patch spliting on the frequency dimension, for 16*16 patchs, fstride=16 means no overlap, fstride=10 means overlap of 6
    :param tstride: the stride of patch spliting on the time dimension, for 16*16 patchs, tstride=16 means no overlap, tstride=10 means overlap of 6
    :param input_fdim: the number of frequency bins of the input spectrogram
    :param input_tdim: the number of time frames of the input spectrogram
    :param imagenet_pretrain: if use ImageNet pretrained model
    :param audioset_pretrain: if use full AudioSet and ImageNet pretrained model
    :param model_size: the model size of AST, should be in [tiny224, small224, base224, base384], base224 and base 384 are same model, but are trained differently during ImageNet pretraining.
    """
    def __init__(self, label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=True, audioset_pretrain=False, model_size='base384', verbose=True):

        super(ASTModel, self).__init__()

        if verbose == True:
            print('---------------AST Model Summary---------------')
            print('ImageNet pretraining: {:s}, AudioSet pretraining: {:s}'.format(str(imagenet_pretrain),str(audioset_pretrain)))
        # override timm input shape restriction
        timm.models.vision_transformer.PatchEmbed = PatchEmbed

        # if AudioSet pretraining is not used (but ImageNet pretraining may still apply)
        if audioset_pretrain == False:
            if model_size == 'tiny224':
                self.v = timm.create_model('vit_deit_tiny_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'small224':
                self.v = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base224':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_224', pretrained=imagenet_pretrain)
            elif model_size == 'base384':
                self.v = timm.create_model('vit_deit_base_distilled_patch16_384', pretrained=imagenet_pretrain)
            else:
                raise Exception('Model size must be one of tiny224, small224, base224, base384.')
            self.original_num_patches = self.v.patch_embed.num_patches
            self.oringal_hw = int(self.original_num_patches ** 0.5)
            self.original_embedding_dim = self.v.pos_embed.shape[2]
            self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

            # automatcially get the intermediate shape
            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('number of patches={:d}'.format(num_patches))

            # the linear projection layer
            new_proj = torch.nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
            if imagenet_pretrain == True:
                new_proj.weight = torch.nn.Parameter(torch.sum(self.v.patch_embed.proj.weight, dim=1).unsqueeze(1))
                new_proj.bias = self.v.patch_embed.proj.bias
            self.v.patch_embed.proj = new_proj

            # the positional embedding
            if imagenet_pretrain == True:
                # get the positional embedding from deit model, skip the first two tokens (cls token and distillation token), reshape it to original 2D shape (24*24).
                new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, self.original_num_patches, self.original_embedding_dim).transpose(1, 2).reshape(1, self.original_embedding_dim, self.oringal_hw, self.oringal_hw)
                # cut (from middle) or interpolate the second dimension of the positional embedding
                if t_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, :, int(self.oringal_hw / 2) - int(t_dim / 2): int(self.oringal_hw / 2) - int(t_dim / 2) + t_dim]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(self.oringal_hw, t_dim), mode='bilinear')
                # cut (from middle) or interpolate the first dimension of the positional embedding
                if f_dim <= self.oringal_hw:
                    new_pos_embed = new_pos_embed[:, :, int(self.oringal_hw / 2) - int(f_dim / 2): int(self.oringal_hw / 2) - int(f_dim / 2) + f_dim, :]
                else:
                    new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
                # flatten the positional embedding
                new_pos_embed = new_pos_embed.reshape(1, self.original_embedding_dim, num_patches).transpose(1,2)
                # concatenate the above positional embedding with the cls token and distillation token of the deit model.
                self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))
            else:
                # if not use imagenet pretrained model, just randomly initialize a learnable positional embedding
                # TODO can use sinusoidal positional embedding instead
                new_pos_embed = nn.Parameter(torch.zeros(1, self.v.patch_embed.num_patches + 2, self.original_embedding_dim))
                self.v.pos_embed = new_pos_embed
                trunc_normal_(self.v.pos_embed, std=.02)

        # now load a model that is pretrained on both ImageNet and AudioSet
        elif audioset_pretrain == True:
            if audioset_pretrain == True and imagenet_pretrain == False:
                raise ValueError('currently model pretrained on only audioset is not supported, please set imagenet_pretrain = True to use audioset pretrained model.')
            if model_size != 'base384':
                raise ValueError('currently only has base384 AudioSet pretrained model.')
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            if os.path.exists('../pretrained_models/audioset_10_10_0.4593.pth') == False:
                # this model performs 0.4593 mAP on the audioset eval set
                audioset_mdl_url = 'https://www.dropbox.com/s/cv4knew8mvbrnvq/audioset_0.4593.pth?dl=1'
                wget.download(audioset_mdl_url, out='../pretrained_models/audioset_10_10_0.4593.pth')
            sd = torch.load('../pretrained_models/audioset_10_10_0.4593.pth', map_location=device)
            audio_model = ASTModel(label_dim=527, fstride=10, tstride=10, input_fdim=128, input_tdim=1024, imagenet_pretrain=False, audioset_pretrain=False, model_size='base384', verbose=False)
            audio_model = torch.nn.DataParallel(audio_model)
            audio_model.load_state_dict(sd, strict=False)
            self.v = audio_model.module.v
            self.original_embedding_dim = self.v.pos_embed.shape[2]
            self.mlp_head = nn.Sequential(nn.LayerNorm(self.original_embedding_dim), nn.Linear(self.original_embedding_dim, label_dim))

            f_dim, t_dim = self.get_shape(fstride, tstride, input_fdim, input_tdim)
            num_patches = f_dim * t_dim
            self.v.patch_embed.num_patches = num_patches
            if verbose == True:
                print('frequncey stride={:d}, time stride={:d}'.format(fstride, tstride))
                print('number of patches={:d}'.format(num_patches))

            new_pos_embed = self.v.pos_embed[:, 2:, :].detach().reshape(1, 1212, 768).transpose(1, 2).reshape(1, 768, 12, 101)
            # if the input sequence length is larger than the original audioset (10s), then cut the positional embedding
            if t_dim < 101:
                new_pos_embed = new_pos_embed[:, :, :, 50 - int(t_dim/2): 50 - int(t_dim/2) + t_dim]
            # otherwise interpolate
            else:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(12, t_dim), mode='bilinear')
            if f_dim < 12:
                new_pos_embed = new_pos_embed[:, :, 6 - int(f_dim/2): 6 - int(f_dim/2) + f_dim, :]
            # otherwise interpolate
            elif f_dim > 12:
                new_pos_embed = torch.nn.functional.interpolate(new_pos_embed, size=(f_dim, t_dim), mode='bilinear')
            new_pos_embed = new_pos_embed.reshape(1, 768, num_patches).transpose(1, 2)
            self.v.pos_embed = nn.Parameter(torch.cat([self.v.pos_embed[:, :2, :].detach(), new_pos_embed], dim=1))

    def get_shape(self, fstride, tstride, input_fdim=128, input_tdim=1024):
        test_input = torch.randn(1, 1, input_fdim, input_tdim)
        test_proj = nn.Conv2d(1, self.original_embedding_dim, kernel_size=(16, 16), stride=(fstride, tstride))
        test_out = test_proj(test_input)
        f_dim = test_out.shape[2]
        t_dim = test_out.shape[3]
        return f_dim, t_dim

    # @autocast()
    def forward(self, x):
        """
        :param x: the input spectrogram, expected shape: (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        :return: prediction
        """
        # expect input x = (batch_size, time_frame_num, frequency_bins), e.g., (12, 1024, 128)
        x = x.unsqueeze(1)
        x = x.transpose(2, 3)

        # Convert input to the same dtype as the model's parameters
        # x = x.to(next(self.parameters()).dtype)

        B = x.shape[0]
        print("path embed")
        x = self.v.patch_embed(x)
        cls_tokens = self.v.cls_token.expand(B, -1, -1)
        dist_token = self.v.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)
        x = x + self.v.pos_embed
        x = self.v.pos_drop(x)
        for blk in self.v.blocks:
            x = blk(x)
        x = self.v.norm(x)
        x = (x[:, 0] + x[:, 1]) / 2

        x = self.mlp_head(x)
        return x

# if __name__ == '__main__':
#     input_tdim = 100
#     ast_mdl = ASTModel(input_tdim=input_tdim)
#     # input a batch of 10 spectrogram, each with 100 time frames and 128 frequency bins

#     device = 'cpu'

#     # 1) cast to float, 2) move model to GPU
#     ast_mdl = ast_mdl.float().to(device)

#     # 3) create your input and move it to the *same* device
#     test_input = torch.rand([10, input_tdim, 128], device=device)
#     test_output = ast_mdl(test_input)
#     # output should be in shape [10, 527], i.e., 10 samples, each with prediction of 527 classes.
#     print(test_output.shape)

#     input_tdim = 256
#     ast_mdl = ASTModel(input_tdim=input_tdim,label_dim=50, audioset_pretrain=True)
#     # input a batch of 10 spectrogram, each with 512 time frames and 128 frequency bins
#     test_input = torch.rand([10, input_tdim, 128])
#     test_output = ast_mdl(test_input)
#     # output should be in shape [10, 50], i.e., 10 samples, each with prediction of 50 classes.
#     print(test_output.shape)

In [3]:
import os
import json
import torch
import torchaudio
import numpy as np
from tqdm import tqdm
from pathlib import Path
from torch.utils.data import Dataset
import csv
import time
from torch import nn
import pickle
import pandas as pd

def load_config(config_path):
    """Load the configuration file containing preprocessing parameters."""
    with open(config_path, 'r') as f:
        config = json.load(f)
    return config

class FbankDataset(Dataset):
    """Dataset class for preprocessed fbank files.
    
    This dataset class works with the preprocessed fbank files generated by process_dataset.
    It loads the preprocessed mel spectrograms directly from disk instead of computing them on the fly.
    
    Args:
        dataset_json_file (str): Path to the JSON file containing fbank file paths and labels
        label_csv (str, optional): Path to the CSV file containing class labels
        audio_conf (dict, optional): Dictionary containing audio configuration parameters
    """
    def __init__(self, dataset_json_file, label_csv=None, audio_conf=None):
        self.datapath = dataset_json_file
        with open(dataset_json_file, 'r') as fp:
            data_json = json.load(fp)
        
        self.data = data_json['data']
        self.audio_conf = audio_conf or {}
        
        # Load label mapping if provided
        self.index_dict = {}
        if label_csv:
            with open(label_csv, 'r') as f:
                csv_reader = csv.DictReader(f)
                for row in csv_reader:
                    self.index_dict[row['mid']] = row['index']
            self.label_num = len(self.index_dict)
        else:
            self.label_num = None
    
    def __getitem__(self, index):
        """Get a single item from the dataset.
        
        Args:
            index (int): Index of the item to get
            
        Returns:
            tuple: (fbank, label_indices)
                - fbank (torch.Tensor): The preprocessed mel spectrogram
                - label_indices (torch.Tensor): The label indices
        """
        datum = self.data[index]
        
        # Load the preprocessed fbank
        fbank = torch.load(datum['wav'])
        
        # Initialize label indices
        if self.label_num is not None:
            label_indices = np.zeros(self.label_num)
            for label_str in datum['labels'].split(','):
                label_indices[int(self.index_dict[label_str])] = 1.0
            label_indices = torch.FloatTensor(label_indices)
        else:
            label_indices = torch.tensor(0)  # Dummy tensor if no labels
        
        return fbank, label_indices
    
    def __len__(self):
        """Get the total number of items in the dataset."""
        return len(self.data)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def train_epoch(model, train_loader, criterion, optimizer, device, epoch, args):
    """Train for one epoch."""
    model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    

    print("initialized time measures")
    end = time.time()
    print("unrolling loader")
    for i, (audio_input, labels) in enumerate(train_loader):
        print("updating time")
        data_time.update(time.time() - end)
        
        print("moving input to device")
        # Move data to device
        audio_input = audio_input.to(device)
        labels = labels.to(device)
        
        print("train forw")
        # Compute output
        audio_output = model(audio_input)
        
        # Compute loss
        if isinstance(criterion, nn.CrossEntropyLoss):
            loss = criterion(audio_output, torch.argmax(labels.long(), axis=1))
        else:
            loss = criterion(audio_output, labels)
        
        # Compute gradient and do optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update statistics
        loss_meter.update(loss.item(), audio_input.size(0))
        batch_time.update(time.time() - end)
        end = time.time()
        
        # Print statistics
        if i % args['print_freq'] == 0:
            print(f'Epoch: [{epoch}][{i}/{len(train_loader)}]\t'
                  f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})')
    
    return loss_meter.avg

def validate(model, val_loader, criterion, device, args):
    """Validate the model."""
    model.eval()
    batch_time = AverageMeter()
    loss_meter = AverageMeter()
    
    predictions = []
    targets = []
    
    with torch.no_grad():
        end = time.time()
        for i, (audio_input, labels) in enumerate(val_loader):
            # Move data to device
            audio_input = audio_input.to(device)
            labels = labels.to(device)
            
            # Compute output
            audio_output = model(audio_input)
            
            # Compute loss
            if isinstance(criterion, nn.CrossEntropyLoss):
                loss = criterion(audio_output, torch.argmax(labels.long(), axis=1))
            else:
                loss = criterion(audio_output, labels)
            
            # Store predictions and targets
            predictions.append(audio_output.cpu())
            targets.append(labels.cpu())
            
            # Update statistics
            loss_meter.update(loss.item(), audio_input.size(0))
            batch_time.update(time.time() - end)
            end = time.time()
            
            if i % args['print_freq'] == 0:
                print(f'Validation: [{i}/{len(val_loader)}]\t'
                      f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})')
    
    # Concatenate predictions and targets
    predictions = torch.cat(predictions)
    targets = torch.cat(targets)
    
    # Calculate metrics
    if isinstance(criterion, nn.CrossEntropyLoss):
        accuracy = (torch.argmax(predictions, dim=1) == torch.argmax(targets, dim=1)).float().mean()
        metrics = {'accuracy': accuracy.item(), 'loss': loss_meter.avg}
    else:
        # For binary/multi-label classification
        predictions = torch.sigmoid(predictions)
        metrics = {
            'loss': loss_meter.avg,
            'accuracy': ((predictions > 0.5) == targets).float().mean().item()
        }
    
    return metrics

def train_model(model, train_loader, val_loader, args):
    """Main training function."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    
    # Move model to device
    model = model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    # Define loss function
    if args['loss'] == 'BCE':
        criterion = nn.BCEWithLogitsLoss()
    elif args['loss'] == 'CE':
        criterion = nn.CrossEntropyLoss()
    else:
        raise ValueError(f'Unknown loss function: {args["loss"]}')
    
    # Define optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
    
    # Define learning rate scheduler
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', factor=0.5, patience=args.lr_patience, verbose=True
    # )

    # ideally for ESC
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, list(range(args['lrscheduler_start'], 1000, args['lrscheduler_step'])),gamma=args['lrscheduler_decay'])


    # Training loop
    best_val_loss = float('inf')
    for epoch in range(args['epochs']):
        print(f'\nEpoch {epoch+1}/{args["epochs"]}')
        
        # Train for one epoch
        print("train loop start")
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch, args)
        
        # Validate
        val_metrics = validate(model, val_loader, criterion, device, args)
        
        # Update learning rate
        scheduler.step(val_metrics['loss'])
        
        # Save checkpoint
        is_best = val_metrics['loss'] < best_val_loss
        best_val_loss = min(val_metrics['loss'], best_val_loss)
        
        if is_best:
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss'],
                'val_accuracy': val_metrics['accuracy']
            }, os.path.join(args['exp_dir'], 'best_model.pth'))
        
        # Save latest checkpoint
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_metrics['loss'],
            'val_accuracy': val_metrics['accuracy']
        }, os.path.join(args['exp_dir'], 'latest_model.pth'))
        
        print(f'Train Loss: {train_loss:.4f}')
        print(f'Val Loss: {val_metrics["loss"]:.4f}')
        print(f'Val Accuracy: {val_metrics["accuracy"]:.4f}')

def run_inference(model, val_json_path, label_csv_path, output_csv_path, device=None):
    """Run inference on a validation set and save predictions to CSV.
    
    Args:
        model (nn.Module): The trained model
        val_json_path (str): Path to the validation JSON file containing fbank file paths
        label_csv_path (str): Path to the CSV file containing class labels
        output_csv_path (str): Path where the output CSV file will be saved
        device (torch.device, optional): Device to run inference on. If None, will use CUDA if available.
    
    Returns:
        float: Overall accuracy of the model on the validation set
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load validation data
    with open(val_json_path, 'r') as f:
        val_data = json.load(f)
    
    # Load label mapping
    label_map = {}
    with open(label_csv_path, 'r') as f:
        csv_reader = csv.DictReader(f)
        for row in csv_reader:
            label_map[row['index']] = row['mid']
    num_classes = len(label_map)
    
    # Prepare output CSV
    output_columns = ['file_path'] + [f'prob_class_{i}' for i in range(num_classes)] + ['true_label', 'predicted_label', 'correct']
    output_data = []
    
    # Set model to evaluation mode
    model = model.to(device)
    model.eval()
    
    correct_predictions = 0
    total_predictions = 0
    
    # Process each file
    for item in tqdm(val_data['data'], desc="Running inference"):
        # Load fbank
        fbank = torch.load(item['wav'])
        fbank = fbank.unsqueeze(0).to(device)  # Add batch dimension
        
        # Get true label
        true_label = item['labels'].split(',')[0]  # Assuming single label per file
        
        # Run inference
        with torch.no_grad():
            logits = model(fbank)
            probabilities = torch.softmax(logits, dim=1)
        
        # Get predicted class
        predicted_class = torch.argmax(probabilities, dim=1).item()
        predicted_label = label_map[str(predicted_class)]
        
        # Check if prediction is correct
        is_correct = predicted_label == true_label
        if is_correct:
            correct_predictions += 1
        total_predictions += 1
        
        # Prepare row for CSV
        row = [item['wav']]  # File path
        row.extend(probabilities[0].cpu().numpy())  # Probabilities for each class
        row.extend([true_label, predicted_label, is_correct])
        output_data.append(row)
    
    # Calculate overall accuracy
    accuracy = correct_predictions / total_predictions
    
    # Save results to CSV
    df = pd.DataFrame(output_data, columns=output_columns)
    df.to_csv(output_csv_path, index=False)
    
    print(f"\nInference complete!")
    print(f"Overall accuracy: {accuracy:.4f}")
    print(f"Results saved to: {output_csv_path}")
    
    return accuracy


# Actual

In [4]:
# Configuration dictionary
config_dict = {
    'num_mel_bins': 128,
    'target_length': 512, # {'audioset':1024, 'esc50':512, 'speechcommands':128}
    'loss' : 'CE',
    'mode':'train', 
    'mean':-6.6268077, # ESC -6.6268077
    'std' : 5.358466, # ESC 5.358466
    'fstride' : 10,
    'tstride' : 10,
    'input_fdim' : 128,
    'input_tdim' : 512,
    'imagenet_pretrain' : True,
    'audioset_pretrain' : True,
    'model_size' : 'base384',
    'epochs' : 25,
    'lr' : 1e-5, # audioset pretrain is false, then one order up
    'weight_decay' : 5e-7,
    'betas' : (0.95, 0.999),
    'lrscheduler_start' : 5,
    'lrscheduler_step' : 1,
    'lrscheduler_decay' : 0.85,
    'print_freq' : 0,
    'exp_dir' : "AST/ast-master/egs/esc50/exp/custom_run"
}

# Paths
train_json = "../egs/esc50/data/datafiles_fbank/esc_train_data_1.json"
eval_json = "../egs/esc50/data/datafiles_fbank/esc_eval_data_1.json"
label_csv = "../egs/esc50/data/esc_class_labels_indices.csv"
train_out_csv = "../egs/esc50/exp/custom_run/esc_train_1.csv"
eval_out_csv = "../egs/esc50/exp/custom_run/esc_eval_1.csv"
os.makedirs(config_dict['exp_dir'], exist_ok=True)

In [5]:
# Dataloaders
train_dataset = FbankDataset(train_json, label_csv=label_csv)
eval_dataset = FbankDataset(eval_json, label_csv=label_csv)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=64, shuffle=True, num_workers=8, pin_memory=True
)
eval_loader = torch.utils.data.DataLoader(
    eval_dataset, batch_size=64, shuffle=False, num_workers=8, pin_memory=True
)

In [None]:
outside = 0
for item in train_loader:
    outside = item
    break

outside

In [6]:
model = ASTModel(
    label_dim=len(train_dataset.index_dict),
    fstride=config_dict['fstride'],
    tstride=config_dict['tstride'],
    input_fdim=config_dict['input_fdim'],
    input_tdim=config_dict['input_tdim'],
    imagenet_pretrain=config_dict['imagenet_pretrain'],
    audioset_pretrain=config_dict['audioset_pretrain'],
    model_size=config_dict['model_size']
)

---------------AST Model Summary---------------
ImageNet pretraining: True, AudioSet pretraining: True
frequncey stride=10, time stride=10
number of patches=600


In [None]:
train_model(model, train_loader, eval_loader, config_dict)

Using device: cuda

Epoch 1/25
train loop start
initialized time measures
unrolling loader


In [None]:
run_inference(model, train_json, label_csv, train_out_csv)
run_inference(model, eval_json, label_csv, eval_out_csv) 