## Import Libraries

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
import torchaudio
from tqdm import tqdm

In [2]:
# HyperParameters
dataset_path = "/content/drive/MyDrive/Dataset/Audio_Dataset_Chunk"
batch_size = 32
learning_rate = 0.001
epochs = 60

## Model

In [3]:
# Model
class Model(nn.Module):
    def __init__(self, n_input=1, n_output=8, stride=4, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = F.softmax(x, dim=1)
        return x

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model().to(device)

## Load Dataset

In [5]:
# DATASET
class AudioDataset(Dataset):
  def __init__(self,root):
    self.dir_path = root
    self.classes = os.listdir(self.dir_path)
    
    self.data_paths=[]
    self.labels=[]
    for names in os.listdir(self.dir_path):
      folder = os.listdir(os.path.join(self.dir_path,names))
      for file in folder:
        self.data_paths.append(os.path.join(self.dir_path,names,file))
        self.labels.append(self.classes.index(names))   

  def __len__(self):
    return len(self.labels)

  def __getitem__(self,index):
    data_path = self.data_paths[index]
    label = self.labels[index]
    signal, sample_rate = torchaudio.load(data_path)
    signal = torch.mean(signal, dim=0, keepdim=True)
    new_sample_rate = 8000
    transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=new_sample_rate)
    signal = transform(signal)

    return signal, label

In [6]:
dataset = AudioDataset(dataset_path)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [7]:
def accuracy(preds,labels):
    _,pred_max = torch.max(preds,1)
    acc = torch.sum(pred_max==labels) / len(preds)
    return acc

In [8]:
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
loss_function = nn.CrossEntropyLoss()

## Training

In [9]:
# TRAIN
for epoch in range(epochs):
  train_acc = 0.0
  train_loss = 0.0
  model.train()
  for audios, labels in tqdm(train):
    labels = labels.to(device)
    audios = audios.to(device)
    labels_one_hot = F.one_hot(labels,num_classes=8).type(torch.FloatTensor).to(device)
    optimizer.zero_grad()
    pred = model(audios)
    loss = loss_function(pred,labels_one_hot)
    loss.backward()
    optimizer.step()

    train_loss += loss
    train_acc += accuracy(pred,labels)

  total_train_loss = train_loss / len(train)
  total_train_acc = train_acc / len(train)
  print(f"Epochs: {epoch+1}, Accuracy: {total_train_acc}, Loss: {total_train_loss}")

100%|██████████| 22/22 [03:13<00:00,  8.80s/it]


Epochs: 1, Accuracy: 0.4488636255264282, Loss: 1.9425605535507202


100%|██████████| 22/22 [00:04<00:00,  4.40it/s]


Epochs: 2, Accuracy: 0.6051136255264282, Loss: 1.8211276531219482


100%|██████████| 22/22 [00:05<00:00,  4.27it/s]


Epochs: 3, Accuracy: 0.7272727489471436, Loss: 1.7183789014816284


100%|██████████| 22/22 [00:05<00:00,  3.71it/s]


Epochs: 4, Accuracy: 0.8338068127632141, Loss: 1.6012557744979858


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 5, Accuracy: 0.859375, Loss: 1.5200703144073486


100%|██████████| 22/22 [00:05<00:00,  4.30it/s]


Epochs: 6, Accuracy: 0.8693181872367859, Loss: 1.483066439628601


100%|██████████| 22/22 [00:05<00:00,  4.18it/s]


Epochs: 7, Accuracy: 0.9147727489471436, Loss: 1.4438719749450684


100%|██████████| 22/22 [00:04<00:00,  4.50it/s]


Epochs: 8, Accuracy: 0.9431818127632141, Loss: 1.4030976295471191


100%|██████████| 22/22 [00:04<00:00,  4.40it/s]


Epochs: 9, Accuracy: 0.9431818127632141, Loss: 1.3891222476959229


100%|██████████| 22/22 [00:04<00:00,  4.46it/s]


Epochs: 10, Accuracy: 0.9517045617103577, Loss: 1.3771122694015503


100%|██████████| 22/22 [00:04<00:00,  4.41it/s]


Epochs: 11, Accuracy: 0.953125, Loss: 1.3618944883346558


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 12, Accuracy: 0.96875, Loss: 1.352372646331787


100%|██████████| 22/22 [00:04<00:00,  4.45it/s]


