In [9]:
import os
import time
import numpy as np
import pandas as pd
import librosa

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
from pydiffres import DiffRes

class ESC50Dataset(Dataset):
    def __init__(self, csv_path, audio_dir, split="train", sr=16000):
        self.meta = pd.read_csv(csv_path)
        self.audio_dir = audio_dir
        self.sr = sr
        self.target_len = sr * 5  # 5 seconds

        if split == "train":
            self.meta = self.meta[self.meta["fold"] <= 3]
        elif split == "val":
            self.meta = self.meta[self.meta["fold"] == 4]
        else:
            self.meta = self.meta[self.meta["fold"] == 5]

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        row = self.meta.iloc[idx]
        path = os.path.join(self.audio_dir, row["filename"])

        y, _ = librosa.load(path, sr=self.sr)

        if len(y) < self.target_len:
            y = np.pad(y, (0, self.target_len - len(y)))
        else:
            y = y[:self.target_len]

        mel = librosa.feature.melspectrogram(
            y=y,
            sr=self.sr,
            n_fft=400,      # 25 ms
            hop_length=160, # 10 ms
            n_mels=128
        )
        mel = librosa.power_to_db(mel)
        mel = torch.from_numpy(mel).float().transpose(0, 1)  # [T, F]

        label = int(row["target"])
        return mel, label


class ESC50BaselineNet(nn.Module):
    def __init__(self, num_classes=50):
        super().__init__()

        self.resize = transforms.Resize((224, 224))
        self.backbone = EfficientNet.from_pretrained("efficientnet-b2")

        for p in self.backbone.parameters():
            p.requires_grad = False

        in_features = self.backbone._fc.in_features
        self.backbone._fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        # x: [B, T, F]
        x = x.unsqueeze(1)          # [B, 1, T, F]
        x = x.permute(0, 1, 3, 2)   # [B, 1, F, T]
        x = x.repeat(1, 3, 1, 1)    # [B, 3, F, T]
        x = self.resize(x)
        return self.backbone(x)


class ESC50DiffResNet(nn.Module):
    def __init__(self, T, F_dim=128, num_classes=50):
        super().__init__()

        self.diffres = DiffRes(
            in_t_dim=T,
            in_f_dim=F_dim,
            dimension_reduction_rate=0.5,
            learn_pos_emb=False
        )

        # Conv layer to convert DiffRes multi-channel output to 3 channels
        self.channel_adapter = nn.Conv2d(
            in_channels=9,  # DiffRes outputs 9 channels
            out_channels=3,
            kernel_size=1,
            bias=False
        )

        self.resize = transforms.Resize((224, 224))
        self.backbone = EfficientNet.from_pretrained("efficientnet-b2")

        for p in self.backbone.parameters():
            p.requires_grad = False

        in_features = self.backbone._fc.in_features
        self.backbone._fc = nn.Linear(in_features, num_classes)

    def forward(self, x):
        ret = self.diffres(x)
        feat = ret["feature"]
        guide_loss = ret["guide_loss"]

        # DiffRes output could be [B, C, T', F'] or [B, T', F']
        if feat.dim() == 3:
            # [B, T', F'] -> [B, 1, T', F']
            feat = feat.unsqueeze(1)
        
        # Now feat is [B, C, T', F']
        # Permute to [B, C, F', T'] for compatibility with vision models
        feat = feat.permute(0, 1, 3, 2)
        
        # If it has 9 channels, convert to 3 channels
        if feat.size(1) == 9:
            feat = self.channel_adapter(feat)  # [B, 9, F', T'] -> [B, 3, F', T']
        elif feat.size(1) == 1:
            # If only 1 channel, repeat to 3
            feat = feat.repeat(1, 3, 1, 1)
        
        feat = self.resize(feat)  # [B, 3, 224, 224]
        out = self.backbone(feat)
        return out, guide_loss


def train_one_epoch(model, loader, optimizer, criterion, device, use_diffres):
    """Train for one full epoch"""
    model.train()
    start = time.time()
    total_loss = 0
    num_batches = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()

        if use_diffres:
            out, g = model(x)
            loss = criterion(out, y) + 0.5 * g.mean()
        else:
            out = model(x)
            loss = criterion(out, y)

        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    elapsed = time.time() - start
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    return elapsed, avg_loss


