In [None]:
# !unzip "/content/2Speakers5KHalf.zip" -d "/content/"
!unzip "/content/2Speakers5KHalfPreprocessed.zip" -d "/content/"
# !unzip smalleroutput3.zip

In [None]:
import zipfile
import os

src_zip = '/content/drive/MyDrive/Colab Notebooks/conv_tasnet_data/2Speakers5KHalfPreprocessed.zip'
dest_dir = '/content/Speakers5KHalfPreprocessed'

os.makedirs(dest_dir, exist_ok=True)

with zipfile.ZipFile(src_zip, 'r') as zip_ref:
    zip_ref.extractall(dest_dir)

print(f"Zip file extracted to: {dest_dir}")

Zip file extracted to: /content/Speakers5KHalfPreprocessed


In [None]:
# Ordinary Conv-TasNet

import torch
import torchaudio
import itertools
print(torch.__version__)
print(torchaudio.__version__)
from IPython.display import Audio
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from torchaudio.pipelines import CONVTASNET_BASE_LIBRI2MIX
from torchaudio.utils import download_asset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import os
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
import time
import csv

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# SNR Metric
def si_snr_metric(estimation, original):
    estimation = estimation - estimation.mean(dim=-1, keepdim=True)
    original = original - original.mean(dim=-1, keepdim=True)
    s_target = torch.sum(estimation * original, dim=-1, keepdim=True) * original / (torch.sum(original**2, dim=-1, keepdim=True) + 1e-8)
    e_noise = estimation - s_target
    snr = torch.sum(s_target**2, dim=-1) / (torch.sum(e_noise**2, dim=-1) + 1e-8)
    si_snr = 10 * torch.log10(snr)
    return si_snr

def permutation_invariant_snr_loss(estimation, originals):
    batch_size, num_sources, _ = estimation.shape
    max_snr = torch.full((batch_size,), float('-inf'), device=estimation.device)

    for permutation in itertools.permutations(range(num_sources)):
        permuted_originals = originals[:, permutation, :]
        si_snr_values = torch.stack([si_snr_metric(estimation[:, i, :], permuted_originals[:, i, :]) for i in range(num_sources)], dim=1)
        batch_mean_snr = si_snr_values.mean(dim=1)  # Mean over sources
        max_snr = torch.maximum(max_snr, batch_mean_snr)

    # Mean over batch
    loss = -max_snr.mean() # negative because optimizer minimizes loss and we want to maximize snr
    return loss

def load_and_preprocess_audio(file_path, target_sample_rate=8000):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != target_sample_rate:
        resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform

class AudioDataset(Dataset):
    def __init__(self, root_dir, target_sample_rate=8000):
        self.root_dir = root_dir
        self.target_sample_rate = target_sample_rate
        self.samples = []

        # Walk through the directory structure
        for dirpath, _, filenames in os.walk(root_dir):
            mixed_file = None
            originals = []

            # Sort files to ensure consistent order
            filenames.sort()

            # Filter and categorize files
            for filename in filenames:
                if "mixed" in filename:
                    mixed_file = os.path.join(dirpath, filename)
                else:
                    originals.append(os.path.join(dirpath, filename))

            if len(originals) != NUM_SOURCES:
                continue

            # Group mixed and original files if mixed file exists
            if mixed_file and originals:
                self.samples.append((mixed_file, originals))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mixed_file, original_files = self.samples[idx]

        # Load and preprocess mixed waveform
        mixed_waveform = load_and_preprocess_audio(mixed_file, self.target_sample_rate)

        # Load and preprocess each original waveform
        originals_waveforms = [load_and_preprocess_audio(file, self.target_sample_rate) for file in original_files]

        return mixed_waveform, originals_waveforms

def pad_sequence(batch):
    max_len = max([s.size(-1) for s in batch])
    # Pad all sequences to have the same length
    batch = [torch.nn.functional.pad(s, (0, max_len - s.size(-1))) for s in batch]
    return torch.stack(batch, dim=0)

