In [4]:
import pandas as pd
import os

# Load CSV metadata
meta_df = pd.read_csv('ESC-50-master/meta/esc50.csv')

# Add full path to audio files
meta_df['filepath'] = meta_df['filename'].apply(lambda x: os.path.join('ESC-50-master/audio/', x))

# Display sample
meta_df.head()


Unnamed: 0,filename,fold,target,category,esc10,src_file,take,filepath
0,1-100032-A-0.wav,1,0,dog,True,100032,A,ESC-50-master/audio/1-100032-A-0.wav
1,1-100038-A-14.wav,1,14,chirping_birds,False,100038,A,ESC-50-master/audio/1-100038-A-14.wav
2,1-100210-A-36.wav,1,36,vacuum_cleaner,False,100210,A,ESC-50-master/audio/1-100210-A-36.wav
3,1-100210-B-36.wav,1,36,vacuum_cleaner,False,100210,B,ESC-50-master/audio/1-100210-B-36.wav
4,1-101296-A-19.wav,1,19,thunderstorm,False,101296,A,ESC-50-master/audio/1-101296-A-19.wav


In [6]:
import random
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset
from audiomentations import Compose, AddGaussianNoise, PitchShift, TimeStretch, Gain

class ESC50Dataset(Dataset):
    def __init__(self, df, sample_rate=44100, duration=5.0, augment_type='none', n_mels=128):
        self.df = df.reset_index(drop=True)
        self.sr = sample_rate
        self.length = int(sample_rate * duration)
        self.augment_type = augment_type
        self.n_mels = n_mels

        self.weak_transform = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        ])
        self.strong_transform = Compose([
            AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.02, p=0.5),
            PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
            TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
            Gain(min_gain_db=-6, max_gain_db=6, p=0.5)
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row['filepath']
        label = row['target'] if 'target' in row else -1

        y, _ = librosa.load(path, sr=self.sr)
        if len(y) < self.length:
            y = np.pad(y, (0, self.length - len(y)))
        else:
            y = y[:self.length]

        if self.augment_type == 'weak':
            y = self.weak_transform(samples=y, sample_rate=self.sr)
        elif self.augment_type == 'strong':
            y = self.strong_transform(samples=y, sample_rate=self.sr)

        # Convert to Mel spectrogram
        mel = librosa.feature.melspectrogram(y=y, sr=self.sr, n_mels=self.n_mels)
        mel_db = librosa.power_to_db(mel + 1e-6, ref=np.max)
        mel_db = np.clip(mel_db, a_min=-80, a_max=0)  # Avoid crazy ranges


        # Normalize and convert to torch tensor [1, H, W]
        mel_tensor = torch.tensor(mel_db, dtype=torch.float32).unsqueeze(0)

        return mel_tensor, torch.tensor(label, dtype=torch.long)


In [3]:
def split_labeled_unlabeled_and_save(df, labeled_fraction=0.1, save_dir="labeled-unlabeled"):
    import os
    os.makedirs(f'{save_dir}/{labeled_fraction}', exist_ok=True)

    labeled_df_list = []
    unlabeled_df_list = []

    for label in sorted(df['target'].unique()):
        class_df = df[df['target'] == label]
        n_total = len(class_df)
        n_labeled = max(1, int(n_total * labeled_fraction))

        labeled = class_df.sample(n=n_labeled, random_state=42)
        unlabeled = class_df.drop(labeled.index)

        labeled_df_list.append(labeled)
        unlabeled_df_list.append(unlabeled)

    labeled_df = pd.concat(labeled_df_list).reset_index(drop=True)
    unlabeled_df = pd.concat(unlabeled_df_list).reset_index(drop=True)

    # Drop label/category from the unlabeled set
    unlabeled_df = unlabeled_df.drop(columns=["target", "category"])

    labeled_df.head()
    unlabeled_df.head()
    
    # Save both
    labeled_df.to_csv(f"{save_dir}/{labeled_fraction}/labeled.csv", index=False)      # with labels
    unlabeled_df.to_csv(f"{save_dir}/{labeled_fraction}/unlabeled.csv", index=False)  # without labels

    print(f"Saved labeled.csv (with labels) and unlabeled.csv (no labels) to '{save_dir}/{labeled_fraction}'")




In [4]:
from torch.utils.data import DataLoader

def get_ssl_loaders(meta_df,labeled_fraction=0.1, fold=1, batch_size=16, split_dir="labeled-unlabeled"):
    # Load pre-saved labeled and unlabeled CSVs for this fold
    labeled_df = pd.read_csv(f"{split_dir}/{labeled_fraction}/labeled.csv")
    unlabeled_df = pd.read_csv(f"{split_dir}/{labeled_fraction}/unlabeled.csv")

    # Get validation set from meta_df
    val_df = meta_df[meta_df['fold'] == fold]

    # Create datasets
    labeled_dataset = ESC50Dataset(labeled_df, augment_type='weak')
    unlabeled_dataset = DualViewESC50Dataset(unlabeled_df)
    val_dataset = ESC50Dataset(val_df, augment_type='none')

    # Create loaders
    labeled_loader = DataLoader(labeled_dataset, batch_size=batch_size, shuffle=True)
    unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    return labeled_loader, unlabeled_loader, val_loader



In [7]:
class DualViewESC50Dataset(Dataset):
    def __init__(self, df):
        self.weak_dataset = ESC50Dataset(df, augment_type='weak')
        self.strong_dataset = ESC50Dataset(df, augment_type='strong')

    def __len__(self):
        return len(self.weak_dataset)

    def __getitem__(self, idx):
        weak_x, _ = self.weak_dataset[idx]
        strong_x, _ = self.strong_dataset[idx]
        return weak_x, strong_x


In [1]:
import torch.nn as nn
import torchvision.models as models
import torch
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

class ESC50CNN(nn.Module):
    def __init__(self, num_classes=50):
        super().__init__()
        self.base = models.resnet18(pretrained=True)
        self.base.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.base.fc = nn.Linear(self.base.fc.in_features, num_classes)

    def forward(self, x):
        return self.base(x)

def get_perturbation(model, labeled_loader, loss_fn, rho=0.1):
    model.train()
    for x, y in labeled_loader:
        x, y = x.to(device), y.to(device)
        x.requires_grad = True
        output = model(x)
        loss = loss_fn(output, y)
        loss.backward()
        grad = []
        for param in model.parameters():
            if param.grad is not None:
                grad.append(param.grad.view(-1))
        grad_vector = torch.cat(grad)
        norm = torch.norm(grad_vector, p=2)
        eps = rho * grad_vector / (norm + 1e-12)

        # Apply perturbation
        i = 0
        for param in model.parameters():
            if param.requires_grad and param.grad is not None:
                numel = param.numel()
                param.data.add_(eps[i:i+numel].view_as(param))
                i += numel
        break  # Use only 1 batch
    return model

def cross_sharpness_loss(model_orig, model_perturbed, unlabeled_loader):
    criterion = nn.KLDivLoss(reduction='batchmean')
    model_orig.eval()
    model_perturbed.eval()

    with torch.no_grad():
        for weak_x, _ in unlabeled_loader:
            weak_x = weak_x.to(device)

            # Safe softmax/log_softmax with clamping
            out_orig = torch.clamp(torch.softmax(model_orig(weak_x), dim=1), min=1e-7)
            out_pert = torch.log_softmax(model_perturbed(weak_x), dim=1)

            loss = criterion(out_pert, out_orig)
            break
    return loss



Using device: cpu


In [2]:
import torch.nn.functional as F

def train_flatmatch_one_epoch(model, labeled_loader, unlabeled_loader, optimizer, device, epoch, rho=0.1):
    model.train()
    total_loss, total_labeled_loss, total_cs_loss = 0.0, 0.0, 0.0

    unlabeled_iter = iter(unlabeled_loader)

    for x_l, y_l in labeled_loader:
        try:
            xw, xs = next(unlabeled_iter)
        except StopIteration:
            unlabeled_iter = iter(unlabeled_loader)
            xw, xs = next(unlabeled_iter)

        x_l, y_l = x_l.to(device), y_l.to(device)
        xw = xw.to(device)  # Only weak used in flatmatch (strong is ignored)

        # Supervised loss
        logits_l = model(x_l)
        ce_loss = F.cross_entropy(logits_l, y_l)

        # Perturb model on labeled batch
        model_perturbed = get_perturbation(copy.deepcopy(model), [(x_l, y_l)], F.cross_entropy, rho=rho)

        # Cross-sharpness KL divergence
        with torch.no_grad():
            p_orig = torch.clamp(F.softmax(model(xw), dim=1), min=1e-7)
        p_pert = F.log_softmax(model_perturbed(xw), dim=1)
        cs_loss = F.kl_div(p_pert, p_orig, reduction="batchmean")

        loss = ce_loss + cs_loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer.step()

        total_loss += loss.item()
        total_labeled_loss += ce_loss.item()
        total_cs_loss += cs_loss.item()

    print(f"[Epoch {epoch}] Loss: {total_loss:.4f} | CE: {total_labeled_loss:.4f} | CrossSharp: {total_cs_loss:.4f}")


In [10]:
from sklearn.metrics import classification_report, roc_auc_score, average_precision_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import numpy as np

def evaluate(encoder, classifier, val_loader, device, num_classes=50, plot_loss_curves=False, train_losses=None, val_losses=None):
    encoder.eval()
    classifier.eval()

    correct, total = 0, 0
    all_preds = []
    all_probs = []
    all_labels = []

    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(device), y.to(device)
            logits = classifier(encoder(x))

            probs = torch.softmax(logits, dim=1)
            preds = probs.argmax(dim=1)

            all_probs.append(probs.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            all_labels.append(y.cpu().numpy())

            correct += (preds == y).sum().item()
            total += y.size(0)

    acc = 100.0 * correct / total
    all_probs = np.concatenate(all_probs, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # print(f"\nValidation Accuracy: {acc:.2f}%")
    # print("\nClassification Report:")
    # print(classification_report(all_labels, all_preds, digits=4))

    # Core metrics
    precision = precision_score(all_labels, all_preds, average="macro", zero_division=0)
    recall = recall_score(all_labels, all_preds, average="macro", zero_division=0)
    f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)

    # AUC-PRC
    try:
        auc_prc = average_precision_score(
            y_true=np.eye(num_classes)[all_labels],
            y_score=all_probs,
            average="macro"
        )
        print(f"AUC-PRC (macro): {auc_prc:.4f}")
    except Exception as e:
        auc_prc = None
        print(f"AUC-PRC: Not computable — {str(e)}")

    # AUC-ROC
    try:
        auc_roc = roc_auc_score(
            y_true=np.eye(num_classes)[all_labels],
            y_score=all_probs,
            average="macro",
            multi_class='ovr'
        )
        print(f"AUC-ROC (macro): {auc_roc:.4f}")
    except Exception as e:
        auc_roc = None
        print(f"AUC-ROC: Not computable — {str(e)}")

    # Optional: Plot Loss Curves
    if plot_loss_curves and train_losses and val_losses:
        plt.figure(figsize=(8, 5))
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Train vs Validation Loss")
        plt.legend()
        plt.grid(True)
        plt.show()

    # Return all stats
    return {
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "auc_prc": auc_prc,
        "auc_roc": auc_roc
    }


In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

num_epochs = 10
batch_size = 16
labeled_fraction = 0.2
split_dir = "labeled-unlabeled"

fold_accuracies = []
fold_metrics = []

save_dir = "models/ESC-50/FlatMatch"
os.makedirs(save_dir, exist_ok=True)

os.makedirs("models/ESC-50/FlatMatchnew", exist_ok=True)

for fold in range(1, 6):
    print(f"\n=== Fold {fold} ===")

    # Load data
    labeled_loader, unlabeled_loader, val_loader = get_ssl_loaders(
        meta_df=meta_df,
        labeled_fraction=labeled_fraction,
        fold=fold,
        batch_size=batch_size,
        split_dir=split_dir
    )

    model = ESC50CNN(num_classes=50).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    best_acc = 0.0
    best_metrics = None
    model_path = os.path.join(save_dir, f"best_model_fold{fold}.pt")

    # Try loading existing model
    if os.path.exists(model_path):
        print(f"Found saved model for Fold {fold}. Loading checkpoint...")
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        best_acc = checkpoint.get('metrics', {}).get('accuracy', 0.0)
        best_metrics = checkpoint.get('metrics')
        start_epoch = checkpoint.get('epoch', 0) + 1
    else:
        start_epoch = 1

    # Train
    for epoch in range(start_epoch, num_epochs + start_epoch):
        train_flatmatch_one_epoch(
            model=model,
            labeled_loader=labeled_loader,
            unlabeled_loader=unlabeled_loader,
            optimizer=optimizer,
            device=device,
            epoch=epoch,
            rho=0.1
        )

        metrics = evaluate(
            encoder=model,
            classifier=nn.Identity(),
            val_loader=val_loader,
            device=device,
            num_classes=50,
            plot_loss_curves=False
        )

        acc = metrics["accuracy"]
        print(f"Validation Accuracy: {acc:.2f}%")

        if acc > best_acc:
            best_acc = acc
            best_metrics = metrics
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'fold': fold,
                'epoch': epoch,
                'metrics': best_metrics
            }, os.path.join("models/ESC-50/FlatMatchnew", f"best_model_fold{fold}.pt"))
            print(f"Saved best model for Fold {fold} at epoch {epoch}")

    fold_accuracies.append(best_acc)
    fold_metrics.append(best_metrics)
    print(f"Best accuracy for Fold {fold}: {best_acc:.2f}%")

