In [7]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
)
import librosa
import torchaudio
from torchaudio import transforms
import random
from collections import defaultdict, Counter
import torchvision.models as models
import warnings
import pickle
import time
import shutil
from pathlib import Path

# Ignore warnings
warnings.filterwarnings("ignore")

# Set random seeds for reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create output directory
output_dir = "SEFOSS_AudioData"
os.makedirs(output_dir, exist_ok=True)

# Audio constants
SAMPLE_RATE = 44100
DURATION = 5  # seconds
NUM_SAMPLES = SAMPLE_RATE * DURATION
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 512
FMIN = 20
FMAX = SAMPLE_RATE // 2

# Training constants
BATCH_SIZE = 32
NUM_EPOCHS = 50
LEARNING_RATE = 0.001
NUM_CLASSES = 41  # Based on FSDKaggle2018 dataset

# Define paths
TRAIN_AUDIO_PATH = "/kaggle/input/filtereddata/audio/train"
TEST_AUDIO_PATH = "/kaggle/input/filtereddata/audio/test"
VAL_AUDIO_PATH = "/kaggle/input/filtereddata/audio/val"
TRAIN_CSV_PATH = "/kaggle/input/filtereddata/audio/train_metadata.csv"
TEST_CSV_PATH = "/kaggle/input/filtereddata/audio/test_metadata.csv"
VAL_CSV_PATH = "/kaggle/input/filtereddata/audio/val_metadata.csv"

# Load metadata
def load_metadata():
    """Load metadata from CSV files"""
    try:
        train_df = pd.read_csv(TRAIN_CSV_PATH)
        test_df = pd.read_csv(TEST_CSV_PATH)
        val_df = pd.read_csv(VAL_CSV_PATH)

        # Get the unique classes
        all_classes = set(train_df["label"].unique())
        all_classes.update(test_df["label"].unique())
        all_classes.update(val_df["label"].unique())

        class_to_idx = {cls: i for i, cls in enumerate(sorted(all_classes))}
        idx_to_class = {i: cls for cls, i in class_to_idx.items()}

        print(f"Total classes: {len(class_to_idx)}")
        return train_df, test_df, val_df, class_to_idx, idx_to_class

    except Exception as e:
        print(f"Error loading metadata: {e}")
        raise


