In [8]:
# ==========================================
# 1. IMPORTS
# ==========================================
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import librosa

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è Running on: {device}")

# ==========================================
# 2. CONFIGURATION
# ==========================================

# Path to the 58-class checkpoint (The one that works!)
CHECKPOINT_PATH = "checkpoints/train43/best_model.pt"

# Folder with recordings to test (Aleksander)
TEST_AUDIO_FOLDER = r"Recordings_1/Aleksander"

# üéØ GROUND TRUTH
# We expect Aleksander to be identified as Class 1.
# Anyone else in this folder is an imposter.
EXPECTED_ID = 1 

# Model Params (Must match training)
N_MELS = 64
EMBED_DIM = 256
NUM_SPEAKERS = 58 

# Logic: Convert specific IDs to "Member"
# Based on your logs: Aleksander(1), Mantas(27), Michal(29), Piotr(38), Rafal(40)
IN_GROUP_IDS = [1, 27, 29, 38, 40]

CONFIDENCE_THRESHOLD = 0.60 

# ==========================================
# 3. MODEL ARCHITECTURE
# ==========================================

class SEBlock(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 4)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        B, C, T, F = x.shape
        s = x.mean(dim=(2, 3))
        w = self.fc(s).view(B, C, 1, 1)
        return x * w

