In [None]:
! unzip "/content/drive/MyDrive/ECE3001_Project/stu_dataset.zip" -d "/content/drive/MyDrive/ECE3001_Project/"

In [None]:
! pip install timm

In [2]:
import torch
import argparse
from tqdm import tqdm
import datetime
import time
from timm.utils import accuracy

import librosa
import numpy as np
import librosa.display
import math
import os
from torch.utils.data import Dataset, DataLoader

import torch.nn as nn
from torchvision.models import vgg11, vgg11_bn, vgg13
from torchvision.models import resnet18

In [3]:
# Loading data
class AudioDataset(Dataset):
    def __init__(self, data_dir, max_len, window_length, window_shift, use_stft):
        self.data_dir = data_dir
        self.file_list = os.listdir(data_dir)
        self.max_len = max_len
        self.window_shift = window_shift
        self.window_length = window_length
        self.use_stft = use_stft

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        filename = os.path.join(self.data_dir, self.file_list[idx])
        wav_data = extract_hpss_features_sg(filename, max_length=self.max_len, window_length=self.window_length, window_shift=self.window_shift, use_stft=self.use_stft)
        wav_data = torch.tensor(wav_data)
        wav_data = wav_data.unsqueeze(0)

        # Parse label from filename (filename format: id1_filename.wav)
        label = self.file_list[idx].split('_')[0][2:]  # Extract label from filename
        label = torch.tensor([int(label)])

        return wav_data, label

def extract_hpss_features_sg(wav_path, max_length, window_length=320, window_shift=160, use_stft=True):
    """Extract Harmonic-Percussive Source Separation features.

    Args:
      wav_dir: string, directory of wavs.
      out_dir: string, directory to write out features.
      recompute: bool, if True recompute all features, if False skip existed
                 extracted features.
    """
    cnt = 0
    t1 = time.time()
    (audio, sr) = read_audio(wav_path)

    if audio.shape[0] == 0:
        print("File %s is corrupted!" % wav_path)
        raise ValueError
    else:
        # librosa.display.waveshow(audio, sr=sr)
        # plt.show()

        if use_stft: # compute stft
            spec = np.log(get_spectrogram(audio, window_length, window_shift) + 1e-8)
        else: # not use stft
            frame = 256
            split_num = math.floor(audio.shape[0] / frame)
            new_audio = np.split(audio[:split_num*frame], split_num)
            spec = np.stack(new_audio, axis=0).T

        spec = norm(spec)
        spec = spec.T
        spec = pad_trunc_seq(spec, max_length)

        # cnt += 1
    # print("Thread %d Extracting feature time: %s" % (i, (time.time() - t1)))
    return spec

def read_audio(path, target_fs=None):
    try :
        audio, fs = librosa.load(path, sr=None) # fs:sample rate
    except:
        print(path)

    if audio.ndim > 1:  # 维度>1，这里考虑双声道的情况，维度为2，在第二个维度上取均值，变成单声道
        audio = np.mean(audio, axis=1)
    if target_fs is not None and fs != target_fs:
        audio = librosa.resample(audio, orig_sr=fs, target_sr=target_fs)  # 重采样输入信号，到目标采样频率
        fs = target_fs
    return audio, fs

def pad_trunc_seq(x, max_len):
    """Pad or truncate a sequence data to a fixed length.

    Args:
      x: ndarray, input sequence data.
      max_len: integer, length of sequence to be padded or truncated.

    Returns:
      ndarray, Padded or truncated input sequence data.
    """
    L = len(x)
    shape = x.shape
    if L < max_len:
        pad_shape = (max_len - L,) + shape[1:]
        pad = np.zeros(pad_shape)
        x_new = np.concatenate((x, pad), axis=0)
    else:
        x_new = x[0:max_len]

    return x_new

def get_spectrogram(wav, win_length, win_shift):
    D = librosa.stft(wav, n_fft=win_length, hop_length=win_shift, win_length=win_length, window='hamming')
    spect, phase = librosa.magphase(D)
    return spect


