In [None]:
import os
import gc
import math
import logging
import warnings
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.nn.modules.rnn import LSTM, GRU
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.linear import Linear
from torch.nn.modules.normalization import LayerNorm

import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader

# Auraloss
from auraloss.time import SISDRLoss
from auraloss.freq import MultiResolutionSTFTLoss

# Sound file
import soundfile as sf

import torchaudio
from tqdm import tqdm

# Complex modules
from cplxmodule.nn import CplxConv2d, CplxConvTranspose2d, CplxBatchNorm2d, CplxConv1d, CplxBatchNorm1d
from utils import *
from complexPyTorch.complexFunctions import complex_relu

# Lightning callbacks
from pytorch_lightning.callbacks import ModelCheckpoint, Callback

warnings.filterwarnings("ignore")
logging.getLogger('root').setLevel(logging.WARNING)

logging.basicConfig(level=logging.INFO)
epsilon = torch.finfo(torch.float32).eps

class CustomTQDMProgressBar(Callback): #custom pytorch lightening progress bar to display the real time progress to display the real time metrics such as SISDR, STFT, total loss, and LR
    """
    Displays a custom TQDM progress bar showing SISDR, STFT, total loss, and LR each batch.
    """
    def __init__(self): #calls constructor of the parent callback class to ensure proper initialization
        super().__init__()
        self.progress_bar = None #initializes the progress_bar attribute to none, this later hold the TQDM progress bar instance

    def on_train_epoch_start(self, trainer, pl_module): #initializes the tqdm progress bar at the start of the each training epoch, trainer object controlling training process, lightenig module being trained
        total_batches = trainer.num_training_batches #retrieves total number of training batches in the current epoch from the trainer
        # Create the progress bar for this epoch
        self.progress_bar = tqdm(
            total=total_batches, #total number of iterations of the progress bar equal to the total batches
            desc=f"Epoch {trainer.current_epoch + 1}/{trainer.max_epochs}", #descript prefix containing current epoch and total number of epochs
            leave=True  # keep the progress bar after the epoch is completed 
        )

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if self.progress_bar:
            # Grab the most recent losses from the module lists (or from `outputs` if you prefer)
            # Make sure pl_module.train_sisdr_epoch etc. actually have entries
            sisdr_val = pl_module.train_sisdr_epoch[-1].item() if pl_module.train_sisdr_epoch else 0.0
            stft_val  = pl_module.train_stft_epoch[-1].item()  if pl_module.train_stft_epoch  else 0.0
            total_val = pl_module.train_total_epoch[-1].item() if pl_module.train_total_epoch else 0.0

            # Current LR from the first optimizer
            current_lr = trainer.optimizers[0].param_groups[0]['lr']

            # Update TQDM postfix
            self.progress_bar.set_postfix({
                'SISDR': f"{sisdr_val:.4f}",
                'STFT': f"{stft_val:.4f}",
                'Total': f"{total_val:.4f}",
                'LR': f"{current_lr:.2e}"
            })
            self.progress_bar.update(1)

    def on_train_epoch_end(self, trainer, pl_module):
        if self.progress_bar:
            self.progress_bar.close()
            self.progress_bar = None

# -------------- Utility function to save wav -------------- #
def save_wav(audio_data, sr, filepath):
    try:
        sf.write(filepath, audio_data, sr)
        logging.info(f"Saved WAV file: {filepath}")
    except Exception as e:
        logging.error(f"Failed to save WAV file {filepath}: {e}", exc_info=True)

# -------------- Dataset for 4s Clipped WAVs -------------- #
class FourSecDataset(Dataset): #pytorch's dataset class inheritance, make it compatible to pytorch's data loading mechanisms such as data loader
    """
    Loads the already preprocessed 4-second clips from disk.
    Expects to find filenames like:
        acc_<something>.wav
        clean_<something>.wav
    in the same directory.

    This does NOT do filtering/resampling/padding, because
    it expects data is already prepared exactly as 4s, 16kHz.
    """
    def __init__(self, root_4s_dir, mode='acc', sample_rate=48000): #root directory where the audios are stored, find acc and pair with clean, sampling rate = 48000 
        super().__init__() #base class inheritance
        # 'mode' can be 'acc' or 'clean', but typically we find pairs.
        self.root_4s_dir = root_4s_dir #directory path and sample rate store 
        self.sample_rate = sample_rate #expected sample rate for verification

        # For example, gather only 'acc_' files, but we'll pair them with 'clean_' on the fly
        self.acc_files = sorted([
            f for f in os.listdir(root_4s_dir) if f.startswith('acc_') and f.lower().endswith('.wav') #list all the files start with acc and ends with .wav, sorting ensure consistent ordering
        ])
        if len(self.acc_files) == 0: #if do not find any acc files in the directory then raise error
            raise RuntimeError(f"No 'acc_' wav files found in {root_4s_dir}!")

    def __len__(self): #return the length or the number of acc files in the directory, which is essential for batching and iteration
        return len(self.acc_files)

    def __getitem__(self, idx):
        acc_filename = self.acc_files[idx]
        # e.g. "acc_speaker_001.wav" -> "clean_speaker_001.wav"
        clean_filename = acc_filename.replace('acc_', 'clean_')

        acc_path = os.path.join(self.root_4s_dir, acc_filename)
        clean_path = os.path.join(self.root_4s_dir, clean_filename)

        # Load them
        acc_wav, sr = torchaudio.load(acc_path)
        clean_wav, sr2 = torchaudio.load(clean_path)

        # Optional: verify sample rates if needed
        if sr != sr2 or sr != self.sample_rate:
            raise RuntimeError(f"Inconsistent sample rates. {acc_path}:{sr}, {clean_path}:{sr2} "
                               f"(expected {self.sample_rate})")

        # Both should be shape [1, T]
        return {
            'acc': acc_wav,
            'clean': clean_wav,
            'filename': acc_filename  # e.g. we store the ACC file name
        }
