# Importing necessary libraries

In [None]:
# Torch-related imports
from torch.utils.data import WeightedRandomSampler, DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision.transforms import Resize
from torchaudio.transforms import MFCC
from torch.cuda.amp import GradScaler
from torchvision import models
from torchinfo import summary
import torch.optim.lr_scheduler as lr_scheduler
import torch.nn.functional as F
import torch.nn as nn
import torchaudio
import torch

# Sklearn-related imports
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder
import sklearn.utils.class_weight as class_weight
import sklearn.model_selection as model_selection
import sklearn.preprocessing as preprocessing
import sklearn.metrics as metrics

# Audio processing imports
from pydub.silence import split_on_silence
from pydub import AudioSegment
import librosa

# Miscellaneous imports
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
import wandb
import os

# Dataset, dataloader, and model parameters

In [None]:
# Set seed for reproducibility
torch.manual_seed(42)

# Path to data
#data = pd.read_csv("/kaggle/input/filtered-csv/filtered_audio_data.csv")
#data["uuid"] = data["uuid"].str.replace("../Dataset/MP3/", "/kaggle/input/covid-19-audio-classification/MP3/")
data = pd.read_csv("misc_data/filtered_audio_data.csv")

# Class labels
labels = ["healthy", "symptomatic", "COVID-19"]

# Silence arguments
min_silence = 500
threshold_dBFS = -40
keep_silence = 250

# Dataloader and dataset arguments
batch = 3
workers = 0
pin_memory = True
dataset_type = "undersampled"
undersampling = 500

# Settings for MelSpectrogram computation
melkwargs = {
    "n_mels": 60,  # How many mel frequency filters are used
    "n_fft": 350,  # How many fft components are used for each feature
    "win_length": 350,  # How many frames are included in each window
    "hop_length": 100,  # How many frames the window is shifted for each component
    "center": False,  # Whether frams are padded such that the component of timestep t is centered at t
    "f_max": 11000,  # Maximum frequency to consider
    "f_min": 0,
}
n_mfcc = 22
sample_rate = 22000

# Model type
"""
Models:
    "resnet18":
    "resnet34":
    "resnet50":
    "vgg_bn":
    "multi_resnet":
    "modified_multi_resnet"
    "modified_multi_resnet_spectral"
"""
model_type = "modified_multi_resnet_spectral"
model_arch = "resnet18"
model_output = "modified_multi_resnet18_spectral_undersampled"

# Model training and Weights and Biases variables
lr = 0.001
step = 10
decay = 0.0001
optimizer_type = "adam"
gamma = 0.1
epochs = 50
arch = "Multi Input ResNet18 Spectral"
desc = "This model is trained on MFCC features, numeric features, and spectral features."
dataset = "COVID-19 Audio Classification"
weighted = False

# Setup and define custom dataset class

The custom dataset class finds each raw audio sample and corresponding label, encodes the label and returns the raw audio sample as mono-channel as well as the label.

In [None]:
def remove_silence(audio_object, min_silence_ms=100, threshold_dBFS=-40, keep_silence=100, seek_step=1):
    # Check for loudness (DEBUGGING)
    # loudness_dBFS = audio_object.dBFS
    # print("Loudness (dBFS):", loudness_dBFS)

    # Attempt to split and remove silence from the audio signal
    audio_segments = split_on_silence(audio_object, min_silence_ms, threshold_dBFS, seek_step)

    # Check if audio_segments is empty if yes return the original audio object as numpy array
    if not audio_segments:

        # Get the array of samples from the audio segment
        org_audio = np.array(audio_object.get_array_of_samples(), dtype=np.float32)

        # Normalize the samples if needed
        org_audio /= np.max(np.abs(org_audio))

        return org_audio

    # Add the different audio segments together
    audio_processed = sum(audio_segments)

    # Return the samples from the processed audio, save as numpy array, and normalize it
    audio_processed = np.array(audio_processed.get_array_of_samples(), dtype=np.float32)
    audio_processed /= np.max(np.abs(audio_processed))

    return audio_processed


def encode_age(age):
    # Define age mapping
    age_mapping = {"child": 0, "teen": 1, "adult": 2, "senior": 3}

    # Determine age range
    if age <= 12:  # Children from ages 0-12
        return age_mapping["child"]
    elif age <= 19:  # Teenagers from ages 13-19
        return age_mapping["teen"]
    elif age <= 50:  # Adults from ages 20-50
        return age_mapping["adult"]
    else:  # Seniors (age > 50)
        return age_mapping["senior"]

In [None]:
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, data, args, label_encoder=None):
        # Initialize attributes
        self.data = data["uuid"]
        self.label = data["status"]
        self.age = data["age"]
        self.gender = data["gender"]
        self.SNR = data["SNR"]
        self.label_encoder = label_encoder
        self.min_silence = args[0]
        self.threshold = args[1]
        self.keep_silence = args[2]
        self.sample_rate = {}

    def __len__(self):
        return len(self.label)

    def __getitem__(self, idx):
        # Extract audio sample from idx
        audio_path = self.data[idx]

        # Load in audio
        audio_object = AudioSegment.from_file(audio_path)
        audio_sample = remove_silence(audio_object, self.min_silence, self.threshold, self.keep_silence)
        self.sample_rate[idx] = audio_object.frame_rate

        # Extract audio label from idx and transform
        audio_label = [self.label[idx]]
        audio_label = self.label_encoder.transform(audio_label)

        # Extract age, gender, and SNR from idx and encode the necessary features
        gender_mapping = {"male": 0, "female": 1}
        gender = np.array([gender_mapping[self.gender[idx]]], dtype=np.int8)
        age = np.array(encode_age(self.age[idx]), dtype=np.int8)
        snr = np.array([self.SNR[idx]])

        # Check if audio sample is stereo -> convert to mono (remove_silence already turns it into 1 channel)
        # if len(audio_sample.shape) > 1 and audio_sample.shape[1] > 1:
        # Convert stereo audio to mono
        # audio_sample = audio_sample.mean(dim=0, keepdim=True)

        return (
            torch.tensor(audio_sample, dtype=torch.float32),
            torch.tensor(audio_label, dtype=torch.int32),
            torch.tensor(gender, dtype=torch.int32),
            torch.tensor(age, dtype=torch.int32),
            torch.tensor(snr, dtype=torch.float32),
        )

    def __get_sample_rate__(self, idx):
        # If needed extract sample rate
        return self.sample_rate.get(idx)

# Custom collate function

The following collate function will take batches of raw audio samples and zero pad them to match the largest sized audio sample.

In [None]:
def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0)

    return batch.unsqueeze(1)  # Add channel dimension for MFCC input