# Summary
print("\n=== 5-Fold Summary ===")
avg_metrics = {}

for i, m in enumerate(fold_metrics, 1):
    print(f"\nFold {i}:")
    for k, v in m.items():
        print(f"  {k}: {v:.4f}")
        avg_metrics[k] = avg_metrics.get(k, 0.0) + (v if v is not None else 0.0)

print("\nAverage Metrics Across Folds:")
for k, total in avg_metrics.items():
    avg_val = total / len(fold_metrics)
    print(f"  {k}: {avg_val:.4f}")


Using device: cpu

=== Fold 1 ===




Found saved model for Fold 1. Loading checkpoint...
[Epoch 20] Loss: 4.4667 | CE: 3.9756 | CrossSharp: 0.4911
AUC-PRC (macro): 0.6752
AUC-ROC (macro): 0.9579
Validation Accuracy: 59.25%
[Epoch 21] Loss: 5.0378 | CE: 4.5590 | CrossSharp: 0.4788
AUC-PRC (macro): 0.6588
AUC-ROC (macro): 0.9491
Validation Accuracy: 54.50%
[Epoch 22] Loss: 7.9515 | CE: 7.4259 | CrossSharp: 0.5257
AUC-PRC (macro): 0.6674
AUC-ROC (macro): 0.9565
Validation Accuracy: 60.25%
[Epoch 23] Loss: 5.1372 | CE: 4.7690 | CrossSharp: 0.3682
AUC-PRC (macro): 0.6274
AUC-ROC (macro): 0.9517
Validation Accuracy: 54.75%
[Epoch 24] Loss: 4.6626 | CE: 4.3759 | CrossSharp: 0.2867
AUC-PRC (macro): 0.6868
AUC-ROC (macro): 0.9626
Validation Accuracy: 58.00%
[Epoch 25] Loss: 4.7298 | CE: 4.2589 | CrossSharp: 0.4709
AUC-PRC (macro): 0.6725
AUC-ROC (macro): 0.9624
Validation Accuracy: 58.50%
[Epoch 26] Loss: 2.9814 | CE: 2.6387 | CrossSharp: 0.3427
AUC-PRC (macro): 0.6866
AUC-ROC (macro): 0.9687
Validation Accuracy: 57.00%
[Epoch 27]