class AudioDataset(Dataset):
    """Dataset for audio classification with Mel spectrograms"""

    def __init__(
        self,
        df,
        audio_dir,
        class_to_idx,
        transform=None,
        target_sample_rate=SAMPLE_RATE,
    ):
        """
        Args:
            df: DataFrame with metadata
            audio_dir: Directory with audio files
            class_to_idx: Dictionary mapping class names to indices
            transform: Optional transform to be applied on a sample
            target_sample_rate: Sample rate to resample audio to
        """
        self.df = df
        self.audio_dir = audio_dir
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_sample_rate = target_sample_rate

        # Ensure all files exist
        self.valid_files = []
        for i, row in df.iterrows():
            file_path = os.path.join(audio_dir, row["fname"])
            if os.path.exists(file_path):
                self.valid_files.append(i)

        self.df = self.df.iloc[self.valid_files].reset_index(drop=True)
        print(f"Found {len(self.df)} valid audio files in {audio_dir}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Get file path and label
        row = self.df.iloc[idx]
        file_path = os.path.join(self.audio_dir, row["fname"])
        label = self.class_to_idx[row["label"]]

        # Load and preprocess audio
        try:
            waveform, sample_rate = torchaudio.load(file_path)

            # Convert to mono if needed
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Resample if needed
            if sample_rate != self.target_sample_rate:
                resampler = torchaudio.transforms.Resample(
                    sample_rate, self.target_sample_rate
                )
                waveform = resampler(waveform)

            # Pad or truncate to target length
            if waveform.shape[1] < NUM_SAMPLES:
                # Pad
                padding = NUM_SAMPLES - waveform.shape[1]
                waveform = F.pad(waveform, (0, padding))
            elif waveform.shape[1] > NUM_SAMPLES:
                # Truncate
                waveform = waveform[:, :NUM_SAMPLES]

            # Transform to mel spectrogram
            mel_spectrogram = transforms.MelSpectrogram(
                sample_rate=self.target_sample_rate,
                n_fft=N_FFT,
                hop_length=HOP_LENGTH,
                n_mels=N_MELS,
                f_min=FMIN,
                f_max=FMAX,
            )(waveform)

            # Convert to decibels
            mel_spectrogram = transforms.AmplitudeToDB()(mel_spectrogram)

            # Apply transformations if provided
            if self.transform:
                mel_spectrogram = self.transform(mel_spectrogram)

            return mel_spectrogram, label

        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            # Return a zero tensor and the label in case of error
            return torch.zeros((1, N_MELS, NUM_SAMPLES // HOP_LENGTH + 1)), label


# Data augmentation functions for audio
class WeakAugmentation:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, mel_spec):
        if random.random() < self.p:
            # Time masking
            time_mask_param = int(mel_spec.shape[2] * 0.1)  # 10% of time steps
            time_mask = transforms.TimeMasking(time_mask_param)
            mel_spec = time_mask(mel_spec)

        return mel_spec


class StrongAugmentation:
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, mel_spec):
        if random.random() < self.p:
            # Time masking (more aggressive)
            time_mask_param = int(mel_spec.shape[2] * 0.2)  # 20% of time steps
            time_mask = transforms.TimeMasking(time_mask_param)
            mel_spec = time_mask(mel_spec)

            # Frequency masking
            freq_mask_param = int(mel_spec.shape[1] * 0.2)  # 20% of frequency bins
            freq_mask = transforms.FrequencyMasking(freq_mask_param)
            mel_spec = freq_mask(mel_spec)

            # Add random noise
            noise = torch.randn_like(mel_spec) * 0.1
            mel_spec = mel_spec + noise

        return mel_spec


# MobileNet model for audio classification
class AudioMobileNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(AudioMobileNet, self).__init__()
        # Load pre-trained MobileNet model
        self.mobilenet = models.mobilenet_v2(weights="DEFAULT")

        # Modify the first convolutional layer to accept single-channel input
        self.mobilenet.features[0][0] = nn.Conv2d(
            1, 32, kernel_size=3, stride=2, padding=1, bias=False
        )

        # Replace the classifier
        in_features = self.mobilenet.classifier[1].in_features
        self.mobilenet.classifier = nn.Sequential(
            nn.Dropout(0.2), nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.mobilenet(x)


# Modified version for SeFOSS (features model f and classifier model g)
class SeFOSSMobileNet(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(SeFOSSMobileNet, self).__init__()
        # Feature extractor f
        mobilenet = models.mobilenet_v2(weights="DEFAULT")
        mobilenet.features[0][0] = nn.Conv2d(
            1, 32, kernel_size=3, stride=2, padding=1, bias=False
        )
        self.feature_extractor = nn.Sequential(
            *list(mobilenet.features), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten()
        )

        # Output dimension of feature extractor
        self.feature_dim = 1280

        # Classifier g
        self.classifier = nn.Sequential(
            nn.Dropout(0.2), nn.Linear(self.feature_dim, num_classes)
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        logits = self.classifier(features)
        return logits

    def get_features(self, x):
        return self.feature_extractor(x)

    def classify(self, features):
        return self.classifier(features)


# SeFOSS implementation
class SeFOSS:
    def __init__(self, num_classes=NUM_CLASSES, feature_dim=1280, device=device):
        self.device = device
        self.num_classes = num_classes
        self.feature_dim = feature_dim

        # Initialize models
        self.model = SeFOSSMobileNet(num_classes=num_classes).to(device)

        # Initialize optimizers
        self.optimizer = optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, "min", patience=5, factor=0.5
        )

        # Loss functions
        self.criterion = nn.CrossEntropyLoss()

        # Augmentations
        self.weak_augmentation = WeakAugmentation(p=0.5)
        self.strong_augmentation = StrongAugmentation(p=0.8)

        # Thresholds
        self.tau_d = 0.7  # Default value, will be tuned
        self.tau_ood = 0.3  # Default value, will be tuned
        self.tau_nood = 0.1  # Default value, will be tuned

        # Loss weights
        self.w_s = 1.0  # Supervised loss weight
        self.w_p = 1.0  # Pseudo-label loss weight
        self.w_e = 0.5  # Entropy loss weight
        self.w_c = 0.5  # Consistency loss weight
        self.w_u = 0.5  # Uncertainty loss weight

        # History tracking
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "val_accuracy": [],
            "train_accuracy": [],
        }

    def pretrain(self, labeled_loader, val_loader, num_epochs=10):
        """Pretrain the model using only labeled data"""
        print("Pretraining model...")
        best_val_loss = float("inf")

        for epoch in range(num_epochs):
            # Training
            self.model.train()
            train_loss = 0.0
            correct = 0
            total = 0

            for inputs, targets in labeled_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)

                # Forward pass
                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                # Backward pass
                loss.backward()
                self.optimizer.step()

                # Statistics
                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            train_loss = train_loss / len(labeled_loader)
            train_acc = 100.0 * correct / total

            # Validation
            val_loss, val_acc = self.evaluate(val_loader)

            # Update scheduler
            self.scheduler.step(val_loss)

            # Save history
            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["train_accuracy"].append(train_acc)
            self.history["val_accuracy"].append(val_acc)

            print(
                f"Pretrain Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%"
            )

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(
                    self.model.state_dict(),
                    os.path.join(output_dir, "pretrained_model.pth"),
                )

        # Load best model
        self.model.load_state_dict(
            torch.load(os.path.join(output_dir, "pretrained_model.pth"))
        )
        print("Pretraining completed!")

    def compute_confidence_thresholds(self, labeled_loader):
        """Compute confidence thresholds by evaluating on labeled data"""
        print("Computing confidence thresholds...")
        self.model.eval()
        all_confidences = []

        with torch.no_grad():
            for inputs, _ in labeled_loader:
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                probas = F.softmax(outputs, dim=1)
                confidences, _ = torch.max(probas, dim=1)
                all_confidences.extend(confidences.cpu().numpy())

        # Convert to numpy array and sort
        all_confidences = np.sort(all_confidences)
        n = len(all_confidences)

        # Set thresholds based on percentiles
        self.tau_d = all_confidences[int(0.7 * n)]  # 70th percentile
        self.tau_ood = all_confidences[int(0.3 * n)]  # 30th percentile
        self.tau_nood = all_confidences[int(0.1 * n)]  # 10th percentile

        print(
            f"Thresholds - tau_d: {self.tau_d:.4f}, tau_ood: {self.tau_ood:.4f}, tau_nood: {self.tau_nood:.4f}"
        )

    def sefoss_train_step(self, labeled_batch, unlabeled_batch):
        """Perform a single SeFOSS training step"""
        # Process labeled data
        labeled_inputs, labeled_targets = labeled_batch
        labeled_inputs, labeled_targets = labeled_inputs.to(
            self.device
        ), labeled_targets.to(self.device)

        # Forward pass on labeled data
        labeled_outputs = self.model(labeled_inputs)

        # Supervised loss
        l_s = self.criterion(labeled_outputs, labeled_targets)

        # Process unlabeled data if provided
        if unlabeled_batch is not None and len(unlabeled_batch[0]) > 0:
            (
                unlabeled_inputs,
                _,
            ) = unlabeled_batch  # Ignoring the labels for unlabeled data
            unlabeled_inputs = unlabeled_inputs.to(self.device)

            # Apply weak augmentation for each sample
            weak_inputs = []
            for i in range(unlabeled_inputs.size(0)):
                weak_inputs.append(
                    self.weak_augmentation(unlabeled_inputs[i].unsqueeze(0))
                )
            weak_inputs = torch.cat(weak_inputs, dim=0)

            # Apply strong augmentation for each sample
            strong_inputs = []
            for i in range(unlabeled_inputs.size(0)):
                strong_inputs.append(
                    self.strong_augmentation(unlabeled_inputs[i].unsqueeze(0))
                )
            strong_inputs = torch.cat(strong_inputs, dim=0)

            # Get predictions with weak augmentation
            self.model.eval()
            with torch.no_grad():
                weak_outputs = self.model(weak_inputs)
                weak_probs = F.softmax(weak_outputs, dim=1)
                weak_confidences, pseudo_labels = torch.max(weak_probs, dim=1)

            # Switch back to training mode
            self.model.train()

            # Get predictions with strong augmentation
            strong_outputs = self.model(strong_inputs)
            strong_probs = F.softmax(strong_outputs, dim=1)

            # Compute entropy
            entropy = -torch.sum(strong_probs * torch.log(strong_probs + 1e-10), dim=1)
            mean_entropy = torch.mean(entropy)

            # Create masks based on confidence thresholds
            mask_d = weak_confidences >= self.tau_d
            mask_ood = (weak_confidences < self.tau_d) & (
                weak_confidences >= self.tau_ood
            )
            mask_nood = weak_confidences < self.tau_ood

            # Pseudo-label loss
            if torch.sum(mask_d) > 0:
                l_p = F.cross_entropy(strong_outputs[mask_d], pseudo_labels[mask_d])
            else:
                l_p = torch.tensor(0.0).to(self.device)

            # Entropy loss (minimizing entropy for uncertain samples)
            if torch.sum(mask_ood) > 0:
                l_e = torch.mean(entropy[mask_ood])
            else:
                l_e = torch.tensor(0.0).to(self.device)

            # Consistency loss
            if torch.sum(mask_d) > 0:
                l_c = F.mse_loss(strong_probs[mask_d], weak_probs[mask_d])
            else:
                l_c = torch.tensor(0.0).to(self.device)

            # Uncertainty loss
            if torch.sum(mask_nood) > 0:
                # Maximize entropy for noisy/OOD samples
                l_u = -torch.mean(entropy[mask_nood])
            else:
                l_u = torch.tensor(0.0).to(self.device)

        else:
            # If no unlabeled data, set all losses to 0
            l_p = torch.tensor(0.0).to(self.device)
            l_e = torch.tensor(0.0).to(self.device)
            l_c = torch.tensor(0.0).to(self.device)
            l_u = torch.tensor(0.0).to(self.device)

        # Compute total loss with adaptive weights
        total_loss = (
            self.w_s * l_s
            + self.w_p * l_p
            + self.w_e * l_e
            + self.w_c * l_c
            + self.w_u * l_u
        )

        return total_loss

    def train(
        self, labeled_loader, unlabeled_loader, val_loader, num_epochs=NUM_EPOCHS
    ):
        """Train the model using SeFOSS algorithm"""
        print("Starting SeFOSS training...")
        best_val_loss = float("inf")

        # First run pretraining
        self.pretrain(labeled_loader, val_loader, num_epochs=5)

        # Compute confidence thresholds
        self.compute_confidence_thresholds(labeled_loader)

        # Reset history
        self.history = {
            "train_loss": [],
            "val_loss": [],
            "val_accuracy": [],
            "train_accuracy": [],
        }

        # Main training loop
        for epoch in range(num_epochs):
            self.model.train()
            train_loss = 0.0

            # Create iterator for unlabeled loader
            if unlabeled_loader is not None:
                unlabeled_iter = iter(unlabeled_loader)

            # Train with labeled and unlabeled data
            for batch_idx, labeled_batch in enumerate(labeled_loader):
                # Get unlabeled batch if available
                try:
                    if unlabeled_loader is not None:
                        unlabeled_batch = next(unlabeled_iter)
                    else:
                        unlabeled_batch = None
                except StopIteration:
                    if unlabeled_loader is not None:
                        unlabeled_iter = iter(unlabeled_loader)
                        try:
                            unlabeled_batch = next(unlabeled_iter)
                        except StopIteration:
                            unlabeled_batch = None
                    else:
                        unlabeled_batch = None

                # Compute SeFOSS loss
                self.optimizer.zero_grad()
                loss = self.sefoss_train_step(labeled_batch, unlabeled_batch)

                # Backward pass
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()

            train_loss = train_loss / len(labeled_loader)

            # Validation
            val_loss, val_acc = self.evaluate(val_loader)

            # Update scheduler
            self.scheduler.step(val_loss)

            # Compute train accuracy
            train_acc = self.compute_accuracy(labeled_loader)

            # Save history
            self.history["train_loss"].append(train_loss)
            self.history["val_loss"].append(val_loss)
            self.history["train_accuracy"].append(train_acc)
            self.history["val_accuracy"].append(val_acc)

            print(
                f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%"
            )

            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(
                    self.model.state_dict(), os.path.join(output_dir, "best_model.pth")
                )

        # Load best model
        self.model.load_state_dict(
            torch.load(os.path.join(output_dir, "best_model.pth"))
        )

        # Save trained model
        torch.save(self.model, os.path.join(output_dir, "full_model.pkl"))

        # Plot training history
        self.plot_training_history()

        print("Training completed!")
        return self.history

    def evaluate(self, data_loader):
        """Evaluate the model on a given data loader"""
        self.model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                loss = self.criterion(outputs, targets)

                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        val_loss = val_loss / len(data_loader)
        val_acc = 100.0 * correct / total

        return val_loss, val_acc

    def compute_accuracy(self, data_loader):
        """Compute accuracy on a given data loader"""
        self.model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = self.model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        return 100.0 * correct / total

    def plot_training_history(self):
        """Plot training and validation loss/accuracy"""
        plt.figure(figsize=(12, 5))

        # Plot loss
        plt.subplot(1, 2, 1)
        plt.plot(self.history["train_loss"], label="Train Loss")
        plt.plot(self.history["val_loss"], label="Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()

        # Plot accuracy
        plt.subplot(1, 2, 2)
        plt.plot(self.history["train_accuracy"], label="Train Accuracy")
        plt.plot(self.history["val_accuracy"], label="Validation Accuracy")
        plt.xlabel("Epoch")
        plt.ylabel("Accuracy (%)")
        plt.title("Training and Validation Accuracy")
        plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "training_history.png"))
        plt.close()

    def predict(self, data_loader):
        """Make predictions on a given data loader"""
        self.model.eval()
        all_preds = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs = inputs.to(self.device)
                outputs = self.model(inputs)
                _, preds = outputs.max(1)

                all_preds.extend(preds.cpu().numpy())
                all_targets.extend(targets.numpy())

        return np.array(all_preds), np.array(all_targets)

    def evaluate_metrics(self, data_loader, idx_to_class=None):
        """Evaluate the model using standard classification metrics"""
        preds, targets = self.predict(data_loader)

        # Calculate metrics
        acc = accuracy_score(targets, preds)
        prec = precision_score(targets, preds, average="weighted", zero_division=0)
        rec = recall_score(targets, preds, average="weighted", zero_division=0)
        f1 = f1_score(targets, preds, average="weighted", zero_division=0)

        # Create confusion matrix
        cm = confusion_matrix(targets, preds)

        # Print metrics
        print(f"Accuracy: {acc:.4f}")
        print(f"Precision: {prec:.4f}")
        print(f"Recall: {rec:.4f}")
        print(f"F1 Score: {f1:.4f}")

        # Plot confusion matrix if number of classes is reasonable
        if len(np.unique(targets)) <= 20:
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt="d", cmap="Blues")
            plt.xlabel("Predicted")
            plt.ylabel("True")
            plt.title("Confusion Matrix")
            plt.savefig(os.path.join(output_dir, "confusion_matrix.png"))
            plt.close()

        return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}


