# ResNet++ Notebook

This notebook separates the major code sections so you can run each step independently.

In [None]:
# Cell 1: Basic imports
import torch
import torch.nn as nn
import torchaudio
import librosa
import numpy as np
# Run this cell first to verify imports


In [None]:
# Cell 2: Device setup and constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")


In [None]:
# Dataset Class
class AudioDataset(Dataset):
    def __init__(self, base_path, protocol_path, transform=None, sr=16000):
        self.transform = transform
        self.file_paths = []
        self.labels = []
        self.sr = sr
        self.base_path = base_path  # Store base_path as instance variable
        
        print(f"Loading dataset from: {base_path}")
        print(f"Using protocol file: {protocol_path}")
        
        if not os.path.exists(base_path):
            raise FileNotFoundError(f"Dataset path does not exist: {base_path}")
        if not os.path.exists(protocol_path):
            raise FileNotFoundError(f"Protocol file does not exist: {protocol_path}")
            
        # Load protocol file first
        label_dict = {}
        with open(protocol_path, 'r') as f:
            for line in f:
                # Split on whitespace but keep filename as first element
                parts = line.strip().split(maxsplit=2)
                if len(parts) < 2:
                    print(f"Warning: Malformed line in protocol file: {line}")
                    continue
                    
                filename = parts[0].strip()  # Get filename without leading/trailing whitespace
                label_type = parts[1].strip()  # Get label type (genuine or spoof)
                label_dict[filename] = 1 if label_type == 'genuine' else 0
                
        print(f"Loaded {len(label_dict)} entries from protocol file")
        print("First few protocol entries:")
        for i, (k, v) in enumerate(label_dict.items()):
            if i < 5:
                print(f"{k}: {'genuine' if v==1 else 'spoof'}")
        
        # Scan directory for audio files
        wav_files = sorted([f for f in os.listdir(base_path) if f.endswith('.wav')])
        print(f"\nFound {len(wav_files)} WAV files in directory")
        print("First few WAV files:")
        for f in wav_files[:5]:
            print(f)
            
        # Match files with protocol entries
        for wav_file in wav_files:
            if wav_file in label_dict:
                self.file_paths.append(os.path.join(base_path, wav_file))
                self.labels.append(label_dict[wav_file])
            else:
                print(f"Warning: No protocol entry for file {wav_file}")
        
        # Print final statistics
        genuine = self.labels.count(1)
        spoofed = self.labels.count(0)
        print("\nDataset Statistics:")
        print(f"Number of genuine samples: {genuine}")
        print(f"Number of spoofed samples: {spoofed}")
        print(f"Total files matched: {len(self.file_paths)}")
        
        if len(self.file_paths) == 0:
            raise RuntimeError("No files were matched between directory and protocol!")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        label = self.labels[idx]

        try:
            # Load and process audio
            audio, sr = load_audio(file_path, self.sr)
            if audio is None or sr is None:
                # Create valid dummy tensor instead of zeros
                dummy_input = torch.randn(3, 224, 224)  # Random noise is better than zeros
                print(f"Warning: Creating dummy tensor for {file_path}")
                return dummy_input, torch.tensor(label, dtype=torch.long)

            # Process audio
            audio = normalize_audio(audio)
            audio = remove_gaussian_noise(audio, sr)
            mel_spectrogram = extract_mel_spectrogram(audio, sr)
            
            # Ensure mel_spectrogram is not None and has correct shape
            if mel_spectrogram is None or mel_spectrogram.size == 0:
                dummy_input = torch.randn(3, 224, 224)
                print(f"Warning: Invalid mel spectrogram for {file_path}")
                return dummy_input, torch.tensor(label, dtype=torch.long)
            
            # Apply augmentation
            if self.transform and 'train' in str(self.base_path).lower():
                mel_spectrogram = spec_augment(mel_spectrogram)
            
            # Convert to image format
            mel_spectrogram = np.stack([mel_spectrogram] * 3, axis=-1)
            mel_spectrogram = Image.fromarray(np.uint8(mel_spectrogram))

            if self.transform:
                mel_spectrogram = self.transform(mel_spectrogram)
            
            # Ensure tensor is valid
            if not isinstance(mel_spectrogram, torch.Tensor):
                dummy_input = torch.randn(3, 224, 224)
                print(f"Warning: Transform failed for {file_path}")
                return dummy_input, torch.tensor(label, dtype=torch.long)
            
            return mel_spectrogram, torch.tensor(label, dtype=torch.long)
            
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            dummy_input = torch.randn(3, 224, 224)
            return dummy_input, torch.tensor(label, dtype=torch.long)