def custom_collate_fn(batch):
    all_waveforms = [item for sublist in batch for item in sublist]
    max_len = max([waveform.size(-1) for pair in batch for waveform in pair[1]] + [pair[0].size(-1) for pair in batch])

    padded_mixed = [torch.nn.functional.pad(pair[0], (0, max_len - pair[0].size(-1))).to(device) for pair in batch]
    padded_originals = [[torch.nn.functional.pad(waveform, (0, max_len - waveform.size(-1))).to(device) for waveform in pair[1]] for pair in batch]

    padded_mixed = torch.stack(padded_mixed, dim=0)
    padded_originals = [torch.stack(list_waveforms, dim=0) for list_waveforms in zip(*padded_originals)]

    return padded_mixed, padded_originals

for NUM_SOURCES in range(2, 5):
    model = torchaudio.models.ConvTasNet(
        num_sources=NUM_SOURCES,
        enc_kernel_size=16,
        enc_num_feats=512,
        msk_kernel_size=3,
        msk_num_feats=128,
        msk_num_hidden_feats=512,
        msk_num_layers=8,
        msk_num_stacks=3,
        msk_activate='sigmoid'
    ).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    root_directory = f'/content/Speakers5KHalf/{NUM_SOURCES}Speakers5KHalf'
    dataset = AudioDataset(root_directory)

    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=custom_collate_fn)

    all_loss = []
    timestamps = []
    gpu_memory_usage = []
    start_time = time.time()

    counter = 1
    start_time = time.time()
    print("Epoch 0")
    for epoch in range(50):
        for mixed, originals in dataloader:
            optimizer.zero_grad()
            separated_sources = model(mixed)
            original_sources = torch.stack(originals, dim=1)
            original_sources = original_sources.squeeze(2)

            loss = permutation_invariant_snr_loss(separated_sources, original_sources)
            loss.backward()
            optimizer.step()

            elapsed_time = time.time() - start_time
            current_gpu_memory = torch.cuda.memory_allocated(device) / (1024 * 1024)  # Convert bytes to MB
            timestamps.append(elapsed_time)
            all_loss.append(loss.item())
            gpu_memory_usage.append(current_gpu_memory)
            print(f"Iteration {counter}: Loss = {loss.item()} Time = {elapsed_time}s GPU Memory = {current_gpu_memory} MB")
            counter += 1

        print(f'\n\nEpoch {epoch+1}, Loss: {loss.item()}\n')

    # Saving results to CSV
    with open(f'/content/drive/MyDrive/Colab Notebooks/conv_tasnet_data/training_log_{NUM_SOURCES}SpeakerHalfB=64Epoch50.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Iteration', 'Loss', 'Time', 'GPU Memory (MB)'])
        for i, (loss, timestamp, gpu_memory) in enumerate(zip(all_loss, timestamps, gpu_memory_usage)):
            writer.writerow([i, loss, timestamp, gpu_memory])

    print("Training log saved to 'training_log.csv'.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iteration 3098: Loss = 0.2766939103603363 Time = 1655.2319116592407s GPU Memory = 86.9326171875 MB
Iteration 3099: Loss = 0.1233791708946228 Time = 1655.769723892212s GPU Memory = 96.6982421875 MB
Iteration 3100: Loss = 0.6494510173797607 Time = 1656.3088710308075s GPU Memory = 86.9326171875 MB
Iteration 3101: Loss = 0.30677855014801025 Time = 1656.8342418670654s GPU Memory = 96.6982421875 MB
Iteration 3102: Loss = 0.2691302001476288 Time = 1657.3644750118256s GPU Memory = 86.9326171875 MB
Iteration 3103: Loss = 0.29814738035202026 Time = 1657.899381160736s GPU Memory = 96.6982421875 MB
Iteration 3104: Loss = 0.11446712911128998 Time = 1658.4323675632477s GPU Memory = 86.9326171875 MB
Iteration 3105: Loss = 0.07494156062602997 Time = 1658.9678382873535s GPU Memory = 96.6982421875 MB
Iteration 3106: Loss = 0.4840344786643982 Time = 1659.5045080184937s GPU Memory = 96.6982421875 MB
Iteration 3107: Loss = 0.04542829841375351

In [None]:
# Limited Pre Conv-TasNet

import torch
import torchaudio
import itertools
print(torch.__version__)
print(torchaudio.__version__)
from IPython.display import Audio
from torchaudio.utils import download_asset
import os
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
import time
import csv
from torch.utils.data import RandomSampler


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
generator = torch.Generator(device=device)

# SNR Metric
def si_snr_metric(estimation, original):
    estimation = estimation - estimation.mean(dim=-1, keepdim=True)
    original = original - original.mean(dim=-1, keepdim=True)
    s_target = torch.sum(estimation * original, dim=-1, keepdim=True) * original / (torch.sum(original**2, dim=-1, keepdim=True) + 1e-8)
    e_noise = estimation - s_target
    snr = torch.sum(s_target**2, dim=-1) / (torch.sum(e_noise**2, dim=-1) + 1e-8)
    si_snr = 10 * torch.log10(snr + 1e-16)
    return si_snr

def permutation_invariant_snr_loss(estimation, originals):
    batch_size, num_sources, _ = estimation.shape
    max_snr = torch.full((batch_size,), float('-inf'), device=estimation.device)

    for permutation in itertools.permutations(range(num_sources)):
        permuted_originals = originals[:, permutation, :]
        si_snr_values = torch.stack([si_snr_metric(estimation[:, i, :], permuted_originals[:, i, :]) for i in range(num_sources)], dim=1)
        batch_mean_snr = si_snr_values.mean(dim=1)  # Mean over sources
        max_snr = torch.maximum(max_snr, batch_mean_snr)

    # Mean over batch
    loss = -max_snr.mean()  # negative because optimizer minimizes loss and we want to maximize snr
    return loss

def load_and_preprocess_audio(file_path, target_sample_rate=8000):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != target_sample_rate:
        resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform

class AudioDataset(Dataset):
    def __init__(self, root_dir, num_sources, target_sample_rate=8000):
        self.root_dir = root_dir
        self.target_sample_rate = target_sample_rate
        self.samples = []
        for dirpath, _, filenames in os.walk(root_dir):
            mixed_file = None
            originals = []
            filenames.sort()
            for filename in filenames:
                if "mixed" in filename:
                    mixed_file = os.path.join(dirpath, filename)
                else:
                    originals.append(os.path.join(dirpath, filename))
            if len(originals) != num_sources:
                continue
            # Group mixed and original files if mixed file exists
            if mixed_file and originals:
                self.samples.append((mixed_file, originals))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mixed_file, original_files = self.samples[idx]
        # Load and preprocess mixed waveform
        mixed_waveform = load_and_preprocess_audio(mixed_file, self.target_sample_rate)
        originals_waveforms = [load_and_preprocess_audio(file, self.target_sample_rate) for file in original_files]
        return mixed_waveform, originals_waveforms

def pad_sequence(batch):
    max_len = max([s.size(-1) for s in batch])
    # Pad all sequences to have the same length
    batch = [torch.nn.functional.pad(s, (0, max_len - s.size(-1))) for s in batch]
    return torch.stack(batch, dim=0)

def custom_collate_fn(batch):
    all_waveforms = [item for sublist in batch for item in sublist]
    max_len = max([waveform.size(-1) for pair in batch for waveform in pair[1]] + [pair[0].size(-1) for pair in batch])
    padded_mixed = [torch.nn.functional.pad(pair[0], (0, max_len - pair[0].size(-1))).to(device) for pair in batch]
    padded_originals = [[torch.nn.functional.pad(waveform, (0, max_len - waveform.size(-1))).to(device) for waveform in pair[1]] for pair in batch]
    padded_mixed = torch.stack(padded_mixed, dim=0)
    padded_originals = [torch.stack(list_waveforms, dim=0) for list_waveforms in zip(*padded_originals)]

    return padded_mixed, padded_originals

for NUM_SOURCES in range(2, 5):
    model = torchaudio.models.ConvTasNet(
        num_sources=NUM_SOURCES,
        enc_kernel_size=16,
        enc_num_feats=512,
        msk_kernel_size=3,
        msk_num_feats=128,
        msk_num_hidden_feats=512,
        msk_num_layers=8,
        msk_num_stacks=3,
        msk_activate='sigmoid'
    ).to(device)

    # Enable automatic mixed precision
    use_amp = True
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    root_directory = f'/content/Speakers5KHalfPreprocessed/{NUM_SOURCES}Speakers5KHalfPreprocessed'
    dataset = AudioDataset(root_directory, NUM_SOURCES)

    sampler = RandomSampler(dataset, generator=torch.Generator(device=device))
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, collate_fn=custom_collate_fn)

    all_loss = []
    timestamps = []
    gpu_memory_usage = []
    start_time = time.time()

    counter = 1
    start_time = time.time()
    print("Epoch 0")
    for epoch in range(50):
        for mixed, originals in dataloader:
            with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
                separated_sources = model(mixed)
                original_sources = torch.stack(originals, dim=1)
                original_sources = original_sources.squeeze(2)
                loss = permutation_invariant_snr_loss(separated_sources, original_sources)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

            elapsed_time = time.time() - start_time
            current_gpu_memory = torch.cuda.memory_allocated(device) / (1024 * 1024)  # Convert to MB
            timestamps.append(elapsed_time)
            all_loss.append(loss.item())
            gpu_memory_usage.append(current_gpu_memory)
            print(f"Iteration {counter}: Loss = {loss.item()} Time = {elapsed_time}s GPU Memory = {current_gpu_memory} MB")
            counter += 1

        print(f'\n\nEpoch {epoch+1}, Loss: {loss.item()}\n')

    # Saving results to CSV
    with open(f'/content/drive/MyDrive/Colab Notebooks/conv_tasnet_data/training_log_{NUM_SOURCES}Speaker_limited_halfB=64Epoch50.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Iteration', 'Loss', 'Time', 'GPU Memory (MB)'])
        for i, (loss, timestamp, gpu_memory) in enumerate(zip(all_loss, timestamps, gpu_memory_usage)):
            writer.writerow([i, loss, timestamp, gpu_memory])

    print(f"Training log saved to 'training_log_{NUM_SOURCES}Speaker_limited_half.csv'.")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Iteration 3098: Loss = -1.598955512046814 Time = 1233.7673757076263s GPU Memory = 66.54736328125 MB
