In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

import numpy as np
from sklearn.metrics import confusion_matrix

import os
import pandas as pd
import time
import warnings
from tqdm import tqdm

import librosa
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(f"Using device: {device}")

warnings.filterwarnings('ignore')

Using device: cuda


In [2]:
class FeatureDataset(Dataset):
    def __init__(self, csv_file, transform=None, target_transform=None):
        self.df = pd.read_csv(csv_file)
        self.transform = transform
        self.target_transform = target_transform
        self.label_map = {
            'Drone': 0,
            'No Drone': 1
        }

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        features = torch.FloatTensor(self.df.iloc[idx, 2:])
        label = self.df.iloc[idx, 1]
        
        if self.target_transform:
            label = self.target_transform(label)
        
        return torch.tensor(features), torch.tensor(label)
    

# Usage example
csv_file = "./anech_esc50_features_normalized.csv"

# Define any transformations if needed
transform = None
target_transform = None

dataset = FeatureDataset(csv_file, transform=transform, target_transform=target_transform)

In [3]:
train_size = int(.8 * len(dataset))
test_size = int(.75 * (len(dataset) - train_size))
val_size = len(dataset) - train_size - test_size

train_dataset, test_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, test_size, val_size])
print(train_dataset)
print(train_dataset[0])

# Using torch.utils.data.weightedRandomSampler to deal with non-uniform dataset
train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True
    # sampler=WeightedRandomSampler(weights=label_weights, num_samples=train_size, replacement=True)
    )

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False
)

val_loader = DataLoader(
    val_dataset,
    batch_size=64,
    shuffle=False
)

<torch.utils.data.dataset.Subset object at 0x00000172FFA1F5C0>
(tensor([-6.7179e+02, -6.6395e+02, -6.8624e+02,  ..., -2.6672e-02,
        -2.6672e-02, -2.6672e-02]), tensor(1))


In [7]:
for feature, label in train_loader:
    print(label)

tensor([0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1,
        1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0,
        1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1])
tensor([1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1,
        0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0,
        0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0])
tensor([1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1,
        0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
        0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0])
tensor([0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0,
        1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1,
        1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0])
tensor([0, 1, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1,

In [8]:
print(train_dataset[7][0].shape)


torch.Size([3393])


In [4]:
class LNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            nn.LazyLinear(4096),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(4096),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(1024),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(1024),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(512),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(256),
            nn.LazyBatchNorm1d(),
            nn.ReLU(),
            nn.LazyLinear(1)
        )

    def forward(self, x):
        logits = self.seq(x).squeeze(dim=1)
        return logits


In [5]:
def split_seconds(seconds):
    minutes = seconds // 60
    hours = minutes // 60
    days = hours // 24
    return seconds % 60, minutes % 60, hours % 24, days


In [7]:
def main():
    # Define model
    model = LNN()

    # Cuda setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #device = torch.device("cpu")
    model = model.to(device)
   
    print(f"Using device: {device}")
    # Optimizer setup
    optimizer = Adam(model.parameters(), lr=1e-4)
   
    # Loss function
    loss_fn = nn.BCEWithLogitsLoss(reduction="mean")
   
    # Number of epochs
    num_epochs = 64
   
    # Train or load model?
    print("Training model....")
    start = time.time()
    for epoch in tqdm(range(num_epochs), desc="Epochs"):
        
        model.train()
        for batch_idx, (features, labels) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)):
            features, labels = features.to(device), labels.to(device)
            optimizer.zero_grad()
           
            # CNN forward pass
            logits = model(features)
            # print(torch.round(torch.sigmoid(logits)))
            loss = loss_fn(logits, labels.to(float))
            loss.backward()
            optimizer.step()

        model.eval()
        num_test = 0
        num_correct = 0
        total_loss = 0
        with torch.no_grad():
            for features, labels in tqdm(test_loader, desc="Testing"):
                features, labels = features.to(device), labels.to(device)
                logits = model(features)
                loss = loss_fn(logits, labels.to(float))
                total_loss += loss.item()
                preds = torch.round(torch.sigmoid(logits))
                num_test += labels.size(0)
                num_correct += preds.eq(labels).sum().item()
    
        tqdm.write(f"Epoch {epoch+1}/{num_epochs}, Test accuracy: {num_correct / num_test * 100:.2f}, Total loss: {total_loss}")

        torch.save(model.state_dict(), f"./dnd_mfcc_models_normalized/_epoch_{str(epoch)}.pt")
   
    end = time.time()
    seconds, minutes, hours, days = split_seconds(end - start)
    print(f"Training Runtime: {int(days)}d {int(hours)}h {int(minutes)}m {seconds:.2f}s")
   
    # Evaluate model on validation data
    model.eval()
    print("Evaluating model....")
    start = time.time()
    total_loss = 0
    num_test = 0
    num_correct = 0
   
    with torch.no_grad():
        for features, labels in tqdm(val_loader, desc="Validating"):
            features, labels = features.to(device), labels.to(device)
            logits = model(features)
            loss = loss_fn(logits, labels.to(float))
            total_loss += loss.item()
            preds = torch.round(torch.sigmoid(logits))
            num_test += labels.size(0)
            num_correct += preds.eq(labels).sum().item()
   
    print(f"Test accuracy: {num_correct / num_test * 100:.2f}, Total loss: {total_loss}")
    end = time.time()
    seconds, minutes, hours, days = split_seconds(end - start)
    print(f"Testing Runtime: {int(days)}d {int(hours)}h {int(minutes)}m {seconds:.2f}s")