def collate_fn(batch):
    # A data tuple has the form:
    # waveform, label

    # Separate audio samples and labels
    waveforms, labels, genders, ages, snrs = zip(*batch)

    # Pad the audio samples (if needed)
    # padded_waveforms = pad_sequence(waveforms)

    # Convert labels to tensor
    labels = torch.tensor(labels, dtype=torch.int32)

    # Stack numeric features into a tensor and normalize them (if needed)
    # scaler = StandardScaler()
    genders = torch.tensor(genders)
    ages = torch.tensor(ages)
    snrs = torch.tensor(snrs)
    numeric_features = torch.stack((genders, ages, snrs), dim=1)
    # numeric_features = torch.tensor(scaler.fit_transform(numeric_features))

    return waveforms, labels, numeric_features

# Miscellaneous functions

The following code block contains miscellaneous functions such as plotting of waveforms, spectograms, fbank, and preprocessing of the data.

In [None]:
def waveform_plot(signal, sr, title, threshold=None, plot=None):
    # Calculate time axis
    time = np.arange(0, len(signal)) / sr

    # Plot standard waveform
    plt.figure(figsize=(10, 8))
    plt.subplot(3, 1, 1)
    plt.plot(time, signal, color="b")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude")
    plt.title(title)
    plt.grid(True)
    plt.show()

    if plot:
        # Calculate dBFS values
        if np.any(signal != 0):
            db_signal = 20 * np.log10(np.abs(signal) / np.max(np.abs(signal)))
        else:
            db_signal = -60

        plt.subplot(3, 1, 2)
        # Plot waveform in dB scale
        plt.plot(time, db_signal, color="b")

        # Plot threshold level
        if threshold:
            plt.axhline(y=threshold, color="r", linestyle="--", label=f"{threshold} dBFS Threshold")
            plt.legend()

        plt.xlabel("Time (s)")
        plt.ylabel("Amplitude (dBFS)")
        plt.title(title)
        plt.grid(True)

        n_fft = 2048  # Length of the FFT window
        hop_length = 512  # Hop length for FFT
        S = np.abs(librosa.stft(signal.astype(float), n_fft=n_fft, hop_length=hop_length))

        # Convert amplitude to dB scale (sound pressure level)
        S_db = librosa.amplitude_to_db(S, ref=np.max)

        # Get frequency bins corresponding to FFT
        freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)

        # Step 3: Plot the SPL values over frequency
        plt.subplot(3, 1, 3)
        plt.plot(freqs, np.mean(S_db, axis=1), color="b")
        plt.title("Sound Pressure Level (SPL) vs. Frequency")
        plt.xlabel("Frequency (Hz)")
        plt.ylabel("SPL (dB)")
        plt.grid(True)
        plt.xlim([20, 25000])  # Set frequency range for better visualization
        plt.xscale("log")  # Use log scale for frequency axis

        plt.tight_layout()
        plt.show()


# Stolen from pytorch tutorial xd
def plot_spectrogram(specgram, title=None, ylabel="freq_bin", batch=0, idx=0, ax=None):
    if ax is None:
        fig, ax = plt.subplots(1, 1)
    if title is not None:
        ax.set_title(title)
    ax.set_ylabel(ylabel)
    im = ax.imshow(
        librosa.power_to_db(specgram),
        origin="lower",
        aspect="auto",
        interpolation="nearest",
    )
    plt.colorbar(im, ax=ax, label="dB")
    # plt.close()
    plt.savefig(f"test_outputs/batch{batch}_idx{idx}_{title}.png")


def plot_fbank(fbank, title=None):
    fig, axs = plt.subplots(1, 1)
    axs.set_title(title or "Filter bank")
    axs.imshow(fbank, aspect="auto")
    axs.set_ylabel("frequency bin")
    axs.set_xlabel("mel bin")


def preprocess_data(data_meta_path, data_dir_path, output_dir):
    # Read data file then remove every column other than the specified columns
    # Removes empty samples and filters through cough probability
    data = pd.read_csv(data_meta_path, sep=",")
    data = data[["uuid", "cough_detected", "SNR", "age", "gender", "status"]].loc[data["cough_detected"] >= 0.8].dropna().reset_index(drop=True).sort_values(by="cough_detected")
    data = data[(data["gender"] != "other")]

    # Count the occurrences of each age value
    age_counts = data["age"].value_counts()

    # Filter out ages with fewer than 100 samples
    ages_to_keep = age_counts.index[age_counts >= 100]

    # Filter the DataFrame based on the selected ages
    data = data[data["age"].isin(ages_to_keep)]

    # Check if the following MP3 with uuid exists
    mp3_data = []
    non_exist = []
    for file in data["uuid"]:
        if os.path.exists(os.path.join(data_dir_path, f"{file}.mp3")):
            mp3_data.append(os.path.join(data_dir_path, f"{file}.mp3"))
        else:
            non_exist.append(file)

    # Remove entries with missing MP3 files from the original data
    data = data[~data["uuid"].isin(non_exist)]

    # Replace the uuids with the path to uuid
    data["uuid"] = mp3_data

    # Save the data as csv
    data.to_csv(os.path.join(output_dir, "filtered_audio_data.csv"), index=False)

    print("Finished processing!")

In [None]:
"""
data_path = r"misc_data/metadata_compiled.csv"
data_dir_path = r"../Dataset/MP3/"
output_dir = r"misc_data/"
preprocess_data(data_path, data_dir_path, output_dir)
"""

# Dataset specific functions

The following codeblock contains functions specially related to the dataset preprocessing.

In [None]:
def preprocess_dataset(data, test_size):
    # Extract audio samples and labels
    X = data.drop(columns=["status"])
    y = data["status"]

    # Perform a stratified split
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, stratify=y, random_state=42)

    # Combine audio samples and target labels for training and validation sets
    train_data = pd.concat([X_train, y_train], axis=1).reset_index(drop=True)
    test_data = pd.concat([X_test, y_test], axis=1).reset_index(drop=True)

    return train_data, test_data


def weighted_sample(data):
    # Find class distribution
    class_counts = data["status"].value_counts()

    # Adjust weighting to each sample
    sample_weights = [1 / class_counts[i] for i in data["status"].values]

    return sample_weights


def undersample(data, minority_class_label, n):
    # Identify minority class
    minority_class = minority_class_label

    # Calculate desired class distribution (e.g., balanced distribution)
    desired_class_count = n  # Target number of samples for each class

    # Select subset from minority class
    undersampled_data_minority = data[data["status"] == minority_class].sample(n=desired_class_count)

    # Combine with samples from majority classes
    undersampled_data_majority = data[~(data["status"] == minority_class)]

    # Combine undersampled minority class with majority classes
    undersampled_data = pd.concat([undersampled_data_majority, undersampled_data_minority])

    # Shuffle the undersampled dataset
    undersampled_data = undersampled_data.sample(frac=1).reset_index(drop=True)

    return undersampled_data


