In [1]:
import os
import librosa
import wandb
import numpy as np
import multiprocessing as mp

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
from torch.utils.data import random_split

# wandb.init(project='UlTraNet')

In [2]:
def load_audio(audio_path, sample_rate=22050, duration=5):
    # Load audio file with librosa, automatically resampling to the given sample rate
    audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration)
    
    # Calculate target number of samples
    target_length = sample_rate * duration
    
    # Pad audio if it is shorter than the target length
    if len(audio) < target_length:
        padding = target_length - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')
    # Truncate audio if it is longer than the target length
    elif len(audio) > target_length:
        audio = audio[:target_length]
    
    return audio

def get_spectrogram(audio, n_fft=2048, hop_length=512, max_length=130):
    # Generate a spectrogram
    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    # Convert to magnitude (amplitude)
    spectrogram = np.abs(spectrogram)
    
    # Pad or truncate the spectrogram to ensure all are the same length
    if spectrogram.shape[1] < max_length:
        padding = max_length - spectrogram.shape[1]
        spectrogram = np.pad(spectrogram, ((0, 0), (0, padding)), mode='constant')
    else:
        spectrogram = spectrogram[:, :max_length]
    
    return spectrogram

class AudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=22050, n_fft=2048, hop_length=512, max_length=130):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.max_length = max_length
        self.files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith('.mp3') or f.endswith('.wav')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        audio = load_audio(audio_path, self.sample_rate)
        spectrogram = get_spectrogram(audio, self.n_fft, self.hop_length, self.max_length)
        return audio, spectrogram

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)

    dataset = AudioDataset(root_dir='DATA')
    loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [3]:
print(dataset[0])
print((dataset[0][0]))
print(dataset[0][0].shape)
    
print(dataset[0][1])
print(dataset[0][1].shape)

(array([ 0.01970823, -0.00225531, -0.03690785, ...,  0.02060048,
       -0.00064142, -0.02519251], dtype=float32), array([[5.7437736e-01, 3.6425254e-01, 1.2862755e-01, ..., 7.2768354e-01,
        2.6252899e-01, 3.3587483e-01],
       [5.1895320e-01, 4.5031342e-01, 1.3857873e-01, ..., 5.0968748e-01,
        4.5548519e-01, 5.7975030e-01],
       [3.0562350e-01, 3.7505564e-01, 2.0325445e-01, ..., 3.5334751e-01,
        4.1140336e-01, 9.5875704e-01],
       ...,
       [1.5232026e-03, 7.0222467e-04, 9.3422665e-07, ..., 2.8230016e-07,
        1.7090463e-07, 1.6775441e-07],
       [1.5125329e-03, 6.9691899e-04, 4.7191645e-07, ..., 1.4758260e-07,
        2.3828001e-07, 3.1422653e-07],
       [1.5089464e-03, 6.9499249e-04, 5.7717961e-07, ..., 3.2488643e-07,
        2.2508639e-08, 2.8777606e-07]], dtype=float32))
[ 0.01970823 -0.00225531 -0.03690785 ...  0.02060048 -0.00064142
 -0.02519251]