class CplxLinear(nn.Module): #custom nn to handle complex-valued inputs and perform linear transformations
    def __init__(self, in_features, out_features, bias=True): #number of input and output features and whether to include a bias term in the linear transformation
        super(CplxLinear, self).__init__()
        self.real_linear = nn.Linear(in_features, out_features, bias=bias) #real valued linear layer which will process the real parts of the complex input
        self.imag_linear = nn.Linear(in_features, out_features, bias=bias) #imaginary valued linear layer which will process the imaginary parts of the complex input
    
    def forward(self, input): #w⋅z=(a+ib)(x+iy)=(ax−by)+i(ay+bx)
        # Assuming input is a ComplexTensor with .real and .imag attributes
        real = self.real_linear(input.real) - self.imag_linear(input.imag) #real = Weight_real * input.real - Weight_imag * input.imag + bias_real
        imag = self.real_linear(input.imag) + self.imag_linear(input.real) #imaginary = Weight_real * input.imag + Weight_imaginary * input.real + bias_imaginary
        return ComplexTensor(real, imag)
    
    
class ComplexEncoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=False,
                 DSC=False):
        super(ComplexEncoder, self).__init__()
        # DSC: depthwise_separable_conv
        if DSC:
            self.conv = DSC_Encoder(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                    bias=bias)
        else:
            self.conv = CplxConv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                   bias=bias)
        self.norm = CplxBatchNorm2d(out_channels)

    def forward(self, x):
        return complex_relu(self.norm(self.conv(x)))


class ComplexDecoder(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1),
                 output_padding=(0, 0), bias=False, DSC=False):
        super(ComplexDecoder, self).__init__()
        # DSC: depthwise_separable_conv
        if DSC:
            self.conv = DSC_Decoder(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                    output_padding=output_padding, bias=bias)
        else:
            self.conv = CplxConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                            padding=padding, output_padding=output_padding, bias=bias)

        self.norm = CplxBatchNorm2d(out_channels)

    def forward(self, x):
        return complex_relu(self.norm(self.conv(x)))


class DSC_Encoder(nn.Module):
    # depthwise_separable_conv
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=(2, 1), padding=(1, 1), bias=False):
        super(DSC_Encoder, self).__init__()
        self.depthwise = CplxConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                                    groups=in_channels, bias=bias)
        self.pointwise = CplxConv2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