def create_balanced_data_splits(
    train_dataset, val_dataset, test_dataset, labeled_samples_per_class, min_samples=2
):
    """Create labeled and unlabeled datasets for semi-supervised learning with balanced class distribution"""
    # Get all training samples by class
    train_samples_by_class = defaultdict(list)
    for idx in range(len(train_dataset)):
        _, label = train_dataset[idx]
        
        # Fix: Handle both tensor and int types for label
        if hasattr(label, 'item'):
            label = label.item()  # Convert tensor to int if it's a tensor
            
        train_samples_by_class[label].append(idx)

    # Select labeled samples
    labeled_indices = []
    unlabeled_indices = []

    for label, indices in train_samples_by_class.items():
        # Ensure we have enough samples for this class
        num_available = len(indices)
        num_to_use = min(labeled_samples_per_class, num_available)

        if num_to_use > 0:
            # Select random samples for labeled set
            selected_indices = random.sample(indices, num_to_use)
            labeled_indices.extend(selected_indices)

            # Remaining samples go to unlabeled set
            remaining_indices = [idx for idx in indices if idx not in selected_indices]
            unlabeled_indices.extend(remaining_indices)
        else:
            print(f"Warning: Class {label} has no samples")

    print(
        f"Labeled samples: {len(labeled_indices)}, Unlabeled samples: {len(unlabeled_indices)}"
    )

    # Create subset datasets
    labeled_dataset = Subset(train_dataset, labeled_indices)
    unlabeled_dataset = (
        Subset(train_dataset, unlabeled_indices) if len(unlabeled_indices) > 0 else None
    )

    return labeled_dataset, unlabeled_dataset, val_dataset, test_dataset