[Epoch 21] Loss: 5.4350 | CE: 4.8431 | CrossSharp: 0.5919
AUC-PRC (macro): 0.7559
AUC-ROC (macro): 0.9672
Validation Accuracy: 64.00%
[Epoch 22] Loss: 5.1483 | CE: 4.4742 | CrossSharp: 0.6741
AUC-PRC (macro): 0.7588
AUC-ROC (macro): 0.9727
Validation Accuracy: 65.00%
[Epoch 23] Loss: 4.7671 | CE: 4.3925 | CrossSharp: 0.3747
AUC-PRC (macro): 0.7540
AUC-ROC (macro): 0.9696
Validation Accuracy: 66.25%
Saved best model for Fold 2 at epoch 23
[Epoch 24] Loss: 3.6879 | CE: 3.3208 | CrossSharp: 0.3671
AUC-PRC (macro): 0.7722
AUC-ROC (macro): 0.9754
Validation Accuracy: 67.50%
Saved best model for Fold 2 at epoch 24
[Epoch 25] Loss: 2.5876 | CE: 2.1488 | CrossSharp: 0.4387
AUC-PRC (macro): 0.7956
AUC-ROC (macro): 0.9775
Validation Accuracy: 70.00%
Saved best model for Fold 2 at epoch 25
[Epoch 26] Loss: 1.4980 | CE: 1.1899 | CrossSharp: 0.3081
AUC-PRC (macro): 0.7782
AUC-ROC (macro): 0.9748
Validation Accuracy: 69.00%
[Epoch 27] Loss: 1.9038 | CE: 1.4835 | CrossSharp: 0.4203
AUC-PRC (macro): 0