def visualize_dataset(data, normalize, title):
    print(f"{title} Distribution")
    print(data["status"].value_counts(normalize=normalize))
    print("Total samples", len(data))

    plt.figure(figsize=(6, 4))
    plt.title(f"Histogram of Patient Status\n- {title}")
    plt.bar(data["status"].value_counts().index, data["status"].value_counts())
    plt.xticks(rotation=20, ha="right", fontsize=8)
    plt.xlabel("Class", fontsize=8)
    plt.ylabel("Frequency", fontsize=8)
    plt.show()

# Initialization of dataset and dataset loader

This codeblock includes the initialization of the dataset as well as any processing needed, such as splitting it into training/testing datasets, as well as different sampling techniques, such as undersampling/weighted sampling.

In [None]:
def select_dataset(dataset_type, train_data, val_data, test_data, args, batch, workers, undersampling=500):
    if dataset_type == "undersampled":
        # Prepare and create undersampled version
        undersampled_data = undersample(data, "healthy", undersampling)
        undersampled_data = undersample(undersampled_data, "symptomatic", undersampling)

        # train_undersampled_data, test_undersampled_data = preprocess_dataset(undersampled_data, 0.3) # ORIGINAL
        train_undersampled_data, val_undersampled_data = preprocess_dataset(undersampled_data, 0.3)
        # val_undersampled_data, test_undersampled_data = preprocess_dataset(test_undersampled_data, 0.5)

        # Undersampled dataset
        train_dataset = AudioDataset(train_undersampled_data, args, le)
        val_dataset = AudioDataset(val_undersampled_data, args, le)
        test_dataset = AudioDataset(val_undersampled_data, args, le)  # CHANGE THIS!

        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch,
            shuffle=True,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch,
            shuffle=False,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=batch,
            shuffle=False,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

        return train_dataloader, val_dataloader, test_dataloader

    elif dataset_type == "weighted":
        # Prepare and create weighted sampler
        train_sample_weights = weighted_sample(train_data)
        val_sample_weights = weighted_sample(val_data)
        test_sample_weights = weighted_sample(test_data)

        train_weighted_Sampler = WeightedRandomSampler(weights=train_sample_weights, num_samples=len(train_data), replacement=True)
        val_weighted_Sampler = WeightedRandomSampler(weights=val_sample_weights, num_samples=len(val_data), replacement=True)
        test_weighted_Sampler = WeightedRandomSampler(weights=test_sample_weights, num_samples=len(test_data), replacement=True)

        # Create dataset and dataloader instances
        train_dataset = AudioDataset(train_data, args, le)
        val_dataset = AudioDataset(val_data, args, le)
        test_dataset = AudioDataset(test_data, args, le)

        train_dataloader = DataLoader(
            train_dataset,
            sampler=train_weighted_Sampler,
            batch_size=batch,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )
        val_dataloader = DataLoader(
            val_dataset,
            sampler=val_weighted_Sampler,
            batch_size=batch,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )
        test_dataloader = DataLoader(
            test_dataset,
            sampler=test_weighted_Sampler,
            batch_size=batch,
            num_workers=workers,
            collate_fn=collate_fn,
            pin_memory=True,
        )

        return train_dataloader, val_dataloader, test_dataloader


# Initialize LabelEncoder
le = LabelEncoder()

# Fit and transform labels into encoded form
encoded_labels = le.fit_transform(labels)

# Silence removal arguments
args = [min_silence, threshold_dBFS, keep_silence]

# Prepare standard dataset
train_data, test_data = preprocess_dataset(data, 0.3)  # First split the original dataset into 70% training
val_data, test_data = preprocess_dataset(test_data, 0.5)  # Second split the "test_data" into 50/50 validation and test (technically 15/15)

# Initialize dataloaders
train_dataloader, val_dataloader, _ = select_dataset(dataset_type, train_data, val_data, test_data, args, batch, workers, undersampling)

# Initialize and define MFCC feature extractor

In the following codeblock the MFCC specific parameters are defined and initialized. The codeblock also includes a function that pads the extracted MFCC features in order to pass it to the model.

In [None]:
def MFCC_Features(data, padding=False, normalize=False, resize=False):
    """
    Args:
    data: Input audio waveform
    max_length: Maximum length for padding
    normalize: Normalize the channel layer
    resize: Resize the spectrogram
    target_size: Target size for resizing
    """
    # Extract MFCC features
    # features = mfcc(data)
    features = [torch.unsqueeze(mfcc(waveform), 0) for waveform in data]  # Adding channels
    # features = [torch.unsqueeze(torch.unsqueeze(mfcc(waveform), 0), 0) for waveform in data] # Adding batch size and channels

    # Hardcoded padding (WIP)
    if padding:
        features = F.pad(features, (0, padding - features.shape[3]), "constant", 0)

    # Normalize the features for each sample
    if normalize == True:
        for j, feature in enumerate(features):
            mean = feature.mean(dim=[1, 2], keepdim=True)
            std = feature.std(dim=[1, 2], keepdim=True)
            features[j] = (feature - mean) / std

    # Resize mel spectrograms
    if resize == True:
        features = [Resize((224, 224), antialias=True)(feature) for feature in features]

    # Stack each feature as [batch_size, channels, features, length]
    features = torch.stack(features)

    return features


# Instantiate MFCC feature extractor
mfcc = MFCC(n_mfcc=n_mfcc, sample_rate=sample_rate, melkwargs=melkwargs)

# Spectral features extraction

In [None]:
def spectral_centroid(S=None, sr=22050, nfft=2048, h_length=512):
    return librosa.feature.spectral_centroid(S=S, sr=sr, n_fft=nfft, hop_length=h_length)


def root_mean_square(S=None, f_length=2048, h_length=512):
    return librosa.feature.rms(S=S, frame_length=f_length, hop_length=h_length)


def zero_crossing_rate(signal, f_length=2048, h_length=512):
    return librosa.feature.zero_crossing_rate(y=signal, frame_length=f_length, hop_length=h_length)