In [None]:
# Cell 4: Audio processing functions
def load_audio(file_path, sr=16000):
    try:
        audio, sr = librosa.load(file_path, sr=sr)
        return audio, sr
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None, None
# Test with a sample file to verify


In [None]:
# Cell 5: Dataset class
class AudioDataset(Dataset):
    def __init__(self, base_path, protocol_path, transform=None):
        self.transform = transform
        self.file_paths = []
        # Add initialization code
        
    def __getitem__(self, idx):
        # Add data loading code
        pass
# Initialize with small subset to test


In [None]:
# Cell 6: Model architecture
class EnhancedASVResNet(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        self.backbone = timm.create_model('efficientnet_b0', pretrained=True)
        # Add model components
        
    def forward(self, x):
        # Add forward pass
        pass
# Create small model instance to verify


In [None]:
# Cell 7: Training utilities
def train_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    # Add training loop
    return running_loss


In [None]:
# Define TransformerBlock
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=8, mlp_ratio=4., dropout=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * mlp_ratio)),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(int(dim * mlp_ratio), dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        x = x + self.attn(*[self.norm1(x)]*3)[0]
        x = x + self.mlp(self.norm2(x))
        return x

# Define SETransformerBlock
class SETransformerBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = SEBlock(channels, reduction)
        self.transformer = TransformerBlock(channels)
        
    def forward(self, x):
        b, c, h, w = x.shape
        # SE attention
        x = self.se(x)
        # Transformer attention
        x_trans = rearrange(x, 'b c h w -> (h w) b c')
        x_trans = self.transformer(x_trans)
        x_trans = rearrange(x_trans, '(h w) b c -> b c h w', h=h, w=w)
        return x + x_trans

# Define MultiScaleFeatureFusion
class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.branch1 = nn.Sequential(
            nn.Conv2d(channels, channels//4, 1),
            nn.BatchNorm2d(channels//4),
            nn.ReLU()
        )
        self.branch2 = nn.Sequential(
            nn.Conv2d(channels, channels//4, 3, padding=1, dilation=1),
            nn.BatchNorm2d(channels//4),
            nn.ReLU()
        )
        self.branch3 = nn.Sequential(
            nn.Conv2d(channels, channels//4, 3, padding=2, dilation=2),
            nn.BatchNorm2d(channels//4),
            nn.ReLU()
        )
        self.branch4 = nn.Sequential(
            nn.Conv2d(channels, channels//4, 3, padding=4, dilation=4),
            nn.BatchNorm2d(channels//4),
            nn.ReLU()
        )
        self.fusion = nn.Sequential(
            nn.Conv2d(channels, channels, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)
        x4 = self.branch4(x)
        return self.fusion(torch.cat([x1, x2, x3, x4], dim=1))

In [None]:
# Cell 9: Main training loop
def main():
    # Initialize components
    model = EnhancedASVResNet().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters())
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        train_loss = train_epoch(model, train_loader, criterion, optimizer)
        val_loss = validate(model, val_loader, criterion)
        print(f"Epoch {epoch}: Train loss={train_loss:.4f}, Val loss={val_loss:.4f}")


In [None]:
# Cell 10: Execute training
if __name__ == "__main__":
    main()