In [8]:
main()

Using device: cuda
Training model....


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.19it/s]
Epochs:   0%|          | 0/64 [00:51<?, ?it/s]

Epoch 1/64, Test accuracy: 100.00, Total loss: 3.603574437871834


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.17it/s]
Epochs:   2%|▏         | 1/64 [01:42<54:12, 51.63s/it]

Epoch 2/64, Test accuracy: 100.00, Total loss: 1.4108474682628487


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.18it/s]
Epochs:   3%|▎         | 2/64 [02:34<53:10, 51.46s/it]

Epoch 3/64, Test accuracy: 100.00, Total loss: 0.7605791715452344


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.32it/s]
Epochs:   5%|▍         | 3/64 [03:26<52:30, 51.65s/it]

Epoch 4/64, Test accuracy: 100.00, Total loss: 0.4840676674470564


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:   6%|▋         | 4/64 [04:18<51:59, 52.00s/it]

Epoch 5/64, Test accuracy: 100.00, Total loss: 0.33439256304942927


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.63it/s]
Epochs:   8%|▊         | 5/64 [05:09<51:01, 51.90s/it]

Epoch 6/64, Test accuracy: 100.00, Total loss: 0.23030190469661546


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:   9%|▉         | 6/64 [06:01<49:50, 51.56s/it]

Epoch 7/64, Test accuracy: 100.00, Total loss: 0.1818152393340521


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.11it/s]
Epochs:  11%|█         | 7/64 [06:53<48:58, 51.56s/it]

Epoch 8/64, Test accuracy: 100.00, Total loss: 0.12878939980348353


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.12it/s]
Epochs:  12%|█▎        | 8/64 [07:45<48:25, 51.89s/it]

Epoch 9/64, Test accuracy: 100.00, Total loss: 0.10392325895150457


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.34it/s]
Epochs:  14%|█▍        | 9/64 [08:36<47:27, 51.78s/it]

Epoch 10/64, Test accuracy: 100.00, Total loss: 0.08261173181846163


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.09it/s]
Epochs:  16%|█▌        | 10/64 [09:27<46:20, 51.49s/it]

Epoch 11/64, Test accuracy: 100.00, Total loss: 0.06864883230749304


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.02it/s]
Epochs:  17%|█▋        | 11/64 [10:18<45:27, 51.46s/it]