def run_sefoss_training(labeled_samples_per_class):
    """Run SeFOSS training with a specific number of labeled samples per class"""
    print(f"{'='*50}")
    print(f"Running SeFOSS with {labeled_samples_per_class} labeled samples per class")
    print(f"{'='*50}")

    # Load metadata
    train_df, test_df, val_df, class_to_idx, idx_to_class = load_metadata()

    # Create datasets
    train_dataset = AudioDataset(train_df, TRAIN_AUDIO_PATH, class_to_idx)
    test_dataset = AudioDataset(test_df, TEST_AUDIO_PATH, class_to_idx)
    val_dataset = AudioDataset(val_df, VAL_AUDIO_PATH, class_to_idx)

    # Create balanced data splits
    (
        labeled_dataset,
        unlabeled_dataset,
        val_dataset,
        test_dataset,
    ) = create_balanced_data_splits(
        train_dataset, val_dataset, test_dataset, labeled_samples_per_class
    )

    # Create data loaders
    labeled_loader = DataLoader(
        labeled_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=2,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    # Create unlabeled loader if we have unlabeled data
    if unlabeled_dataset is not None and len(unlabeled_dataset) > 0:
        unlabeled_loader = DataLoader(
            unlabeled_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=2,
            pin_memory=True,
        )
    else:
        unlabeled_loader = None
        print("Warning: No unlabeled data available")

    # Initialize SeFOSS
    sefoss = SeFOSS(num_classes=len(class_to_idx), feature_dim=1280, device=device)

    # Train model
    start_time = time.time()
    sefoss.train(labeled_loader, unlabeled_loader, val_loader, num_epochs=NUM_EPOCHS)
    training_time = time.time() - start_time
    print(f"Training completed in {training_time:.2f} seconds")

    # Evaluate on test set
    print("Evaluating on test set...")
    metrics = sefoss.evaluate_metrics(test_loader, idx_to_class)

    # Save metrics to CSV
    metrics_df = pd.DataFrame([metrics])
    metrics_df["labeled_samples_per_class"] = labeled_samples_per_class
    metrics_df["training_time"] = training_time
    metrics_df.to_csv(
        os.path.join(output_dir, f"metrics_{labeled_samples_per_class}.csv"),
        index=False,
    )

    # Return metrics
    return sefoss, metrics


def plot_comparison_results(results):
    """Plot comparison of results for different numbers of labeled samples"""
    samples = list(results.keys())
    accuracies = [results[s]["accuracy"] for s in samples]
    precisions = [results[s]["precision"] for s in samples]
    recalls = [results[s]["recall"] for s in samples]
    f1_scores = [results[s]["f1"] for s in samples]

    # Create dataframe for results
    df = pd.DataFrame(
        {
            "Labeled Samples Per Class": samples,
            "Accuracy": accuracies,
            "Precision": precisions,
            "Recall": recalls,
            "F1 Score": f1_scores,
        }
    )

    # Save to CSV
    df.to_csv(os.path.join(output_dir, "comparison_results.csv"), index=False)

    # Plot results
    plt.figure(figsize=(12, 8))

    plt.subplot(2, 2, 1)
    plt.plot(samples, accuracies, "o-", linewidth=2)
    plt.xlabel("Labeled Samples Per Class")
    plt.ylabel("Accuracy")
    plt.title("Accuracy vs Labeled Samples")
    plt.grid(True)

    plt.subplot(2, 2, 2)
    plt.plot(samples, precisions, "o-", linewidth=2)
    plt.xlabel("Labeled Samples Per Class")
    plt.ylabel("Precision")
    plt.title("Precision vs Labeled Samples")
    plt.grid(True)

    plt.subplot(2, 2, 3)
    plt.plot(samples, recalls, "o-", linewidth=2)
    plt.xlabel("Labeled Samples Per Class")
    plt.ylabel("Recall")
    plt.title("Recall vs Labeled Samples")
    plt.grid(True)

    plt.subplot(2, 2, 4)
    plt.plot(samples, f1_scores, "o-", linewidth=2)
    plt.xlabel("Labeled Samples Per Class")
    plt.ylabel("F1 Score")
    plt.title("F1 Score vs Labeled Samples")
    plt.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "comparison_results.png"))
    plt.close()

    return df