def norm(spec):
    mean = np.reshape(np.mean(spec, axis=1), (spec.shape[0],1))
    std = np.reshape(np.std(spec, axis=1), (spec.shape[0],1))
    spec = np.divide(np.subtract(spec,np.repeat(mean, spec.shape[1], axis=1)), np.repeat(std, spec.shape[1], axis=1))
    return spec


## You can try different models in this part.

In [4]:
# Loading Model
class vgg_base(nn.Module):
    def __init__(self, input_dim):
        super(vgg_base,self).__init__()
        self.vggmodel=vgg11(pretrained=False).features
        self.vggmodel[0]=nn.Conv2d(input_dim,64,kernel_size = 3, padding= 1)

    def forward(self, x):
        x = self.vggmodel(x)
        return x

class vggbn_base(nn.Module):
    def __init__(self, input_dim):
        super(vggbn_base,self).__init__()
        self.vggmodel=vgg11_bn(pretrained=False).features
        self.vggmodel[0]=nn.Conv2d(input_dim,64,kernel_size = 3, padding= 1)

    def forward(self, x):
        x = self.vggmodel(x)
        return x


class resnet_base(nn.Module):
    def __init__(self, input_dim):
        super(resnet_base,self).__init__()
        self.resnetmodel=resnet18(pretrained=False)
        self.resnetmodel.conv1=nn.Conv2d(input_dim,64,kernel_size = 7, stride=2,padding= 3,bias=False)

    def forward(self, x):
        x = self.resnetmodel(x)
        return x

class My_model(nn.Module):
    def __init__(self, input_dim=1, num_classes=93, model_base="vgg"):
        super(My_model,self).__init__()
        if model_base == "vgg":
            self.backbone=vgg_base(input_dim)
        elif model_base == "vggbn":
            self.backbone=vggbn_base(input_dim)
        elif model_base =="resnet":
            self.backbone=resnet_base(input_dim)

        self.model_base=model_base
        self.avgpool = nn.AvgPool1d(kernel_size=200, stride=1)
        self.linear = nn.Linear(in_features=512, out_features=num_classes)
        self.linear2 = nn.Linear(in_features=1000, out_features=num_classes)
        self.activate = nn.Softmax(dim=1)
        self.criteria = nn.CrossEntropyLoss()

    def forward(self, input, label=None):
        result = self.backbone(input)
        if self.model_base in ["vgg","vggbn"]:
            result = result.view(result.size(0), result.size(1), -1)
            result = self.avgpool(result)
            result = result.reshape(result.size(0), -1)
            result = self.linear(result)

        elif self.model_base == "resnet":
            result = self.linear2(result)

        result = self.activate(result)

        _, pred_label = result.max(-1)

        if label is not None: # train
            loss = self.criteria(result, label.view(-1))
            return loss, result, pred_label
        else: # test
            return result, pred_label


In [5]:
# train and valid
def valid(args, model):
    print('Predcting...')
    audio_testset = AudioDataset(args.test_path, args.max_len, args.window_length, args.window_shift, args.use_stft)
    test_data = DataLoader(audio_testset, batch_size=args.batchsize, shuffle=False)

    model.eval()
    acc1_total = 0.
    acc5_total = 0.
    step = 0

    with torch.no_grad():
        for step, (x, label) in enumerate(tqdm(test_data)):
            x = x.to(dtype=torch.float32, device=device)
            label = label.to(device)
            result, pred = model(x)
            acc1, acc5 = accuracy(result, label.view(-1), topk=(1, 5))
            acc1, acc5 = acc1.item()/100, acc5.item()/100
            # loss_total += float(loss.item())
            acc1_total += acc1
            acc5_total += acc5
    print("Valid_acc1:{}, Valid_acc5: {}".format( acc1_total / (step+1), acc5_total / (step+1) ))
    return acc1_total / (step+1), acc5_total / (step+1)