class Backbone(nn.Module):
    def __init__(self, no_mels, embed_dim, rnn_hidden, rnn_layers, bidir):
        super().__init__()
        self.cnn_block = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            SEBlock(32, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            SEBlock(64, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            SEBlock(128, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
        )
        self.rnn_hidden = rnn_hidden
        self.rnn = nn.GRU(input_size=128 * (no_mels // 8), hidden_size=self.rnn_hidden,
                          num_layers=rnn_layers, bidirectional=bidir, batch_first=True, dropout=0.2)
        out_dim = (2 if bidir else 1) * rnn_hidden
        self.rnn_ln = nn.LayerNorm(out_dim)
        self.att = nn.Sequential(nn.Linear(out_dim, 128), nn.Tanh(), nn.Linear(128, 1))
        self.proj = nn.Sequential(nn.Linear(out_dim*2, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, embed_dim))

    def forward(self, x):
        h = self.cnn_block(x)
        
        # FIX: Rename frequency var to 'Freq' to avoid shadowing torch.nn.functional (F)
        B, C, T, Freq = h.shape  
        h = h.permute(0, 2, 1, 3).contiguous().view(B, T, C * Freq)
        
        rnn_out, _ = self.rnn(h)
        rnn_out = self.rnn_ln(rnn_out)
        a = self.att(rnn_out).squeeze(-1)
        w = torch.softmax(a, dim=1).unsqueeze(-1)
        mean = torch.sum(w * rnn_out, dim=1)
        var = torch.sum(w * (rnn_out - mean.unsqueeze(1))**2, dim=1)
        std = torch.sqrt(var + 1e-5)
        stats = torch.cat([mean, std], 1)
        z = self.proj(stats)
        
        # Now 'F' correctly refers to torch.nn.functional
        return F.normalize(z, p=2, dim=1)

class AAMSoftmax(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.20):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    def forward(self, emb):
        W = F.normalize(self.weight, dim=1)
        return emb @ W.T * self.s

class SpeakerClassifier(nn.Module):
    def __init__(self, backbone, num_speakers):
        super().__init__()
        self.backbone = backbone
        self.aamsm = AAMSoftmax(256, num_speakers)
    def forward(self, x):
        emb = self.backbone(x)
        return self.aamsm(emb)

# ==========================================
# 4. PREPROCESSING (MATCHING TRAINING)
# ==========================================

def preprocess_file(file_path):
    try:
        # 1. Load Audio
        y, sr = librosa.load(file_path, sr=16000, mono=True)
        
        # 2. Trim Silence (Optional but good)
        y, _ = librosa.effects.trim(y, top_db=20)
        
        if len(y) < 1000: return None

        # 3. Volume Normalization (Peak Norm) - Matches AudioPreprocessor.normalize_volume
        y = y / (np.max(np.abs(y)) + 1e-9)

        # 4. Chunking (3.0s chunks)
        chunk_len = int(3.0 * 16000)
        stride = int(2.0 * 16000)
        chunks = []
        
        if len(y) < chunk_len:
            y = np.pad(y, (0, chunk_len - len(y)))
            chunks.append(y)
        else:
            for i in range(0, len(y) - chunk_len + 1, stride):
                chunks.append(y[i : i + chunk_len])
        
        mels = []
        for c in chunks:
            # 5. Compute Mel Spectrogram (Exact Parameters from prepare_h5.ipynb)
            mel = librosa.feature.melspectrogram(
                y=c, 
                sr=16000, 
                n_fft=2048,       # Training used 2048
                hop_length=512,   # Training used 512
                n_mels=N_MELS
            )
            
            # 6. Log-Mel Scaling (Match Training: ref=np.max)
            log_mel = librosa.power_to_db(mel, ref=np.max)
            
            # Transpose to [Time, Freq]
            mels.append(log_mel.T)
            
        return torch.tensor(np.array(mels), dtype=torch.float32).unsqueeze(1)
        
    except Exception as e:
        print(f"Error {file_path}: {e}")
        return None

# ==========================================
# 5. EXECUTION
# ==========================================

print(f"‚è≥ Loading Model ({NUM_SPEAKERS} Classes)...")
backbone = Backbone(no_mels=N_MELS, embed_dim=EMBED_DIM, rnn_hidden=256, rnn_layers=2, bidir=True)
model = SpeakerClassifier(backbone, num_speakers=NUM_SPEAKERS)

if os.path.exists(CHECKPOINT_PATH):
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(state_dict, strict=False) 
    model.to(device)
    model.eval()
    print("‚úÖ Model loaded successfully!")
else:
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT_PATH}")

# Search ONLY for WAV files (since we converted them)
audio_files = glob.glob(os.path.join(TEST_AUDIO_FOLDER, "*.wav"))

print(f"üìÇ Found {len(audio_files)} WAV files in '{TEST_AUDIO_FOLDER}'")
print(f"üéØ Target Class: {EXPECTED_ID} (Member)\n")

print(f"{'FILENAME':<40} | {'PREDICTION':<15} | {'CONFIDENCE'} | {'ID'} | {'STATUS'}")
print("-" * 95)

results = []
correct_members = 0
correct_outsiders = 0
total_members = 0
total_outsiders = 0

# Updated list of files that are ACTUALLY Aleksander
# (Based on your CSV analysis)
TRUE_MEMBER_FILES = [
    "Alexander-aleksander.wav",
    "Gallic-Wars-Aleksander.wav",
    "Napoleon-aleksander.wav",
    "prince-Aleksander.wav"
]

with torch.no_grad():
    for file_path in audio_files:
        batch = preprocess_file(file_path)
        if batch is None: continue
        
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.softmax(logits, dim=1)
        
        chunk_conf, chunk_ids = torch.max(probs, dim=1)
        votes = chunk_ids.cpu().tolist()
        pred_id = max(set(votes), key=votes.count)
        avg_conf = chunk_conf.mean().item()
        
        # --- BINARY LOGIC ---
        if pred_id in IN_GROUP_IDS:
            if avg_conf >= CONFIDENCE_THRESHOLD:
                label = "‚úÖ MEMBER"
                pred_class = 1 # Access Granted
            else:
                label = "‚ùì LOW CONF (1)"
                pred_class = 0 # Reject if unsure
        else:
            label = "‚ùå OUTSIDER"
            pred_class = 0 # Access Denied
            
        fname = os.path.basename(file_path)
        
        # --- GROUND TRUTH CHECK ---
        is_actually_member = fname in TRUE_MEMBER_FILES
        
        if is_actually_member:
            total_members += 1
            if pred_class == 1: 
                correct_members += 1
                status = "MATCH"
            else:
                status = "MISS" # Should have been member
        else:
            total_outsiders += 1
            if pred_class == 0: 
                correct_outsiders += 1
                status = "MATCH"
            else:
                status = "FALSE ACCEPT" # Dangerous!
        
        print(f"{fname[:38]:<40} | {label:<15} | {avg_conf:.1%}      | {pred_id:<2} | {status}")
        
        results.append({
            "file": fname,
            "prediction": label,
            "raw_id": pred_id,
            "confidence": avg_conf
        })

# Final Report
print("\n" + "="*40)
print(f"üìä ACCURACY REPORT")
print("="*40)
if total_members > 0:
    print(f"üë§ Members (Aleksander): {correct_members}/{total_members} ({(correct_members/total_members)*100:.1f}%)")
else:
    print("üë§ Members (Aleksander): 0/0 (No files found)")

if total_outsiders > 0:
    print(f"üö´ Outsiders (Imposters): {correct_outsiders}/{total_outsiders} ({(correct_outsiders/total_outsiders)*100:.1f}%)")
else:
    print("üö´ Outsiders (Imposters): 0/0 (No files found)")
print("="*40)

üñ•Ô∏è Running on: cuda
‚è≥ Loading Model (58 Classes)...
‚úÖ Model loaded successfully!
üìÇ Found 31 WAV files in 'Recordings_1/Aleksander'
üéØ Target Class: 1 (Member)

FILENAME                                 | PREDICTION      | CONFIDENCE | ID | STATUS
-----------------------------------------------------------------------------------------------
adi.wav                                  | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
Alexander-aleksander.wav                 | ‚úÖ MEMBER        | 100.0%      | 1  | MATCH
churchill-1.wav                          | ‚ùå OUTSIDER      | 99.5%      | 8  | MATCH
fdr.wav                                  | ‚ùå OUTSIDER      | 99.4%      | 11 | MATCH
Gallic-Wars-Aleksander.wav               | ‚úÖ MEMBER        | 100.0%      | 1  | MATCH
gatsby-ania.wav                          | ‚ùå OUTSIDER      | 100.0%      | 3  | MATCH
grian-1.wav                              | ‚ùå OUTSIDER      | 98.8%      | 13 | MATCH
grian-2.wav                    

In [9]:
# ==========================================
# 1. IMPORTS
# ==========================================
import os
import glob
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import librosa

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è Running on: {device}")

# ==========================================
# 2. CONFIGURATION
# ==========================================

# Path to the BINARY checkpoint
CHECKPOINT_PATH = "checkpoints/train31/best_model.pt"

# Folder with recordings (Aleksander)
TEST_AUDIO_FOLDER = r"Recordings_1/Aleksander"

# üéØ Target Class: 1 (Member)
# We want to see if Aleksander gets predicted as 1.
EXPECTED_CLASS = 1 

# Model Params (Binary = 2 Speakers)
NUM_SPEAKERS = 2 
N_MELS = 64
EMBED_DIM = 256

CONFIDENCE_THRESHOLD = 0.50

# ==========================================
# 3. MODEL ARCHITECTURE
# ==========================================

class SEBlock(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 4)
        self.fc = nn.Sequential(
            nn.Linear(channels, hidden, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, channels, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        B, C, T, F = x.shape
        s = x.mean(dim=(2, 3))
        w = self.fc(s).view(B, C, 1, 1)
        return x * w

class Backbone(nn.Module):
    def __init__(self, no_mels, embed_dim, rnn_hidden, rnn_layers, bidir):
        super().__init__()
        self.cnn_block = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(),
            SEBlock(32, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
            SEBlock(64, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
            SEBlock(128, reduction=8), nn.MaxPool2d(kernel_size=(1, 2)),
        )
        self.rnn_hidden = rnn_hidden
        self.rnn = nn.GRU(input_size=128 * (no_mels // 8), hidden_size=self.rnn_hidden,
                          num_layers=rnn_layers, bidirectional=bidir, batch_first=True, dropout=0.2)
        out_dim = (2 if bidir else 1) * rnn_hidden
        self.rnn_ln = nn.LayerNorm(out_dim)
        self.att = nn.Sequential(nn.Linear(out_dim, 128), nn.Tanh(), nn.Linear(128, 1))
        self.proj = nn.Sequential(nn.Linear(out_dim*2, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Linear(256, embed_dim))

    def forward(self, x):
        h = self.cnn_block(x)
        
        # CORRECT RESHAPE LOGIC
        B, C, T, Freq = h.shape  
        h = h.permute(0, 2, 1, 3).contiguous().view(B, T, C * Freq)
        
        rnn_out, _ = self.rnn(h)
        rnn_out = self.rnn_ln(rnn_out)
        a = self.att(rnn_out).squeeze(-1)
        w = torch.softmax(a, dim=1).unsqueeze(-1)
        mean = torch.sum(w * rnn_out, dim=1)
        var = torch.sum(w * (rnn_out - mean.unsqueeze(1))**2, dim=1)
        std = torch.sqrt(var + 1e-5)
        stats = torch.cat([mean, std], 1)
        z = self.proj(stats)
        return F.normalize(z, p=2, dim=1)

class AAMSoftmax(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.20):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    def forward(self, emb):
        W = F.normalize(self.weight, dim=1)
        return emb @ W.T * self.s

class SpeakerClassifier(nn.Module):
    def __init__(self, backbone, num_speakers):
        super().__init__()
        self.backbone = backbone
        self.aamsm = AAMSoftmax(256, num_speakers)
    def forward(self, x):
        emb = self.backbone(x)
        return self.aamsm(emb)

# ==========================================
# 4. PREPROCESSING (FIXED)
# ==========================================

def preprocess_file(file_path):
    try:
        # 1. Load Audio
        y, sr = librosa.load(file_path, sr=16000, mono=True)
        y, _ = librosa.effects.trim(y, top_db=20)
        
        if len(y) < 1000: return None

        # 2. Volume Normalization (Peak Norm) - Matches Training!
        y = y / (np.max(np.abs(y)) + 1e-9)

        # 3. Chunking (3.0s chunks)
        chunk_len = int(3.0 * 16000)
        stride = int(2.0 * 16000)
        chunks = []
        
        if len(y) < chunk_len:
            y = np.pad(y, (0, chunk_len - len(y)))
            chunks.append(y)
        else:
            for i in range(0, len(y) - chunk_len + 1, stride):
                chunks.append(y[i : i + chunk_len])
        
        mels = []
        for c in chunks:
            # 4. Mel Spectrogram (Exact Parameters from prepare_h5.ipynb)
            mel = librosa.feature.melspectrogram(
                y=c, 
                sr=16000, 
                n_fft=2048,       # Fixed
                hop_length=512,   # Fixed
                n_mels=N_MELS
            )
            
            # 5. Log-Mel Scaling (Match Training: ref=np.max)
            log_mel = librosa.power_to_db(mel, ref=np.max)
            
            mels.append(log_mel.T)
            
        return torch.tensor(np.array(mels), dtype=torch.float32).unsqueeze(1)
        
    except Exception as e:
        print(f"Error {file_path}: {e}")
        return None

# ==========================================
# 5. EXECUTION
# ==========================================

print(f"‚è≥ Loading Binary Model ({NUM_SPEAKERS} Classes)...")
backbone = Backbone(no_mels=N_MELS, embed_dim=EMBED_DIM, rnn_hidden=256, rnn_layers=2, bidir=True)
model = SpeakerClassifier(backbone, num_speakers=NUM_SPEAKERS)

if os.path.exists(CHECKPOINT_PATH):
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
    # strict=False allows loading despite minor key differences
    model.load_state_dict(state_dict, strict=False) 
    model.to(device)
    model.eval()
    print("‚úÖ Model loaded successfully!")
else:
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT_PATH}")

# Search ONLY for WAV files
audio_files = glob.glob(os.path.join(TEST_AUDIO_FOLDER, "*.wav"))

print(f"üìÇ Found {len(audio_files)} WAV files in '{TEST_AUDIO_FOLDER}'")
print(f"üéØ Target: Class {EXPECTED_CLASS} (Member)\n")

print(f"{'FILENAME':<40} | {'PREDICTION':<15} | {'CONFIDENCE'} | {'ID'} | {'STATUS'}")
print("-" * 95)

results = []
correct_members = 0
correct_outsiders = 0
total_members = 0
total_outsiders = 0

# True Member Files (Aleksander)
TRUE_MEMBER_FILES = [
    "Alexander-aleksander.wav",
    "Gallic-Wars-Aleksander.wav",
    "Napoleon-aleksander.wav",
    "prince-Aleksander.wav"
]

with torch.no_grad():
    for file_path in audio_files:
        batch = preprocess_file(file_path)
        if batch is None: continue
        
        batch = batch.to(device)
        logits = model(batch)
        probs = torch.softmax(logits, dim=1)
        
        chunk_conf, chunk_ids = torch.max(probs, dim=1)
        votes = chunk_ids.cpu().tolist()
        
        # Binary Prediction: 0 or 1
        pred_id = max(set(votes), key=votes.count)
        avg_conf = chunk_conf.mean().item()
        
        # --- DECISION LOGIC ---
        if pred_id == 1:
            if avg_conf >= CONFIDENCE_THRESHOLD:
                label = "‚úÖ MEMBER"
            else:
                label = "‚ùì LOW CONF (1)"
        else:
            label = "‚ùå OUTSIDER"
            
        fname = os.path.basename(file_path)
        
        # --- GROUND TRUTH CHECK ---
        is_actually_member = fname in TRUE_MEMBER_FILES
        
        if is_actually_member:
            total_members += 1
            if pred_id == 1: 
                correct_members += 1
                status = "MATCH"
            else:
                status = "MISS"
        else:
            total_outsiders += 1
            if pred_id == 0: 
                correct_outsiders += 1
                status = "MATCH"
            else:
                status = "FALSE ACCEPT" 
        
        print(f"{fname[:38]:<40} | {label:<15} | {avg_conf:.1%}      | {pred_id:<2} | {status}")
        
        results.append({
            "file": fname,
            "prediction": label,
            "raw_id": pred_id,
            "confidence": avg_conf
        })

# Final Report
print("\n" + "="*40)
print(f"üìä BINARY MODEL ACCURACY REPORT")
print("="*40)
if total_members > 0:
    print(f"üë§ Members (Aleksander): {correct_members}/{total_members} ({(correct_members/total_members)*100:.1f}%)")
if total_outsiders > 0:
    print(f"üö´ Outsiders (Imposters): {correct_outsiders}/{total_outsiders} ({(correct_outsiders/total_outsiders)*100:.1f}%)")
print("="*40)

üñ•Ô∏è Running on: cuda
‚è≥ Loading Binary Model (2 Classes)...
‚úÖ Model loaded successfully!
üìÇ Found 31 WAV files in 'Recordings_1/Aleksander'
üéØ Target: Class 1 (Member)

FILENAME                                 | PREDICTION      | CONFIDENCE | ID | STATUS
-----------------------------------------------------------------------------------------------
adi.wav                                  | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
Alexander-aleksander.wav                 | ‚úÖ MEMBER        | 100.0%      | 1  | MATCH
churchill-1.wav                          | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
fdr.wav                                  | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
Gallic-Wars-Aleksander.wav               | ‚úÖ MEMBER        | 100.0%      | 1  | MATCH
gatsby-ania.wav                          | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
grian-1.wav                              | ‚ùå OUTSIDER      | 100.0%      | 0  | MATCH
grian-2.wav           