# Experimenting with different networks and datasets


In [None]:
#from torch.utils.data import Dataset
import torchaudio
import torch
import numpy as np

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, data, label_encoder=None):
        # Initialize attributes
        self.data = data["uuid"]
        self.label = data["status"]
        self.label_encoder = label_encoder
        self.sample_rate = {}
        

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        # Extract audio sample from idx
        audio_path = self.data[idx]

        # Load in audio
        audio_sample, sample_rate = torchaudio.load(audio_path)
        self.sample_rate[idx] = sample_rate
        
        info = torchaudio.info(audio_path)
        print("Audio channels:", info.num_channels)

        # Extract audio label from idx and transform
        audio_label = [self.label[idx]]
        audio_label = self.label_encoder.transform(audio_label)
        
        return audio_sample, torch.tensor(audio_label)

    def __get_sample_rate__(self, idx):
        return self.sample_rate.get(idx)

In [None]:
# from IPython.display import Audio
# from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import librosa


# Stolen from pytorch tutorial xd
def plot_waveform(waveform, sr, title="Waveform", ax=None):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sr

    if ax is None:
        _, ax = plt.subplots(num_channels, 1)
    ax.plot(time_axis, waveform[0], linewidth=1)
    ax.grid(True)
    ax.set_xlim([0, time_axis[-1]])
    ax.set_title(title)


def plot_spectrogram(specgram, title=None, ylabel="freq_bin", ax=None):
    if ax is None:
        _, ax = plt.subplots(1, 1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.imshow(
        librosa.power_to_db(specgram),
        origin="lower",
        aspect="auto",
        interpolation="nearest",
    )


def plot_fbank(fbank, title=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Filter bank")
    axs.imshow(fbank, aspect="auto")
    axs.set_ylabel("frequency bin")
    axs.set_xlabel("mel bin")

In [None]:
import numpy as np

def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0)
    return batch.permute(0, 2, 1)


def collate_fn(batch):
    # A data tuple has the form:
    # waveform, label

    # Separate audio samples and labels
    waveforms, labels = zip(*batch)
    
    # Print the lengths of the waveforms
    #for i, waveform in enumerate(waveforms):
    #    print(f"Waveform {i} length: {len(waveform)}")
    
    # Pad the audio samples
    padded_waveforms = pad_sequence(waveforms)

    # Convert labels to tensor
    labels = torch.tensor(labels)

    return padded_waveforms, labels

In [None]:
import pandas as pd



def undersample(data, n, normalize=False):

    # Step 1: Identify majority class

    class_counts = data["status"].value_counts()

    majority_class = class_counts.idxmax()


    # Step 2: Calculate desired class distribution (e.g., balanced distribution)

    desired_class_count = n  # Target number of samples for each class


    # Step 3: Select subset from majority class

    undersampled_data_majority = data[data["status"] == majority_class].sample(
        n=desired_class_count
    )


    # Combine with samples from minority classes

    undersampled_data_minority = data[~(data["status"] == majority_class)]


    # Combine undersampled majority class with minority classes

    undersampled_data = pd.concat(
        [undersampled_data_majority, undersampled_data_minority]
    )


    # Shuffle the undersampled dataset

    undersampled_data = undersampled_data.sample(frac=1).reset_index(drop=True)

    return undersampled_data

In [None]:
def weighted_sample(data):
    # Find class distribution
    class_counts = data["status"].value_counts()
    # print(class_counts)

    # Check class weights
    class_weights = 1 / class_counts
    # print(class_weights)

    # Adjust weighting to each sample
    sample_weights = [1 / class_counts[i] for i in data["status"].values]
    # print("len sample weights:",len(sample_weights))

    return sample_weights

In [None]:
from sklearn.model_selection import train_test_split

def preprocess_dataset(data, test_size):
    # Extract audio samples and labels
    X = data.drop(columns=["status"])
    y = data["status"]


    # Perform a stratified split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, stratify=y, random_state=42
    )

    # Combine audio samples and target labels for training and validation sets
    train_data = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)
    test_data = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)

    return train_data, test_data

In [None]:
def visualize_dataset(data, normalize, title):
    print(f"{title} Distribution")
    print(data["status"].value_counts(normalize=normalize))
    print("Total samples", len(data))

    plt.figure(figsize=(6, 4))
    plt.title(f"Histogram of Patient Status\n- {title}")
    plt.bar(data["status"].value_counts().index, data["status"].value_counts())
    plt.xticks(rotation=20, ha="right", fontsize=8)
    plt.xlabel("Class", fontsize=8)
    plt.ylabel("Frequency", fontsize=8)
    plt.show()