def analyze_model_performance(model, test_loader, idx_to_class):
    """Analyze model performance in detail"""
    model.eval()
    all_preds = []
    all_targets = []
    all_confidences = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = inputs.to(device)
            outputs = model.model(inputs)
            probs = F.softmax(outputs, dim=1)
            confidences, preds = torch.max(probs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(targets.numpy())
            all_confidences.extend(confidences.cpu().numpy())

    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_confidences = np.array(all_confidences)

    # Calculate per-class metrics
    classes = sorted(idx_to_class.keys())
    class_names = [idx_to_class[i] for i in classes]

    per_class_metrics = []
    for cls in classes:
        cls_indices = all_targets == cls
        if np.sum(cls_indices) > 0:
            cls_acc = np.mean(all_preds[cls_indices] == cls)
            cls_count = np.sum(cls_indices)
            cls_conf = np.mean(all_confidences[cls_indices])
            per_class_metrics.append(
                {
                    "class_id": cls,
                    "class_name": idx_to_class[cls],
                    "accuracy": cls_acc,
                    "count": cls_count,
                    "confidence": cls_conf,
                }
            )

    # Create dataframe and save
    per_class_df = pd.DataFrame(per_class_metrics)
    per_class_df.to_csv(os.path.join(output_dir, "per_class_metrics.csv"), index=False)

    # Plot per-class accuracy
    plt.figure(figsize=(12, 6))
    plt.bar(per_class_df["class_name"], per_class_df["accuracy"])
    plt.xlabel("Class")
    plt.ylabel("Accuracy")
    plt.title("Per-Class Accuracy")
    plt.xticks(rotation=90)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "per_class_accuracy.png"))
    plt.close()

    # Plot confidence distribution
    plt.figure(figsize=(10, 6))
    plt.hist(all_confidences, bins=20)
    plt.xlabel("Confidence")
    plt.ylabel("Count")
    plt.title("Model Confidence Distribution")
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "confidence_distribution.png"))
    plt.close()

    # Analyze mistakes - find examples with high confidence but wrong prediction
    high_conf_mistakes = []
    for i in range(len(all_preds)):
        if all_preds[i] != all_targets[i] and all_confidences[i] > 0.8:
            high_conf_mistakes.append(
                {
                    "true": idx_to_class[all_targets[i]],
                    "pred": idx_to_class[all_preds[i]],
                    "confidence": all_confidences[i],
                }
            )

    if high_conf_mistakes:
        mistakes_df = pd.DataFrame(high_conf_mistakes)
        mistakes_df.to_csv(
            os.path.join(output_dir, "high_confidence_mistakes.csv"), index=False
        )

    return per_class_df