Epochs: 13, Accuracy: 0.9602272510528564, Loss: 1.3575395345687866


100%|██████████| 22/22 [00:04<00:00,  4.48it/s]


Epochs: 14, Accuracy: 0.9616477489471436, Loss: 1.3451992273330688


100%|██████████| 22/22 [00:04<00:00,  4.45it/s]


Epochs: 15, Accuracy: 0.9644886255264282, Loss: 1.3391841650009155


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 16, Accuracy: 0.9630681872367859, Loss: 1.3450500965118408


100%|██████████| 22/22 [00:04<00:00,  4.46it/s]


Epochs: 17, Accuracy: 0.9701704382896423, Loss: 1.3387482166290283


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 18, Accuracy: 0.9744318127632141, Loss: 1.3305646181106567


100%|██████████| 22/22 [00:05<00:00,  4.30it/s]


Epochs: 19, Accuracy: 0.9786931872367859, Loss: 1.3220452070236206


100%|██████████| 22/22 [00:04<00:00,  4.43it/s]


Epochs: 20, Accuracy: 0.9744318127632141, Loss: 1.321245551109314


100%|██████████| 22/22 [00:04<00:00,  4.49it/s]


Epochs: 21, Accuracy: 0.9772727489471436, Loss: 1.3201714754104614


100%|██████████| 22/22 [00:04<00:00,  4.42it/s]


Epochs: 22, Accuracy: 0.9801136255264282, Loss: 1.3174723386764526


100%|██████████| 22/22 [00:05<00:00,  4.36it/s]


Epochs: 23, Accuracy: 0.9786931872367859, Loss: 1.3144663572311401


100%|██████████| 22/22 [00:04<00:00,  4.40it/s]


Epochs: 24, Accuracy: 0.9772727489471436, Loss: 1.315994143486023


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 25, Accuracy: 0.9829545617103577, Loss: 1.3095027208328247


100%|██████████| 22/22 [00:04<00:00,  4.44it/s]


Epochs: 26, Accuracy: 0.9872159361839294, Loss: 1.3062583208084106


100%|██████████| 22/22 [00:05<00:00,  4.37it/s]


Epochs: 27, Accuracy: 0.9801136255264282, Loss: 1.3088173866271973


100%|██████████| 22/22 [00:04<00:00,  4.40it/s]


Epochs: 28, Accuracy: 0.9758522510528564, Loss: 1.3136996030807495


100%|██████████| 22/22 [00:05<00:00,  4.29it/s]


Epochs: 29, Accuracy: 0.9857954382896423, Loss: 1.306183934211731


100%|██████████| 22/22 [00:05<00:00,  3.80it/s]


Epochs: 30, Accuracy: 0.9815340638160706, Loss: 1.3101680278778076


100%|██████████| 22/22 [00:05<00:00,  4.21it/s]


Epochs: 31, Accuracy: 0.9872159361839294, Loss: 1.2999986410140991


100%|██████████| 22/22 [00:04<00:00,  4.42it/s]


Epochs: 32, Accuracy: 0.9872159361839294, Loss: 1.301103949546814


100%|██████████| 22/22 [00:04<00:00,  4.45it/s]


Epochs: 33, Accuracy: 0.9801136255264282, Loss: 1.3135279417037964


100%|██████████| 22/22 [00:04<00:00,  4.43it/s]


Epochs: 34, Accuracy: 0.9829545617103577, Loss: 1.3123704195022583


100%|██████████| 22/22 [00:05<00:00,  4.39it/s]


Epochs: 35, Accuracy: 0.984375, Loss: 1.310907244682312


100%|██████████| 22/22 [00:04<00:00,  4.43it/s]


Epochs: 36, Accuracy: 0.9900568127632141, Loss: 1.2969741821289062


100%|██████████| 22/22 [00:05<00:00,  4.34it/s]


Epochs: 37, Accuracy: 0.9914772510528564, Loss: 1.2930245399475098


100%|██████████| 22/22 [00:05<00:00,  4.37it/s]


Epochs: 38, Accuracy: 0.9943181872367859, Loss: 1.2922413349151611


100%|██████████| 22/22 [00:05<00:00,  4.31it/s]


Epochs: 39, Accuracy: 0.9914772510528564, Loss: 1.296143651008606


100%|██████████| 22/22 [00:05<00:00,  4.36it/s]