def train_multiple_epochs(model, loader, optimizer, criterion, device, use_diffres, num_epochs=10):
    """Train for multiple epochs and return statistics"""
    epoch_times = []
    epoch_losses = []
    
    print(f"Training for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        elapsed, avg_loss = train_one_epoch(model, loader, optimizer, criterion, device, use_diffres)
        epoch_times.append(elapsed)
        epoch_losses.append(avg_loss)
        
        print(f"  Epoch {epoch+1}/{num_epochs}: {elapsed:.2f}s, Loss: {avg_loss:.4f}")
    
    return {
        'times': epoch_times,
        'losses': epoch_losses,
        'avg_time': np.mean(epoch_times),
        'std_time': np.std(epoch_times),
        'avg_loss': np.mean(epoch_losses)
    }


def measure_memory(model, loader, device, use_diffres):
    if not torch.cuda.is_available():
        return 0
    
    torch.cuda.reset_peak_memory_stats()
    model.eval()
    x, _ = next(iter(loader))
    x = x.to(device)

    with torch.no_grad():
        _ = model(x) if not use_diffres else model(x)[0]

    return torch.cuda.max_memory_allocated() / 1024**2


def measure_inference(model, loader, device, use_diffres, repeat=50):
    model.eval()
    x, _ = next(iter(loader))
    x = x.to(device)

    # Warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(x) if not use_diffres else model(x)[0]

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    start = time.time()
    with torch.no_grad():
        for _ in range(repeat):
            _ = model(x) if not use_diffres else model(x)[0]

    if torch.cuda.is_available():
        torch.cuda.synchronize()
    
    return (time.time() - start) / repeat


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}\n")

    # Load datasets
    train_ds = ESC50Dataset("ESC-50-master/meta/esc50.csv",
                            "ESC-50-master/audio", "train")
    test_ds = ESC50Dataset("ESC-50-master/meta/esc50.csv",
                           "ESC-50-master/audio", "test")

    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
    test_loader = DataLoader(test_ds, batch_size=16, num_workers=2)

    # Get input dimensions
    sample_x, _ = train_ds[0]
    T, F = sample_x.shape
    print(f"Input shape: T={T}, F={F}")
    print(f"Training samples: {len(train_ds)}")
    print(f"Test samples: {len(test_ds)}\n")

    # Initialize models
    baseline = ESC50BaselineNet().to(device)
    diffres = ESC50DiffResNet(T, F).to(device)

    criterion = nn.CrossEntropyLoss()
    opt_base = optim.Adam(baseline.parameters(), lr=2.5e-4)
    opt_diff = optim.Adam(diffres.parameters(), lr=2.5e-4)

    # ========== TRAINING TIME COMPARISON ==========
    print("="*60)
    print("TRAINING TIME COMPARISON (10 EPOCHS)")
    print("="*60)
    
    print("\n[BASELINE MODEL]")
    baseline_stats = train_multiple_epochs(
        baseline, train_loader, opt_base, criterion, device, False, num_epochs=10
    )
    
    print("\n[DIFFRES MODEL]")
    diffres_stats = train_multiple_epochs(
        diffres, train_loader, opt_diff, criterion, device, True, num_epochs=10
    )

    # Print summary
    print("\n" + "="*60)
    print("TRAINING SUMMARY")
    print("="*60)
    print(f"Baseline - Avg time/epoch: {baseline_stats['avg_time']:.2f}s ± {baseline_stats['std_time']:.2f}s")
    print(f"           Total time: {sum(baseline_stats['times']):.2f}s")
    print(f"           Final loss: {baseline_stats['losses'][-1]:.4f}")
    print()
    print(f"DiffRes  - Avg time/epoch: {diffres_stats['avg_time']:.2f}s ± {diffres_stats['std_time']:.2f}s")
    print(f"           Total time: {sum(diffres_stats['times']):.2f}s")
    print(f"           Final loss: {diffres_stats['losses'][-1]:.4f}")
    print()
    
    if diffres_stats['avg_time'] > 0:
        speedup = baseline_stats['avg_time'] / diffres_stats['avg_time']
        print(f"Training speedup: {speedup:.2f}x")
        if speedup > 1:
            print(f"DiffRes is {speedup:.2f}x FASTER")
        else:
            print(f"DiffRes is {1/speedup:.2f}x SLOWER")

    # ========== INFERENCE TIME ==========
    print("\n" + "="*60)
    print("INFERENCE TIME")
    print("="*60)
    i_base = measure_inference(baseline, test_loader, device, False)
    i_diff = measure_inference(diffres, test_loader, device, True)

    print(f"Baseline: {i_base*1000:.2f} ms/batch")
    print(f"DiffRes:  {i_diff*1000:.2f} ms/batch")
    if i_diff > 0:
        print(f"Inference speedup: {i_base/i_diff:.2f}x")

    # ========== GPU MEMORY ==========
    if torch.cuda.is_available():
        print("\n" + "="*60)
        print("GPU MEMORY USAGE")
        print("="*60)
        m_base = measure_memory(baseline, test_loader, device, False)
        m_diff = measure_memory(diffres, test_loader, device, True)

        print(f"Baseline: {m_base:.1f} MB")
        print(f"DiffRes:  {m_diff:.1f} MB")
        if m_base > 0:
            reduction = (1 - m_diff/m_base) * 100
            print(f"Memory {'reduction' if reduction > 0 else 'increase'}: {abs(reduction):.1f}%")

    # ========== THROUGHPUT ==========
    bs = next(iter(test_loader))[0].size(0)
    print("\n" + "="*60)
    print("THROUGHPUT (FPS)")
    print("="*60)
    print(f"Baseline: {bs/i_base:.1f} samples/sec")
    print(f"DiffRes:  {bs/i_diff:.1f} samples/sec")
    print("="*60)