In [None]:
import os

def preprocess_data(data_path, data_dir_path):
    # Read data file then remove every column other than the specified columns
    # Removes empty samples and filters through cough probability
    data = pd.read_csv(data_path, sep=",")
    
    data = (
        data[["uuid", "cough_detected", "SNR", "age", "gender", "status"]]
        .loc[data["cough_detected"] >= 0.5]
        .dropna().reset_index(drop=True)
    )

    # Check if the following MP3 with uuid exists
    mp3_data = []
    non_exist = []
    for file in data["uuid"]:
        if os.path.exists(os.path.join(data_dir_path, f"{file}.mp3")):
            #print("Exists!")
            mp3_data.append(os.path.join(data_dir_path, f"{file}.mp3"))
        else:
            #print("Does not exist!")
            non_exist.append(file)
        # elif os.path.exists(os.path.join(data_dir_path, f'{file}.ogg')):
        #    ogg_data.append(os.path.join(data_dir_path, f'{file}.ogg'))

    # Remove entries with missing MP3 files from the original data
    data = data[~data["uuid"].isin(non_exist)]

    # Replace the uuids with the path to uuid
    data["uuid"] = mp3_data

    data.to_csv("audio_data.csv", index=False)
    print("Finished!")

# Define data variables
data_meta = "metadata_compiled.csv"
data_dir_path = r"../Dataset/MP3/"

# Preprocess data if you havent already
#preprocess_data(data_meta, data_dir_path)

In [None]:
from torch.utils.data import WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
import numpy as np
import os

# Set seed for reproducibility
torch.manual_seed(42)

# Define data variables
data_meta = "metadata_compiled.csv"
data_dir_path = r"../Dataset/MP3/"
data = pd.read_csv("audio_data.csv")


# Initialize LabelEncoder
le = LabelEncoder()

# Fit and transform labels into encoded form
labels = ["healthy", "symptomatic", "COVID-19"]
encoded_labels = le.fit_transform(labels)

# Visualize standard dataset
# visualize_dataset(data, True, "Standard")
#print("DATASET BEFORE PREPROCESS\n", data)

# Prepare standard dataset
train_data, test_data = preprocess_dataset(data, 0.33)
#print("type",type(train_data))
#print("DATA", train_data)
#print("DATA ILOC", train_data.iloc[0]["uuid"])
#data1 = train_data["uuid"]#.to_numpy()
#label1 = train_data["status"]#.to_numpy()
#f = train_data["uuid"]
#l = train_data["status"]

#print(type(f),f)
#print(type(l),l)

#print(f[7563])
#print(l[7563])

# Create undersampled version
undersampled_data = undersample(data, 2000, True)

# Visualize undersampled dataset
# visualize_dataset(undersampled_data, True, "Undersampled")

# Prepare undersampled dataset
train_undersampled_data, test_undersampled_data = preprocess_dataset(
    undersampled_data, 0.33
)

# Preparing weighted dataset
sample_weights = weighted_sample(data)
weighted_Sampler = WeightedRandomSampler(
    weights=sample_weights, num_samples=len(data), replacement=True
)


# Create AudioDataset instances for training and validation sets
"""
We should try training with different datasets such as:
 * Standard
 * Undersampled
 * Weighted

"""
# Standard dataset
train_dataset = AudioDataset(train_data, le)
test_dataset = AudioDataset(test_data, le)

# Undersampled dataset
train_undersampled_dataset = AudioDataset(train_undersampled_data, le)
test_undersampled_dataset = AudioDataset(test_undersampled_data, le)

# Create training and test dataloader instances
batch = 256
workers = 0
pin_memory = True

train_dataloader = DataLoader(train_dataset, batch_size=batch, shuffle=True, num_workers=workers, collate_fn=collate_fn, pin_memory=pin_memory)
test_dataloader = DataLoader(test_dataset, batch_size=batch, shuffle=False, num_workers=workers, collate_fn=collate_fn, pin_memory=pin_memory)

"""
train_undersampled_dataloader = DataLoader(
    train_undersampled_dataset,
    batch_size=batch,
    shuffle=True,
    num_workers=workers,
    collate_fn=collate_fn,
    pin_memory=True,
)
test_undersampled_dataloader = DataLoader(
    test_undersampled_dataset,
    batch_size=batch,
    shuffle=False,
    num_workers=workers,
    collate_fn=collate_fn,
    pin_memory=True,
)
"""