def fix_sefoss_train_step(self, labeled_batch, unlabeled_batch):
    """A fixed version of sefoss_train_step to handle list inputs properly"""
    # Process labeled data
    labeled_inputs, labeled_targets = labeled_batch
    labeled_inputs, labeled_targets = labeled_inputs.to(
        self.device
    ), labeled_targets.to(self.device)

    # Forward pass on labeled data
    labeled_outputs = self.model(labeled_inputs)

    # Supervised loss
    l_s = self.criterion(labeled_outputs, labeled_targets)

    # Process unlabeled data if provided
    if (
        unlabeled_batch is not None
        and isinstance(unlabeled_batch, list)
        and len(unlabeled_batch) > 0
    ):
        unlabeled_inputs, _ = unlabeled_batch  # Ignoring the labels for unlabeled data

        # Handle the case when unlabeled_inputs is a list
        if isinstance(unlabeled_inputs, list):
            if len(unlabeled_inputs) == 0:
                # No unlabeled data available
                l_p = torch.tensor(0.0).to(self.device)
                l_e = torch.tensor(0.0).to(self.device)
                l_c = torch.tensor(0.0).to(self.device)
                l_u = torch.tensor(0.0).to(self.device)
            else:
                unlabeled_inputs = torch.stack(unlabeled_inputs).to(self.device)
                # Continue with regular processing
        else:
            unlabeled_inputs = unlabeled_inputs.to(self.device)

        # Only proceed if we have unlabeled data
        if isinstance(unlabeled_inputs, torch.Tensor) and unlabeled_inputs.size(0) > 0:
            # Apply weak augmentation for each sample
            weak_inputs = []
            for i in range(unlabeled_inputs.size(0)):
                weak_inputs.append(
                    self.weak_augmentation(unlabeled_inputs[i].unsqueeze(0))
                )
            weak_inputs = torch.cat(weak_inputs, dim=0)

            # Apply strong augmentation for each sample
            strong_inputs = []
            for i in range(unlabeled_inputs.size(0)):
                strong_inputs.append(
                    self.strong_augmentation(unlabeled_inputs[i].unsqueeze(0))
                )
            strong_inputs = torch.cat(strong_inputs, dim=0)

            # Get predictions with weak augmentation
            self.model.eval()
            with torch.no_grad():
                weak_outputs = self.model(weak_inputs)
                weak_probs = F.softmax(weak_outputs, dim=1)
                weak_confidences, pseudo_labels = torch.max(weak_probs, dim=1)

            # Switch back to training mode
            self.model.train()

            # Get predictions with strong augmentation
            strong_outputs = self.model(strong_inputs)
            strong_probs = F.softmax(strong_outputs, dim=1)

            # Compute entropy
            entropy = -torch.sum(strong_probs * torch.log(strong_probs + 1e-10), dim=1)
            mean_entropy = torch.mean(entropy)

            # Create masks based on confidence thresholds
            mask_d = weak_confidences >= self.tau_d
            mask_ood = (weak_confidences < self.tau_d) & (
                weak_confidences >= self.tau_ood
            )
            mask_nood = weak_confidences < self.tau_ood

            # Pseudo-label loss
            if torch.sum(mask_d) > 0:
                l_p = F.cross_entropy(strong_outputs[mask_d], pseudo_labels[mask_d])
            else:
                l_p = torch.tensor(0.0).to(self.device)

            # Entropy loss (minimizing entropy for uncertain samples)
            if torch.sum(mask_ood) > 0:
                l_e = torch.mean(entropy[mask_ood])
            else:
                l_e = torch.tensor(0.0).to(self.device)

            # Consistency loss
            if torch.sum(mask_d) > 0:
                l_c = F.mse_loss(strong_probs[mask_d], weak_probs[mask_d])
            else:
                l_c = torch.tensor(0.0).to(self.device)

            # Uncertainty loss
            if torch.sum(mask_nood) > 0:
                # Maximize entropy for noisy/OOD samples
                l_u = -torch.mean(entropy[mask_nood])
            else:
                l_u = torch.tensor(0.0).to(self.device)
        else:
            # No valid unlabeled data tensor
            l_p = torch.tensor(0.0).to(self.device)
            l_e = torch.tensor(0.0).to(self.device)
            l_c = torch.tensor(0.0).to(self.device)
            l_u = torch.tensor(0.0).to(self.device)
    else:
        # No unlabeled batch provided
        l_p = torch.tensor(0.0).to(self.device)
        l_e = torch.tensor(0.0).to(self.device)
        l_c = torch.tensor(0.0).to(self.device)
        l_u = torch.tensor(0.0).to(self.device)

    # Compute total loss with adaptive weights
    total_loss = (
        self.w_s * l_s
        + self.w_p * l_p
        + self.w_e * l_e
        + self.w_c * l_c
        + self.w_u * l_u
    )

    return total_loss


# Fix the SeFOSS class by monkey-patching the problematic method
SeFOSS.sefoss_train_step = fix_sefoss_train_step