Epoch 12/64, Test accuracy: 100.00, Total loss: 0.05466495605711922


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.10it/s]
Epochs:  19%|█▉        | 12/64 [11:10<44:35, 51.45s/it]

Epoch 13/64, Test accuracy: 100.00, Total loss: 0.04345347005437361


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.11it/s]
Epochs:  20%|██        | 13/64 [12:01<43:40, 51.38s/it]

Epoch 14/64, Test accuracy: 100.00, Total loss: 0.0376910450602534


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.99it/s]
Epochs:  22%|██▏       | 14/64 [12:52<42:47, 51.34s/it]

Epoch 15/64, Test accuracy: 100.00, Total loss: 0.031296254307714357


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:  23%|██▎       | 15/64 [13:44<42:00, 51.44s/it]

Epoch 16/64, Test accuracy: 100.00, Total loss: 0.026648163360783253


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.07it/s]
Epochs:  25%|██▌       | 16/64 [14:35<41:08, 51.43s/it]

Epoch 17/64, Test accuracy: 100.00, Total loss: 0.021683430268037936


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.98it/s]
Epochs:  27%|██▋       | 17/64 [15:27<40:15, 51.39s/it]

Epoch 18/64, Test accuracy: 100.00, Total loss: 0.018372921699008456


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.13it/s]
Epochs:  28%|██▊       | 18/64 [16:18<39:26, 51.44s/it]

Epoch 19/64, Test accuracy: 100.00, Total loss: 0.01526633041085764


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:  30%|██▉       | 19/64 [17:09<38:34, 51.43s/it]

Epoch 20/64, Test accuracy: 100.00, Total loss: 0.013873991851571392


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.02it/s]
Epochs:  31%|███▏      | 20/64 [18:01<37:43, 51.45s/it]

Epoch 21/64, Test accuracy: 100.00, Total loss: 0.011674941874275646


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:  33%|███▎      | 21/64 [18:52<36:51, 51.43s/it]

Epoch 22/64, Test accuracy: 100.00, Total loss: 0.009859735712262186


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.10it/s]
Epochs:  34%|███▍      | 22/64 [19:44<35:59, 51.41s/it]

Epoch 23/64, Test accuracy: 100.00, Total loss: 0.00848503557449476


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.08it/s]
Epochs:  36%|███▌      | 23/64 [20:35<35:05, 51.36s/it]

Epoch 24/64, Test accuracy: 100.00, Total loss: 0.007152467714452844


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.07it/s]
Epochs:  38%|███▊      | 24/64 [21:27<34:15, 51.40s/it]

Epoch 25/64, Test accuracy: 100.00, Total loss: 0.006442935242653726


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:  39%|███▉      | 25/64 [22:18<33:25, 51.41s/it]

Epoch 26/64, Test accuracy: 100.00, Total loss: 0.005590619499729301


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.07it/s]
Epochs:  41%|████      | 26/64 [23:09<32:33, 51.41s/it]

Epoch 27/64, Test accuracy: 100.00, Total loss: 0.0049969852052337655


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.07it/s]
Epochs:  42%|████▏     | 27/64 [24:01<31:42, 51.41s/it]

Epoch 28/64, Test accuracy: 100.00, Total loss: 0.00422432859995878


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:  44%|████▍     | 28/64 [24:52<30:49, 51.39s/it]

Epoch 29/64, Test accuracy: 100.00, Total loss: 0.003636515079610566


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.13it/s]
Epochs:  45%|████▌     | 29/64 [25:43<29:58, 51.39s/it]

Epoch 30/64, Test accuracy: 100.00, Total loss: 0.003312856400643543


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.02it/s]
Epochs:  47%|████▋     | 30/64 [26:35<29:05, 51.33s/it]

Epoch 31/64, Test accuracy: 100.00, Total loss: 0.0028483829070909098


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.08it/s]
Epochs:  48%|████▊     | 31/64 [27:26<28:14, 51.36s/it]

