In [35]:
import os
import json
import random
import torch
import torchaudio
from kan import KAN
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torchaudio.transforms as T

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [18]:
import os
print(os.path.isfile('./data/Dipco/audio/dev/S10_P29.wav'))

True


In [19]:
import torchaudio

# Load the audio file
waveform, sample_rate = torchaudio.load('./data/Dipco/audio/dev/S10_P29.wav')

# Print the shape of the waveform tensor
print("Shape of waveform:", waveform.shape)

# Print the sample rate
print("Sample rate:", sample_rate)

Shape of waveform: torch.Size([1, 19349333])
Sample rate: 16000


In [40]:
class DiPCoSeparationDataset(Dataset):
    def __init__(self, root_dir, session_ids, version='dev', 
                 segment_length=3, sr=16000, max_pairs=1000, 
                 n_mels=128, mel_reduction='mean'):
        self.sr = sr
        self.segment_samples = int(sr * segment_length)
        self.n_mels = n_mels

        self.mel_reduction = mel_reduction
        self.session_data = {}

        self.mel_transform = T.MelSpectrogram(
            sample_rate=sr,
            n_fft=1024,
            hop_length=256,
            n_mels=n_mels
        )

        for session in session_ids:
            trans_path = os.path.join(root_dir, 'transcriptions', version, f'{session}.json')
            with open(trans_path, 'r') as f:
                session_trans = json.load(f)
            
            self.session_data[session] = []
            for utt in session_trans:
                audio_path = os.path.join(root_dir, 'audio', version,
                                        f"{utt['session_id']}_{utt['speaker_id']}.wav")
                if os.path.exists(audio_path):
                    self.session_data[session].append({
                        'audio_path': audio_path,
                        'start': utt['start_time'],
                        'end': utt['end_time']
                    })
        
        self.pairs = []
        for session in session_ids:
            speakers = self.session_data[session]
            for i in range(len(speakers)):
                for j in range(i+1, len(speakers)):
                    self.pairs.append((session, i, j))
                    if len(self.pairs) >= max_pairs:
                        break
                if len(self.pairs) >= max_pairs:
                    break

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        session, i, j = self.pairs[idx]
        utt1 = self.session_data[session][i]
        utt2 = self.session_data[session][j]

        audio1 = self._load_full_audio(utt1['audio_path'])
        audio2 = self._load_full_audio(utt2['audio_path'])

        min_len = min(audio1.shape[0], audio2.shape[0])

        if min_len < self.segment_samples:
            audio1 = self._pad_segment(audio1)
            audio2 = self._pad_segment(audio2)
            segment1 = audio1
            segment2 = audio2
        else:
            max_start = min_len - self.segment_samples
            start = random.randint(0, max_start)
            segment1 = audio1[start:start + self.segment_samples]
            segment2 = audio2[start:start + self.segment_samples]

        mel1 = self._to_mel(audio1)
        mel2 = self._to_mel(audio2)

        mixed = (mel1 + mel2).clamp(min=0)  # Mel spectrograms are non-negative

        return mixed, torch.stack([mel1, mel2])  # Shape: [2, n_mels] or [2, n_mels, T]

    def _load_full_audio(self, path):
        waveform, orig_sr = torchaudio.load(path)
        if orig_sr != self.sr:
            resampler = torchaudio.transforms.Resample(orig_sr, self.sr)
            waveform = resampler(waveform)
        waveform = torch.mean(waveform, dim=0)  # Convert to mono
        return waveform

    def _pad_segment(self, audio):
        if len(audio) < self.segment_samples:
            padding = self.segment_samples - len(audio)
            audio = torch.nn.functional.pad(audio, (0, padding))
        else:
            audio = audio[:self.segment_samples]
        return audio
    

    def _to_mel(self, audio):
        mel = self.mel_transform(audio.unsqueeze(0))  # [1, n_mels, T]
        if self.mel_reduction == 'mean':
            return mel.mean(dim=-1).squeeze(0)  # [n_mels]
        elif self.mel_reduction == 'flatten':
            return mel.squeeze(0).flatten() # [n_mels * T]
        else:
            return mel.squeeze(0) # [n_mels, T]