Using device: cuda

Input shape: T=501, F=128
Training samples: 1200
Test samples: 400

Loaded pretrained weights for efficientnet-b2
Loaded pretrained weights for efficientnet-b2
TRAINING TIME COMPARISON (10 EPOCHS)

[BASELINE MODEL]
Training for 10 epochs...
  Epoch 1/10: 31.64s, Loss: 3.8293
  Epoch 2/10: 29.58s, Loss: 3.5227
  Epoch 3/10: 27.55s, Loss: 3.2758
  Epoch 4/10: 27.78s, Loss: 3.0510
  Epoch 5/10: 27.38s, Loss: 2.8544
  Epoch 6/10: 27.78s, Loss: 2.6717
  Epoch 7/10: 30.08s, Loss: 2.4880
  Epoch 8/10: 27.43s, Loss: 2.3901
  Epoch 9/10: 28.83s, Loss: 2.2255
  Epoch 10/10: 28.91s, Loss: 2.1617

[DIFFRES MODEL]
Training for 10 epochs...


  activeness = torch.std(importance_score[id][~score_mask[id]])


  Epoch 1/10: 30.40s, Loss: 4.0224
  Epoch 2/10: 30.76s, Loss: 3.7517
  Epoch 3/10: 30.90s, Loss: 3.5291
  Epoch 4/10: 29.54s, Loss: 3.3300
  Epoch 5/10: 32.00s, Loss: 3.1432
  Epoch 6/10: 32.06s, Loss: 2.9962
  Epoch 7/10: 32.78s, Loss: 2.8328
  Epoch 8/10: 30.95s, Loss: 2.6875
  Epoch 9/10: 28.20s, Loss: 2.5799
  Epoch 10/10: 30.83s, Loss: 2.4795

TRAINING SUMMARY
Baseline - Avg time/epoch: 28.70s ± 1.33s
           Total time: 286.96s
           Final loss: 2.1617

DiffRes  - Avg time/epoch: 30.84s ± 1.24s
           Total time: 308.41s
           Final loss: 2.4795

Training speedup: 0.93x
DiffRes is 1.07x SLOWER

INFERENCE TIME
Baseline: 32.75 ms/batch
DiffRes:  50.78 ms/batch
Inference speedup: 0.65x

GPU MEMORY USAGE
Baseline: 329.3 MB
DiffRes:  341.0 MB
Memory increase: 3.6%

THROUGHPUT (FPS)
Baseline: 488.5 samples/sec
DiffRes:  315.1 samples/sec