(110250,)
[[5.7437736e-01 3.6425254e-01 1.2862755e-01 ... 7.2768354e-01
  2.6252899e-01 3.3587483e-01]
 [

In [4]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    test_size = total_size - train_size - val_size  # Ensure all data is used

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    return train_dataset, val_dataset, test_dataset

data_folder_path = 'DATA'
# dataset = AudioDataset(root_dir=data_folder_path)

# Assuming 'dataset' is an instance of AudioDataset
train_dataset, val_dataset, test_dataset = split_dataset(dataset)

# Create DataLoaders for each dataset split
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0)

In [5]:
# Check if the dataset is correctly set up
print("Number of samples in dataset:", len(train_dataset))

# Create a DataLoader instance (make sure parameters like batch_size are set correctly)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Try to fetch a single batch to see if it works
try:
    data = next(iter(train_loader))
    print("Single batch loaded successfully:", data)
except Exception as e:
    print("Failed to load a batch:", e)

Number of samples in dataset: 407
Single batch loaded successfully: [tensor([[-4.9719e-02, -1.0606e-01, -1.3332e-01,  ...,  3.5279e-01,
          1.7849e-01,  4.1908e-02],
        [ 5.0701e-02,  7.9188e-02,  7.8954e-02,  ...,  3.0429e-02,
          4.0938e-02,  2.9423e-02],
        [-4.6979e-02, -7.3127e-02, -6.4071e-02,  ..., -6.5827e-02,
         -1.3733e-01, -2.2772e-01],
        ...,
        [ 1.2095e-05,  5.3243e-06,  1.6990e-05,  ...,  2.0313e-02,
          1.1765e-02,  1.0368e-02],
        [ 1.3617e-02,  5.9560e-02,  8.5912e-02,  ...,  5.8365e-01,
          4.9166e-01,  4.2171e-01],
        [-6.9656e-02, -1.2492e-01, -1.1681e-01,  ...,  4.7795e-01,
          4.3123e-01,  4.2015e-01]]), tensor([[[4.1046e+00, 2.1018e+00, 5.9732e-02,  ..., 1.9285e-02,
          2.8772e-02, 1.5061e-01],
         [4.0843e+00, 2.0338e+00, 2.3267e-02,  ..., 3.0185e-02,
          5.0452e-02, 1.5334e-01],
         [4.0597e+00, 2.0199e+00, 6.7382e-03,  ..., 3.4802e-02,
          5.0014e-02, 1.7139e-01],
 

In [6]:
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 407
Validation set size: 87
Test set size: 88


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Load a pre-trained VGGish model for audio feature extraction
vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')

# Define the Perceptual Loss using VGGish as the feature extractor
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.feature_extractor.eval()  # Set to evaluation mode

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.feature_extractor(target_audio)
        generated_features = self.feature_extractor(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

perceptual_loss = PerceptualLoss(vggish)

Using cache found in C:\Users\rahat/.cache\torch\hub\harritaylor_torchvggish_master


In [8]:
class MultiScaleSpectrogramLoss(nn.Module):
    def __init__(self, scales=[1024, 2048, 4096]):
        super(MultiScaleSpectrogramLoss, self).__init__()
        self.scales = scales

    def forward(self, generated_audio, target_audio):
        loss = 0
        for scale in self.scales:
            gen_spec = torch.stft(generated_audio, n_fft=scale, return_complex=True)
            target_spec = torch.stft(target_audio, n_fft=scale, return_complex=True)
            loss += F.l1_loss(gen_spec.abs(), target_spec.abs())
        return loss / len(self.scales)

spectrogram_loss = MultiScaleSpectrogramLoss()

In [9]:
# For demonstration, let's assume we have a simple CNN as a discriminator
class SimpleAudioDiscriminator(nn.Module):
    def __init__(self):
        super(SimpleAudioDiscriminator, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)  # Changed to Conv1d
        self.fc1 = nn.Linear(16 * 16, 1)  # Adjust size according to actual output dimensions

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

    def intermediate_forward(self, x):
        x = F.relu(self.conv1(x))
        return x

discriminator = SimpleAudioDiscriminator()

class FeatureMatchingLoss(nn.Module):
    def __init__(self, discriminator):
        super(FeatureMatchingLoss, self).__init__()
        self.discriminator = discriminator
        self.discriminator.eval()

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.discriminator.intermediate_forward(target_audio)
        generated_features = self.discriminator.intermediate_forward(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

feature_matching_loss = FeatureMatchingLoss(discriminator)

In [10]:
# Example of a composite loss
class CompositeLoss(nn.Module):
    def __init__(self, perceptual_loss, spectrogram_loss, feature_matching_loss):
        super(CompositeLoss, self).__init__()
        self.perceptual_loss = perceptual_loss
        self.spectrogram_loss = spectrogram_loss
        self.feature_matching_loss = feature_matching_loss

    def forward(self, generated_audio, target_audio):
        loss = (self.perceptual_loss(generated_audio, target_audio) +
                self.spectrogram_loss(generated_audio, target_audio) +
                self.feature_matching_loss(generated_audio, target_audio))
        return loss

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UltimateTransformerWaveNet(nn.Module):
    def __init__(self, audio_channels=1, spectrogram_channels=1025, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
        super(UltimateTransformerWaveNet, self).__init__()
        self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
        self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)

        # Dilated convolutions with residual and skip connections for both streams
        self.audio_dilated_convs = nn.ModuleList()
        self.spectrogram_dilated_convs = nn.ModuleList()
        self.audio_skip_convs = nn.ModuleList()
        self.spectrogram_skip_convs = nn.ModuleList()
        for i in range(num_layers):
            dilation = 2 ** i
            self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.audio_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))
            self.spectrogram_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

        # Multi-head attention for combining features
        self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

        # Transformer block with residual connection
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
            num_layers=3)

        # Output layers
        self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
        self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

        # Additional residual connection across the network
        self.residual_conv = nn.Conv1d(num_channels, num_channels, 1)

    def forward(self, audio, spectrogram):
        # print("Original audio shape:", audio.shape)
        # print("Original spectrogram shape:", spectrogram.shape)
        
        audio_input = F.relu(self.audio_conv(audio))
        spectrogram_input = F.relu(self.spectrogram_conv(spectrogram))
        audio = audio_input
        spectrogram = spectrogram_input
        
        # print("Audio shape:", audio.shape)
        # print("Spectrogram shape:", spectrogram.shape)

        audio_skip = 0
        spectrogram_skip = 0

        # Process through dilated convolutions with residual and skip connections
        i = 0
        for audio_conv, audio_skip_conv, spectro_conv, spectro_skip_conv in zip(self.audio_dilated_convs, self.audio_skip_convs, self.spectrogram_dilated_convs, self.spectrogram_skip_convs):
            # print("Audio conv shape:", audio_conv.shape)
            # print("Audio skip conv shape:", audio_skip_conv.shape)
            # print("Spectrogram conv shape:", spectro_conv.shape)
            # print("Spectrogram skip conv shape:", spectro_skip_conv.shape)
            t1 = F.relu(audio_conv(audio))
            # print("relu shape:", t1.shape)
            # print("audio shape:", audio.shape)
            t1 = t1[:, :, : -(2**i)]
            # print("Modified relu shape:", t1.shape)
            audio = t1 + audio
            
            t2 = F.relu(spectro_conv(spectrogram))
            # print("Spectrogram conv shape:", t2.shape)
            # print("Spectrogram shape:", spectrogram.shape)
            t2 = t2[:, :, :-(2**i)]
            # print("Modified Spectrogram conv shape:", t2.shape)
            
            spectrogram = t2 + spectrogram
            audio_skip += audio_skip_conv(audio)
            spectrogram_skip += spectro_skip_conv(spectrogram)
            i += 1
        
        # print("Audio skip shape:", audio_skip.shape)
        # print("Spectrogram skip shape:", spectrogram_skip.shape)

        # Combine using multi-head attention
        combined, _ = self.feature_attention(audio_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2))
        combined = combined.transpose(1, 2)
        
        # print("Combined shape:", combined.shape)

        # Transformer processing with residual connection
        combined = self.transformer(combined.transpose(1, 2)).transpose(1, 2) + self.residual_conv(combined)

        # print("Combined shape Transformer:", combined.shape)
        # Final processing
        x = F.relu(self.final_conv1(combined))
        x = self.final_conv2(x)

        return x
    
    
    def generate_audio(self, audio, spectrogram, sample_no, device):
        """
        Generate audio using the model and save it to a directory.

        Args:
        audio (torch.Tensor): The input audio tensor.
        spectrogram (torch.Tensor): The input spectrogram tensor.
        sample_no (int): The sample number to append to the filename.
        device (str): The device to perform computation on.
        """
        # Ensure the model is in evaluation mode
        self.eval()
        # Move inputs to the correct device
        audio = audio.to(device)
        spectrogram = spectrogram.to(device)
        # Generate audio using the forward method
        with torch.no_grad():
            generated_audio = self.forward(audio, spectrogram)
        # Ensure the output directory exists
        output_dir = 'gen_music'
        os.makedirs(output_dir, exist_ok=True)
        # Save the generated audio to a file
        output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')
        torch.save(generated_audio, output_path)
        print(f"Generated audio saved to {output_path}")
        # Optionally, return the path or the audio tensor for further use
        return output_path

In [12]:
import torch
from torch.utils.data import DataLoader
import os
import wandb  # Ensure wandb is imported if you're using it

def train(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    model.to(device)
    print(device)
    print("Training Begins!")
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for i, (audio, spectrogram) in enumerate(train_loader):
            audio, spectrogram = audio.to(device), spectrogram.to(device)
            
            if audio.dim() == 2:
                audio = audio.unsqueeze(1)  # Add channel dimension
            elif audio.dim() != 3:
                raise ValueError("Audio input must be 2D or 3D tensor")
            
            window = torch.hann_window(1024, device=device)
            spectrogram = torch.stft(spectrogram.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, window=window, return_complex=True)


            optimizer.zero_grad()
            output = model(audio, spectrogram)
            loss = criterion(output, audio)  # Ensure the criterion is correctly defined for the expected output
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Log loss to wandb
            wandb.log({"train_loss": loss.item()})
            if i % 10 == 0:  # Log every 10 steps
                print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}")

            # Save model checkpoint periodically or based on performance
            if i % 100 == 0:  # Save every 100 iterations
                checkpoint_path = f'TW_Checkpoint/model_TW_{epoch}_{i}.pt'
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")

            # # Generate synthetic data and add to train_loader if needed
            # if i % 50 == 0:  # Generate synthetic data every 50 iterations
            #     with torch.no_grad():
            #         synthetic_audio = model.generate_audio(audio, spectrogram, i, device)
            #         # Assuming train_loader.dataset is a list or supports append
            #         train_loader.dataset.append((synthetic_audio.detach(), spectrogram))

        epoch_loss /= len(train_loader)
        print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {epoch_loss}")
        wandb.log({"epoch_loss": epoch_loss})

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for audio, spectrogram in val_loader:
                audio, spectrogram = audio.to(device), spectrogram.to(device)
                output = model(audio, spectrogram)
                val_loss += criterion(output, audio).item()

        val_loss /= len(val_loader)
        wandb.log({"val_loss": val_loss})
        print(f"Validation Loss: {val_loss}")

In [13]:
from torch.optim import Adam

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UltimateTransformerWaveNet().to(device)
print(model)
optimizer = Adam(model.parameters(), lr=0.001)
# composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)
composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)

# train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

UltimateTransformerWaveNet(
  (audio_conv): Conv1d(1, 64, kernel_size=(1,), stride=(1,))
  (spectrogram_conv): Conv1d(1025, 64, kernel_size=(1,), stride=(1,))
  (audio_dilated_convs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(1,), groups=4)
    (1): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(2,), dilation=(2,), groups=4)
    (2): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(4,), dilation=(4,), groups=4)
    (3): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(8,), dilation=(8,), groups=4)
    (4): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(16,), dilation=(16,), groups=4)
    (5): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(32,), dilation=(32,), groups=4)
    (6): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(64,), dilation=(64,), groups=4)
    (7): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(128,), dilation=(128,), groups=4)
    (8): Conv1d(64, 64, kernel_size=(2,), stride=(1,), pa

In [14]:
# train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

In [15]:
import torch
import os

def load_latest_checkpoint(checkpoint_dir):
    checkpoint_files = [file for file in os.listdir(checkpoint_dir) if file.endswith('.pt')]
    latest_file = max(checkpoint_files, key=lambda x: os.path.getctime(os.path.join(checkpoint_dir, x)))
    latest_path = os.path.join(checkpoint_dir, latest_file)
    print(f"Loading checkpoint: {latest_path}")
    return torch.load(latest_path)

checkpoint_dir = 'UlTranNet_Checkpoints'
model = UltimateTransformerWaveNet()  # Assuming the model class is defined and imported
latest_checkpoint = load_latest_checkpoint(checkpoint_dir)
model.load_state_dict(latest_checkpoint)
model.eval()

Loading checkpoint: UlTranNet_Checkpoints\model_TW_0.pt


UltimateTransformerWaveNet(
  (audio_conv): Conv1d(1, 64, kernel_size=(1,), stride=(1,))
  (spectrogram_conv): Conv1d(1025, 64, kernel_size=(1,), stride=(1,))
  (audio_dilated_convs): ModuleList(
    (0): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(1,), groups=4)
    (1): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(2,), dilation=(2,), groups=4)
    (2): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(4,), dilation=(4,), groups=4)
    (3): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(8,), dilation=(8,), groups=4)
    (4): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(16,), dilation=(16,), groups=4)
    (5): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(32,), dilation=(32,), groups=4)
    (6): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(64,), dilation=(64,), groups=4)
    (7): Conv1d(64, 64, kernel_size=(2,), stride=(1,), padding=(128,), dilation=(128,), groups=4)
    (8): Conv1d(64, 64, kernel_size=(2,), stride=(1,), pa

In [16]:
# def evaluate_model(model, test_loader, device):
#     model.to(device)
#     model.eval()
#     with torch.no_grad():
#         for audio, spectrogram in test_loader:
#             audio, spectrogram = audio.to(device), spectrogram.to(device)
#             if audio.dim() == 2:
#                 audio = audio.unsqueeze(1)  # Add channel dimension
#             elif audio.dim() != 3:
#                 raise ValueError("Audio input must be 2D or 3D tensor")
            
#             # window = torch.hann_window(1024, device=device)
#             # spectrogram = torch.stft(spectrogram.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, window=window, return_complex=True)
#             output = model(audio, spectrogram)
#             # Here you can add code to calculate any metrics or losses if needed
#     print("Evaluation completed.")

# # Assuming test_loader is defined and device is set
# evaluate_model(model, test_loader, device)

In [24]:
import torch
import pesq
import numpy as np
from scipy.io import wavfile

def calculate_pesq(reference_audio, generated_audio, sample_rate):
    # PESQ expects audio data in the form of numpy arrays with int16 format.
    ref_audio_int16 = (reference_audio.numpy() * 32767).astype(np.int16)
    gen_audio_int16 = (generated_audio.numpy() * 32767).astype(np.int16)
    score = pesq.pesq(sample_rate, ref_audio_int16, gen_audio_int16, 'wb')  # Wideband mode
    return score

def evaluate_model_simplified(model, test_loader, device):
    model.to(device)
    model.eval()
    total_pesq, count = 0, 0
    with torch.no_grad():
        for audio, spectrogram in test_loader:
            audio, spectrogram = audio.to(device), spectrogram.to(device)
            if audio.dim() == 2:
                audio = audio.unsqueeze(1)  # Add channel dimension
            elif audio.dim() != 3:
                raise ValueError("Audio input must be 2D or 3D tensor")
     
            output = model(audio.to(device), spectrogram)
            pesq_score = calculate_pesq(audio, output, 22050)
            total_pesq += pesq_score
            count += 1
    avg_pesq = total_pesq / count
    print(f"Average PESQ: {avg_pesq}")

# Assuming test_loader and device are defined
evaluate_model_simplified(model, test_loader, device)

  return F.conv1d(input, weight, bias, self.stride,
  return torch._transformer_encoder_layer_fwd(


TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

In [None]:
import torch
import pesq
import numpy as np
from scipy.io import wavfile
import random

def calculate_pesq(reference_audio, generated_audio, sample_rate):
    # PESQ expects audio data in the form of numpy arrays with int16 format.
    ref_audio_int16 = (reference_audio.numpy() * 32767).astype(np.int16)
    gen_audio_int16 = (generated_audio.numpy() * 32767).astype(np.int16)
    score = pesq.pesq(sample_rate, ref_audio_int16, gen_audio_int16, 'wb')  # Wideband mode
    return score

def generate_and_save_wav(audio_tensor, filename, sample_rate=22050):
    # Convert the tensor to numpy array and scale to int16
    audio_numpy = (audio_tensor.numpy() * 32767).astype(np.int16)
    # Write the WAV file
    wavfile.write(filename, sample_rate, audio_numpy)
    print(f"Saved generated audio to {filename}")

def evaluate_and_generate_audio_random_sample(model, test_loader, device, save_path='generated_audio.wav'):
    model.to(device)
    model.eval()
    # Convert DataLoader to a list to randomly select a sample
    test_samples = list(test_loader)
    random_index = random.randint(0, len(test_samples) - 1)
    audio, spectrogram = test_samples[random_index]
    
    audio, spectrogram = audio.to(device), spectrogram.to(device)
    if audio.dim() == 2:
        audio = audio.unsqueeze(1)  # Add channel dimension
    elif audio.dim() != 3:
        raise ValueError("Audio input must be 2D or 3D tensor")

    with torch.no_grad():
        output = model(audio, spectrogram)
        pesq_score = calculate_pesq(audio, output, 22050)
        print(f"PESQ Score: {pesq_score}")

        # Save the generated audio to a WAV file
        generate_and_save_wav(output[0], save_path, 22050)

# Assuming test_loader and device are defined
evaluate_and_generate_audio_random_sample(model, test_loader, device)