In [1]:
import pandas as pd
from pathlib import Path

Catholic_file = 'data/data2.csv'
df = pd.read_csv(Catholic_file)
df

Unnamed: 0,filename,category,class
0,0001-2.wav,healthy,0
1,0002-1.wav,healthy,0
2,0002-2.wav,healthy,0
3,0002-3.wav,healthy,0
4,0002-4.wav,healthy,0
...,...,...,...
258,0614-2.wav,wheezing,1
259,0614-3.wav,wheezing,1
260,0614-4.wav,wheezing,1
261,0615-1.wav,wheezing,1


In [2]:
df['relative_path'] = '/' + df['filename'].astype(str)
df = df[['relative_path', 'class']]
df.head()

Unnamed: 0,relative_path,class
0,/0001-2.wav,0
1,/0002-1.wav,0
2,/0002-2.wav,0
3,/0002-3.wav,0
4,/0002-4.wav,0


In [3]:
import math, random
import torch
import torchaudio
from torchaudio import transforms
from IPython.display import Audio

class Breath_sound_Util():
  
  def open(audio_file):
    sig, sr = torchaudio.load(audio_file)
    return (sig, sr)

  def resample(aud, newsr):
    sig, sr = aud

    if (sr == newsr):
      return aud

    num_channels = sig.shape[0]
    resig = torchaudio.transforms.Resample(sr, newsr)(sig[:1,:])
    if (num_channels > 1):
      retwo = torchaudio.transforms.Resample(sr, newsr)(sig[1:,:])
      resig = torch.cat([resig, retwo])

    return ((resig, newsr))

  def pad(aud, max_ms):
    sig, sr = aud
    num_rows, sig_len = sig.shape
    max_len = sr//1000 * max_ms

    if (sig_len > max_len):
      sig = sig[:,:max_len]

    elif (sig_len < max_len):
      pad_begin_len = random.randint(0, max_len - sig_len)
      pad_end_len = max_len - sig_len - pad_begin_len

      pad_begin = torch.zeros((num_rows, pad_begin_len))
      pad_end = torch.zeros((num_rows, pad_end_len))
        #제로패딩

      sig = torch.cat((pad_begin, sig, pad_end), 1)
      
    return (sig, sr)

  def time_shift(aud, shift_limit):
    sig,sr = aud
    _, sig_len = sig.shape
    shift_amt = int(random.random() * shift_limit * sig_len)
    return (sig.roll(shift_amt), sr)

  def spectro_gram(aud, n_mels=64, n_fft=1024, hop_len=None):
    sig,sr = aud
    top_db = 80

    spec = transforms.MelSpectrogram(sr, n_fft=n_fft, hop_length=hop_len, n_mels=n_mels)(sig)

    spec = transforms.AmplitudeToDB(top_db=top_db)(spec)
    return (spec)

  def spectro_augment(spec, max_mask_pct=0.1, n_freq_masks=1, n_time_masks=1):
    _, n_mels, n_steps = spec.shape
    mask_value = spec.mean()
    aug_spec = spec

    freq_mask_param = max_mask_pct * n_mels
    for _ in range(n_freq_masks):
      aug_spec = transforms.FrequencyMasking(freq_mask_param)(aug_spec, mask_value)

    time_mask_param = max_mask_pct * n_steps
    for _ in range(n_time_masks):
      aug_spec = transforms.TimeMasking(time_mask_param)(aug_spec, mask_value)

    return aug_spec

In [4]:
data_path = 'data/all_data'

In [5]:
from torch.utils.data import DataLoader, Dataset, random_split
import torchaudio