def train(args):
    model = My_model(num_classes=92, model_base=args.model_base)

    # load pretrained model
    # model.load_state_dict(torch.load(args.load_model_path,map_location=device))

    model = model.to(dtype=torch.float32, device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    best_epoch = -1
    best_acc1 = 0
    best_acc5 = 0
    best_model = 0

    # valid(args, model)
    audio_trainset = AudioDataset(args.train_path, args.max_len, args.window_length, args.window_shift, args.use_stft)
    print(f"Length of training set: {len(audio_trainset)}")
    train_data = DataLoader(audio_trainset, batch_size=args.batchsize, shuffle=True, drop_last=True)

    for epoch in range(args.epochs):
        start_time = time.time()
        model.train()
        acc1_total = 0.
        acc5_total = 0.
        loss_total = 0.
        for step, (x, label) in enumerate(tqdm(train_data)):
            x = x.to(dtype=torch.float32, device=device)
            label = label.to(device)
            optimizer.zero_grad()
            loss, result, pred = model(x, label)
            acc1, acc5 = accuracy(result, label.view(-1), topk=(1, 5))
            try:
                acc1, acc5 = acc1.item() / 100, acc5.item() / 100
            except:
                print("testt")
            loss.backward()
            optimizer.step()
            acc1_total += acc1
            acc5_total += acc5
            loss_total += float(loss.item())
            if step % args.print_every == 0 and step != 0:
                print('epoch %d, step %d, step_loss %.4f, step_acc1 %.4f, step_acc5 %.4f' % (epoch, step, loss_total/(step+1), acc1_total/(step+1), acc5_total/(step+1)))

        # save model
        # if epoch % args.save_every == 0 and epoch != 0:
        #     if args.use_stft:
        #         model_name = args.save_model_path+ "_epoch_"+ str(epoch)+'_stft.pt'
        #     else:
        #         model_name = args.save_model_path + "_epoch_"+ str(epoch) + 'no_stft.pt'
        #     torch.save(model.state_dict(), model_name)
        acc1, acc5 = valid(args, model)
        if acc1 > best_acc1:
            best_acc1 = acc1
            best_acc5 = acc5
            best_epoch = epoch
            best_model = model
            # torch.save(model.state_dict(), args.checkpoint_path+'_pretrain.pt')
        print('best acc1 is: {}, acc5 is: {}, in epoch {}'.format(best_acc1, best_acc5, best_epoch))

        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('Training time {}'.format(total_time_str))


## Model Training

In [22]:
# 修改工作路径
%cd /content/drive/MyDrive/ECE3001_Project/

/content/drive/MyDrive/ECE3001_Project


In [23]:
%pwd

'/content/drive/MyDrive/ECE3001_Project'

In [None]:
device = 'cuda:0' #'cpu'
device = torch.device(device)
# def parse_config():
parser = argparse.ArgumentParser()
# path parameters
parser.add_argument('--train_path', type=str, default='./stu_dataset/train')
parser.add_argument('--test_path', type=str, default='./stu_dataset/test')
parser.add_argument('--save_model_path', type=str, default='./model/vggbn/vggbn11')
# parser.add_argument('--load_model_path', type=str, default='./model/vggbn/vggbn11_epoch_19_stft.pt')

# training parameters
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--print_every', type=int, default=10)
parser.add_argument('--save_every', type=int, default=1)
parser.add_argument('--batchsize', type=int, default=64)
parser.add_argument('--lr', type=float, default=1e-4,help="learning rate")
parser.add_argument('--model_base', type=str, default="vggbn",help="model base: vgg, resnet, vggbn")

# data processing parameters
parser.add_argument("--max_len", default=800, type=int, help="max_len")
parser.add_argument("--window_shift", default=256, type=int, help="hop shift")
parser.add_argument("--window_length", default=510, type=int, help="window length") # 256
parser.add_argument("--use_stft", default=True, type=bool, help="whether to use stft")
# return parser.parse_args()
# parser = argparse.ArgumentParser()
# args = parser.parse_args(
#     ["--train_path","../stu_dataset/train"],
#     ["--test_path","../stu_dataset/test"],
#     ["--save_model_path","./model/vggbn/vggbn11"],
#     ["--epochs", 20],
#     ["--print_every", 10],
#     ["--save_every", 10],
#     ["--batchsize", 32],
#     ["--lr", 1e-4],
#     ["--model_base", "vggbn"],
#     ["--max_len", 800],
#     ["--window_shift", 256],
#     ["window_length", 510],
#     ["--use_stft", True]
#                          )
args = parser.parse_args(args=[])
train(args)

Length of training set: 4769


  0%|          | 0/74 [00:00<?, ?it/s]