Found saved model for Fold 3. Loading checkpoint...
[Epoch 18] Loss: 6.6607 | CE: 6.2889 | CrossSharp: 0.3717
AUC-PRC (macro): 0.7521
AUC-ROC (macro): 0.9732
Validation Accuracy: 63.00%
[Epoch 19] Loss: 4.8829 | CE: 4.4257 | CrossSharp: 0.4572
AUC-PRC (macro): 0.8057
AUC-ROC (macro): 0.9811
Validation Accuracy: 71.00%
Saved best model for Fold 3 at epoch 19
[Epoch 20] Loss: 5.3251 | CE: 4.7786 | CrossSharp: 0.5465
AUC-PRC (macro): 0.8291
AUC-ROC (macro): 0.9882
Validation Accuracy: 73.75%
Saved best model for Fold 3 at epoch 20
[Epoch 21] Loss: 4.0788 | CE: 3.5869 | CrossSharp: 0.4919
AUC-PRC (macro): 0.8135
AUC-ROC (macro): 0.9841
Validation Accuracy: 69.75%
[Epoch 22] Loss: 4.6072 | CE: 4.0782 | CrossSharp: 0.5290
AUC-PRC (macro): 0.7874
AUC-ROC (macro): 0.9784
Validation Accuracy: 69.25%
[Epoch 23] Loss: 6.7049 | CE: 6.2524 | CrossSharp: 0.4525
AUC-PRC (macro): 0.7773
AUC-ROC (macro): 0.9763
Validation Accuracy: 69.50%
[Epoch 24] Loss: 4.6092 | CE: 4.1673 | CrossSharp: 0.4419
AUC-PR