Iteration 3099: Loss = -1.1516087055206299 Time = 1234.1653127670288s GPU Memory = 66.44580078125 MB
Iteration 3100: Loss = -1.7627147436141968 Time = 1234.5581171512604s GPU Memory = 66.5478515625 MB
Iteration 3101: Loss = -1.499795913696289 Time = 1234.954992055893s GPU Memory = 66.44580078125 MB
Iteration 3102: Loss = -1.4854400157928467 Time = 1235.3475306034088s GPU Memory = 66.54736328125 MB
Iteration 3103: Loss = -1.5390839576721191 Time = 1235.7399806976318s GPU Memory = 66.44580078125 MB
Iteration 3104: Loss = -1.392268180847168 Time = 1236.135347366333s GPU Memory = 66.54736328125 MB
Iteration 3105: Loss = -1.66135835647583 Time = 1236.5334961414337s GPU Memory = 66.44580078125 MB
Iteration 3106: Loss = -1.374371886253357 Time = 1236.9325890541077s GPU Memory = 66.54736328125 MB
Iteration 3107: Loss = -1.149978876

In [None]:
import zipfile
import os

# Paths to source zip file and destination directory
src_zip = '/content/drive/MyDrive/Colab Notebooks/conv_tasnet_data/3Speakers5KHalf.zip'
dest_dir = '/content/Speakers5KHalf'