class breathDS(Dataset):
    
  def __init__(self, df, data_path):
    self.df = df
    self.data_path = str(data_path)
    self.duration = 4000
    self.sr = 48000
    self.shift_pct = 0.4
            
  def __len__(self):
    return len(self.df)    
    
  def __getitem__(self, idx):
    audio_file = self.data_path + self.df.loc[idx, 'relative_path']
    class_id = self.df.loc[idx, 'class']
    aud = Breath_sound_Util.open(audio_file)
    reaud = Breath_sound_Util.resample(aud, self.sr)
    dur_aud = Breath_sound_Util.pad(reaud, self.duration)
    shift_aud = Breath_sound_Util.time_shift(dur_aud, self.shift_pct)
    sgram = Breath_sound_Util.spectro_gram(shift_aud, n_mels=64, n_fft=1024, hop_len=None)
    aug_sgram = Breath_sound_Util.spectro_augment(sgram, max_mask_pct=0.1, n_freq_masks=2, n_time_masks=2)

    return aug_sgram, class_id

In [6]:
from torch.utils.data import random_split

brds = breathDS(df, data_path)

In [7]:
num_items = len(brds)
num_train = round(num_items * 0.8)
num_val = num_items - num_train
train_ds, val_ds = random_split(brds, [num_train, num_val])

# Create training and validation data loaders
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=16, shuffle=True)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=16, shuffle=True)

In [8]:
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

# --------------------------------------------------------------
# 호흡음의 Healthy, Wheezing을 판단하는 Binary Classification Model
# --------------------------------------------------------------

class WhoWheezing(nn.Module):

    def __init__(self):
        super().__init__()
        conv_layers = []
        self.conv1 = nn.Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
        self.relu1 = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(8)
        init.kaiming_normal_(self.conv1.weight, a=0.1)
        self.conv1.bias.data.zero_()
        conv_layers += [self.conv1, self.relu1, self.bn1]

        self.conv2 = nn.Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu2 = nn.ReLU()
        self.bn2 = nn.BatchNorm2d(16)
        init.kaiming_normal_(self.conv2.weight, a=0.1)
        self.conv2.bias.data.zero_()
        conv_layers += [self.conv2, self.relu2, self.bn2]

        self.conv3 = nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu3 = nn.ReLU()
        self.bn3 = nn.BatchNorm2d(32)
        init.kaiming_normal_(self.conv3.weight, a=0.1)
        self.conv3.bias.data.zero_()
        conv_layers += [self.conv3, self.relu3, self.bn3]

        self.conv4 = nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        self.relu4 = nn.ReLU()
        self.bn4 = nn.BatchNorm2d(64)
        init.kaiming_normal_(self.conv4.weight, a=0.1)
        self.conv4.bias.data.zero_()
        conv_layers += [self.conv4, self.relu4, self.bn4]

        self.ap = nn.AdaptiveAvgPool2d(output_size=1)
        self.lin = nn.Linear(in_features=64, out_features=2)
        self.sigmoid = nn.Sigmoid()

        self.conv = nn.Sequential(*conv_layers)
 
    def forward(self, x):
        x = self.conv(x)
        x = self.ap(x)
        x = x.view(x.shape[0], -1)
        x = self.lin(x)
        x = self.sigmoid(x)
        return x

Model1 = WhoWheezing()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Model1 = Model1.to(device)
next(Model1.parameters()).device

device(type='cuda', index=0)

In [9]:
def training(model, train_dl, num_epochs):
  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
  scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001,
                                                steps_per_epoch=int(len(train_dl)),
                                                epochs=num_epochs,
                                                anneal_strategy='linear')

  for epoch in range(num_epochs):
    running_loss = 0.0
    correct_prediction = 0
    total_prediction = 0

    for i, data in enumerate(train_dl):
        inputs, labels = data[0].to(device), data[1].to(device)
        inputs_m, inputs_s = inputs.mean(), inputs.std()
        inputs = (inputs - inputs_m) / inputs_s
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        running_loss += loss.item()
        _, prediction = torch.max(outputs,1)
        correct_prediction += (prediction == labels).sum().item()
        total_prediction += prediction.shape[0]

    num_batches = len(train_dl)
    avg_loss = running_loss / num_batches
    acc = correct_prediction/total_prediction
    print(f'Epoch: {epoch}, Loss: {avg_loss:.4f}, Accuracy: {acc:.4f}')

  print('Finished Training')
  