[Epoch 20] Loss: 3.9905 | CE: 3.6162 | CrossSharp: 0.3743
AUC-PRC (macro): 0.7930
AUC-ROC (macro): 0.9727
Validation Accuracy: 68.25%
[Epoch 21] Loss: 4.0343 | CE: 3.4990 | CrossSharp: 0.5353
AUC-PRC (macro): 0.7860
AUC-ROC (macro): 0.9725
Validation Accuracy: 68.50%
[Epoch 22] Loss: 6.2828 | CE: 5.4211 | CrossSharp: 0.8617
AUC-PRC (macro): 0.7125
AUC-ROC (macro): 0.9592
Validation Accuracy: 61.75%
[Epoch 23] Loss: 8.0912 | CE: 7.5167 | CrossSharp: 0.5745
AUC-PRC (macro): 0.7715
AUC-ROC (macro): 0.9736
Validation Accuracy: 67.25%
[Epoch 24] Loss: 5.0631 | CE: 4.5213 | CrossSharp: 0.5418
AUC-PRC (macro): 0.7854
AUC-ROC (macro): 0.9789
Validation Accuracy: 69.75%
[Epoch 25] Loss: 3.1063 | CE: 2.6304 | CrossSharp: 0.4759
AUC-PRC (macro): 0.7795
AUC-ROC (macro): 0.9767
Validation Accuracy: 65.50%
[Epoch 26] Loss: 3.8216 | CE: 3.3493 | CrossSharp: 0.4723
AUC-PRC (macro): 0.7468
AUC-ROC (macro): 0.9672
Validation Accuracy: 66.25%
[Epoch 27] Loss: 4.5374 | CE: 4.0556 | CrossSharp: 0.4818
AUC-



