In [154]:
import os
import librosa
import wandb
import numpy as np
import multiprocessing as mp

import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
from torch.utils.data import random_split

wandb.init(project='UltimateTransformerWaveNet')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrah-m[0m ([33mrebot[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [155]:
def load_audio(audio_path, sample_rate=22050, duration=5):
    # Load audio file with librosa, automatically resampling to the given sample rate
    audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration)
    
    # Calculate target number of samples
    target_length = sample_rate * duration
    
    # Pad audio if it is shorter than the target length
    if len(audio) < target_length:
        padding = target_length - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')
    # Truncate audio if it is longer than the target length
    elif len(audio) > target_length:
        audio = audio[:target_length]
    
    return audio

def get_spectrogram(audio, n_fft=2048, hop_length=512, max_length=130):
    # Generate a spectrogram
    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    # Convert to magnitude (amplitude)
    spectrogram = np.abs(spectrogram)
    
    # Pad or truncate the spectrogram to ensure all are the same length
    if spectrogram.shape[1] < max_length:
        padding = max_length - spectrogram.shape[1]
        spectrogram = np.pad(spectrogram, ((0, 0), (0, padding)), mode='constant')
    else:
        spectrogram = spectrogram[:, :max_length]
    
    return spectrogram

class AudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=22050, n_fft=2048, hop_length=512, max_length=130):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.max_length = max_length
        self.files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith('.mp3') or f.endswith('.wav')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        audio = load_audio(audio_path, self.sample_rate)
        spectrogram = get_spectrogram(audio, self.n_fft, self.hop_length, self.max_length)
        return audio, spectrogram

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)

    dataset = AudioDataset(root_dir='DATA')
    loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [156]:
print(dataset[0])
print((dataset[0][0]))
print(dataset[0][0].shape)

    
print(dataset[0][1])
print(dataset[0][1].shape)


(array([ 0.01970823, -0.00225531, -0.03690785, ...,  0.02060048,
       -0.00064142, -0.02519251], dtype=float32), array([[5.7437736e-01, 3.6425254e-01, 1.2862755e-01, ..., 7.2768354e-01,
        2.6252899e-01, 3.3587483e-01],
       [5.1895320e-01, 4.5031342e-01, 1.3857873e-01, ..., 5.0968748e-01,
        4.5548519e-01, 5.7975030e-01],
       [3.0562350e-01, 3.7505564e-01, 2.0325445e-01, ..., 3.5334751e-01,
        4.1140336e-01, 9.5875704e-01],
       ...,
       [1.5232026e-03, 7.0222467e-04, 9.3422665e-07, ..., 2.8230016e-07,
        1.7090463e-07, 1.6775441e-07],
       [1.5125329e-03, 6.9691899e-04, 4.7191645e-07, ..., 1.4758260e-07,
        2.3828001e-07, 3.1422653e-07],
       [1.5089464e-03, 6.9499249e-04, 5.7717961e-07, ..., 3.2488643e-07,
        2.2508639e-08, 2.8777606e-07]], dtype=float32))
[ 0.01970823 -0.00225531 -0.03690785 ...  0.02060048 -0.00064142
 -0.02519251]
(110250,)
[[5.7437736e-01 3.6425254e-01 1.2862755e-01 ... 7.2768354e-01
  2.6252899e-01 3.3587483e-01]
 [

In [157]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    test_size = total_size - train_size - val_size  # Ensure all data is used

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    return train_dataset, val_dataset, test_dataset

data_folder_path = 'DATA'
# dataset = AudioDataset(root_dir=data_folder_path)

# Assuming 'dataset' is an instance of AudioDataset
train_dataset, val_dataset, test_dataset = split_dataset(dataset)

# Create DataLoaders for each dataset split
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0)

In [158]:
# Check if the dataset is correctly set up
print("Number of samples in dataset:", len(train_dataset))

# Create a DataLoader instance (make sure parameters like batch_size are set correctly)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)

# Try to fetch a single batch to see if it works
try:
    data = next(iter(train_loader))
    print("Single batch loaded successfully:", data)
except Exception as e:
    print("Failed to load a batch:", e)

Number of samples in dataset: 407
Single batch loaded successfully: [tensor([[ 3.1908e-03, -3.7624e-03,  5.7128e-03,  ..., -6.3606e-02,
         -3.1002e-02, -5.7576e-03],
        [-4.2829e-02,  3.0695e-02,  3.1982e-02,  ...,  3.9355e-01,
         -1.3681e-01,  6.6621e-02],
        [ 5.0504e-06, -1.9910e-05,  7.9151e-06,  ..., -6.6205e-03,
         -3.3100e-02, -6.4018e-02],
        ...,
        [ 5.0701e-02,  7.9188e-02,  7.8954e-02,  ...,  3.0429e-02,
          4.0938e-02,  2.9423e-02],
        [ 6.7924e-03,  1.0553e-02,  1.5429e-02,  ..., -3.0689e-01,
         -3.4870e-01, -2.8935e-01],
        [-6.9137e-02, -1.8144e-01, -2.5575e-01,  ..., -6.5567e-03,
         -7.3422e-02, -1.0121e-01]]), tensor([[[3.9674e+01, 6.9683e+01, 6.8097e+01,  ..., 5.2310e-01,
          4.2656e-02, 4.0288e-01],
         [3.2716e+01, 4.3127e+01, 3.4205e+01,  ..., 1.3027e+00,
          1.3306e+00, 1.0972e+00],
         [1.8309e+01, 9.0223e+00, 1.2512e+00,  ..., 1.5728e+00,
          1.4240e+00, 8.3851e-01],
 

In [159]:
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 407
Validation set size: 87
Test set size: 88


In [160]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class TransformerWaveNet(nn.Module):
#     def __init__(self, audio_channels=1, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
#         super(TransformerWaveNet, self).__init__()
#         self.num_blocks = num_blocks
#         self.num_layers = num_layers
#         self.dilated_convs = nn.ModuleList()
#         self.condition_convs = nn.ModuleList()
#         self.residual_convs = nn.ModuleList()
#         self.skip_convs = nn.ModuleList()

#         # Initial convolution layer for raw audio
#         self.audio_conv = nn.Conv1d(audio_channels, num_channels, 1)
#         # self.audio_conv = nn.Conv1d(10, out_channels, kernel_size)

#         # Initial convolution layer for spectrogram
#         self.spectrogram_conv = nn.Conv1d(audio_channels, num_channels, 1)

#         # Transformer block
#         self.transformer = nn.TransformerEncoder(
#             nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
#             num_layers=3)

#         # Dilated convolutions and condition convolutions
#         for _ in range(num_blocks):
#             for i in range(num_layers):
#                 dilation = 2 ** i
#                 self.dilated_convs.append(nn.Conv1d(num_channels, 2 * num_channels, kernel_size, dilation=dilation, padding=dilation))
#                 self.condition_convs.append(nn.Conv1d(num_channels, 2 * num_channels, 1))
#                 self.residual_convs.append(nn.Conv1d(num_channels, num_channels, 1))
#                 self.skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

#         # Output layers
#         self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
#         self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

#     def forward(self, audio, spectrogram):
#         # Process audio and spectrogram
#         audio = self.audio_conv(audio)
#         spectrogram = self.spectrogram_conv(spectrogram)

#         # Combine audio and spectrogram
#         x = audio + spectrogram

#         # Transformer processing
#         x = self.transformer(x)
        
#         skip_connections = []

#         for b in range(self.num_blocks):
#             for l in range(self.num_layers):
#                 # Dilated convolution
#                 dilated = self.dilated_convs[b * self.num_layers + l](x)
#                 # Split for gated activation
#                 filtered, gate = torch.split(dilated, dilated.size(1) // 2, dim=1)
#                 x = torch.tanh(filtered) * torch.sigmoid(gate)
#                 # Residual and skip connections
#                 x = self.residual_convs[b * self.num_layers + l](x)
#                 skip = self.skip_convs[b * self.num_layers + l](x)
#                 skip_connections.append(skip)

#         # Sum all skip connections
#         x = torch.sum(torch.stack(skip_connections), dim=0)

#         # Final convolutions
#         x = F.relu(self.final_conv1(x))
#         x = self.final_conv2(x)

#         return x

#     def generate(self, audio, spectrogram):
#         """
#         Generate audio using the model in an autoregressive manner.
#         Assumes the model is already trained and in eval mode.
#         """
#         self.eval()  # Ensure the model is in evaluation mode
#         with torch.no_grad():  # No need to track gradients
#             # Assuming the inputs are already on the correct device and preprocessed
#             generated_audio = self.forward(audio, spectrogram)

#             # Post-processing if necessary (e.g., applying a sigmoid to ensure output is in the correct range)
#             generated_audio = torch.sigmoid(generated_audio)  # Example post-processing

#         return generated_audio

In [161]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class TransformerWaveNet(nn.Module):
#     def __init__(self, audio_channels=1, spectrogram_channels=1, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
#         super(TransformerWaveNet, self).__init__()
#         self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
#         self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)

#         # Multi-head attention for combining audio and spectrogram features
#         self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

#         # Dilated convolutions with residual connections
#         self.audio_dilated_convs = nn.ModuleList()
#         self.spectrogram_dilated_convs = nn.ModuleList()
#         for i in range(num_layers):
#             dilation = 2 ** i
#             self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
#             self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))

#         # Transformer block
#         self.transformer = nn.TransformerEncoder(
#             nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
#             num_layers=3)

#         # Final output layers
#         self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
#         self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

#     def forward(self, audio, spectrogram):
#         audio = F.relu(self.audio_conv(audio))
#         spectrogram = F.relu(self.spectrogram_conv(spectrogram))

#         # Process through dilated convolutions with residual connections
#         for conv in self.audio_dilated_convs:
#             audio = F.relu(conv(audio)) + audio
#         for conv in self.spectrogram_dilated_convs:
#             spectrogram = F.relu(conv(spectrogram)) + spectrogram

#         # Combine using multi-head attention
#         combined, _ = self.feature_attention(audio.transpose(1, 2), spectrogram.transpose(1, 2), spectrogram.transpose(1, 2))
#         combined = combined.transpose(1, 2)

#         # Transformer processing
#         combined = self.transformer(combined.transpose(1, 2)).transpose(1, 2)

#         # Final processing
#         x = F.relu(self.final_conv1(combined))
#         x = self.final_conv2(x)

#         return x

In [162]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class UltimateTransformerWaveNet(nn.Module):
#     def __init__(self, audio_channels=1, spectrogram_channels=1, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
#         super(UltimateTransformerWaveNet, self).__init__()
#         self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
#         self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)
#         # Dilated convolutions with residual and skip connections for both streams
#         self.audio_dilated_convs = nn.ModuleList()
#         self.spectrogram_dilated_convs = nn.ModuleList()
#         self.audio_res_convs = nn.ModuleList()  # Residual connections for audio
#         self.spectrogram_res_convs = nn.ModuleList()  # Residual connections for spectrogram
#         self.audio_skip_convs = nn.ModuleList()
#         self.spectrogram_skip_convs = nn.ModuleList()
        
#         for i in range(num_layers):
#             dilation = 2 ** i
#             self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
#             self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
#             self.audio_res_convs.append(nn.Conv1d(num_channels, num_channels, 1))  # Residual layer for audio
#             self.spectrogram_res_convs.append(nn.Conv1d(num_channels, num_channels, 1))  # Residual layer for spectrogram
#             self.audio_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))
#             self.spectrogram_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

#         # Multi-head attention for combining features
#         self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

#         # Transformer block with residual connection
#         self.transformer = nn.TransformerEncoder(
#             nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
#             num_layers=3)

#         # Output layers
#         self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
#         self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

#         # Additional residual connection across the network
#         self.residual_conv = nn.Conv1d(num_channels, num_channels, 1)
        
#     def forward(self, audio, spectrogram):
#         # Initial convolution processing
#         audio = F.relu(self.audio_conv(audio))
#         spectrogram = F.relu(self.spectrogram_conv(spectrogram))

#         # Initialize skip connections accumulators
#         audio_skip = 0
#         spectrogram_skip = 0

#         # Process through dilated convolutions with residual and skip connections
#         for i in range(len(self.audio_dilated_convs)):
#             # Apply dilated convolution
#             audio_dilated = self.audio_dilated_convs[i](audio)
#             spectrogram_dilated = self.spectrogram_dilated_convs[i](spectrogram)

#             # Apply activation function and add residual connection
#             audio = F.relu(audio_dilated) + self.audio_res_convs[i](audio)
#             spectrogram = F.relu(spectrogram_dilated) + self.spectrogram_res_convs[i](spectrogram)

#             # Update skip connections
#             audio_skip += self.audio_skip_convs[i](audio)
#             spectrogram_skip += self.spectrogram_skip_convs[i](spectrogram)

#         # Combine audio and spectrogram features using multi-head attention
#         combined, _ = self.feature_attention(audio_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2))
#         combined = combined.transpose(1, 2)

#         # Transformer processing with additional residual connection
#         combined_transformed = self.transformer(combined.transpose(1, 2)).transpose(1, 2)
#         combined = combined_transformed + self.residual_conv(combined)

#         # Final processing through convolution layers
#         x = F.relu(self.final_conv1(combined))
#         x = self.final_conv2(x)

#         return x
    
#     def generate_audio(self, audio, spectrogram, sample_no, device='cpu'):
#         """
#         Generate audio using the model and save it to a directory.

#         Args:
#         audio (torch.Tensor): The input audio tensor.
#         spectrogram (torch.Tensor): The input spectrogram tensor.
#         sample_no (int): The sample number to append to the filename.
#         device (str): The device to perform computation on.
#         """
#         # Ensure the model is in evaluation mode
#         self.eval()
#         # Move inputs to the correct device
#         audio = audio.to(device)
#         spectrogram = spectrogram.to(device)
#         # Generate audio using the forward method
#         with torch.no_grad():
#             generated_audio = self.forward(audio, spectrogram)
#         # Ensure the output directory exists
#         output_dir = 'gen_music'
#         os.makedirs(output_dir, exist_ok=True)
#         # Save the generated audio to a file
#         output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')
#         torch.save(generated_audio, output_path)
#         print(f"Generated audio saved to {output_path}")
#         # Optionally, return the path or the audio tensor for further use
#         return output_path

In [163]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class UltimateTransformerWaveNet(nn.Module):
#     def __init__(self, audio_channels=1, spectrogram_channels=1025, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
#         super(UltimateTransformerWaveNet, self).__init__()
#         self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
#         self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)

#         # Dilated convolutions with residual and skip connections for both streams
#         self.audio_dilated_convs = nn.ModuleList()
#         self.spectrogram_dilated_convs = nn.ModuleList()
#         self.audio_skip_convs = nn.ModuleList()
#         self.spectrogram_skip_convs = nn.ModuleList()
#         for i in range(num_layers):
#             dilation = 2 ** i
#             self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
#             self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
#             self.audio_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))
#             self.spectrogram_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

#         # Multi-head attention for combining features
#         self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

#         # Transformer block with residual connection
#         self.transformer = nn.TransformerEncoder(
#             nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
#             num_layers=3)

#         # Output layers
#         self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
#         self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

#         # Additional residual connection across the network
#         self.residual_conv = nn.Conv1d(num_channels, num_channels, 1)

#     def forward(self, audio, spectrogram):
#         print("Original audio shape:", audio.shape)
#         print("Original spectrogram shape:", spectrogram.shape)
        
#         audio_input = F.relu(self.audio_conv(audio))
#         spectrogram_input = F.relu(self.spectrogram_conv(spectrogram))
#         audio = audio_input
#         spectrogram = spectrogram_input
        
#         print("Audio shape:", audio.shape)
#         print("Spectrogram shape:", spectrogram.shape)

#         audio_skip = 0
#         spectrogram_skip = 0

#         # Process through dilated convolutions with residual and skip connections
#         for audio_conv, audio_skip_conv, spectro_conv, spectro_skip_conv in zip(self.audio_dilated_convs, self.audio_skip_convs, self.spectrogram_dilated_convs, self.spectrogram_skip_convs):
#             print("Audio conv shape:", audio_conv.shape)
#             print("Audio skip conv shape:", audio_skip_conv.shape)
#             print("Spectrogram conv shape:", spectro_conv.shape)
#             print("Spectrogram skip conv shape:", spectro_skip_conv.shape)
#             tmp = F.relu(audio_conv(audio))
#             print("relu shape:", tmp.shape)
#             print("audio shape:", audio.shape)
#             audio = F.relu(audio_conv(audio))-1 + audio
#             spectrogram = F.relu(spectro_conv(spectrogram)) + spectrogram
#             audio_skip += audio_skip_conv(audio)
#             spectrogram_skip += spectro_skip_conv(spectrogram)
        
#         print("Audio skip shape:", audio_skip.shape)
#         print("Spectrogram skip shape:", spectrogram_skip.shape)

#         # Combine using multi-head attention
#         combined, _ = self.feature_attention(audio_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2))
#         combined = combined.transpose(1, 2)
        
#         print("Combined shape:", combined.shape)

#         # Transformer processing with residual connection
#         combined = self.transformer(combined.transpose(1, 2)).transpose(1, 2) + self.residual_conv(combined)

#         print("Combined shape Transformer:", combined.shape)
#         # Final processing
#         x = F.relu(self.final_conv1(combined))
#         x = self.final_conv2(x)

#         return x
    
#     def generate_audio(self, audio, spectrogram, sample_no, device='cpu'):
#         """
#         Generate audio using the model and save it to a directory.

#         Args:
#         audio (torch.Tensor): The input audio tensor.
#         spectrogram (torch.Tensor): The input spectrogram tensor.
#         sample_no (int): The sample number to append to the filename.
#         device (str): The device to perform computation on.
#         """
#         # Ensure the model is in evaluation mode
#         self.eval()
#         # Move inputs to the correct device
#         audio = audio.to(device)
#         spectrogram = spectrogram.to(device)
#         # Generate audio using the forward method
#         with torch.no_grad():
#             generated_audio = self.forward(audio, spectrogram)
#         # Ensure the output directory exists
#         output_dir = 'gen_music'
#         os.makedirs(output_dir, exist_ok=True)
#         # Save the generated audio to a file
#         output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')
#         torch.save(generated_audio, output_path)
#         print(f"Generated audio saved to {output_path}")
#         # Optionally, return the path or the audio tensor for further use
#         return output_path

In [164]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Load a pre-trained VGGish model for audio feature extraction
vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')

# Define the Perceptual Loss using VGGish as the feature extractor
class PerceptualLoss(nn.Module):
    def __init__(self, feature_extractor):
        super(PerceptualLoss, self).__init__()
        self.feature_extractor = feature_extractor
        self.feature_extractor.eval()  # Set to evaluation mode

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.feature_extractor(target_audio)
        generated_features = self.feature_extractor(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

perceptual_loss = PerceptualLoss(vggish)

Using cache found in C:\Users\rahat/.cache\torch\hub\harritaylor_torchvggish_master


In [165]:
class MultiScaleSpectrogramLoss(nn.Module):
    def __init__(self, scales=[1024, 2048, 4096]):
        super(MultiScaleSpectrogramLoss, self).__init__()
        self.scales = scales

    def forward(self, generated_audio, target_audio):
        loss = 0
        for scale in self.scales:
            gen_spec = torch.stft(generated_audio, n_fft=scale, return_complex=True)
            target_spec = torch.stft(target_audio, n_fft=scale, return_complex=True)
            loss += F.l1_loss(gen_spec.abs(), target_spec.abs())
        return loss / len(self.scales)

spectrogram_loss = MultiScaleSpectrogramLoss()

In [166]:
# For demonstration, let's assume we have a simple CNN as a discriminator
class SimpleAudioDiscriminator(nn.Module):
    def __init__(self):
        super(SimpleAudioDiscriminator, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)  # Changed to Conv1d
        self.fc1 = nn.Linear(16 * 16, 1)  # Adjust size according to actual output dimensions

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

    def intermediate_forward(self, x):
        x = F.relu(self.conv1(x))
        return x

discriminator = SimpleAudioDiscriminator()

class FeatureMatchingLoss(nn.Module):
    def __init__(self, discriminator):
        super(FeatureMatchingLoss, self).__init__()
        self.discriminator = discriminator
        self.discriminator.eval()

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.discriminator.intermediate_forward(target_audio)
        generated_features = self.discriminator.intermediate_forward(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

feature_matching_loss = FeatureMatchingLoss(discriminator)

In [167]:
# Example of a composite loss
class CompositeLoss(nn.Module):
    def __init__(self, perceptual_loss, spectrogram_loss, feature_matching_loss):
        super(CompositeLoss, self).__init__()
        self.perceptual_loss = perceptual_loss
        self.spectrogram_loss = spectrogram_loss
        self.feature_matching_loss = feature_matching_loss

    def forward(self, generated_audio, target_audio):
        loss = (self.perceptual_loss(generated_audio, target_audio) +
                self.spectrogram_loss(generated_audio, target_audio) +
                self.feature_matching_loss(generated_audio, target_audio))
        return loss

In [172]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UltimateTransformerWaveNet(nn.Module):
    def __init__(self, audio_channels=1, spectrogram_channels=1025, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
        super(UltimateTransformerWaveNet, self).__init__()
        self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
        self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)

        # Dilated convolutions with residual and skip connections for both streams
        self.audio_dilated_convs = nn.ModuleList()
        self.spectrogram_dilated_convs = nn.ModuleList()
        self.audio_skip_convs = nn.ModuleList()
        self.spectrogram_skip_convs = nn.ModuleList()
        for i in range(num_layers):
            dilation = 2 ** i
            self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.audio_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))
            self.spectrogram_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

        # Multi-head attention for combining features
        self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

        # Transformer block with residual connection
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
            num_layers=3)

        # Output layers
        self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
        self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

        # Additional residual connection across the network
        self.residual_conv = nn.Conv1d(num_channels, num_channels, 1)

    def forward(self, audio, spectrogram):
        # print("Original audio shape:", audio.shape)
        # print("Original spectrogram shape:", spectrogram.shape)
        
        audio_input = F.relu(self.audio_conv(audio))
        spectrogram_input = F.relu(self.spectrogram_conv(spectrogram))
        audio = audio_input
        spectrogram = spectrogram_input
        
        # print("Audio shape:", audio.shape)
        # print("Spectrogram shape:", spectrogram.shape)

        audio_skip = 0
        spectrogram_skip = 0

        # Process through dilated convolutions with residual and skip connections
        i = 0
        for audio_conv, audio_skip_conv, spectro_conv, spectro_skip_conv in zip(self.audio_dilated_convs, self.audio_skip_convs, self.spectrogram_dilated_convs, self.spectrogram_skip_convs):
            # print("Audio conv shape:", audio_conv.shape)
            # print("Audio skip conv shape:", audio_skip_conv.shape)
            # print("Spectrogram conv shape:", spectro_conv.shape)
            # print("Spectrogram skip conv shape:", spectro_skip_conv.shape)
            t1 = F.relu(audio_conv(audio))
            # print("relu shape:", t1.shape)
            # print("audio shape:", audio.shape)
            t1 = t1[:, :, : -(2**i)]
            # print("Modified relu shape:", t1.shape)
            audio = t1 + audio
            
            t2 = F.relu(spectro_conv(spectrogram))
            # print("Spectrogram conv shape:", t2.shape)
            # print("Spectrogram shape:", spectrogram.shape)
            t2 = t2[:, :, :-(2**i)]
            # print("Modified Spectrogram conv shape:", t2.shape)
            
            spectrogram = t2 + spectrogram
            audio_skip += audio_skip_conv(audio)
            spectrogram_skip += spectro_skip_conv(spectrogram)
            i += 1
        
        # print("Audio skip shape:", audio_skip.shape)
        # print("Spectrogram skip shape:", spectrogram_skip.shape)

        # Combine using multi-head attention
        combined, _ = self.feature_attention(audio_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2))
        combined = combined.transpose(1, 2)
        
        # print("Combined shape:", combined.shape)

        # Transformer processing with residual connection
        combined = self.transformer(combined.transpose(1, 2)).transpose(1, 2) + self.residual_conv(combined)

        # print("Combined shape Transformer:", combined.shape)
        # Final processing
        x = F.relu(self.final_conv1(combined))
        x = self.final_conv2(x)

        return x
    
    
    def generate_audio(self, audio, spectrogram, sample_no, device='cpu'):
        """
        Generate audio using the model and save it to a directory.

        Args:
        audio (torch.Tensor): The input audio tensor.
        spectrogram (torch.Tensor): The input spectrogram tensor.
        sample_no (int): The sample number to append to the filename.
        device (str): The device to perform computation on.
        """
        # Ensure the model is in evaluation mode
        self.eval()
        # Move inputs to the correct device
        audio = audio.to(device)
        spectrogram = spectrogram.to(device)
        # Generate audio using the forward method
        with torch.no_grad():
            generated_audio = self.forward(audio, spectrogram)
        # Ensure the output directory exists
        output_dir = 'gen_music'
        os.makedirs(output_dir, exist_ok=True)
        # Save the generated audio to a file
        output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')
        torch.save(generated_audio, output_path)
        print(f"Generated audio saved to {output_path}")
        # Optionally, return the path or the audio tensor for further use
        return output_path

In [187]:
import torch
from torch.utils.data import DataLoader
import os
import wandb  # Ensure wandb is imported if you're using it

def train(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    model.to(device)
    print(device)
    print("Training Begins!")
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for i, (audio, spectrogram) in enumerate(train_loader):
            audio, spectrogram = audio.to(device), spectrogram.to(device)
            
            if audio.dim() == 2:
                audio = audio.unsqueeze(1)  # Add channel dimension
            elif audio.dim() != 3:
                raise ValueError("Audio input must be 2D or 3D tensor")

            optimizer.zero_grad()
            output = model(audio, spectrogram)
            loss = criterion(output, audio)  # Ensure the criterion is correctly defined for the expected output
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            # Log loss to wandb
            wandb.log({"train_loss": loss.item()})
            if i % 10 == 0:  # Log every 10 steps
                print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}")

            # Save model checkpoint periodically or based on performance
            if i % 100 == 0:  # Save every 100 iterations
                checkpoint_path = f'TW_Checkpoint/model_TW_{epoch}_{i}.pt'
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Checkpoint saved to {checkpoint_path}")

            # Generate synthetic data and add to train_loader if needed
            if i % 50 == 0:  # Generate synthetic data every 50 iterations
                with torch.no_grad():
                    synthetic_audio = model.generate_audio(audio, spectrogram, i, device)
                    # Assuming train_loader.dataset is a list or supports append
                    train_loader.dataset.append((synthetic_audio.detach(), spectrogram))

        epoch_loss /= len(train_loader)
        print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {epoch_loss}")
        wandb.log({"epoch_loss": epoch_loss})

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for audio, spectrogram in val_loader:
                audio, spectrogram = audio.to(device), spectrogram.to(device)
                output = model(audio, spectrogram)
                val_loss += criterion(output, audio).item()

        val_loss /= len(val_loader)
        wandb.log({"val_loss": val_loss})
        print(f"Validation Loss: {val_loss}")

In [188]:
from torch.optim import Adam

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UltimateTransformerWaveNet().to(device)
# print(model)
optimizer = Adam(model.parameters(), lr=0.001)
# composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)
composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)

# train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [189]:
train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

cuda
Training Begins!


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