def dynamic_parameters(audio_sample, sr=48000):
    duration_seconds = len(audio_sample) / sr

    if duration_seconds <= 0.5:  # Very short audio (less than 0.5 seconds)
        n_fft = 512
        hop_length = 128
        frame_length = 512
    elif 0.5 < duration_seconds <= 1:  # Short audio (0.5 - 1 second)
        n_fft = 1024
        hop_length = 256
        frame_length = 1024
    elif 1 < duration_seconds <= 5:  # Medium-length audio (1-5 seconds)
        n_fft = 2048
        hop_length = 512
        frame_length = 2048
    else:  # Long audio (longer than 5 seconds)
        n_fft = 4096
        hop_length = 1024
        frame_length = 4096
    return n_fft, hop_length, frame_length


def spectral_features(y):
    #ZCR_features = []
    #RMS_features = []
    #SC_features = []
    spectral_features = []
    for sample in y:
        sample = np.asarray(sample)

        # Return n_fft, hop_length, frame_length based on length of sample
        n_fft, hop_length, frame_length = dynamic_parameters(sample)

        # Compute magnitude spectrum of sample
        S, _ = librosa.magphase(librosa.stft(y=sample, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, window="hann", center=True, pad_mode="constant"))

        # Compute ZCR, RMS, and SC for additional feature extraction
        ZCR = zero_crossing_rate(signal=sample, f_length=frame_length, h_length=hop_length)
        RMS = root_mean_square(S=S, f_length=frame_length, h_length=hop_length)
        SC = spectral_centroid(S=S, sr=48000, nfft=n_fft, h_length=hop_length)

        # Convert into tensors and normalize
        ZCR = torch.tensor(ZCR, dtype=torch.float32)
        RMS = torch.tensor(RMS, dtype=torch.float32)
        SC = torch.tensor(SC, dtype=torch.float32)
        ZCR = (ZCR - ZCR.mean()) / ZCR.std()
        RMS = (RMS - RMS.mean()) / RMS.std()
        SC = (SC - SC.mean()) / SC.std()

        # Append the features for the current signal to the list
        #ZCR_features.append(ZCR)
        #RMS_features.append(RMS)
        #SC_features.append(SC)
        spectral_features.append((ZCR,RMS,SC))

    # Convert list of tuples of spectral features into tuple of lists
    #ZCR, RMS, SC = zip(*spectral_features)
   
    # Compute the maximum length of features the combined features
    max_len = max(max(zcr.shape[1], rms.shape[1], sc.shape[1]) for zcr, rms, sc in spectral_features)
    
    # Pad each feature vector to the max length
    ZCR = [F.pad(zcr, (0, max_len - zcr.shape[1]), value=0.0) for zcr in ZCR]
    RMS = [F.pad(rms, (0, max_len - rms.shape[1]), value=0.0) for rms in RMS]
    SC = [F.pad(sc, (0, max_len - sc.shape[1]), value=0.0) for sc in SC]
    
    # Stack each list of tensors to create new shape:
    # [batch_size, feature_size, feature_length]
    ZCR = torch.stack(ZCR, dim=0)
    RMS = torch.stack(RMS, dim=0)
    SC = torch.stack(SC, dim=0)
    #print("ZCR shape:", ZCR.shape)
    #print("RMS shape:", RMS.shape)
    #print("SC shape:", SC.shape)


    # Stack the individual spectral features to create new shape:
    # [batch_size, spectral_features, spectral_features_length]
    #combined_features = torch.stack((ZCR, RMS, SC), dim=1).permute(0,2,1,3)
    #print("combined features shape", combined_features.shape)
    #combined_features = torch.stack((ZCR.squeeze(1), RMS.squeeze(1), SC.squeeze(1)), dim=1)
    #print("combined features shape squeezing", combined_features.shape)

    #tensor_list = torch.stack(RMS_features, dim=0)
    #print("tensor list RMS:", tensor_list.shape)

    #return combined_features
    return ZCR, RMS, SC

# Network architectures

In [None]:
class MultiInputResNet(nn.Module):
    def __init__(self, weights=None, num_classes=3, model_arch="resnet18"):
        super(MultiInputResNet, self).__init__()
        if model_arch == "resnet18":
            self.resnet = models.resnet18(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet34":
            self.resnet = models.resnet34(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet50":
            self.resnet = models.resnet50(weights=weights, num_classes=num_classes)

        # Adjust the first convolutional layer to match number of channels
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        # Add additional branch to handle numeric features
        self.numeric_features = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
            nn.Linear(64, 512),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
        )

        # Adding a fully connected layer
        self.fc = nn.Linear(512 + 3, num_classes)  # 512 from ResNet + 3 from numeric features

    def forward(self, mfcc, numeric):
        # Process MFCC input through ResNet
        mfcc_resnet_output = self.resnet(mfcc)

        # Process numeric features through additional branch
        numeric_output = self.numeric_features(numeric)

        # Concatenate the outputs from both branches
        combined_features = torch.cat((mfcc_resnet_output, numeric_output), dim=1)

        # return raw scores/logits
        output = self.fc(combined_features)

        # Apply softmax activation to get probabilities
        # output_probs = F.softmax(output, dim=1)

        return output


class MultiInputResNet_spectral(nn.Module):
    def __init__(self, weights=None, num_classes=3, model_arch="resnet18"):
        super(MultiInputResNet_spectral, self).__init__()
        if model_arch == "resnet18":
            self.resnet = models.resnet18(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet34":
            self.resnet = models.resnet34(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet50":
            self.resnet = models.resnet50(weights=weights, num_classes=num_classes)

        # Adjust the first convolutional layer to match number of channels
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        # Add additional branch to handle numeric features
        self.numeric_features = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
            nn.Linear(64, 512),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
        )

        # Add additional branch to handle spectral features
        self.spectral_features = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(inplace=True), 
            nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(inplace=True), 
            nn.MaxPool1d(kernel_size=2), 
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(inplace=True), 
            nn.Conv1d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), 
            nn.ReLU(inplace=True), 
            nn.AdaptiveAvgPool1d(1)
        )

        # Adding a fully connected layer
        self.fc = nn.Linear(512 + 3 + 512 + 512 + 512, num_classes)  # 512 from ResNet, 512 from each spectral feature + 3 from numeric features

        # Adding fully connected layers for age and gender prediction
        self.fc_age = nn.Linear(3, 4)
        self.fc_gender = nn.Linear(3, 2)

    def forward(self, mfcc, numeric, ZCR, RMS, SC):
        # Process MFCC input through ResNet
        mfcc_resnet_output = self.resnet(mfcc)

        # Process numeric features through the additional branch
        numeric_output = self.numeric_features(numeric)
        numeric_output2 = F.relu(numeric)

        # Process each spectral features through the additional branch
        ZCR_output = self.spectral_features(ZCR).squeeze(dim=2)
        RMS_output = self.spectral_features(RMS).squeeze(dim=2)
        SC_output = self.spectral_features(SC).squeeze(dim=2)
        #print("ZCR_output shape:", ZCR_output.shape)
        #print("RMS_output shape:", RMS_output.shape)
        #print("SC_output shape:", SC_output.shape)

        
        #ZCR_output = []
        #RMS_output = []
        #SC_output = []
        #for zcr, rms, sc in zip(ZCR, RMS, SC):
        #    ZCR_output.append(self.spectral_features(zcr))
        #    RMS_output.append(self.spectral_features(rms))
        #    SC_output.append(self.spectral_features(sc))

        ## Concatenate the outputs from all branches
        #ZCR_output = torch.stack(ZCR_output, dim=0).squeeze(dim=2)
        #RMS_output = torch.stack(RMS_output, dim=0).squeeze(dim=2)
        #SC_output = torch.stack(SC_output, dim=0).squeeze(dim=2)     
        combined_features = torch.cat((mfcc_resnet_output, numeric_output, ZCR_output, RMS_output, SC_output), dim=1)

        # Classification output
        output_class = self.fc(combined_features)
        
        # Age and gender predictions
        output_age = self.fc_age(numeric_output2)
        output_gender = self.fc_gender(numeric_output2)

        # Apply softmax activation to get probabilities
        # output_probs = F.softmax(output, dim=1)

        return output_class, output_age, output_gender