Epoch 32/64, Test accuracy: 100.00, Total loss: 0.002452891746262114


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.04it/s]
Epochs:  50%|█████     | 32/64 [28:17<27:21, 51.29s/it]

Epoch 33/64, Test accuracy: 100.00, Total loss: 0.002106978511092292


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.98it/s]
Epochs:  52%|█████▏    | 33/64 [29:09<26:30, 51.29s/it]

Epoch 34/64, Test accuracy: 100.00, Total loss: 0.0018312168339649462


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.04it/s]
Epochs:  53%|█████▎    | 34/64 [30:00<25:40, 51.34s/it]

Epoch 35/64, Test accuracy: 100.00, Total loss: 0.0015851412800695168


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.01it/s]
Epochs:  55%|█████▍    | 35/64 [30:52<24:50, 51.41s/it]

Epoch 36/64, Test accuracy: 100.00, Total loss: 0.001356879872871117


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  56%|█████▋    | 36/64 [31:43<23:59, 51.41s/it]

Epoch 37/64, Test accuracy: 100.00, Total loss: 0.0012203872264456726


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:  58%|█████▊    | 37/64 [32:35<23:08, 51.41s/it]

Epoch 38/64, Test accuracy: 100.00, Total loss: 0.0010860132480306423


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.96it/s]
Epochs:  59%|█████▉    | 38/64 [33:26<22:17, 51.44s/it]

Epoch 39/64, Test accuracy: 100.00, Total loss: 0.0009351677337261006


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  61%|██████    | 39/64 [34:17<21:26, 51.45s/it]

Epoch 40/64, Test accuracy: 100.00, Total loss: 0.0008205191008951513


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  62%|██████▎   | 40/64 [35:09<20:33, 51.39s/it]

Epoch 41/64, Test accuracy: 100.00, Total loss: 0.0007067668945287172


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.99it/s]
Epochs:  64%|██████▍   | 41/64 [36:00<19:43, 51.46s/it]

Epoch 42/64, Test accuracy: 100.00, Total loss: 0.0005878233714366527


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.99it/s]
Epochs:  66%|██████▌   | 42/64 [36:52<18:52, 51.45s/it]

Epoch 43/64, Test accuracy: 100.00, Total loss: 0.0005288523130791608


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  67%|██████▋   | 43/64 [37:43<18:00, 51.43s/it]

Epoch 44/64, Test accuracy: 100.00, Total loss: 0.00048704988949231794


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.09it/s]
Epochs:  69%|██████▉   | 44/64 [38:34<17:08, 51.41s/it]

Epoch 45/64, Test accuracy: 100.00, Total loss: 0.0004143521318324714


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.99it/s]
Epochs:  70%|███████   | 45/64 [39:26<16:15, 51.35s/it]

Epoch 46/64, Test accuracy: 100.00, Total loss: 0.0003564825179856888


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.07it/s]
Epochs:  72%|███████▏  | 46/64 [40:17<15:25, 51.42s/it]

Epoch 47/64, Test accuracy: 100.00, Total loss: 0.00032092204931583664


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.99it/s]
Epochs:  73%|███████▎  | 47/64 [41:09<14:33, 51.40s/it]

Epoch 48/64, Test accuracy: 100.00, Total loss: 0.00028318550639422964


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.12it/s]
Epochs:  75%|███████▌  | 48/64 [42:00<13:42, 51.38s/it]

Epoch 49/64, Test accuracy: 100.00, Total loss: 0.0002525898714173833


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.96it/s]
Epochs:  77%|███████▋  | 49/64 [42:51<12:49, 51.30s/it]

Epoch 50/64, Test accuracy: 100.00, Total loss: 0.000217032189872531


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.08it/s]
Epochs:  78%|███████▊  | 50/64 [43:42<11:59, 51.36s/it]

Epoch 51/64, Test accuracy: 100.00, Total loss: 0.00019081942055601233


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.08it/s]
Epochs:  80%|███████▉  | 51/64 [44:34<11:07, 51.35s/it]