num_epochs=100
training(Model1, train_dl, num_epochs)

Epoch: 0, Loss: 0.7118, Accuracy: 0.3238
Epoch: 1, Loss: 0.6876, Accuracy: 0.5095
Epoch: 2, Loss: 0.6582, Accuracy: 0.7619
Epoch: 3, Loss: 0.6267, Accuracy: 0.7714
Epoch: 4, Loss: 0.6012, Accuracy: 0.8095
Epoch: 5, Loss: 0.5805, Accuracy: 0.8000
Epoch: 6, Loss: 0.5408, Accuracy: 0.8095
Epoch: 7, Loss: 0.5225, Accuracy: 0.8381
Epoch: 8, Loss: 0.5271, Accuracy: 0.8333
Epoch: 9, Loss: 0.5176, Accuracy: 0.8381
Epoch: 10, Loss: 0.5068, Accuracy: 0.8429
Epoch: 11, Loss: 0.4863, Accuracy: 0.8333
Epoch: 12, Loss: 0.4665, Accuracy: 0.8667
Epoch: 13, Loss: 0.4926, Accuracy: 0.8524
Epoch: 14, Loss: 0.4892, Accuracy: 0.8619
Epoch: 15, Loss: 0.4785, Accuracy: 0.8619
Epoch: 16, Loss: 0.4778, Accuracy: 0.8667
Epoch: 17, Loss: 0.4776, Accuracy: 0.8667
Epoch: 18, Loss: 0.4715, Accuracy: 0.8762
Epoch: 19, Loss: 0.4552, Accuracy: 0.8571
Epoch: 20, Loss: 0.4673, Accuracy: 0.8857
Epoch: 21, Loss: 0.4438, Accuracy: 0.8952
Epoch: 22, Loss: 0.4662, Accuracy: 0.8857
Epoch: 23, Loss: 0.4537, Accuracy: 0.8905
Ep