class Modified_MultiInputResNet(nn.Module):
    def __init__(self, weights=None, num_classes=3, model_arch="resnet18"):
        super(Modified_MultiInputResNet, self).__init__()
        if model_arch == "resnet18":
            self.resnet = models.resnet18(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet34":
            self.resnet = models.resnet34(weights=weights, num_classes=num_classes)
        elif model_arch == "resnet50":
            self.resnet = models.resnet50(weights=weights, num_classes=num_classes)

        # Adjust the first convolutional layer to match number of channels
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

        # Add additional branch to handle numeric features
        self.numeric_features = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
            nn.Linear(64, 512),
            nn.ReLU(inplace=True),
            # nn.Dropout(0.5),
        )

        # Adding a fully connected layer
        self.fc = nn.Linear(512 + 3, num_classes)  # 512 from ResNet + 3 from numeric features

        # Adding fully connected layers for age and gender prediction
        self.fc_age = nn.Linear(3, 4)
        self.fc_gender = nn.Linear(3, 2)

    def forward(self, mfcc, numeric):
        # Process MFCC input through ResNet
        mfcc_resnet_output = self.resnet(mfcc)

        # Process numeric features through additional branch
        numeric_output = self.numeric_features(numeric)
        numeric_output2 = F.relu(numeric)

        # Concatenate the outputs from both branches
        combined_features = torch.cat((mfcc_resnet_output, numeric_output), dim=1)

        output_class = self.fc(combined_features)
        output_age = self.fc_age(numeric_output2)
        output_gender = self.fc_gender(numeric_output2)

        ## Apply softmax activation to get probabilities
        # output_probs_class = F.softmax(output_class, dim=1)
        # output_probs_age = F.softmax(output_age, dim=1)
        # output_probs_gender = F.softmax(output_gender, dim=1)

        return output_class, output_age, output_gender

# Setup weights and bias logging

In [None]:
#Initialize wandb
#!wandb login --relogin 9be53a0c7076cae09612be80ee5e0e80d9dac79c

#Defining weights and biases config
#wandb.init(
#   # set the wandb project where this run will be logged
#   project="mini-project",
#   config={
#   "architecture": arch,
#   "dataset": dataset,
#   "description": desc,
#   "learning_rate": lr,
#   "step_size": step,
#   "weight_decay": decay,
#   "optimizer": optimizer_type,
#   "gamma": gamma,
#   "epochs": epochs
#   }
#)

# Compute class weights for loss function (if using weighted)

In [None]:
from sklearn.utils.class_weight import compute_class_weight
train_labels = train_data["status"]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(train_labels), y=train_labels)

print("COVID-19","Healthy", "Symptomatic")
print(class_weights)

# Training and validation functions

In [None]:
import torch.optim as optim


def initialize_training_setup(model, optimizer_type, weighted=False):
    if optimizer_type == "adam":
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=decay)

    else:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=decay)

    if weighted == True:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        criterion = nn.CrossEntropyLoss()

    return optimizer, criterion


def train_epoch_multi(model, device, epoch, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    print("Currently: Training")
    for i, (inputs, targets, numeric) in tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Training",
    ):
        # Training loop
        features = MFCC_Features(inputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        features, targets, numeric = features.to(device), targets.to(device), numeric.to(device)  # Load them onto GPU
        optimizer.zero_grad()  # Zero the parameters
        outputs = model(features, numeric)  # Retrieve the output from the model
        loss = criterion(outputs, targets)  # Compute the loss
        loss.backward()  # Compute gradients of the loss
        optimizer.step()  # Update weights
        running_loss += loss.item()
        # Calculate correct predictions
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == targets).sum().item()
        total_predictions += targets.size(0)

        # Print statistics for every 10th mini-batch
        if i % 10 == 9:
            print(f"Epoch {epoch+1}/{epochs} | Batch {i}/{len(train_dataloader)} | Training Loss: {loss.item():.4f}")

    # Compute and return average training loss and accuracy for the epoch
    accuracy = correct_predictions / total_predictions
    print(f"Training Accuracy: {accuracy:.4f}")
    avg_loss = running_loss / len(train_dataloader)
    return avg_loss, accuracy