def main_fixed():
    """Main function to run all experiments with fixed error handling"""
    print("Starting SeFOSS experiments on FSDKaggle2018 dataset...")

    # Create output directory
    global output_dir
    os.makedirs(output_dir, exist_ok=True)

    # Run experiments with different numbers of labeled samples
    labeled_samples_sizes = [40, 100, 200, 400]
    results = {}

    for size in labeled_samples_sizes:
        print(f"{'='*50}")
        print(f"Running experiment with {size} labeled samples per class")
        print(f"{'='*50}")

        try:
            # Create directory for this experiment
            exp_dir = os.path.join(output_dir, f"samples_{size}")
            os.makedirs(exp_dir, exist_ok=True)

            # Temporarily change output directory
            
            original_output_dir = output_dir
            output_dir = exp_dir

            # Run SeFOSS training
            sefoss, metrics = run_sefoss_training(labeled_samples_per_class=size)
            results[size] = metrics

            # Load metadata for analysis
            _, _, _, class_to_idx, idx_to_class = load_metadata()

            # Create test loader for analysis
            test_df = pd.read_csv(TEST_CSV_PATH)
            test_dataset = AudioDataset(test_df, TEST_AUDIO_PATH, class_to_idx)
            test_loader = DataLoader(
                test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2
            )

            # Analyze model performance
            analyze_model_performance(sefoss, test_loader, idx_to_class)

            # Reset output directory
            output_dir = original_output_dir

        except Exception as e:
            print(f"Error in experiment with {size} labeled samples per class: {e}")
            import traceback

            traceback.print_exc()
            # Reset output directory
            output_dir = original_output_dir
            # Store None for this experiment
            results[size] = {
                "accuracy": 0.0,
                "precision": 0.0,
                "recall": 0.0,
                "f1": 0.0,
            }

    # Plot comparison of results
    comparison_df = plot_comparison_results(results)
    print("Experiment completed successfully!")
    print(comparison_df)

    return results


if __name__ == "__main__":
    main_fixed()


Using device: cuda
Starting SeFOSS experiments on FSDKaggle2018 dataset...
Running experiment with 40 labeled samples per class
Running SeFOSS with 40 labeled samples per class
Total classes: 41
Found 1000 valid audio files in /kaggle/input/filtereddata/audio/train
Found 1000 valid audio files in /kaggle/input/filtereddata/audio/test
Found 1000 valid audio files in /kaggle/input/filtereddata/audio/val


Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth


Labeled samples: 996, Unlabeled samples: 4


100%|██████████| 13.6M/13.6M [00:00<00:00, 122MB/s]


Starting SeFOSS training...
Pretraining model...
Pretrain Epoch 1/5 | Train Loss: 3.3218 | Train Acc: 13.86% | Val Loss: 2.8051 | Val Acc: 24.40%
Pretrain Epoch 2/5 | Train Loss: 2.2686 | Train Acc: 36.65% | Val Loss: 2.2851 | Val Acc: 32.80%
Pretrain Epoch 3/5 | Train Loss: 1.6716 | Train Acc: 51.51% | Val Loss: 1.9732 | Val Acc: 47.60%
Pretrain Epoch 4/5 | Train Loss: 1.1831 | Train Acc: 67.37% | Val Loss: 1.7918 | Val Acc: 52.40%
Pretrain Epoch 5/5 | Train Loss: 0.8190 | Train Acc: 77.71% | Val Loss: 1.7824 | Val Acc: 53.30%
Pretraining completed!
Computing confidence thresholds...
Thresholds - tau_d: 0.9614, tau_ood: 0.7395, tau_nood: 0.4651
Epoch 1/50 | Train Loss: -0.7875 | Train Acc: 74.70% | Val Loss: 2.1447 | Val Acc: 45.70%
Epoch 2/50 | Train Loss: -1.1925 | Train Acc: 84.34% | Val Loss: 1.9823 | Val Acc: 46.30%
Epoch 3/50 | Train Loss: -1.3707 | Train Acc: 90.26% | Val Loss: 1.8705 | Val Acc: 52.70%
Epoch 4/50 | Train Loss: -1.5291 | Train Acc: 86.45% | Val Loss: 2.1428 | Va

Traceback (most recent call last):
  File "/tmp/ipykernel_31/444710727.py", line 1196, in main_fixed
    analyze_model_performance(sefoss, test_loader, idx_to_class)
  File "/tmp/ipykernel_31/444710727.py", line 934, in analyze_model_performance
    model.eval()
    ^^^^^^^^^^
AttributeError: 'SeFOSS' object has no attribute 'eval'


Found 1000 valid audio files in /kaggle/input/filtereddata/audio/test
Found 1000 valid audio files in /kaggle/input/filtereddata/audio/val
Labeled samples: 1000, Unlabeled samples: 0
Starting SeFOSS training...
Pretraining model...
Pretrain Epoch 1/5 | Train Loss: 3.3533 | Train Acc: 13.80% | Val Loss: 2.8655 | Val Acc: 22.00%
Pretrain Epoch 2/5 | Train Loss: 2.3792 | Train Acc: 34.40% | Val Loss: 2.5738 | Val Acc: 27.50%
Pretrain Epoch 3/5 | Train Loss: 1.6552 | Train Acc: 52.20% | Val Loss: 2.0045 | Val Acc: 43.00%
Pretrain Epoch 4/5 | Train Loss: 1.1209 | Train Acc: 68.70% | Val Loss: 1.7053 | Val Acc: 53.50%
Pretrain Epoch 5/5 | Train Loss: 0.8409 | Train Acc: 76.70% | Val Loss: 1.7742 | Val Acc: 52.00%
Pretraining completed!
Computing confidence thresholds...
Thresholds - tau_d: 0.8964, tau_ood: 0.5137, tau_nood: 0.2977
Epoch 1/50 | Train Loss: 0.8144 | Train Acc: 91.50% | Val Loss: 1.7109 | Val Acc: 54.40%
Epoch 2/50 | Train Loss: 0.5404 | Train Acc: 95.70% | Val Loss: 1.6488 | V