Epochs: 40, Accuracy: 0.9928977489471436, Loss: 1.28983473777771


100%|██████████| 22/22 [00:05<00:00,  4.40it/s]


Epochs: 41, Accuracy: 0.9900568127632141, Loss: 1.2906262874603271


100%|██████████| 22/22 [00:05<00:00,  4.34it/s]


Epochs: 42, Accuracy: 0.9857954382896423, Loss: 1.2953180074691772


100%|██████████| 22/22 [00:04<00:00,  4.42it/s]


Epochs: 43, Accuracy: 0.9872159361839294, Loss: 1.2965141534805298


100%|██████████| 22/22 [00:05<00:00,  4.33it/s]


Epochs: 44, Accuracy: 0.9900568127632141, Loss: 1.2934372425079346


100%|██████████| 22/22 [00:05<00:00,  4.37it/s]


Epochs: 45, Accuracy: 0.984375, Loss: 1.3044135570526123


100%|██████████| 22/22 [00:04<00:00,  4.41it/s]


Epochs: 46, Accuracy: 0.9914772510528564, Loss: 1.29478120803833


100%|██████████| 22/22 [00:05<00:00,  4.34it/s]


Epochs: 47, Accuracy: 0.9914772510528564, Loss: 1.2913373708724976


100%|██████████| 22/22 [00:05<00:00,  4.32it/s]


Epochs: 48, Accuracy: 0.9914772510528564, Loss: 1.2876139879226685


100%|██████████| 22/22 [00:05<00:00,  4.34it/s]


Epochs: 49, Accuracy: 0.9943181872367859, Loss: 1.2868949174880981


100%|██████████| 22/22 [00:05<00:00,  4.32it/s]


Epochs: 50, Accuracy: 0.9943181872367859, Loss: 1.2878235578536987


100%|██████████| 22/22 [00:05<00:00,  4.39it/s]


Epochs: 51, Accuracy: 0.9857954382896423, Loss: 1.295627474784851


100%|██████████| 22/22 [00:04<00:00,  4.43it/s]


Epochs: 52, Accuracy: 0.9943181872367859, Loss: 1.2920591831207275


100%|██████████| 22/22 [00:04<00:00,  4.41it/s]


Epochs: 53, Accuracy: 0.9886363744735718, Loss: 1.294432282447815


100%|██████████| 22/22 [00:05<00:00,  4.39it/s]


Epochs: 54, Accuracy: 0.9886363744735718, Loss: 1.2898032665252686


100%|██████████| 22/22 [00:04<00:00,  4.52it/s]


Epochs: 55, Accuracy: 0.9928977489471436, Loss: 1.286945104598999


100%|██████████| 22/22 [00:05<00:00,  4.20it/s]


Epochs: 56, Accuracy: 0.9900568127632141, Loss: 1.288092851638794


100%|██████████| 22/22 [00:07<00:00,  2.96it/s]


Epochs: 57, Accuracy: 0.9928977489471436, Loss: 1.2860203981399536


100%|██████████| 22/22 [00:06<00:00,  3.54it/s]


Epochs: 58, Accuracy: 0.9914772510528564, Loss: 1.2864336967468262


100%|██████████| 22/22 [00:05<00:00,  4.31it/s]


Epochs: 59, Accuracy: 0.9943181872367859, Loss: 1.2838088274002075


100%|██████████| 22/22 [00:05<00:00,  4.30it/s]

Epochs: 60, Accuracy: 0.9943181872367859, Loss: 1.2843093872070312





## Evaluation

In [10]:
model.eval()

test_acc = 0.0
test_loss = 0.0
for audios, labels in tqdm(test):
    audios = audios.to(device)
    labels = labels.to(device)
    labels_one_hot = F.one_hot(labels,num_classes=8).type(torch.FloatTensor).to(device)

    pred = model(audios)
    loss = loss_function(pred,labels_one_hot)
    test_loss += loss
    test_acc += accuracy(pred, labels)

total_test_loss = test_loss / len(test)
total_test_acc = test_acc / len(test)
print(f"Accuracy: {total_test_acc}, Loss: {total_test_loss}")

100%|██████████| 6/6 [00:28<00:00,  4.71s/it]

Accuracy: 0.9635416865348816, Loss: 1.3298324346542358





In [None]:
torch.save(model.state_dict(), "weights.pth")