def train_epoch_modified(model, device, epoch, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct_predictions_class = 0
    correct_predictions_age = 0
    correct_predictions_gender = 0
    total_predictions_class = 0
    total_predictions_age = 0
    total_predictions_gender = 0

    print("Currently: Training")
    for i, (inputs, targets, numeric) in tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Training",
    ):
        # Convert features and target labels into long
        class_targets = torch.round(targets).to(torch.long)
        gender_targets = torch.round(numeric[:, 0]).to(torch.long)
        age_targets = torch.round(numeric[:, 1]).to(torch.long)

        # Load them onto GPU
        class_targets = class_targets.to(device)
        gender_targets = gender_targets.to(device)
        age_targets = age_targets.to(device)

        # Training loop
        features = MFCC_Features(inputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        features, numeric = features.to(device), numeric.to(device)  # Load them onto GPU
        optimizer.zero_grad()  # Zero the parameters
        output_class, output_age, output_gender = model(features, numeric)  # Retrieve the output from the model

        # Compute the loss for each output separately
        loss_class = criterion(output_class, class_targets)
        loss_gender = criterion(output_gender, gender_targets)
        loss_age = criterion(output_age, age_targets)

        # Compute total loss (you can also use weighted sum of losses)
        total_loss = loss_class + loss_age + loss_gender
        total_loss.backward()  # Compute gradients of the loss
        optimizer.step()  # Update weights

        running_loss += total_loss.item()

        # Compute correct predictions and update statistics for class prediction
        _, predicted_class = torch.max(output_class, 1)
        correct_predictions_class += (predicted_class == class_targets).sum().item()
        total_predictions_class += class_targets.size(0)

        # Compute correct predictions and update statistics for age prediction
        _, predicted_age = torch.max(output_age, 1)
        correct_predictions_age += (predicted_age == age_targets).sum().item()
        total_predictions_age += age_targets.size(0)

        # Compute correct predictions and update statistics for gender prediction
        _, predicted_gender = torch.max(output_gender, 1)
        correct_predictions_gender += (predicted_gender == gender_targets).sum().item()
        total_predictions_gender += gender_targets.size(0)

        # Print statistics for every 10th mini-batch
        if i % 10 == 9:
            print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_dataloader)} | Training Loss: {total_loss.item():.4f}")

    # Compute and return average training loss and accuracy for the epoch
    accuracy_class = correct_predictions_class / total_predictions_class
    accuracy_age = correct_predictions_age / total_predictions_age
    accuracy_gender = correct_predictions_gender / total_predictions_gender
    print(f"Training Accuracy Class: {accuracy_class:.4f}")
    print(f"Training Accuracy Age: {accuracy_age:.4f}")
    print(f"Training Accuracy Gender: {accuracy_gender:.4f}")
    avg_loss = running_loss / len(train_dataloader)

    return avg_loss, accuracy_class, accuracy_age, accuracy_gender