Epoch 52/64, Test accuracy: 100.00, Total loss: 0.0001683810102501108


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.02it/s]
Epochs:  81%|████████▏ | 52/64 [45:25<10:15, 51.26s/it]

Epoch 53/64, Test accuracy: 100.00, Total loss: 0.00014645020724716554


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:  83%|████████▎ | 53/64 [46:16<09:23, 51.24s/it]

Epoch 54/64, Test accuracy: 100.00, Total loss: 0.00012736799886248207


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:  84%|████████▍ | 54/64 [47:08<08:33, 51.39s/it]

Epoch 55/64, Test accuracy: 100.00, Total loss: 0.00011231982248954214


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  86%|████████▌ | 55/64 [47:59<07:42, 51.37s/it]

Epoch 56/64, Test accuracy: 100.00, Total loss: 9.953097327652927e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.04it/s]
Epochs:  88%|████████▊ | 56/64 [48:51<06:50, 51.34s/it]

Epoch 57/64, Test accuracy: 100.00, Total loss: 8.508122548864548e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  4.97it/s]
Epochs:  89%|████████▉ | 57/64 [49:42<05:59, 51.38s/it]

Epoch 58/64, Test accuracy: 100.00, Total loss: 7.521134462172717e-05


Testing: 100%|██████████| 46/46 [00:08<00:00,  5.15it/s]
Epochs:  91%|█████████ | 58/64 [50:34<05:08, 51.49s/it]

Epoch 59/64, Test accuracy: 100.00, Total loss: 6.761072967511852e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.06it/s]
Epochs:  92%|█████████▏| 59/64 [51:25<04:17, 51.45s/it]

Epoch 60/64, Test accuracy: 100.00, Total loss: 5.959804348165676e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.04it/s]
Epochs:  94%|█████████▍| 60/64 [52:16<03:25, 51.42s/it]

Epoch 61/64, Test accuracy: 100.00, Total loss: 5.2020784749048e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.09it/s]
Epochs:  95%|█████████▌| 61/64 [53:08<02:34, 51.45s/it]

Epoch 62/64, Test accuracy: 100.00, Total loss: 4.5718319603957643e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.00it/s]
Epochs:  97%|█████████▋| 62/64 [54:00<01:42, 51.43s/it]

Epoch 63/64, Test accuracy: 100.00, Total loss: 3.961825437339941e-05


Testing: 100%|██████████| 46/46 [00:09<00:00,  5.03it/s]
Epochs:  98%|█████████▊| 63/64 [54:51<00:51, 51.51s/it]

Epoch 64/64, Test accuracy: 100.00, Total loss: 3.5263611240029156e-05


Epochs: 100%|██████████| 64/64 [54:51<00:00, 51.44s/it]


Training Runtime: 0d 0h 54m 51.98s
Evaluating model....


Validating: 100%|██████████| 16/16 [00:03<00:00,  4.94it/s]

Test accuracy: 100.00, Total loss: 1.2245395306938652e-05
Testing Runtime: 0d 0h 0m 3.24s





In [9]:
# Define model
model = LNN()
# Cuda setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
model = model.to(device)

print(f"Using device: {device}")

directory = "./dnd_mfcc_models"
acc_epoch = None
max_acc = 0
loss_epoch = None
min_loss = float('inf')

for file in os.listdir(directory):
    # Optimizer setup
    optimizer = Adam(model.parameters(), lr=1e-3)

    # Loss function
    loss_fn = nn.BCEWithLogitsLoss(reduction="mean")

    # Load model
    filepath = os.path.join(directory, file)
    state = torch.load(filepath)
    model.load_state_dict(state)

    # Evaluate model on validation data
    model.eval()
    print(f"Evaluating model {file}")
    total_loss = 0
    num_test = 0
    num_correct = 0

    with torch.no_grad():
        for features, labels in tqdm(val_loader, desc="Validating"):
            features, labels = features.to(device), labels.to(device)
            logits = model(features)
            loss = loss_fn(logits, labels.to(float))
            total_loss += loss.item()
            preds = torch.round(torch.sigmoid(logits))
            num_test += labels.size(0)
            num_correct += preds.eq(labels).sum().item()

    accuracy = num_correct / num_test * 100
    if accuracy > max_acc:
        max_acc = accuracy
        acc_epoch = file
    if total_loss < min_loss:
        min_loss = total_loss
        loss_epoch = file

    print(f"Test accuracy: {accuracy:.2f}, Total loss: {total_loss}")