Traceback (most recent call last):
  File "/tmp/ipykernel_31/444710727.py", line 1196, in main_fixed
    analyze_model_performance(sefoss, test_loader, idx_to_class)
  File "/tmp/ipykernel_31/444710727.py", line 934, in analyze_model_performance
    model.eval()
    ^^^^^^^^^^
AttributeError: 'SeFOSS' object has no attribute 'eval'


Found 1000 valid audio files in /kaggle/input/filtereddata/audio/val
Labeled samples: 1000, Unlabeled samples: 0
Starting SeFOSS training...
Pretraining model...
Pretrain Epoch 1/5 | Train Loss: 3.2625 | Train Acc: 15.50% | Val Loss: 2.7426 | Val Acc: 24.50%
Pretrain Epoch 2/5 | Train Loss: 2.1639 | Train Acc: 40.90% | Val Loss: 2.1217 | Val Acc: 40.70%
Pretrain Epoch 3/5 | Train Loss: 1.4518 | Train Acc: 60.30% | Val Loss: 1.9217 | Val Acc: 47.20%
Pretrain Epoch 4/5 | Train Loss: 1.0148 | Train Acc: 71.50% | Val Loss: 1.7180 | Val Acc: 53.20%
Pretrain Epoch 5/5 | Train Loss: 0.7385 | Train Acc: 79.90% | Val Loss: 1.7578 | Val Acc: 53.50%
Pretraining completed!
Computing confidence thresholds...
Thresholds - tau_d: 0.9312, tau_ood: 0.6220, tau_nood: 0.3302
Epoch 1/50 | Train Loss: 0.6204 | Train Acc: 93.60% | Val Loss: 1.6629 | Val Acc: 56.30%
Epoch 2/50 | Train Loss: 0.4455 | Train Acc: 96.70% | Val Loss: 1.6728 | Val Acc: 56.50%
Epoch 3/50 | Train Loss: 0.3242 | Train Acc: 98.30% | V

Traceback (most recent call last):
  File "/tmp/ipykernel_31/444710727.py", line 1196, in main_fixed
    analyze_model_performance(sefoss, test_loader, idx_to_class)
  File "/tmp/ipykernel_31/444710727.py", line 934, in analyze_model_performance
    model.eval()
    ^^^^^^^^^^
AttributeError: 'SeFOSS' object has no attribute 'eval'


Found 1000 valid audio files in /kaggle/input/filtereddata/audio/val
Labeled samples: 1000, Unlabeled samples: 0
Starting SeFOSS training...
Pretraining model...
Pretrain Epoch 1/5 | Train Loss: 3.2326 | Train Acc: 14.90% | Val Loss: 2.7691 | Val Acc: 27.20%
Pretrain Epoch 2/5 | Train Loss: 2.2021 | Train Acc: 40.60% | Val Loss: 2.1556 | Val Acc: 37.20%
Pretrain Epoch 3/5 | Train Loss: 1.4844 | Train Acc: 59.20% | Val Loss: 1.9435 | Val Acc: 49.60%
Pretrain Epoch 4/5 | Train Loss: 1.0409 | Train Acc: 69.70% | Val Loss: 1.8418 | Val Acc: 50.10%
Pretrain Epoch 5/5 | Train Loss: 0.6973 | Train Acc: 81.10% | Val Loss: 1.7808 | Val Acc: 54.90%
Pretraining completed!
Computing confidence thresholds...
Thresholds - tau_d: 0.9765, tau_ood: 0.7561, tau_nood: 0.4499
Epoch 1/50 | Train Loss: 0.4730 | Train Acc: 94.60% | Val Loss: 1.7766 | Val Acc: 56.10%
Epoch 2/50 | Train Loss: 0.3361 | Train Acc: 95.20% | Val Loss: 1.9029 | Val Acc: 56.00%
Epoch 3/50 | Train Loss: 0.2586 | Train Acc: 96.90% | V

Traceback (most recent call last):
  File "/tmp/ipykernel_31/444710727.py", line 1196, in main_fixed
    analyze_model_performance(sefoss, test_loader, idx_to_class)
  File "/tmp/ipykernel_31/444710727.py", line 934, in analyze_model_performance
    model.eval()
    ^^^^^^^^^^
AttributeError: 'SeFOSS' object has no attribute 'eval'


Experiment completed successfully!
   Labeled Samples Per Class  Accuracy  Precision  Recall  F1 Score
0                         40       0.0        0.0     0.0       0.0
1                        100       0.0        0.0     0.0       0.0
2                        200       0.0        0.0     0.0       0.0
3                        400       0.0        0.0     0.0       0.0


In [None]:
import os
import zipfile

# Define the directory to zip and the output zip file path
dir_to_zip = "/kaggle/working/SEFOSS_AudioData"
zip_path = "/kaggle/working/SEFOSS_AudioData.zip"

# Create a zip file
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(dir_to_zip):
        for file in files:
            file_path = os.path.join(root, file)
            # Add file to zip, maintaining relative path
            zipf.write(file_path, os.path.relpath(file_path, dir_to_zip))

print(f"Zipped output directory to: {zip_path}")