In [44]:
class SeparationKAN(nn.Module):
    def __init__(self, input_size=48000, encoder:list[int]=[256, 256], latent_dim=256, decoder:list[int] = [256, 256]):
        super().__init__()
        self.encoder = KAN(width=[input_size, *encoder, latent_dim])
        self.decoder1 = KAN(width=[latent_dim//2, *decoder, input_size])
        self.decoder2 = KAN(width=[latent_dim//2, *decoder, input_size])

    def forward(self, x):
        latent = self.encoder(x)
        latent1, latent2 = torch.chunk(latent, 2, dim=-1)
        return torch.stack([
            self.decoder1(latent1),
            self.decoder2(latent2)
        ], dim=1)

In [45]:
class PermutationInvariantLoss(nn.Module):
    def __init__(self, base_loss=nn.L1Loss()):
        super().__init__()
        self.base_loss = base_loss

    def forward(self, outputs, targets):
        loss_a = self.base_loss(outputs[:,0], targets[:,0]) + \
                 self.base_loss(outputs[:,1], targets[:,1])
        loss_b = self.base_loss(outputs[:,0], targets[:,1]) + \
                 self.base_loss(outputs[:,1], targets[:,0])
        return torch.min(loss_a, loss_b)

In [46]:
dataset = DiPCoSeparationDataset(
    root_dir='./data/Dipco/',
    session_ids=['S02', 'S04', 'S05', 'S09', 'S10'],
    segment_length=3,
)
train_loader = DataLoader(dataset, batch_size=1, shuffle=True)
print(len(train_loader))
for mixed, clean in train_loader:
    print(mixed.shape, clean.shape)
    print(mixed, clean)
    break

1004
torch.Size([1, 128]) torch.Size([1, 2, 128])
tensor([[2.3471e-04, 7.8877e-04, 1.7647e-03, 2.8701e-03, 4.1116e-03, 8.3375e-03,
         3.6199e-02, 1.7696e-01, 4.4592e-01, 6.2209e-01, 6.2036e-01, 4.8691e-01,
         3.5304e-01, 2.6162e-01, 2.3135e-01, 2.7661e-01, 3.0550e-01, 3.3568e-01,
         3.3718e-01, 3.0596e-01, 2.9672e-01, 2.9946e-01, 2.4864e-01, 3.0665e-01,
         3.5024e-01, 3.4213e-01, 3.5246e-01, 3.1802e-01, 3.5571e-01, 2.6817e-01,
         2.7617e-01, 2.8831e-01, 2.4807e-01, 2.3035e-01, 2.1701e-01, 1.5719e-01,
         9.3636e-02, 8.2106e-02, 6.6607e-02, 6.7845e-02, 4.9850e-02, 3.8268e-02,
         3.5620e-02, 3.2617e-02, 3.8799e-02, 2.9091e-02, 2.5990e-02, 2.6568e-02,
         2.5496e-02, 2.4307e-02, 2.8702e-02, 3.3976e-02, 2.4577e-02, 2.3908e-02,
         2.5708e-02, 2.6513e-02, 2.6127e-02, 2.4911e-02, 2.6735e-02, 2.2643e-02,
         2.1462e-02, 2.1346e-02, 2.0697e-02, 1.7069e-02, 1.3754e-02, 1.2777e-02,
         1.1058e-02, 7.7549e-03, 6.3068e-03, 6.1289e-03, 4.

In [47]:
model = SeparationKAN(input_size=128)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = PermutationInvariantLoss()

# Training loop
for epoch in range(1):
    epoch_loss = 0
    for mixed, clean in train_loader:
        optimizer.zero_grad()
        outputs = model(mixed)
        loss = criterion(outputs, clean)
        print(loss)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {epoch_loss/len(train_loader):.4f}")


checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0


  self.subnode_actscale.append(torch.std(x, dim=0).detach())
  input_range = torch.std(preacts, dim=0) + 0.1
  output_range_spline = torch.std(postacts_numerical, dim=0) # for training, only penalize the spline part
  output_range = torch.std(postacts, dim=0) # for visualization, include the contribution from both spline + symbolic


KeyboardInterrupt: 