def train_epoch(model, device, epoch, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    print("Currently: Training")
    for i, (inputs, targets) in tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Training",
    ):
        # Training loop
        features = MFCC_Features(inputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        features, targets = features.to(device), targets.to(device)  # Load them onto GPU
        optimizer.zero_grad()  # Zero the parameters
        outputs = model(features)  # Retrieve the output from the model
        loss = criterion(outputs, targets)  # Compute the loss
        loss.backward()  # Compute gradients of the loss
        optimizer.step()  # Update weights
        running_loss += loss.item()

        # Calculate correct predictions
        _, predicted = torch.max(outputs, 1)
        correct_predictions += (predicted == targets).sum().item()
        total_predictions += targets.size(0)

        # Print statistics for every 10th mini-batch
        if i % 10 == 9:
            print(f"Epoch {epoch+1}/{epochs} | Batch {i}/{len(train_dataloader)} | Training Loss: {loss.item():.4f}")

    # Compute and return average training loss and accuracy for the epoch
    accuracy = correct_predictions / total_predictions
    print(f"Training Accuracy: {accuracy:.4f}")
    avg_loss = running_loss / len(train_dataloader)
    return avg_loss, accuracy


def train_epoch_spectral(model, device, epoch, optimizer, criterion):
    model.train()
    running_loss = 0.0
    correct_predictions_class = 0
    correct_predictions_age = 0
    correct_predictions_gender = 0
    total_predictions_class = 0
    total_predictions_age = 0
    total_predictions_gender = 0

    print("Currently: Training")
    for i, (inputs, targets, numeric) in tqdm(
        enumerate(train_dataloader),
        total=len(train_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Training",
    ):
        # Convert features and target labels into long
        class_targets = torch.round(targets).to(torch.long)
        gender_targets = torch.round(numeric[:, 0]).to(torch.long)
        age_targets = torch.round(numeric[:, 1]).to(torch.long)

        # Load them onto GPU
        class_targets = class_targets.to(device)
        gender_targets = gender_targets.to(device)
        age_targets = age_targets.to(device)

        # Training loop
        features = MFCC_Features(inputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        features, numeric= features.to(device), numeric.to(device) # Load them onto GPU
        ZCR, RMS, SC = spectral_features(inputs)
        ZCR, RMS, SC = ZCR.to(device), RMS.to(device), SC.to(device)   
        optimizer.zero_grad()  # Zero the parameters
        output_class, output_age, output_gender = model(features, numeric, ZCR, RMS, SC)  # Retrieve the output from the model

        # Compute the loss for each output separately
        loss_class = criterion(output_class, class_targets)
        loss_gender = criterion(output_gender, gender_targets)
        loss_age = criterion(output_age, age_targets)

        # Compute total loss (you can also use weighted sum of losses)
        total_loss = loss_class + loss_age + loss_gender
        total_loss.backward()  # Compute gradients of the loss
        optimizer.step()  # Update weights

        running_loss += total_loss.item()

        # Compute correct predictions and update statistics for class prediction
        _, predicted_class = torch.max(output_class, 1)
        correct_predictions_class += (predicted_class == class_targets).sum().item()
        total_predictions_class += class_targets.size(0)

        # Compute correct predictions and update statistics for age prediction
        _, predicted_age = torch.max(output_age, 1)
        correct_predictions_age += (predicted_age == age_targets).sum().item()
        total_predictions_age += age_targets.size(0)

        # Compute correct predictions and update statistics for gender prediction
        _, predicted_gender = torch.max(output_gender, 1)
        correct_predictions_gender += (predicted_gender == gender_targets).sum().item()
        total_predictions_gender += gender_targets.size(0)

        # Print statistics for every 10th mini-batch
        if i % 5 == 4:
            print(f"Epoch {epoch+1}/{epochs} | Batch {i+1}/{len(train_dataloader)} | Training Loss: {total_loss.item():.4f}")

    # Compute and return average training loss and accuracy for the epoch
    accuracy_class = correct_predictions_class / total_predictions_class
    accuracy_age = correct_predictions_age / total_predictions_age
    accuracy_gender = correct_predictions_gender / total_predictions_gender
    print(f"Training Accuracy Class: {accuracy_class:.4f}")
    print(f"Training Accuracy Age: {accuracy_age:.4f}")
    print(f"Training Accuracy Gender: {accuracy_gender:.4f}")
    avg_loss = running_loss / len(train_dataloader)

    return avg_loss, accuracy_class, accuracy_age, accuracy_gender


# TODO CREATE SEPARATE VALIDATION LOOP FOR EACH MODEL TYPE!
def validate_epoch_multi(model, device, epoch, criterion):
    model.eval()
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    print("Currently: Validating")
    for j, (vinputs, vtargets, vnumeric) in tqdm(
        enumerate(val_dataloader),
        total=len(val_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Validating",
    ):

        # Validation loop
        vfeatures = MFCC_Features(vinputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        vfeatures, vtargets, vnumeric = (
            vfeatures.to(device),
            vtargets.to(device),
            vnumeric.to(device),
        )  # Load them onto GPU
        voutputs = model(vfeatures, vnumeric)
        vloss = criterion(voutputs, vtargets)

        running_loss += vloss.item()

        # Calculate correct predictions
        _, vpredicted = torch.max(voutputs, 1)
        correct_predictions += (vpredicted == vtargets).sum().item()
        total_predictions += vtargets.size(0)

    # Compute and return average validation loss, accuracy, precision, recall, and F1 score
    avg_vloss = running_loss / len(val_dataloader)

    # Compute accuracy
    vaccuracy = correct_predictions / total_predictions

    # Compute precision, recall, F1 score
    precision = precision_score(vtargets.cpu(), vpredicted.cpu(), average="macro", zero_division=0.0)
    recall = recall_score(vtargets.cpu(), vpredicted.cpu(), average="macro", zero_division=0.0)
    f1 = f1_score(vtargets.cpu(), vpredicted.cpu(), average="macro", zero_division=0.0)

    return avg_vloss, vaccuracy, precision, recall, f1


def validate_epoch_modified(model, device, epoch, criterion):
    model.eval()
    running_loss = 0.0
    correct_predictions_class = 0
    correct_predictions_age = 0
    correct_predictions_gender = 0
    total_predictions_class = 0
    total_predictions_age = 0
    total_predictions_gender = 0

    print("Currently: Validating")
    for j, (vinputs, vtargets, vnumeric) in tqdm(
        enumerate(val_dataloader),
        total=len(val_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Validating",
    ):
        class_targets = torch.round(vtargets).to(torch.long)
        gender_targets = torch.round(vnumeric[:, 0]).to(torch.long)
        age_targets = torch.round(vnumeric[:, 1]).to(torch.long)

        class_targets = class_targets.to(device)
        gender_targets = gender_targets.to(device)
        age_targets = age_targets.to(device)

        # Validation loop
        vfeatures = MFCC_Features(vinputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        vfeatures, vnumeric = (vfeatures.to(device), vnumeric.to(device))  # Load them onto GPU
        voutput_class, voutput_age, voutput_gender = model(vfeatures, vnumeric)

        # Compute the loss for each output separately
        vloss_class = criterion(voutput_class, class_targets)
        vloss_gender = criterion(voutput_gender, gender_targets)
        vloss_age = criterion(voutput_age, age_targets)

        # Compute total loss (you can also use weighted sum of losses)
        total_loss = vloss_class + vloss_age + vloss_gender

        running_loss += total_loss.item()

        # Compute correct predictions and update statistics for class prediction
        _, predicted_class = torch.max(voutput_class, 1)
        correct_predictions_class += (predicted_class == class_targets).sum().item()
        total_predictions_class += class_targets.size(0)

        # Compute correct predictions and update statistics for age prediction
        _, predicted_age = torch.max(voutput_age, 1)
        correct_predictions_age += (predicted_age == age_targets).sum().item()
        total_predictions_age += age_targets.size(0)

        # Compute correct predictions and update statistics for gender prediction
        _, predicted_gender = torch.max(voutput_gender, 1)
        correct_predictions_gender += (predicted_gender == gender_targets).sum().item()
        total_predictions_gender += gender_targets.size(0)

    # Compute and return average validation loss, accuracy, precision, recall, and F1 score
    avg_vloss = running_loss / len(val_dataloader)

    # Compute and return average training loss and accuracy for the epoch
    vaccuracy_class = correct_predictions_class / total_predictions_class
    vaccuracy_age = correct_predictions_age / total_predictions_age
    vaccuracy_gender = correct_predictions_gender / total_predictions_gender

    # Compute precision, recall, F1 score
    precision_class = precision_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    precision_gender = precision_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    precision_age = precision_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)
    recall_class = recall_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    recall_gender = recall_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    recall_age = recall_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)
    f1_class = f1_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    f1_gender = f1_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    f1_age = f1_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)

    # return avg_vloss, vaccuracy, precision, recall, f1
    metrics = ((vaccuracy_class, vaccuracy_age, vaccuracy_gender), (precision_class, precision_age, precision_gender), (recall_class, recall_age, recall_gender), (f1_class, f1_age, f1_gender))

    return avg_vloss, metrics


def validate_epoch_spectral(model, device, epoch, criterion):
    model.eval()
    running_loss = 0.0
    correct_predictions_class = 0
    correct_predictions_age = 0
    correct_predictions_gender = 0
    total_predictions_class = 0
    total_predictions_age = 0
    total_predictions_gender = 0

    print("Currently: Validating")
    for j, (vinputs, vtargets, vnumeric) in tqdm(
        enumerate(val_dataloader),
        total=len(val_dataloader),
        leave=True,
        desc=f"Epoch {epoch+1}/{epochs} | Validating",
    ):
        class_targets = torch.round(vtargets).to(torch.long)
        gender_targets = torch.round(vnumeric[:, 0]).to(torch.long)
        age_targets = torch.round(vnumeric[:, 1]).to(torch.long)

        class_targets = class_targets.to(device)
        gender_targets = gender_targets.to(device)
        age_targets = age_targets.to(device)

        # Validation loop
        vfeatures = MFCC_Features(vinputs, padding=False, normalize=True, resize=True)  # Compute the MFCC features
        vfeatures, vnumeric = vfeatures.to(device), vnumeric.to(device)  # Load them onto GPU
        ZCR, RMS, SC = spectral_features(vinputs)
        ZCR, RMS, SC = ZCR.to(device), RMS.to(device), SC.to(device)
        voutput_class, voutput_age, voutput_gender = model(vfeatures, vnumeric, ZCR, RMS, SC)

        # Compute the loss for each output separately
        vloss_class = criterion(voutput_class, class_targets)
        vloss_gender = criterion(voutput_gender, gender_targets)
        vloss_age = criterion(voutput_age, age_targets)

        # Compute total loss (you can also use weighted sum of losses)
        total_loss = vloss_class + vloss_age + vloss_gender

        running_loss += total_loss.item()

        # Compute correct predictions and update statistics for class prediction
        _, predicted_class = torch.max(voutput_class, 1)
        correct_predictions_class += (predicted_class == class_targets).sum().item()
        total_predictions_class += class_targets.size(0)

        # Compute correct predictions and update statistics for age prediction
        _, predicted_age = torch.max(voutput_age, 1)
        correct_predictions_age += (predicted_age == age_targets).sum().item()
        total_predictions_age += age_targets.size(0)

        # Compute correct predictions and update statistics for gender prediction
        _, predicted_gender = torch.max(voutput_gender, 1)
        correct_predictions_gender += (predicted_gender == gender_targets).sum().item()
        total_predictions_gender += gender_targets.size(0)

    # Compute and return average validation loss, accuracy, precision, recall, and F1 score
    avg_vloss = running_loss / len(val_dataloader)

    # Compute and return average training loss and accuracy for the epoch
    vaccuracy_class = correct_predictions_class / total_predictions_class
    vaccuracy_age = correct_predictions_age / total_predictions_age
    vaccuracy_gender = correct_predictions_gender / total_predictions_gender

    # Compute precision, recall, F1 score
    precision_class = precision_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    precision_gender = precision_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    precision_age = precision_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)
    recall_class = recall_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    recall_gender = recall_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    recall_age = recall_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)
    f1_class = f1_score(class_targets.cpu(), predicted_class.cpu(), average="macro", zero_division=0.0)
    f1_gender = f1_score(gender_targets.cpu(), predicted_gender.cpu(), average="macro", zero_division=0.0)
    f1_age = f1_score(age_targets.cpu(), predicted_age.cpu(), average="macro", zero_division=0.0)

    # return avg_vloss, vaccuracy, precision, recall, f1
    metrics = ((vaccuracy_class, vaccuracy_age, vaccuracy_gender), (precision_class, precision_age, precision_gender), (recall_class, recall_age, recall_gender), (f1_class, f1_age, f1_gender))

    return avg_vloss, metrics

# Training and validating model


In [None]:
def select_model(model_type):
    if model_type == "resnet18":
        model = models.resnet18(weights=None, num_classes=3)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        return model
    elif model_type == "resnet34":
        model = models.resnet50(weights=None, num_classes=3)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        return model
    elif model_type == "resnet50":
        model = models.resnet50(weights=None, num_classes=3)
        model.conv1 = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        return model
    elif model_type == "vgg_bn":
        model = models.vgg16_bn(weights=None, num_classes=3)
        model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        return model
    elif model_type == "multi_resnet":
        model = MultiInputResNet(weights=None, num_classes=3, model_arch=model_arch)
        return model
    elif model_type == "modified_multi_resnet":
        model = Modified_MultiInputResNet(weights=None, num_classes=3, model_arch=model_arch)
        return model
    elif model_type == "modified_multi_resnet_spectral":
        model = MultiInputResNet_spectral(weights=None, num_classes=3, model_arch=model_arch)
        return model


# Initialize model
model = select_model(model_type)

# Set the model to training mode and put it on GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
print(f"Device running on: {device}")

model.to(device)

# Wrap model with DataParallel if multiple GPUs are available
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
else:
    print("Only 1 GPU available!")

In [None]:
# Initialize training setup
optimizer, criterion = initialize_training_setup(model, optimizer_type, weighted)
scheduler = StepLR(optimizer, step_size=step, gamma=gamma)
best_vloss = float("inf")
model_no = 0

# Training and validation loop
for epoch in tqdm(range(epochs), total=epochs, leave=True, desc=f"Epoch | "):
    # avg_loss, accuracy = train_epoch(model, model_type, device, epoch, optimizer, criterion)
    # avg_loss, accuracy_class, accuracy_age, accuracy_gender = train_epoch_modified(model, device, epoch, optimizer, criterion)
    avg_loss, accuracy_class, accuracy_age, accuracy_gender = train_epoch_spectral(model, device, epoch, optimizer, criterion)
    # avg_vloss, vaccuracy, precision, recall, f1 = validate_epoch(model, device, epoch, criterion)
    # avg_vloss, metrics = validate_epoch_modified(model, device, epoch, criterion)
    avg_vloss, metrics = validate_epoch_spectral(model, device, epoch, criterion)

    # Print and log metrics
    # MultiResnet model
    # print(f"Epoch #{epoch+1} | Training Loss: {avg_loss:.4f} | Validation Loss: {avg_vloss:.4f} | Validation Accuracy: {vaccuracy:.4f}")
    # print(f"Epoch #{epoch+1} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1 Score: {f1:.4f}")
    # print(f"Epoch #{epoch+1} | Learning Rate: {scheduler.get_last_lr()}")

    # Extract metrics
    vaccuracy = metrics[0]
    precision = metrics[1]
    recall = metrics[2]
    f1 = metrics[3]

    # ModifiedMultiResnet model
    print(f"Epoch #{epoch+1} | Training Loss: {avg_loss:.4f} | Validation Loss: {avg_vloss:.4f}")
    print(f"Epoch #{epoch+1} | Class Accuracy: {vaccuracy[0]:.4f} | Age Accuracy: {vaccuracy[1]:.4f} | Gender Accuracy: {vaccuracy[2]:.4f}")
    print(f"Epoch #{epoch+1} | Class Precision: {precision[0]:.4f} | Precision: {precision[1]:.4f} | Precision: {precision[2]:.4f}")
    print(f"Epoch #{epoch+1} | Class Recall: {recall[0]:.4f} | Age Recall: {recall[1]:.4f} | Gender Recall: {recall[2]:.4f}")
    print(f"Epoch #{epoch+1} | Class F1 score: {f1[0]:.4f} | Age F1 score: {f1[1]:.4f} | Gender F1 Score: {f1[2]:.4f}")
    print(f"Epoch #{epoch+1} | Learning Rate: {scheduler.get_last_lr()}")

    # Update learning rate
    scheduler.step()

    # Log metrics to wandb
    # MultiResnet model
    # wandb.log({
    #        "epoch": epoch + 1,
    #        "train_loss": avg_loss,
    #        "train_acc": accuracy,
    #        "val_loss": avg_vloss,
    #        "val_accuracy": vaccuracy,
    #    })
    # wandb.log({"precision": precision, "recall": recall, "f1_score": f1})

    # ModifiedMultiResnet model
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_loss,
        "class_accuracy": vaccuracy[0],
        "age_accuracy": vaccuracy[1],
        "gender_accuracy": vaccuracy[2],
        "class_precision": precision[0],
        "age_precision": precision[1],
        "gender_precision": precision[2],
        "class_recall": recall[0],
        "age_recall": recall[1],
        "gender_recall": recall[2],
        "class_f1_score": f1[0],
        "age_f1_score": f1[1],
        "gender_f1_score": f1[2]
    })

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_no += 1
        if not os.path.exists("models"):
            os.makedirs("models")
        model_path = f"models/{model_output}_no_{model_no}_epoch_{epoch+1}.pth"
        torch.save(model.state_dict(), model_path)

In [None]:
wandb.finish()