class DSC_Decoder(nn.Module):
    # depthwise_separable_conv
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=(2, 1), padding=(1, 1), output_padding=(0, 0),
                 bias=False):
        super(DSC_Decoder, self).__init__()
        self.depthwise = CplxConvTranspose2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,
                                             padding=padding, groups=in_channels, output_padding=output_padding,
                                             bias=bias)
        self.pointwise = CplxConvTranspose2d(in_channels, out_channels, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out

# ===================== ComplexGRU =====================
class ComplexGRU(nn.Module): #inherits from nn.module, which is foundational class for all NN modules in PyTorch
    def __init__(self, input_size, output_size, num_layers): #ComplexGRU integrate with pytorch components,  #number of expected features in input, output 
        super(ComplexGRU, self).__init__()
        self.rGRU = nn.Sequential(nn.GRU(input_size=input_size, hidden_size=input_size // 2, num_layers=num_layers, batch_first=True, bidirectional=True), SelectItem(0)), #input and output is provided in the following batch, sequence, feature 
        
        self.iGRU = nn.Sequential(
            nn.GRU(input_size=input_size, hidden_size=input_size // 2, num_layers=num_layers, batch_first=True,
                   bidirectional=True), 
            SelectItem(0) #hidden state will be discarded, only output will selected
        )
        self.linear = CplxLinear(input_size, output_size) #complex linear transformation to convert GRU outputs from input_size to output_size

    def forward(self, x):
        # x shape: [B, channel, T, F], but we typically pass [B, channel, T], or something similar
        x = x.transpose(-1, -2).contiguous()  # [B, channel, T, F] --> [B, channel, F, T | Contiguous ensures that tensor is stored in a contiguous chunk of memory for efficient computation
        real = self.rGRU(x.real) - self.iGRU(x.imag) 
        imag = self.rGRU(x.imag) + self.iGRU(x.real)
        out = self.linear(ComplexTensor(real, imag)).transpose(-1, -2)
        return out

# Step: 4.2 --------------------  Complex Transformer layers --------------------------


class Transformer_single(nn.Module): #encaptulates a single transformer encoder layer and applies it to the input tensor after appropriate reshaping and permutation
    def __init__(self, nhead=8):
        super(Transformer_single, self).__init__()
        self.nhead = nhead

    def forward(self, x):
        # x = torch.randn(10, 2, 80, 256) [batch, Ch, F, T]
        b, c, F, T = x.shape
        STB = TransformerEncoderLayer(d_model=F, nhead=self.nhead)  # d_model = Expected feature
        STB.to("cuda")
        x = x.permute(1, 0, 3, 2).contiguous().view(-1, b * T, F)  # [c, b*T, F]
        x = x.to("cuda")
        x = STB(x)
        x = x.view(b, c, F, T)  # [b, c, F, T]
        return x


class Transformer_multi(nn.Module):
    # d_model = x.shape[3]
    def __init__(self, nhead, layer_num=2):
        super(Transformer_multi, self).__init__()
        self.layer_num = layer_num
        self.MTB = Transformer_single(nhead=nhead)  # d_model: the number of expected features in the input

    def forward(self, x):
        for i in range(self.layer_num):
            x = self.MTB(x)
        return x


class ComplexTransformer(nn.Module):
    def __init__(self, nhead, num_layer):
        super(ComplexTransformer, self).__init__()
        self.rTrans = Transformer_multi(nhead=nhead, layer_num=num_layer)  # d_model = x.shape[3]
        self.iTrans = Transformer_multi(nhead=nhead, layer_num=num_layer)  # d_model = x.shape[3]
        # self.Trans = Transformer_multi(nhead=17, layer_num=num_layer)

    def forward(self, x):
        # real = self.Trans(x.real)
        # imag = self.Trans(x.imag)
        real = self.rTrans(x.real) - self.iTrans(x.imag)
        imag = self.rTrans(x.imag) + self.iTrans(x.real)
        out = ComplexTensor(real, imag)  # .contiguous()
        return out


class TransformerEncoderLayer(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.
    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).
    Examples::
        # >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        # >>> src = torch.rand(10, 32, 512)
        # >>> out = encoder_layer(src)
    """
    #This is a single encoder layer of transformer, incorporating multi-head self attention and feedforward network
    #uses GRU rather than the Feedforward network used traditionally 
    def __init__(self, d_model, nhead, bidirectional=True, dropout=0, activation="relu"):
        super(TransformerEncoderLayer, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout).to("cuda") #focus on different part of the input sequence simultaneously
        # Implementation of Feedforward model
        # self.linear1 = Linear(d_model, dim_feedforward)
        self.gru = GRU(d_model, d_model * 2, 1, bidirectional=bidirectional) #replaces the FF network with GRU to capture the sequential dependencies more effectively
        self.dropout = Dropout(dropout) #prevent overfitting
        # self.linear2 = Linear(dim_feedforward, d_model)
        if bidirectional:
            self.linear2 = Linear(d_model * 2 * 2, d_model) #project GRU's output back to desired dimentionality 
        else:
            self.linear2 = Linear(d_model * 2, d_model)

        self.norm1 = LayerNorm(d_model) #normalizes the inputs to stabilize and accelerate training
        self.norm2 = LayerNorm(d_model) #prevent overfitting by randomly zeroing some of the elements of the input tensors with probability dropout
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)

        self.activation = _get_activation_fn(activation) #ReLU or GeLU activation function, applies non linearity 

    def __setstate__(self, state):
        if 'activation' not in state:
            state['activation'] = F.relu
        super(TransformerEncoderLayer, self).__setstate__(state)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
        r"""Pass the input through the encoder layer.
        Args:
            src: the sequnce to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """

        # src = src.to("cuda")
        # print("Tensor src evice:", src.device)
        src2 = self.self_attn(src, src, src, attn_mask=src_mask, #applies multi head self attention to the input src
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2) #add results to the original src (residual connection)
        src = self.norm1(src) #normalizes
        # src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        self.gru.flatten_parameters() #passes the normalized tensor through the GRU
        out, h_n = self.gru(src) #deletes the hidden state
        del h_n
        src2 = self.linear2(self.dropout(self.activation(out))) #applies activation, dropout, and linear transformation
        src = src + self.dropout2(src2) #residual connection addes results to the original src
        src = self.norm2(src) #normalizes the src
        return src


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu

    raise RuntimeError("activation should be relu/gelu, not {}".format(activation))


# ===================== FTB Blocks =====================
class NodeReshape(nn.Module):
    def __init__(self, shape):
        super(NodeReshape, self).__init__()
        self.shape = shape

    def forward(self, feature_in: torch.Tensor):
        shape = feature_in.size()
        batch = shape[0]
        new_shape = [batch]
        new_shape.extend(list(self.shape))
        return feature_in.reshape(new_shape)


class Freq_FC(nn.Module):
    def __init__(self, F_dim, bias=False):
        super(Freq_FC, self).__init__()
        self.linear = CplxLinear(F_dim, F_dim, bias=bias)

    def forward(self, x):
        # x shape [batch, channel, T, F], we'll transpose to [batch, channel, F, T]
        out = x.transpose(-1, -2).contiguous()  # [B, C, F, T]
        out = self.linear(out)
        out = torch.complex(out.real, out.imag)
        out = out.transpose(-1, -2).contiguous()  # [B, C, T, F]
        return out


class ComplexFTB(torch.nn.Module):
    """docstring for FTB"""

    def __init__(self, F_dim, channels):
        super(ComplexFTB, self).__init__()
        self.channels = channels
        self.C_r = 5
        self.F_dim = F_dim

        self.Conv2D_1 = nn.Sequential(
            CplxConv2d(in_channels=self.channels, out_channels=self.C_r, kernel_size=1, stride=1, padding=0),
            CplxBatchNorm2d(self.C_r),

        )
        self.Conv1D_1 = nn.Sequential(
            CplxConv1d(self.F_dim * self.C_r, self.F_dim, kernel_size=9, padding=4),
            CplxBatchNorm1d(self.F_dim),
        )
        self.FC = Freq_FC(self.F_dim, bias=False)
        self.Conv2D_2 = nn.Sequential(
            CplxConv2d(2 * self.channels, self.channels, kernel_size=1, stride=1, padding=0),
            CplxBatchNorm2d(self.channels),
        )

        self.att_inner_reshape = NodeReshape([self.F_dim * self.C_r, -1])
        self.att_out_reshape = NodeReshape([1, F_dim, -1])

    def cat(self, x, y, dim):
        real = torch.cat([x.real, y.real], dim)
        imag = torch.cat([x.imag, y.imag], dim)
        return ComplexTensor(real, imag)

    def forward(self, inputs, verbose=False):
        # feature_n: [batch, channel_in_out, T, F]

        _, _, self.F_dim, self.T_dim = inputs.shape
        # Conv2D
        out = complex_relu(self.Conv2D_1(inputs));
        if verbose: print('Layer-1               : ', out.shape)  # [B,Cr,T,F]
        # Reshape: [batch, channel_attention, F, T] -> [batch, channel_attention*F, T]
        out = out.view(out.shape[0], out.shape[1] * out.shape[2], out.shape[3])
        # out = self.att_inner_reshape(out);
        if verbose: print('Layer-2               : ', out.shape)
        # out = out.view(-1, self.T_dim, self.F_dim * self.C_r) ; print(out.shape) # [B,c_ftb_r*f,segment_length]
        # Conv1D
        out = complex_relu(self.Conv1D_1(out));
        if verbose: print('Layer-3               : ', out.shape)  # [B,F, T]
        # temp = self.att_inner_reshape(temp); print(temp.shape)
        out = out.unsqueeze(1)
        # out = out.view(-1, self.channels, self.F_dim, self.T_dim);
        if verbose: print('Layer-4               : ', out.shape)  # [B,c_a,segment_length,1]
        # Multiplication with input
        out = out * inputs;
        if verbose: print('Layer-5               : ', out.shape)  # [B,c_a,segment_length,1]*[B,c_a,segment_length,f]
        # Frequency- FC
        # out = torch.transpose(out, 2, 3)  # [batch, channel_in_out, T, F]
        out = self.FC(out);
        # if verbose: print('Layer-6               : ', out.shape)  # [B,c_a,segment_length,f]
        # out = torch.transpose(out, 2, 3)  # [batch, channel_in_out, T, F]
        # Concatenation with Input
        out = self.cat(out, inputs, 1);
        if verbose: print('Layer-7               : ', out.shape)  # [B,2*c_a,segment_length,f]
        # Conv2D
        outputs = complex_relu(self.Conv2D_2(out));
        if verbose: print('Layer-8               : ', outputs.shape)  # [B,c_a,segment_length,f]

        return outputs

# -------------------------------- Depth wise Seperable Convolution --------------------------------
class depthwise_separable_convx(nn.Module):
    def __init__(self, nin, nout, kernel_size=3, padding=1, bias=False):
        super(depthwise_separable_convx, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size=kernel_size, padding=padding, groups=nin, bias=bias)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1, bias=bias)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out


# ===================== Skip Blocks / Connections =====================
class SkipBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, DSC=False):
        super(SkipBlock, self).__init__()
        if DSC:
            self.conv = DSC_Encoder(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                    padding=padding, bias=True)
        else:
            self.conv = CplxConv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,
                                   padding=padding, bias=True)
        self.norm = CplxBatchNorm2d(in_channels)

    def forward(self, x):
        return complex_relu(self.norm(self.conv(x))) + x


class SkipConnection(nn.Module):
    """
    SkipConnection is a concatenation of SkipBlocks
    """
    def __init__(self, in_channels, num_convblocks, DSC=False):
        super(SkipConnection, self).__init__()
        self.skip_blocks = nn.ModuleList([
            SkipBlock(in_channels, in_channels, kernel_size=3, stride=1, padding=1, DSC=DSC)
            for _ in range(num_convblocks)
        ])

    def forward(self, x):
        out = x
        for block in self.skip_blocks:
            out = block(out)
        return out

# ===================== CFTNet Model Definition =====================
class DCCTN(torch.nn.Module):
    """
    FTBComplexSkipConvNet + Transformer
    Complex Skip convolution
    It uses only two FTB layers; one in the first layer and one in the last layer
    Instead of using LSTM, it uses transformer
    """

    def __init__(self, L=256, N=256, H=128, Mask=[5, 7], B=24, F_dim=129):
        super().__init__()
        self.name = 'DCCTN'  # 'FTBTxComplexSkipConvNet2'
        self.f_taps = list(range(-Mask[0] // 2 + 1, Mask[0] // 2 + 1))
        self.t_taps = list(range(-Mask[1] // 2 + 1, Mask[1] // 2 + 1))

        self.stft = STFT(frame_len=L, frame_hop=H, num_fft=N)
        self.istft = iSTFT(frame_len=L, frame_hop=H, num_fft=N)

        self.enc1 = ComplexEncoder(1, 1 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB1 = ComplexFTB(math.ceil(F_dim / 2), channels=1 * B)  # First FTB layer
        self.enc2 = ComplexEncoder(1 * B, 2 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB2 = ComplexFTB(math.ceil(F_dim / 4), channels=2 * B)  # First FTB layer
        self.enc3 = ComplexEncoder(2 * B, 2 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB3 = ComplexFTB(math.ceil(F_dim / 8), channels=2 * B)  # First FTB layer
        self.enc4 = ComplexEncoder(2 * B, 3 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB4 = ComplexFTB(math.ceil(F_dim / 16), channels=3 * B)  # First FTB layer
        self.enc5 = ComplexEncoder(3 * B, 3 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB5 = ComplexFTB(math.ceil(F_dim / 32), channels=3 * B)  # First FTB layer
        self.enc6 = ComplexEncoder(3 * B, 4 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB6 = ComplexFTB(math.ceil(F_dim / 64), channels=4 * B)  # First FTB layer
        self.enc7 = ComplexEncoder(4 * B, 4 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.FTB7 = ComplexFTB(math.ceil(F_dim / 128), channels=4 * B)  # First FTB layer
        self.enc8 = ComplexEncoder(4 * B, 8 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.TB = ComplexTransformer(nhead=1, num_layer=2)  # d_model = x.shape[3]
        self.GRU = ComplexGRU(8 * B, 8 * B, num_layers=2)

        self.skip1 = SkipConnection(8 * B, num_convblocks=4)
        self.skip2 = SkipConnection(4 * B, num_convblocks=4)
        self.skip3 = SkipConnection(4 * B, num_convblocks=3)
        self.skip4 = SkipConnection(3 * B, num_convblocks=3)
        self.skip5 = SkipConnection(3 * B, num_convblocks=2)
        self.skip6 = SkipConnection(2 * B, num_convblocks=2)
        self.skip7 = SkipConnection(2 * B, num_convblocks=1)
        self.skip8 = SkipConnection(1 * B, num_convblocks=1)

        self.dec1 = ComplexDecoder(16 * B, 8 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True,
                                   output_padding=(1, 0))
        self.dec2 = ComplexDecoder(12 * B, 8 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec3 = ComplexDecoder(12 * B, 4 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec4 = ComplexDecoder(7 * B, 3 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec5 = ComplexDecoder(6 * B, 3 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec6 = ComplexDecoder(5 * B, 2 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec7 = ComplexDecoder(4 * B, 2 * B, kernel_size=(3, 3), stride=(2, 1), padding=(1, 1), bias=True)
        self.dec8 = ComplexDecoder(3 * B, Mask[0] * Mask[1], kernel_size=(3, 3), stride=(2, 1), padding=(1, 1),
                                   bias=True)

    def cat(self, x, y, dim):
        real = torch.cat([x.real, y.real], dim)
        imag = torch.cat([x.imag, y.imag], dim)
        return ComplexTensor(real, imag)

    def deepfiltering(self, deepfilter, cplxInput):
        deepfilter = deepfilter.permute(0, 2, 3, 1)
        real_tf_shift = torch.stack(
            [torch.roll(cplxInput.real, (i, j), dims=(1, 2)) for i in self.f_taps for j in self.t_taps], 3).transpose(
            -1, -2)
        imag_tf_shift = torch.stack(
            [torch.roll(cplxInput.imag, (i, j), dims=(1, 2)) for i in self.f_taps for j in self.t_taps], 3).transpose(
            -1, -2)
        imag_tf_shift += 1e-10
        cplxInput_shift = ComplexTensor(real_tf_shift, imag_tf_shift)
        est_complex = einsum('bftd,bfdt->bft', [deepfilter.conj(), cplxInput_shift])
        return est_complex

    def forward(self, audio, verbose=False):
        """
        batch: tensor of shape (batch_size x channels x num_samples)
        """
        if verbose: print('*' * 60)
        if verbose: print('Input Audio Shape         : ', audio.shape)
        if verbose: print('*' * 60)

        _, _, real, imag = self.stft(audio)
        cplxIn = ComplexTensor(real, imag)
        if verbose: print('STFT Complex Spec         : ', cplxIn.shape)

        if verbose: print('\n' + '-' * 20)
        if verbose: print('Encoder Network')
        if verbose: print('-' * 20)

        enc1 = self.enc1(cplxIn.unsqueeze(1))
        if verbose: print('Encoder-1                 : ', enc1.shape)
        FTB1 = self.FTB1(enc1)
        if verbose: print('FTB-1               : ', FTB1.shape)
        enc2 = self.enc2(FTB1)
        if verbose: print('Encoder-2                 : ', enc2.shape)
        enc3 = self.enc3(enc2)
        if verbose: print('Encoder-3                 : ', enc3.shape)
        enc4 = self.enc4(enc3)
        if verbose: print('Encoder-4                 : ', enc4.shape)
        enc5 = self.enc5(enc4)
        if verbose: print('Encoder-5                 : ', enc5.shape)
        enc6 = self.enc6(enc5)
        if verbose: print('Encoder-6                 : ', enc6.shape)
        enc7 = self.enc7(enc6)
        if verbose: print('Encoder-7                 : ', enc7.shape)
        FTB7 = self.FTB7(enc7)
        if verbose: print('FTB-7               : ', FTB7.shape)
        enc8 = self.enc8(FTB7)
        if verbose: print('Encoder-8                 : ', enc8.shape)

        # +++++++++++++++++++ Expanding Path  +++++++++++++++++++++ #

        MLTB = self.TB(enc8)
        if verbose: print('Transformer-1               : ', MLTB.shape)
        if verbose: print('\n' + '-' * 20)
        if verbose: print('Decoder Network')
        if verbose: print('-' * 20)
        dec = self.dec1(self.cat(MLTB, self.skip1(enc8), 1))
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

        if verbose: print('Decoder-1                 : ', dec.shape)
        dec = self.dec2(self.cat(dec, self.skip2(enc7), 1))
        if verbose: print('Decoder-2                 : ', dec.shape)
        dec = self.dec3(self.cat(dec, self.skip3(enc6), 1))
        if verbose: print('Decoder-3                 : ', dec.shape)
        dec = self.dec4(self.cat(dec, self.skip4(enc5), 1))
        if verbose: print('Decoder-4                 : ', dec.shape)
        dec = self.dec5(self.cat(dec, self.skip5(enc4), 1))
        if verbose: print('Decoder-5                 : ', dec.shape)
        dec = self.dec6(self.cat(dec, self.skip6(enc3), 1))
        if verbose: print('Decoder-6                 : ', dec.shape)
        dec = self.dec7(self.cat(dec, self.skip7(enc2), 1))
        if verbose: print('Decoder-7                 : ', dec.shape)
        dec = self.dec8(self.cat(dec, self.skip8(enc1), 1))
        if verbose: print('Decoder-8                 : ', dec.shape)

        deepfilter = ComplexTensor(dec.real, dec.imag)
        enhanced = self.deepfiltering(deepfilter, cplxIn)
        enh_mag, enh_phase = enhanced.abs(), enhanced.angle()
        audio_enh = self.istft(enh_mag, enh_phase, squeeze=True)
        if verbose: print('*' * 60)
        if verbose: print('Output Audio Shape        : ', audio_enh.shape)
        if verbose: print('*' * 60)

        return audio_enh
    
class SelectItem(nn.Module):
    """Select item [0] from the tuple returned by e.g. GRU output"""
    def __init__(self, idx=0):
        super().__init__()
        self.idx = idx

    def forward(self, x):
        return x[self.idx]

# -------------- Lightning Module -------------- #
class DeepLearningModel(pl.LightningModule):
    def __init__(self, net, batch_size=1, save_sample_outputs=False, save_dir='enhanced_samples'):
        super().__init__()
        self.model = net
        self.batch_size = batch_size
        self.save_sample_outputs = save_sample_outputs
        self.save_dir = save_dir

        # Loss functions
        self.si_sdr_loss = SISDRLoss()
        self.stft_loss_func = MultiResolutionSTFTLoss(
            fft_sizes=[256, 512, 1024],
            hop_sizes=[128, 256, 512],
            win_lengths=[256, 512, 1024],
            scale=None,
            n_bins=128,
            sample_rate=48000,
            perceptual_weighting=False,
        )
        os.makedirs(self.save_dir, exist_ok=True)

        # For storing epoch-level losses
        self.train_sisdr_log = []
        self.train_stft_log = []
        self.train_total_log = []
        self.val_sisdr_log = []
        self.val_stft_log = []
        self.val_total_log = []
        self.train_sisdr_epoch = []
        self.train_stft_epoch = []
        self.train_total_epoch = []
        self.val_sisdr_epoch = []
        self.val_stft_epoch = []
        self.val_total_epoch = []

        self.log_file = os.path.join(self.save_dir, "loss_logs.txt")

    def forward(self, x):
        return self.model(x)

    def compute_losses(self, pred, target):
        """
        Compute SISDR and Multi-Resolution STFT losses separately + total.
        Expect shape [B, T] or [B, 1, T].
        Make sure we always convert to [B, 1, T] for Auraloss.
        """
        # If shape is [T], add a batch dimension => [1, T]
        if pred.dim() == 1:
            pred = pred.unsqueeze(0)
        if target.dim() == 1:
            target = target.unsqueeze(0)

        # If shape is [B, T], add a channel dimension => [B, 1, T]
        if pred.dim() == 2:
            pred = pred.unsqueeze(1)
        if target.dim() == 2:
         target = target.unsqueeze(1)
        
        target = torch.clamp(target, min=1e-8, max=1.0)
        pred   = torch.clamp(pred,   min=-1.0, max=1.0)

        # Now pred & target are [B, 1, T]. 
        sisdr = self.si_sdr_loss(target, pred)
        freq_loss = self.stft_loss_func(pred, target) * 5.0
        total = sisdr + freq_loss
        return sisdr, freq_loss, total

    def training_step(self, batch, batch_idx):
        acc_data = batch['acc']
        clean_data = batch['clean']

        enh_audio = self(acc_data)
        sisdr_loss, stft_loss, total_loss = self.compute_losses(enh_audio, clean_data)

        self.train_sisdr_epoch.append(sisdr_loss)
        self.train_stft_epoch.append(stft_loss)
        self.train_total_epoch.append(total_loss)

        self.log('train_sisdr_step', sisdr_loss, on_step=True)
        self.log('train_stft_step', stft_loss, on_step=True)
        self.log('train_total_step', total_loss, on_step=True)

        # Optionally save some examples
        if self.save_sample_outputs and (batch_idx < 5):
            filenames = batch['filename']
            for i in range(acc_data.size(0)):
                enhanced = enh_audio[i].detach().cpu().numpy().squeeze()
                base_name = os.path.splitext(filenames[i])[0]
                epoch_num = self.current_epoch
                out_filename = f"{base_name}_epoch_{epoch_num}.wav"
                out_filepath = os.path.join(self.save_dir, out_filename)
                save_wav(enhanced, 48000, out_filepath)

        return total_loss

    def on_train_epoch_end(self):
        sisdr_mean = torch.stack(self.train_sisdr_epoch).mean()
        stft_mean = torch.stack(self.train_stft_epoch).mean()
        total_mean = torch.stack(self.train_total_epoch).mean()

        self.train_sisdr_log.append(sisdr_mean.item())
        self.train_stft_log.append(stft_mean.item())
        self.train_total_log.append(total_mean.item())

        self.log('train_sisdr_epoch', sisdr_mean)
        self.log('train_stft_epoch', stft_mean)
        self.log('train_total_epoch', total_mean)

        current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
        print(f"Epoch {self.current_epoch} - LR: {current_lr:.6f} | "
              f"Train SISDR: {sisdr_mean:.4f}, Train STFT: {stft_mean:.4f}, Train Total: {total_mean:.4f}")

        with open(self.log_file, "a") as f:
            f.write(f"[Train] Epoch {self.current_epoch} | LR: {current_lr:.6f} | "
                    f"SISDR: {sisdr_mean:.4f} | STFT: {stft_mean:.4f} | Total: {total_mean:.4f}\n")

        self.train_sisdr_epoch.clear()
        self.train_stft_epoch.clear()
        self.train_total_epoch.clear()

    def validation_step(self, batch, batch_idx):
        acc_data = batch['acc']
        clean_data = batch['clean']

        enh_audio = self(acc_data)
        sisdr_loss, stft_loss, total_loss = self.compute_losses(enh_audio, clean_data)

        self.val_sisdr_epoch.append(sisdr_loss)
        self.val_stft_epoch.append(stft_loss)
        self.val_total_epoch.append(total_loss)

        if self.save_sample_outputs and (batch_idx < 5):
            filenames = batch['filename']
            for i in range(acc_data.size(0)):
                enhanced = enh_audio[i].detach().cpu().numpy().squeeze()
                base_name = os.path.splitext(filenames[i])[0]
                epoch_num = self.current_epoch
                out_filename = f"{base_name}_epoch_{epoch_num}.wav"
                out_filepath = os.path.join(self.save_dir, out_filename)
                save_wav(enhanced, 48000, out_filepath)

        return total_loss

    def on_validation_epoch_end(self):
        sisdr_mean = torch.stack(self.val_sisdr_epoch).mean()
        stft_mean = torch.stack(self.val_stft_epoch).mean()
        total_mean = torch.stack(self.val_total_epoch).mean()

        self.val_sisdr_log.append(sisdr_mean.item())
        self.val_stft_log.append(stft_mean.item())
        self.val_total_log.append(total_mean.item())

        self.log('val_sisdr_epoch', sisdr_mean)
        self.log('val_stft_epoch', stft_mean)
        self.log('val_total_epoch', total_mean)

        print(f"Epoch {self.current_epoch} - "
              f"Val SISDR: {sisdr_mean:.4f}, Val STFT: {stft_mean:.4f}, Val Total: {total_mean:.4f}")

        with open(self.log_file, "a") as f:
            f.write(f"[Val]   Epoch {self.current_epoch} | "
                    f"SISDR: {sisdr_mean:.4f} | STFT: {stft_mean:.4f} | Total: {total_mean:.4f}\n")

        self.val_sisdr_epoch.clear()
        self.val_stft_epoch.clear()
        self.val_total_epoch.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=1e-4,
            weight_decay=1e-5,
            betas=(0.5, 0.999)
        )
        scheduler = {
            'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=10,
                T_mult=1,
                eta_min=0.0,
                last_epoch=-1
            ),
            'interval': 'epoch',
            'frequency': 1
        }
        return [optimizer], [scheduler]

class TQDMEpochProgressBar(Callback):
    def on_train_epoch_start(self, trainer, pl_module):
        trainer.epoch_progress = tqdm(
            total=trainer.num_training_batches,
            desc=f"Epoch {trainer.current_epoch + 1}/{trainer.max_epochs}",
            unit='batch',
            leave=False
        )

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        if hasattr(trainer, 'epoch_progress'):
            trainer.epoch_progress.update(1)

    def on_train_epoch_end(self, trainer, pl_module):
        if hasattr(trainer, 'epoch_progress'):
            trainer.epoch_progress.close()
            del trainer.epoch_progress

class ResourceCleanupCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        gc.collect()
        torch.cuda.empty_cache()

# --------------------- MAIN: training usage --------------------- #
if __name__ == "__main__":
    import sys

    model_name = 'DCCTN'
    batch_size = 2
    epochs = 100
    gpu_ids = [0]
    loss_function = 'SISDR+MultiResFreqLoss'
    
    # Point these to the directories created by data_preprocessing.py
    train_4s_dir = r"C:\Users\Anomadarshi\Desktop\VCTK_down_upsample_2_48\Train_4s\4s_clips"
    dev_4s_dir   = r"C:\Users\Anomadarshi\Desktop\VCTK_down_upsample_2_48\Dev_4s\4s_clips"

    # Check existence
    if not os.path.isdir(train_4s_dir):
        raise RuntimeError(f"Train directory {train_4s_dir} does not exist!")
    if not os.path.isdir(dev_4s_dir):
        raise RuntimeError(f"Dev directory {dev_4s_dir} does not exist!")

    # 1) Create model
    net = DCCTN()
    model = DeepLearningModel(net, batch_size=batch_size,
                              save_sample_outputs=True, save_dir='enhanced_samples')

    # 2) Create callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_total_epoch',
        dirpath=os.path.join(os.getcwd(), 'Saved_Models', model_name),
        filename=model_name + '-ACCEAR-' + loss_function + '-{epoch:02d}-{val_loss:.2f}',
        save_top_k=1,
        mode='min'
    )
    tqdm_callback = TQDMEpochProgressBar()
    cleanup_callback = ResourceCleanupCallback()
    callbacks = [checkpoint_callback, tqdm_callback, cleanup_callback]

    # 3) Create data loaders from the 4s clipped WAVs
    train_dataset = FourSecDataset(train_4s_dir)
    dev_dataset   = FourSecDataset(dev_4s_dir)

    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                              shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
    dev_loader   = DataLoader(dev_dataset, batch_size=batch_size,
                              shuffle=False, num_workers=0, pin_memory=True)

    # 4) Trainer
    trainer_kwargs = {
        'max_epochs': epochs,
        'accelerator': 'gpu' if torch.cuda.is_available() else 'cpu',
        'devices': [0],
        'callbacks': callbacks,
        'gradient_clip_val': 10,
        'accumulate_grad_batches': 8,
        'log_every_n_steps': 10,
        'precision': 32,
        'enable_checkpointing': True,
        'enable_progress_bar': True,
        'num_sanity_val_steps': 0
    }
    if len(gpu_ids) > 1 and torch.cuda.is_available():
        trainer_kwargs['strategy'] = 'ddp'

    trainer = pl.Trainer(**trainer_kwargs)
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=dev_loader)

    # 5) Save "last epoch" model
    last_model_path = os.path.join(checkpoint_callback.dirpath, "model_last.ckpt")
    trainer.save_checkpoint(last_model_path)
    print("Training complete!")

    # 6) Plot or visualize losses if desired
    epochs_range = range(len(model.train_sisdr_log))
    plt.figure(figsize=(12, 8))

    plt.subplot(311)
    plt.plot(epochs_range, model.train_sisdr_log, 'b-o', label='Train SISDR')
    plt.plot(epochs_range, model.val_sisdr_log, 'r-o', label='Val SISDR')
    plt.title("SISDR Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(312)
    plt.plot(epochs_range, model.train_stft_log, 'b-o', label='Train STFT')
    plt.plot(epochs_range, model.val_stft_log, 'r-o', label='Val STFT')
    plt.title("MultiRes STFT Loss")
    plt.legend()
    plt.grid(True)

    plt.subplot(313)
    plt.plot(epochs_range, model.train_total_log, 'b-o', label='Train Total')
    plt.plot(epochs_range, model.val_total_log, 'r-o', label='Val Total')
    plt.title("Total Loss")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    loss_plot_path = os.path.join(checkpoint_callback.dirpath, "loss_components_plot.png")
    plt.savefig(loss_plot_path)
    print(f"Saved loss plot to {loss_plot_path}")

    # 7) (Optional) Enhance dev set with best/last model and save

    best_model_path = checkpoint_callback.best_model_path
    print(f"Best model checkpoint path: {best_model_path}")

    best_model = DeepLearningModel.load_from_checkpoint(
        best_model_path,
        net=DCCTN()
    )
    best_model.eval().cuda()

    last_model = DeepLearningModel.load_from_checkpoint(
        last_model_path,
        net=DCCTN()
    )
    last_model.eval().cuda()

    def enhance_and_save(pl_model, dataset, out_folder):
        os.makedirs(out_folder, exist_ok=True)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        pl_model.to(device)
        pl_model.eval()
        with torch.no_grad():
            for i in range(len(dataset)):
                sample = dataset[i]
                acc_data = sample['acc'].unsqueeze(0).to(device)
                enhanced = pl_model(acc_data)
                enhanced_np = enhanced.squeeze().cpu().numpy()
                base_name = os.path.splitext(sample['filename'])[0]
                save_path = os.path.join(out_folder, f"{base_name}.wav")
                save_wav(enhanced_np, 48000, save_path)

    dev_enh_best = os.path.join(checkpoint_callback.dirpath, "Enhanced_Best")
    dev_enh_last = os.path.join(checkpoint_callback.dirpath, "Enhanced_Last")

    print(f"Generating enhanced audio with best model into {dev_enh_best}...")
    enhance_and_save(best_model, dev_dataset, dev_enh_best)

    print(f"Generating enhanced audio with last model into {dev_enh_last}...")
    enhance_and_save(last_model, dev_dataset, dev_enh_last)

    print("Enhancement complete.")


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 4090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:numexpr.utils:Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type                    | Params | Mode 
-------------------------------------------------------------------
0 | model          | DCCTN                   | 10.0 M | train
1 | si_sdr_loss    | SISDRLoss               | 0      | train
2 | stft_loss_func | MultiResolutionSTFTLoss | 0      |

Training: |          | 0/? [00:00<?, ?it/s]