In [11]:
import torchsummary
torchsummary.summary(Model1, (1,64,344))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 8, 32, 172]             208
            Conv2d-2           [-1, 8, 32, 172]             208
              ReLU-3           [-1, 8, 32, 172]               0
              ReLU-4           [-1, 8, 32, 172]               0
       BatchNorm2d-5           [-1, 8, 32, 172]              16
       BatchNorm2d-6           [-1, 8, 32, 172]              16
            Conv2d-7           [-1, 16, 16, 86]           1,168
            Conv2d-8           [-1, 16, 16, 86]           1,168
              ReLU-9           [-1, 16, 16, 86]               0
             ReLU-10           [-1, 16, 16, 86]               0
      BatchNorm2d-11           [-1, 16, 16, 86]              32
      BatchNorm2d-12           [-1, 16, 16, 86]              32
           Conv2d-13            [-1, 32, 8, 43]           4,640
           Conv2d-14            [-1, 32

In [13]:
import numpy as np
for epoch in range(2):
    for i, data in enumerate(train_dl, 0):
        # get the inputs
        inputs, labels = data
        inputs = np.array(inputs)
        print(inputs.shape)
        # Run your training process
        print(f'Epoch: {i} | Inputs {inputs} | Labels {labels}')

(16, 1, 64, 376)
Epoch: 0 | Inputs [[[[-22.842403   -20.391092   -21.075161   ... -28.087852
    -21.38609    -20.669758  ]
   [-27.302248   -23.30309    -24.678      ... -34.579636
    -25.50793    -26.908415  ]
   [-67.96166    -67.96166    -67.96166    ... -67.96166
    -67.96166    -67.96166   ]
   ...
   [-88.67812    -88.67812    -88.67812    ... -88.67812
    -88.67812    -88.67812   ]
   [-88.67812    -88.67812    -88.67812    ... -88.67812
    -88.67812    -88.67812   ]
   [-88.67812    -88.67812    -88.67812    ... -88.67812
    -88.67812    -88.67812   ]]]


 [[[ -4.187351    -5.0021696  -17.785624   ...   4.288905
      5.3620887    6.175113  ]
   [ -2.03666     -5.7268186  -14.329409   ...  -0.39140773
      0.9416094   -0.4608304 ]
   [ -4.337493   -10.482044   -18.242296   ...  -8.574345
     -6.327701    -8.739846  ]
   ...
   [-71.1058     -71.1058     -71.1058     ... -71.1058
    -71.1058     -71.1058    ]
   [-71.1058     -71.1058     -71.1058     ... -71.1058
    -

(16, 1, 64, 376)
Epoch: 3 | Inputs [[[[ 4.52024555e+00 -1.02915201e+01 -4.21748543e+01 ... -3.75432014e+00
     9.35477138e-01  3.46088529e+00]
   [ 7.57430744e+00 -5.63741255e+00 -4.21748543e+01 ... -4.24294853e+00
     3.21968174e+00  4.48832083e+00]
   [ 1.23587074e+01  5.84362268e-01 -4.21748543e+01 ...  6.55700636e+00
     1.34221077e+01  2.76579171e-01]
   ...
   [-4.70225220e+01 -4.79389381e+01 -4.21748543e+01 ... -4.58327866e+01
    -4.57080383e+01 -4.47884064e+01]
   [-4.66509094e+01 -4.55646133e+01 -4.21748543e+01 ... -4.65633202e+01
    -4.44066620e+01 -4.49335251e+01]
   [-4.90806999e+01 -4.66467667e+01 -4.21748543e+01 ... -4.70672035e+01
    -4.60663834e+01 -4.75036774e+01]]]


 [[[-3.75415063e+00 -3.30915189e+00 -1.25610752e+01 ... -3.91175127e+00
    -4.65317583e+00  2.64944172e+00]
   [-8.62617940e-02  6.14210701e+00 -3.78817821e+00 ... -2.10072398e+00
    -6.64621305e+00  4.74248886e+00]
   [ 1.19148445e+01  6.90902472e+00 -7.64720058e+00 ... -3.29317856e+00
    -6.279

(16, 1, 64, 376)
Epoch: 7 | Inputs [[[[ 17.241188     5.113799     5.445076   ...  -2.5073507
     11.434198   -24.617493  ]
   [ 16.014076     6.8708973    8.751032   ...   2.9325542
      9.975414     5.6283855 ]
   [ 18.253445     7.4623017    8.119885   ...   1.370741
      5.1806793   16.569769  ]
   ...
   [-44.185852   -44.185852   -44.185852   ... -44.185852
    -44.185852   -44.185852  ]
   [-44.185852   -44.185852   -44.185852   ... -44.185852
    -44.185852   -44.116142  ]
   [-44.185852   -44.185852   -44.185852   ... -44.185852
    -44.185852   -44.185852  ]]]


 [[[  0.8589483   -4.2665095  -11.025791   ... -10.5673065
    -16.22908     -2.3276522 ]
   [ -5.085207    -7.708846   -15.653107   ... -14.677743
    -13.11081     -5.05615   ]
   [-15.616064   -14.301758   -23.019732   ... -21.712654
    -16.680016   -10.7437935 ]
   ...
   [-74.41989    -74.41989    -74.41989    ... -74.41989
    -74.41989    -74.41989   ]
   [-74.41989    -74.41989    -74.41989    ... -74.4198

(16, 1, 64, 376)
Epoch: 10 | Inputs [[[[ -4.8854294    4.318926     5.4134507  ... -11.641736
    -10.67272    -16.461193  ]
   [ -5.1977754    0.33280516   0.13900301 ... -15.603782
    -16.87581    -11.63382   ]
   [ -9.647349    -6.9906263   -9.0699835  ... -22.436865
    -27.39684    -14.242403  ]
   ...
   [-72.59007    -72.59007    -72.59007    ... -72.59007
    -72.59007    -72.59007   ]
   [-72.59007    -72.59007    -72.59007    ... -72.59007
    -72.59007    -72.59007   ]
   [-72.59007    -72.59007    -72.59007    ... -72.59007
    -72.59007    -72.59007   ]]]


 [[[-65.167816   -65.167816   -65.167816   ... -65.167816
    -65.167816   -65.167816  ]
   [-65.167816   -65.167816   -65.167816   ... -65.167816
    -65.167816   -65.167816  ]
   [-65.167816   -65.167816   -65.167816   ... -65.167816
    -65.167816   -65.167816  ]
   ...
   [-65.167816   -65.167816   -65.167816   ... -65.167816
    -65.167816   -65.167816  ]
   [-65.167816   -65.167816   -65.167816   ... -65.167816
 

(16, 1, 64, 376)
Epoch: 0 | Inputs [[[[ 4.65137672e+00  4.55624151e+00  7.04891253e+00 ...  1.03540254e+00
    -2.12575316e+00 -6.09030771e+00]
   [-3.59486580e-01  4.62366867e+00  7.97209835e+00 ...  2.83481693e+00
     1.57897687e+00  1.40355647e+00]
   [ 1.50997519e+00  1.45344448e+00  5.66577625e+00 ... -2.10775197e-01
    -1.25138497e+00  6.35372686e+00]
   ...
   [-4.32866440e+01 -4.32866440e+01 -4.32866440e+01 ... -4.32866440e+01
    -4.32866440e+01 -4.32866440e+01]
   [-4.32866440e+01 -4.32866440e+01 -4.32866440e+01 ... -4.32866440e+01
    -4.32866440e+01 -4.32866440e+01]
   [-4.32866440e+01 -4.32866440e+01 -4.32866440e+01 ... -4.32866440e+01
    -4.32866440e+01 -4.32866440e+01]]]


 [[[-1.02376413e+01 -5.54471245e+01 -4.06557846e+01 ... -3.38488350e+01
    -3.78283958e+01 -9.75191212e+00]
   [-9.51476479e+00 -5.63541946e+01 -4.32769165e+01 ... -4.03948975e+01
    -4.26258659e+01 -9.18759155e+00]
   [-8.89227009e+00 -4.74700356e+01 -4.85845909e+01 ... -4.28821259e+01
    -4.737

(16, 1, 64, 376)
Epoch: 3 | Inputs [[[[ -8.721686    -2.1792965   -9.697014   ...  -9.962258
    -11.427444    -2.826513  ]
   [ -4.0413218    0.29051268  -8.1268635  ...  -4.837557
     -1.2483759   -1.2384652 ]
   [ -8.00616     -2.9141722   -4.4488435  ...  -2.681712
      0.16449401   4.7532196 ]
   ...
   [-47.995483   -47.10533    -46.94433    ... -51.158005
    -47.840496   -50.064255  ]
   [-47.332478   -46.721794   -47.78675    ... -47.88826
    -46.69391    -49.637688  ]
   [-47.343544   -46.845913   -47.384094   ... -48.645397
    -49.529034   -48.74746   ]]]


 [[[ -0.29238915  -7.4155045  -11.629497   ...   2.6014261
     -3.6294124  -10.574734  ]
   [ -4.9017086   -8.634675    -1.7528625  ...   5.50476
      1.3984599   -1.8958806 ]
   [-13.127709    -4.288104    -3.147245   ...   4.9548664
      0.3539034   -2.244135  ]
   ...
   [-47.543358   -47.543358   -47.543358   ... -47.543358
    -47.543358   -47.543358  ]
   [-47.543358   -47.543358   -47.543358   ... -47.543358

(16, 1, 64, 376)
Epoch: 6 | Inputs [[[[-15.603048   -11.184987    -5.2167015  ... -15.952293
    -15.818289   -24.669607  ]
   [-21.547646   -16.674326   -11.533085   ... -22.179556
    -20.48129    -20.655365  ]
   [-23.498362   -24.593992   -24.695107   ... -33.839764
    -26.777895   -19.727314  ]
   ...
   [-61.349483   -61.349483   -61.349483   ... -61.349483
    -61.349483   -61.349483  ]
   [-61.349483   -61.349483   -61.349483   ... -61.349483
    -61.349483   -61.349483  ]
   [-75.09069    -75.09069    -75.09069    ... -75.09069
    -75.09069    -75.09069   ]]]


 [[[-21.197086   -17.62812    -22.035202   ... -19.97218
    -10.743894   -28.780767  ]
   [-20.997551   -16.827013   -19.473408   ... -18.420738
     -6.7079835  -22.91037   ]
   [-17.11118    -18.467598   -11.800835   ... -23.074438
     -5.672081   -26.75328   ]
   ...
   [-46.867027   -46.781178   -46.867027   ... -46.867027
    -46.867027   -46.867027  ]
   [-46.646824   -46.867027   -46.42015    ... -46.617313
 

(16, 1, 64, 376)
Epoch: 9 | Inputs [[[[-7.70440626e+00 -1.16059818e+01 -7.40668821e+00 ... -1.15323467e+01
    -1.27601585e+01 -8.24060822e+00]
   [-1.13214121e+01 -1.67288666e+01 -1.25077209e+01 ... -1.36450968e+01
    -1.70421066e+01 -1.15629234e+01]
   [-1.82498055e+01 -2.26231422e+01 -2.14645596e+01 ... -1.89817333e+01
    -2.35183945e+01 -1.82273903e+01]
   ...
   [-7.52490311e+01 -7.52490311e+01 -7.52490311e+01 ... -7.52490311e+01
    -7.52490311e+01 -7.52490311e+01]
   [-7.52490311e+01 -7.52490311e+01 -7.52490311e+01 ... -7.52490311e+01
    -7.52490311e+01 -7.52490311e+01]
   [-7.52490311e+01 -7.52490311e+01 -7.52490311e+01 ... -7.52490311e+01
    -7.52490311e+01 -7.52490311e+01]]]


 [[[-1.25210943e+01  9.68135357e-01 -3.98150730e+00 ...  4.79858589e+00
    -3.21765274e-01 -7.07931280e+00]
   [-2.19437814e+00  1.38263464e+00 -2.22992969e+00 ...  7.38394547e+00
     1.18955266e+00 -5.31850338e+00]
   [-2.89238214e+00 -1.47531557e+00 -2.27314281e+00 ...  5.69078827e+00
    -1.683

(16, 1, 64, 376)
Epoch: 12 | Inputs [[[[-14.179516   -22.149204   -33.97691    ... -11.389906
    -13.4621525  -28.415981  ]
   [-20.793459   -27.426874   -27.610672   ... -14.371863
    -16.742645   -19.794077  ]
   [-24.727398   -28.589901   -28.484596   ... -20.740883
    -22.299515   -21.074158  ]
   ...
   [-75.3171     -75.3171     -75.3171     ... -75.3171
    -75.3171     -75.3171    ]
   [-75.3171     -75.3171     -75.3171     ... -75.3171
    -75.3171     -75.3171    ]
   [-75.3171     -75.3171     -75.3171     ... -75.3171
    -75.3171     -75.3171    ]]]


 [[[-40.47453    -36.385464   -42.59546    ... -42.05215
    -36.510353   -31.957361  ]
   [-44.24387    -36.028427   -39.023003   ... -37.50746
    -33.61444    -32.741085  ]
   [-44.34382    -35.91642    -36.938354   ... -36.88084
    -34.537754   -35.63559   ]
   ...
   [-72.59413    -72.59413    -72.59413    ... -72.59413
    -72.59413    -72.59413   ]
   [-72.59413    -72.59413    -72.59413    ... -72.59413
    -72.5