# Make sure the destination directory exists
os.makedirs(dest_dir, exist_ok=True)

# Extract the zip file into the destination directory
with zipfile.ZipFile(src_zip, 'r') as zip_ref:
    zip_ref.extractall(dest_dir)

print(f"Zip file extracted to: {dest_dir}")

Zip file extracted to: /content/Speakers5KHalf


In [None]:
#Attention Conv-TasNet

import torch
import torchaudio
import itertools
from torch import nn
import os
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
import time
import csv
from torch.utils.data import RandomSampler


print(torch.__version__)
print(torchaudio.__version__)
# Set GPU to device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)
generator = torch.Generator(device=device)

def si_snr_metric(estimation, original):
    estimation = estimation - estimation.mean(dim=-1, keepdim=True)
    original = original - original.mean(dim=-1, keepdim=True)
    s_target = torch.sum(estimation * original, dim=-1, keepdim=True) * original / (torch.sum(original**2, dim=-1, keepdim=True) + 1e-6)
    e_noise = estimation - s_target
    snr = torch.sum(s_target**2, dim=-1) / (torch.sum(e_noise**2, dim=-1) + 1e-8)
    si_snr = 10 * torch.log10(snr)
    return si_snr

def permutation_invariant_snr_loss(estimation, originals):
    batch_size, num_sources, _ = estimation.shape
    max_snr = torch.full((batch_size,), float('-inf'), device=estimation.device)

    for permutation in itertools.permutations(range(num_sources)):
        permuted_originals = originals[:, permutation, :]
        si_snr_values = torch.stack([si_snr_metric(estimation[:, i, :], permuted_originals[:, i, :]) for i in range(num_sources)], dim=1)
        batch_mean_snr = si_snr_values.mean(dim=1)  # Mean over sources
        max_snr = torch.maximum(max_snr, batch_mean_snr)

    # Mean over batch
    loss = -max_snr.mean() # negative because optimizer minimizes loss and we want to maximize snr
    return loss