# train_weighted_dataloader = DataLoader(train_dataset, sampler=weighted_Sampler, batch_size=batch, num_workers=workers, collate_fn=collate_fn, pin_memory=True)
# test_weighted_dataloader = DataLoader(test_dataset, sampler=weighted_Sampler, batch_size=batch, num_workers=workers, collate_fn=collate_fn, pin_memory=True)

In [None]:
# Testing stuff
#print(train_data)
#print(test_data)
#print(train_dataset.__getitem__(7563))
#print(train_dataset.__len__())

#train_features, train_labels = next(iter(train_dataloader))
#print(f"Feature batch shape: {train_features.size()}")
#print(f"Labels batch shape: {train_labels.size()}")
"""



import torchaudio.transforms as T




print(t := test.__getitem__(0))



print(t3 := test.__getitem__(1))
print("len t", len(t[0][0]))



print("len t3", len(t3[0][0]))

print(t[0])



print(t[1])




print(t2 := test.__get_sample_rate__(0))




print("decoded label",le.inverse_transform(t[1]))




# Define transform



spectrogram = T.Spectrogram(n_fft=512)




# Perform transform
spec = spectrogram(t[0])



fig, axs = plt.subplots(2, 1)



plot_waveform(t[0], t2, title="Original waveform", ax=axs[0])



plot_spectrogram(spec[0], title="spectrogram", ax=axs[1])



fig.tight_layout()
"""

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# import matplotlib.pyplot as plt
# import IPython.display as ipd

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device running on: {device}")


class M5(nn.Module):
    def __init__(self, n_input=1, n_output=35, stride=16, n_channel=32):
        super().__init__()
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.bn1 = nn.BatchNorm1d(n_channel)
        self.pool1 = nn.MaxPool1d(4)
        self.conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(n_channel)
        self.pool2 = nn.MaxPool1d(4)
        self.conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.bn3 = nn.BatchNorm1d(2 * n_channel)
        self.pool3 = nn.MaxPool1d(4)
        self.conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.bn4 = nn.BatchNorm1d(2 * n_channel)
        self.pool4 = nn.MaxPool1d(4)
        self.fc1 = nn.Linear(2 * n_channel, n_output)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(self.bn1(x))
        x = self.pool1(x)
        x = self.conv2(x)
        x = F.relu(self.bn2(x))
        x = self.pool2(x)
        x = self.conv3(x)
        x = F.relu(self.bn3(x))
        x = self.pool3(x)
        x = self.conv4(x)
        x = F.relu(self.bn4(x))
        x = self.pool4(x)
        x = F.avg_pool1d(x, x.shape[-1])
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        return F.log_softmax(x, dim=2)


model = M5(n_output=len(labels))
model.to(device)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


n = count_parameters(model)
print("Number of parameters: %s" % n)

In [None]:
from torchaudio.transforms import MFCC

# Settings for MelSpectrogram computation
melkwargs = {
    "n_mels": 80,  # How many mel frequency filters are used
    "n_fft": 480,  # How many fft components are used for each feature
    "win_length": 480,  # How many frames are included in each window
    "hop_length": 160,  # How many frames the window is shifted for each component
    "center": False,  # Whether frams are padded such that the component of timestep t is centered at t
    "f_max": 7600,  # Maximum frequency to consider
    "f_min": 20,
}

# Instantiate MFCC feature extractor
mfcc = MFCC(
    n_mfcc=40,  # Number of cepstrum components
    sample_rate=16000,  # Sample rate of input audio
    melkwargs=melkwargs,
)  # Keyword arguments for MelSpectogram

In [13]:
from tqdm.notebook import tqdm
import torch.optim as optim

# Defining training variables
epochs = 40
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(
    optimizer, step_size=20, gamma=0.1
)  # reduce the learning after 20 epochs by a factor of 10
log_interval = 20
losses = []

for epoch in range(epochs):
    # for inputs, labels in tqdm(
    #    train_undersampled_dataloader,
    #    total=len(train_undersampled_dataloader),
    #    leave=True,
    #    desc=f"Epoch {epoch}/{epochs}",
    # ):
    for inputs, labels in train_dataloader:
        #print(inputs)
        #print(labels)
        i = 0



Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1




Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1




Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1




Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1




Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1




Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
Audio channels: 1