[Epoch 21] Loss: 8.6932 | CE: 8.1874 | CrossSharp: 0.5058
AUC-PRC (macro): 0.6079
AUC-ROC (macro): 0.9463
Validation Accuracy: 49.25%
[Epoch 22] Loss: 7.4399 | CE: 6.9611 | CrossSharp: 0.4788
AUC-PRC (macro): 0.6129
AUC-ROC (macro): 0.9462
Validation Accuracy: 52.50%
[Epoch 23] Loss: 5.3901 | CE: 4.7309 | CrossSharp: 0.6593
AUC-PRC (macro): 0.6780
AUC-ROC (macro): 0.9627
Validation Accuracy: 58.00%
[Epoch 24] Loss: 7.0740 | CE: 6.5913 | CrossSharp: 0.4827
AUC-PRC (macro): 0.6412
AUC-ROC (macro): 0.9533
Validation Accuracy: 55.00%
[Epoch 25] Loss: 5.6567 | CE: 5.1186 | CrossSharp: 0.5381
AUC-PRC (macro): 0.6682
AUC-ROC (macro): 0.9504
Validation Accuracy: 59.00%
Saved best model for Fold 5 at epoch 25
[Epoch 26] Loss: 6.0656 | CE: 5.5379 | CrossSharp: 0.5277
AUC-PRC (macro): 0.6503
AUC-ROC (macro): 0.9545
Validation Accuracy: 56.25%
[Epoch 27] Loss: 3.9672 | CE: 3.6098 | CrossSharp: 0.3574
AUC-PRC (macro): 0.6847
AUC-ROC (macro): 0.9599
Validation Accuracy: 57.50%
[Epoch 28] Loss: 2.629

In [13]:
!pip install torch torchvision torchaudio





[notice] A new release of pip is available: 24.2 -> 25.0.1
[notice] To update, run: C:\Users\Arnav\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\python.exe -m pip install --upgrade pip


In [7]:
!pip show torch

Name: torch
Version: 2.6.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: c:\users\arnav\appdata\local\packages\pythonsoftwarefoundation.python.3.10_qbz5n2kfra8p0\localcache\local-packages\python310\site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: torchaudio, torchvision


In [8]:
!nvidia-smi


Sun Apr 13 10:56:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 566.07                 Driver Version: 566.07         CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   57C    P0             16W /   75W |       0MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [11]:
def evaluate_model_on_full_esc50(model_path, meta_df, batch_size=16, device=None):
    import torch
    import numpy as np
    from sklearn.metrics import classification_report
    from torch.utils.data import DataLoader
    from tqdm import tqdm

    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create full dataset (no augmentations)
    full_dataset = ESC50Dataset(meta_df, augment_type='none')
    full_loader = DataLoader(full_dataset, batch_size=batch_size, shuffle=False)

    # Load model
    model = ESC50CNN(num_classes=50).to(device)
    checkpoint = torch.load(model_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

    all_preds = []
    all_targets = []

    with torch.no_grad():
        for x, y in tqdm(full_loader, desc="Evaluating on full ESC-50"):
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = torch.argmax(logits, dim=1)

            all_preds.append(preds.cpu())
            all_targets.append(y.cpu())

    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_targets).numpy()

    print("\nClassification Report on Full ESC-50:")
    print(classification_report(y_true, y_pred, digits=4))

    report = classification_report(y_true, y_pred, digits=4, output_dict=True)

    return {
        "accuracy": report['accuracy'],
        "precision": report['macro avg']['precision'],
        "recall": report['macro avg']['recall'],
        "f1": report['macro avg']['f1-score']
    }

In [13]:
fold_metrics = []

for fold in range(1, 6):
    print(f"\n=== Evaluating Fold {fold} Model on Full ESC-50 Dataset ===")
    model_path = f"models/ESC-50/FlatMatchnew/best_model_fold{fold}.pt"
    
    if not os.path.exists(model_path):
        print(f"Model not found at {model_path}, skipping...")
        continue
    
    metrics = evaluate_model_on_full_esc50(
        model_path=model_path,
        meta_df=meta_df,
        batch_size=16
    )
    
    fold_metrics.append((fold, metrics))
    print(f"Fold {fold} Metrics: {metrics}")

# Average metrics across folds
print("\n=== Average Performance on Full ESC-50 (by fold models) ===")
avg_metrics = {}
for _, m in fold_metrics:
    for k, v in m.items():
        avg_metrics[k] = avg_metrics.get(k, 0.0) + v

for k, v in avg_metrics.items():
    print(f"{k}: {v / len(fold_metrics):.4f}")