print(f"Best accuracy: {max_acc} at {acc_epoch}")
print(f"Best loss: {min_loss} at {loss_epoch}")

Using device: cuda
Evaluating model _epoch_0.pt


Validating: 100%|██████████| 16/16 [00:02<00:00,  6.01it/s]


Test accuracy: 99.38, Total loss: 0.44693242612993345
Evaluating model _epoch_1.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.10it/s]


Test accuracy: 99.48, Total loss: 0.3530965393911174
Evaluating model _epoch_10.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.11it/s]


Test accuracy: 99.79, Total loss: 0.22827279053884317
Evaluating model _epoch_11.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.06it/s]


Test accuracy: 99.58, Total loss: 0.347285017484694
Evaluating model _epoch_12.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.09it/s]


Test accuracy: 99.69, Total loss: 0.22337744270738824
Evaluating model _epoch_13.pt


Validating: 100%|██████████| 16/16 [00:02<00:00,  5.69it/s]


Test accuracy: 99.69, Total loss: 0.22835588566000808
Evaluating model _epoch_14.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.18it/s]


Test accuracy: 99.48, Total loss: 0.22730364123986924
Evaluating model _epoch_15.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.12it/s]


Test accuracy: 99.69, Total loss: 0.23456750425281334
Evaluating model _epoch_16.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.12it/s]


Test accuracy: 99.58, Total loss: 0.3051924229337435
Evaluating model _epoch_17.pt


Validating: 100%|██████████| 16/16 [00:02<00:00,  5.83it/s]


Test accuracy: 99.58, Total loss: 0.3038771496738093
Evaluating model _epoch_18.pt


Validating: 100%|██████████| 16/16 [00:02<00:00,  5.70it/s]


Test accuracy: 99.27, Total loss: 0.5479878352402408
Evaluating model _epoch_19.pt


Validating: 100%|██████████| 16/16 [00:02<00:00,  5.62it/s]


Test accuracy: 99.79, Total loss: 0.24021291696023775
Evaluating model _epoch_2.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.20it/s]


Test accuracy: 99.27, Total loss: 0.3167606486290424
Evaluating model _epoch_20.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.09it/s]


Test accuracy: 99.58, Total loss: 0.27579902713683374
Evaluating model _epoch_21.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.12it/s]


Test accuracy: 99.58, Total loss: 0.2779007012330504
Evaluating model _epoch_3.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.18it/s]


Test accuracy: 99.58, Total loss: 0.28368345780837484
Evaluating model _epoch_4.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.13it/s]


Test accuracy: 99.58, Total loss: 0.2625419608593802
Evaluating model _epoch_5.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.21it/s]


Test accuracy: 99.58, Total loss: 0.2822882408581305
Evaluating model _epoch_6.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.16it/s]


Test accuracy: 99.48, Total loss: 0.2629280700207346
Evaluating model _epoch_7.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.15it/s]


Test accuracy: 99.58, Total loss: 0.224102534888137
Evaluating model _epoch_8.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.16it/s]


Test accuracy: 99.07, Total loss: 0.5374630583117247
Evaluating model _epoch_9.pt


Validating: 100%|██████████| 16/16 [00:03<00:00,  5.15it/s]

Test accuracy: 99.69, Total loss: 0.24447241613691517
Best accuracy: 99.79231568016614 at _epoch_10.pt
Best loss: 0.22337744270738824 at _epoch_12.pt