def load_and_preprocess_audio(file_path, target_sample_rate=8000):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != target_sample_rate:
        resampler = Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
        waveform = resampler(waveform)
    return waveform

def pad_sequence(batch):
    max_len = max([s.size(-1) for s in batch])
    batch = [torch.nn.functional.pad(s, (0, max_len - s.size(-1))) for s in batch]
    return torch.stack(batch, dim=0)

def custom_collate_fn(batch):
    max_len = max([waveform.size(-1) for pair in batch for waveform in pair[1]] + [pair[0].size(-1) for pair in batch])

    padded_mixed = [torch.nn.functional.pad(pair[0], (0, max_len - pair[0].size(-1))).to(device) for pair in batch]
    padded_originals = [[torch.nn.functional.pad(waveform, (0, max_len - waveform.size(-1))).to(device) for waveform in pair[1]] for pair in batch]

    padded_mixed = torch.stack(padded_mixed, dim=0)
    padded_originals = [torch.stack(list_waveforms, dim=0) for list_waveforms in zip(*padded_originals)]

    return padded_mixed, padded_originals

class AudioDataset(Dataset):
    def __init__(self, root_dir, num_sources, target_sample_rate=8000):
        self.root_dir = root_dir
        self.target_sample_rate = target_sample_rate
        self.samples = []
        for dirpath, _, filenames in os.walk(root_dir):
            mixed_file = None
            originals = []
            filenames.sort()
            for filename in filenames:
                if "mixed" in filename:
                    mixed_file = os.path.join(dirpath, filename)
                else:
                    originals.append(os.path.join(dirpath, filename))
            if len(originals) != num_sources:
                continue
            if mixed_file and originals:
                self.samples.append((mixed_file, originals))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mixed_file, original_files = self.samples[idx]
        mixed_waveform = load_and_preprocess_audio(mixed_file, self.target_sample_rate)
        originals_waveforms = [load_and_preprocess_audio(file, self.target_sample_rate) for file in original_files]
        return mixed_waveform, originals_waveforms


# Attention mechanism
class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_channels, kernel_size, num_layers):
        super(Encoder, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.Conv1d(input_channels, hidden_channels, kernel_size, stride=2, padding=kernel_size // 2)
            )
            layers.append(nn.ReLU())
            input_channels = hidden_channels
        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)

class AttentionLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(AttentionLayer, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.transpose(0, 1)  # [seq_len, batch_size, d_model]
        attn_output, _ = self.multihead_attn(x, x, x)
        x = x + self.dropout(attn_output)
        x = x.transpose(0, 1)  # [batch_size, seq_len, d_model]
        return self.norm(x)

class Decoder(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size, num_layers):
        super(Decoder, self).__init__()
        layers = []
        for _ in range(num_layers):
            layers.append(
                nn.ConvTranspose1d(input_channels, output_channels, kernel_size, stride=2, padding=kernel_size // 2)
            )
            layers.append(nn.ReLU())
            input_channels = output_channels
        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

class AttentionConvTasNet(nn.Module):
    def __init__(self, encoder_params, attention_params, decoder_params, tasnet_params):
        super(AttentionConvTasNet, self).__init__()
        self.encoder = Encoder(**encoder_params)

        # Projection layer to match attention `d_model`
        encoder_out_features = encoder_params['hidden_channels']
        self.projection = nn.Linear(encoder_out_features, attention_params['d_model'])

        self.attention_layer = AttentionLayer(**attention_params)
        self.decoder = Decoder(**decoder_params)
        self.conv_tasnet = torchaudio.models.ConvTasNet(**tasnet_params)

    def forward(self, x):
        encoded_features = self.encoder(x)
        
        
        # Ensure features are in the expected [Batch, Sequence, Feature] format for attention
        batch_size, num_features, sequence_length = encoded_features.shape
        encoded_features = encoded_features.permute(0, 2, 1)  # [Batch, Sequence, Feature]


        # Project to `d_model`
        encoded_features = self.projection(encoded_features)

        # Attention processing
        attention_output = self.attention_layer(encoded_features)

        # Reorder back to [Batch, Feature, Sequence] for Conv1D processing
        attention_output = attention_output.permute(0, 2, 1)

        # Decoder processing
        decoded_waveform = self.decoder(attention_output)

        # Final separation using Conv-TasNet
        separated_sources = self.conv_tasnet(decoded_waveform)
        return separated_sources


for NUM_SOURCES in range(4, 5):
    encoder_params = {
    'input_channels': 1,  # Adjust as necessary for mono audio
    'hidden_channels': 512,  
    'kernel_size': 3,
    'num_layers': 3
    }

    attention_params = {
        'd_model': 512,  # Match this with `hidden_channels`
        'num_heads': 8,  
        'dropout': 0.1
    }

    decoder_params = {
        'input_channels': 512,  # Match `d_model`
        'output_channels': 1,  
        'kernel_size': 3,
        'num_layers': 3
    }

    tasnet_params = {
        'num_sources': NUM_SOURCES,
        'enc_kernel_size': 16,
        'enc_num_feats': 512,
        'msk_kernel_size': 3,
        'msk_num_feats': 128,
        'msk_num_hidden_feats': 512,
        'msk_num_layers': 8,
        'msk_num_stacks': 3,
        'msk_activate': 'sigmoid'
    }

    # combined model
    model_with_attention = AttentionConvTasNet(encoder_params, attention_params, decoder_params, tasnet_params).to(device)
    optimizer = torch.optim.Adam(model_with_attention.parameters(), lr=1e-4)

    root_directory = f'/content/Speakers5KHalf/{NUM_SOURCES}Speakers5KHalf'
    dataset = AudioDataset(root_directory, NUM_SOURCES)
    sampler = RandomSampler(dataset, generator=torch.Generator(device=device))
    dataloader = DataLoader(dataset, batch_size=64, sampler=sampler, collate_fn=custom_collate_fn)

    all_loss = []
    timestamps = []
    gpu_memory_usage = []
    start_time = time.time()
    counter = 1
    use_amp = False
    scaler = torch.cuda.amp.GradScaler()

    print("Epoch 0")
    for epoch in range(50):
        for mixed, originals in dataloader:
            optimizer.zero_grad()  # Reset gradients before computing the next batch
            with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
                separated_sources = model_with_attention(mixed)

                original_sources = torch.stack(originals, dim=1)

                if original_sources.ndim == 4:
                    original_sources = original_sources.squeeze(2)
                expected_length = separated_sources.size(-1)
                actual_length = original_sources.size(-1)

                if actual_length != expected_length:
                    original_sources = original_sources[:, :, :expected_length]

                if separated_sources.shape != original_sources.shape:
                    raise RuntimeError(f"Shape mismatch: separated {separated_sources.shape}, original {original_sources.shape}")
                loss = permutation_invariant_snr_loss(separated_sources, original_sources)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            elapsed_time = time.time() - start_time
            current_gpu_memory = torch.cuda.memory_allocated(device) / (1024 * 1024)  # Convert to MB
            timestamps.append(elapsed_time)
            all_loss.append(loss.item())
            gpu_memory_usage.append(current_gpu_memory)

            print(f"Iteration {counter}: Loss = {loss.item()} Time = {elapsed_time}s")
            counter += 1

        print(f'\n\nEpoch {epoch + 1}, Loss: {loss.item()}\n')

    with open(f'/content/drive/MyDrive/Colab Notebooks/conv_tasnet_data/training_log_{NUM_SOURCES}Speaker_attentionEpoch50HighParam.csv', 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Iteration', 'Loss', 'Time', 'GPU Memory (MB)'])
        for i, (loss, timestamp, gpu_memory) in enumerate(zip(all_loss, timestamps, gpu_memory_usage)):
            writer.writerow([i, loss, timestamp, gpu_memory])

    print(f"Training log saved to 'training_log_{NUM_SOURCES}Speaker_attention.csv'.")