=== Evaluating Fold 1 Model on Full ESC-50 Dataset ===


Evaluating on full ESC-50: 100%|█████████████████████████████████████████████████████| 125/125 [00:54<00:00,  2.31it/s]



Classification Report on Full ESC-50:
              precision    recall  f1-score   support

           0     0.7436    0.7250    0.7342        40
           1     0.9000    0.9000    0.9000        40
           2     0.7941    0.6750    0.7297        40
           3     0.8333    0.7500    0.7895        40
           4     0.7059    0.6000    0.6486        40
           5     0.8125    0.6500    0.7222        40
           6     0.6444    0.7250    0.6824        40
           7     0.7097    0.5500    0.6197        40
           8     0.7838    0.7250    0.7532        40
           9     0.8500    0.8500    0.8500        40
          10     0.5357    0.7500    0.6250        40
          11     0.8846    0.5750    0.6970        40
          12     0.8750    0.5250    0.6562        40
          13     0.7727    0.4250    0.5484        40
          14     0.6154    0.8000    0.6957        40
          15     0.6897    0.5000    0.5797        40
          16     0.6000    0.5250    0.560

Evaluating on full ESC-50: 100%|█████████████████████████████████████████████████████| 125/125 [00:53<00:00,  2.33it/s]



Classification Report on Full ESC-50:
              precision    recall  f1-score   support

           0     0.6667    0.7000    0.6829        40
           1     0.8182    0.9000    0.8571        40
           2     0.8846    0.5750    0.6970        40
           3     0.9524    0.5000    0.6557        40
           4     0.7045    0.7750    0.7381        40
           5     0.7778    0.7000    0.7368        40
           6     0.6897    0.5000    0.5797        40
           7     0.5641    0.5500    0.5570        40
           8     0.9524    0.5000    0.6557        40
           9     0.8333    0.8750    0.8537        40
          10     0.6512    0.7000    0.6747        40
          11     0.7000    0.7000    0.7000        40
          12     0.7500    0.6000    0.6667        40
          13     0.5588    0.4750    0.5135        40
          14     0.5254    0.7750    0.6263        40
          15     0.6667    0.5000    0.5714        40
          16     0.8261    0.4750    0.603

Evaluating on full ESC-50: 100%|█████████████████████████████████████████████████████| 125/125 [00:53<00:00,  2.32it/s]



Classification Report on Full ESC-50:
              precision    recall  f1-score   support

           0     0.7714    0.6750    0.7200        40
           1     0.7778    0.8750    0.8235        40
           2     0.9583    0.5750    0.7188        40
           3     0.7647    0.6500    0.7027        40
           4     0.6744    0.7250    0.6988        40
           5     0.7273    0.6000    0.6575        40
           6     0.6458    0.7750    0.7045        40
           7     0.8214    0.5750    0.6765        40
           8     0.8519    0.5750    0.6866        40
           9     0.8919    0.8250    0.8571        40
          10     0.4194    0.6500    0.5098        40
          11     0.8571    0.6000    0.7059        40
          12     0.7941    0.6750    0.7297        40
          13     0.7857    0.5500    0.6471        40
          14     0.7209    0.7750    0.7470        40
          15     0.4894    0.5750    0.5287        40
          16     0.5789    0.5500    0.564

Evaluating on full ESC-50: 100%|█████████████████████████████████████████████████████| 125/125 [00:54<00:00,  2.31it/s]


Classification Report on Full ESC-50:
              precision    recall  f1-score   support

           0     0.8889    0.6000    0.7164        40
           1     0.8605    0.9250    0.8916        40
           2     0.8500    0.4250    0.5667        40
           3     0.8387    0.6500    0.7324        40
           4     0.6809    0.8000    0.7356        40
           5     0.5778    0.6500    0.6118        40
           6     0.6875    0.5500    0.6111        40
           7     0.9444    0.4250    0.5862        40
           8     0.7941    0.6750    0.7297        40
           9     0.7317    0.7500    0.7407        40
          10     0.8571    0.7500    0.8000        40
          11     0.8125    0.6500    0.7222        40
          12     0.7143    0.7500    0.7317        40
          13     0.5946    0.5500    0.5714        40
          14     0.6500    0.6500    0.6500        40
          15     0.5938    0.4750    0.5278        40
          16     0.6333    0.4750    0.542


