In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Fri May 10 19:07:11 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0              47W / 400W |      2MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

if ram_gb < 20:
  print('Not using a high-RAM runtime')
else:
  print('You are using a high-RAM runtime!')

Your runtime has 89.6 gigabytes of available RAM

You are using a high-RAM runtime!


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# !pip install torchvggish

In [5]:
!pip install resampy
!pip install wandb



In [6]:

!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mrah-m[0m ([33mrebot[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
!pip install torchaudio

import os
import librosa
import wandb
import numpy as np
import multiprocessing as mp

import torch
import torchaudio
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
from torch.utils.data import random_split

wandb.init(project='UlTraNet')

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch==2.2.1->torchaudio)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.

[34m[1mwandb[0m: Currently logged in as: [33mrah-m[0m ([33mrebot[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
def load_audio(audio_path, sample_rate=22050, duration=5):
    # Load audio file with librosa, automatically resampling to the given sample rate
    audio, sr = librosa.load(audio_path, sr=sample_rate, duration=duration)

    # Calculate target number of samples
    target_length = sample_rate * duration

    # Pad audio if it is shorter than the target length
    if len(audio) < target_length:
        padding = target_length - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')
    # Truncate audio if it is longer than the target length
    elif len(audio) > target_length:
        audio = audio[:target_length]

    return audio

def get_spectrogram(audio, n_fft=2048, hop_length=512, max_length=130):
    # Generate a spectrogram
    spectrogram = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length)
    # Convert to magnitude (amplitude)
    spectrogram = np.abs(spectrogram)

    # Pad or truncate the spectrogram to ensure all are the same length
    if spectrogram.shape[1] < max_length:
        padding = max_length - spectrogram.shape[1]
        spectrogram = np.pad(spectrogram, ((0, 0), (0, padding)), mode='constant')
    else:
        spectrogram = spectrogram[:, :max_length]

    return spectrogram

class AudioDataset(Dataset):
    def __init__(self, root_dir, sample_rate=22050, n_fft=2048, hop_length=512, max_length=130):
        self.root_dir = root_dir
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.max_length = max_length
        self.files = [os.path.join(dp, f) for dp, dn, filenames in os.walk(root_dir) for f in filenames if f.endswith('.mp3') or f.endswith('.wav')]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        audio_path = self.files[idx]
        audio = load_audio(audio_path, self.sample_rate)
        spectrogram = get_spectrogram(audio, self.n_fft, self.hop_length, self.max_length)
        return audio, spectrogram

if __name__ == '__main__':
    mp.set_start_method('spawn', force=True)

    data_folder_path = os.path.join('/content/drive/My Drive', 'DATA')
    dataset = AudioDataset(root_dir=data_folder_path)
    loader = DataLoader(dataset, batch_size=10, shuffle=True)

In [9]:
# print(dataset[0])
# print((dataset[0][0]))
print(dataset[0][0].shape)

# print(dataset[0][1])
print(dataset[0][1].shape)

(110250,)
(1025, 130)


In [10]:
def split_dataset(dataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    val_size = int(total_size * val_ratio)
    test_size = total_size - train_size - val_size  # Ensure all data is used

    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
    return train_dataset, val_dataset, test_dataset

data_folder_path = 'DATA'
# dataset = AudioDataset(root_dir=data_folder_path)

# Assuming 'dataset' is an instance of AudioDataset
train_dataset, val_dataset, test_dataset = split_dataset(dataset)

# Create DataLoaders for each dataset split
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False, num_workers=0)

In [11]:
# # Try to fetch a single batch to see if it works
# try:
#     data = next(iter(train_loader))
#     print("Single batch loaded successfully:", data)
# except Exception as e:
#     print("Failed to load a batch:", e)

In [12]:
print(f"Training set size: {len(train_dataset)}")
print(f"Validation set size: {len(val_dataset)}")
print(f"Test set size: {len(test_dataset)}")

Training set size: 407
Validation set size: 87
Test set size: 88


In [13]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# # Load a pre-trained VGGish model for audio feature extraction
# vggish = torch.hub.load('harritaylor/torchvggish', 'vggish')

# # Define the Perceptual Loss using VGGish as the feature extractor
# class PerceptualLoss(nn.Module):
#     def __init__(self, feature_extractor):
#         super(PerceptualLoss, self).__init__()
#         self.feature_extractor = feature_extractor
#         self.feature_extractor.eval()  # Set to evaluation mode

#     def forward(self, generated_audio, target_audio):
#         with torch.no_grad():
#             real_features = self.feature_extractor(target_audio)
#         generated_features = self.feature_extractor(generated_audio)
#         loss = F.l1_loss(generated_features, real_features)
#         return loss

# perceptual_loss = PerceptualLoss(vggish)

In [14]:
class MultiScaleSpectrogramLoss(nn.Module):
    def __init__(self, scales=[1024, 2048, 4096]):
        super(MultiScaleSpectrogramLoss, self).__init__()
        self.scales = scales

    def forward(self, generated_audio, target_audio):
        loss = 0
        for scale in self.scales:
            # print("Scale:", scale)
            # print(generated_audio.shape)
            # print(target_audio.shape)
            gen_spec = torch.stft(generated_audio.squeeze(1), n_fft=scale, return_complex=True)
            target_spec = torch.stft(target_audio.squeeze(1), n_fft=scale, return_complex=True)
            loss += F.l1_loss(gen_spec.abs(), target_spec.abs())
        return loss / len(self.scales)

spectrogram_loss = MultiScaleSpectrogramLoss()

In [15]:
# For demonstration, let's assume we have a simple CNN as a discriminator
class SimpleAudioDiscriminator(nn.Module):
    def __init__(self):
        super(SimpleAudioDiscriminator, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=1, padding=1)  # Changed to Conv1d
        self.num_features = 16 * 110250  # 16 channels, length preserved at 110250
        self.fc1 = nn.Linear(self.num_features, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

    def intermediate_forward(self, x):
        x = F.relu(self.conv1(x))
        return x

discriminator = SimpleAudioDiscriminator()

class FeatureMatchingLoss(nn.Module):
    def __init__(self, discriminator):
        super(FeatureMatchingLoss, self).__init__()
        self.discriminator = discriminator
        self.discriminator.eval()

    def forward(self, generated_audio, target_audio):
        with torch.no_grad():
            real_features = self.discriminator.intermediate_forward(target_audio)
        generated_features = self.discriminator.intermediate_forward(generated_audio)
        loss = F.l1_loss(generated_features, real_features)
        return loss

feature_matching_loss = FeatureMatchingLoss(discriminator)

In [16]:
# Example of a composite loss
class CompositeLoss(nn.Module):
    def __init__(self, spectrogram_loss, feature_matching_loss):
        super(CompositeLoss, self).__init__()
        # self.perceptual_loss = perceptual_loss
        self.spectrogram_loss = spectrogram_loss
        self.feature_matching_loss = feature_matching_loss

    def forward(self, generated_audio, target_audio):
        # loss = (self.perceptual_loss(generated_audio, target_audio) +
        loss = (self.spectrogram_loss(generated_audio, target_audio) +
                self.feature_matching_loss(generated_audio, target_audio))
        return loss

In [17]:
def save_checkpoint(model, epoch, i):
    checkpoint_path = f'TW_Checkpoint/model_TW_{epoch}_{i}.pt'
    data_folder_path = os.path.join('/content/drive/My Drive', checkpoint_path)

    # Ensure the directory exists
    os.makedirs(os.path.dirname(data_folder_path), exist_ok=True)

    # Save the model state
    torch.save(model.state_dict(), data_folder_path)
    print(f"Checkpoint saved to {data_folder_path}")

In [18]:
class UltimateTransformerWaveNet(nn.Module):
    def __init__(self, audio_channels=1, spectrogram_channels=1025, num_channels=64, kernel_size=2, num_blocks=4, num_layers=10, num_heads=8):
        super(UltimateTransformerWaveNet, self).__init__()
        self.audio_conv = nn.Conv1d(audio_channels, num_channels, kernel_size=1)
        self.spectrogram_conv = nn.Conv1d(spectrogram_channels, num_channels, kernel_size=1)

        # Dilated convolutions with residual and skip connections for both streams
        self.audio_dilated_convs = nn.ModuleList()
        self.spectrogram_dilated_convs = nn.ModuleList()
        self.audio_skip_convs = nn.ModuleList()
        self.spectrogram_skip_convs = nn.ModuleList()
        for i in range(num_layers):
            dilation = 2 ** i
            self.audio_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.spectrogram_dilated_convs.append(nn.Conv1d(num_channels, num_channels, kernel_size, dilation=dilation, padding=dilation, groups=num_channels//16))
            self.audio_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))
            self.spectrogram_skip_convs.append(nn.Conv1d(num_channels, num_channels, 1))

        # Multi-head attention for combining features
        self.feature_attention = nn.MultiheadAttention(embed_dim=num_channels, num_heads=num_heads, batch_first=True)

        # Transformer block with residual connection
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_channels, nhead=num_heads, dim_feedforward=num_channels * 4, batch_first=True),
            num_layers=3)

        # Output layers
        self.final_conv1 = nn.Conv1d(num_channels, num_channels, 1)
        self.final_conv2 = nn.Conv1d(num_channels, audio_channels, 1)

        # Additional residual connection across the network
        self.residual_conv = nn.Conv1d(num_channels, num_channels, 1)

    def forward(self, audio, spectrogram):
        # print("Original audio shape:", audio.shape)
        # print("Original spectrogram shape:", spectrogram.shape)

        audio_input = F.relu(self.audio_conv(audio))
        spectrogram_input = F.relu(self.spectrogram_conv(spectrogram))
        audio = audio_input
        spectrogram = spectrogram_input

        # print("Audio shape:", audio.shape)
        # print("Spectrogram shape:", spectrogram.shape)

        audio_skip = 0
        spectrogram_skip = 0

        # Process through dilated convolutions with residual and skip connections
        i = 0
        for audio_conv, audio_skip_conv, spectro_conv, spectro_skip_conv in zip(self.audio_dilated_convs, self.audio_skip_convs, self.spectrogram_dilated_convs, self.spectrogram_skip_convs):
            # print("Audio conv shape:", audio_conv.shape)
            # print("Audio skip conv shape:", audio_skip_conv.shape)
            # print("Spectrogram conv shape:", spectro_conv.shape)
            # print("Spectrogram skip conv shape:", spectro_skip_conv.shape)
            t1 = F.relu(audio_conv(audio))
            # print("relu shape:", t1.shape)
            # print("audio shape:", audio.shape)
            t1 = t1[:, :, : -(2**i)]
            # print("Modified relu shape:", t1.shape)
            audio = t1 + audio

            t2 = F.relu(spectro_conv(spectrogram))
            # print("Spectrogram conv shape:", t2.shape)
            # print("Spectrogram shape:", spectrogram.shape)
            t2 = t2[:, :, :-(2**i)]
            # print("Modified Spectrogram conv shape:", t2.shape)

            spectrogram = t2 + spectrogram
            audio_skip += audio_skip_conv(audio)
            spectrogram_skip += spectro_skip_conv(spectrogram)
            i += 1

        # print("Audio skip shape:", audio_skip.shape)
        # print("Spectrogram skip shape:", spectrogram_skip.shape)

        # Combine using multi-head attention
        combined, _ = self.feature_attention(audio_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2), spectrogram_skip.transpose(1, 2))
        combined = combined.transpose(1, 2)

        # print("Combined shape:", combined.shape)

        # Transformer processing with residual connection
        combined = self.transformer(combined.transpose(1, 2)).transpose(1, 2) + self.residual_conv(combined)

        # print("Combined shape Transformer:", combined.shape)
        # Final processing
        x = F.relu(self.final_conv1(combined))
        x = self.final_conv2(x)

        # print("Final x: ", x.shape)

        return x

    # def generate_audio(self, audio, spectrogram, sample_no, device='cpu'):
    #     """
    #     Generate audio using the model and save it to a directory.

    #     Args:
    #     audio (torch.Tensor): The input audio tensor.
    #     spectrogram (torch.Tensor): The input spectrogram tensor.
    #     sample_no (int): The sample number to append to the filename.
    #     device (str): The device to perform computation on.
    #     """
    #     # Ensure the model is in evaluation mode
    #     self.eval()
    #     # Move inputs to the correct device
    #     audio = audio.to(device)
    #     spectrogram = spectrogram.to(device)
    #     # Generate audio using the forward method
    #     with torch.no_grad():
    #         generated_audio = self.forward(audio, spectrogram)
    #     # Ensure the output directory exists
    #     # output_dir = 'gen_music'
    #     output_dir = os.path.join('/content/drive/My Drive', 'gen_music')
    #     os.makedirs(output_dir, exist_ok=True)
    #     # Save the generated audio to a file
    #     output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')
    #     torch.save(generated_audio, output_path)
    #     print(f"Generated audio saved to {output_path}")
    #     # Optionally, return the path or the audio tensor for further use
    #     return output_path

    def generate_audio(self, audio, spectrogram, sample_no, device='cpu'):
        self.eval()  # Ensure the model is in evaluation mode
        audio = audio.to(device)
        spectrogram = spectrogram.to(device)

        with torch.no_grad():
            generated_audio = self.forward(audio, spectrogram)

        # Normalize the audio to the range [-1, 1] for WAV file compatibility
        generated_audio = generated_audio / torch.max(torch.abs(generated_audio))

        # Ensure the output directory exists
        output_dir = os.path.join('/content/drive/My Drive', 'gen_music')
        os.makedirs(output_dir, exist_ok=True)
        # print(f"Output directory ensured at {output_dir}.")

        # Save the generated audio to a WAV file
        output_path = os.path.join(output_dir, f'ultranwav_{sample_no}.wav')

        # Ensure the tensor is in the correct shape (channels, frames)
        if generated_audio.squeeze(0).dim() == 1:
            generated_audio = generated_audio.unsqueeze(0)  # Add channel dimension for mono
        elif generated_audio.squeeze(0).dim() == 2:
            pass  # Correct format
        else:
            raise ValueError("Unsupported tensor shape for audio saving")

        # Set the sample rate and channel layout
        sample_rate = 22050  # Adjust as needed
        channels = generated_audio.shape[0]
        channel_layout = 'mono' if channels == 1 else 'stereo'

        # Save using torchaudio with the correct settings
        torchaudio.save(output_path, generated_audio.squeeze(0).cpu(), sample_rate, format='wav', encoding='PCM_S', bits_per_sample=16)

        print(f"Generated audio saved to {output_path}")
        return output_path

In [19]:
def train(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    model.to(device)
    print(device)
    print("Training Begins!")
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for i, (audio, spectrogram) in enumerate(train_loader):
            audio, spectrogram = audio.to(device), spectrogram.to(device)

            if audio.dim() == 2:
                audio = audio.unsqueeze(1)  # Add channel dimension
            elif audio.dim() != 3:
                raise ValueError("Audio input must be 2D or 3D tensor")

            # print("Out here")

            # window = torch.hann_window(1024, device=device)
            # spectrogram = torch.stft(spectrogram.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, window=window, return_complex=True)
            # spectrogram.to(device)

            optimizer.zero_grad()
            output = model(audio, spectrogram)
            # print(output.shape)
            loss = criterion(output, audio)  # Ensure the criterion is correctly defined for the expected output
            loss.backward()
            optimizer.step()


            epoch_loss += loss.item()

            # Log loss to wandb
            wandb.log({"train_loss": loss.item()})
            if i % 1 == 0:  # Log every 10 steps
                print(f"Epoch [{epoch + 1}/{epochs}], Step [{i + 1}/{len(train_loader)}], Loss: {loss.item()}")

            # Save model checkpoint periodically or based on performance
            if i % 1 == 0:  # Save every 100 iterations
                save_checkpoint(model, epoch, i)
                # synthetic_audio = model.generate_audio(audio, spectrogram, i, device)

            # # # Generate synthetic data and add to train_loader if needed
            # if i % 1 == 0:  # Generate synthetic data every 50 iterations
            #     # with torch.no_grad():
            #     synthetic_audio = model.generate_audio(audio, spectrogram, i, device)
            # #         # Assuming train_loader.dataset is a list or supports append
            # #         train_loader.dataset.append((synthetic_audio.detach(), spectrogram))

        epoch_loss /= len(train_loader)
        print(f"Epoch [{epoch + 1}/{epochs}], Average Loss: {epoch_loss}")
        wandb.log({"epoch_loss": epoch_loss})

        # Validation loop
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for audio, spectrogram in val_loader:
                audio, spectrogram = audio.to(device), spectrogram.to(device)
                audio = audio.unsqueeze(1)
                output = model(audio, spectrogram)
                val_loss += criterion(output, audio).item()

        val_loss /= len(val_loader)
        wandb.log({"val_loss": val_loss})
        print(f"Validation Loss: {val_loss}")

In [20]:
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=False, num_workers=0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UltimateTransformerWaveNet().to(device)
# print(model)
optimizer = Adam(model.parameters(), lr=0.001)
# composite_loss = CompositeLoss(perceptual_loss, spectrogram_loss, feature_matching_loss)
composite_loss = CompositeLoss(spectrogram_loss, feature_matching_loss).to(device)

In [None]:
train(model, train_loader, val_loader, optimizer, composite_loss, epochs=50, device=device)

cuda
Training Begins!


  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


Epoch [1/50], Step [1/41], Loss: 3.3143563270568848
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_0.pt
Epoch [1/50], Step [2/41], Loss: 4.457818984985352
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_1.pt
Epoch [1/50], Step [3/41], Loss: 3.0026683807373047
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_2.pt
Epoch [1/50], Step [4/41], Loss: 5.681921005249023
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_3.pt
Epoch [1/50], Step [5/41], Loss: 3.6221697330474854
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_4.pt
Epoch [1/50], Step [6/41], Loss: 2.1675972938537598
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_5.pt
Epoch [1/50], Step [7/41], Loss: 4.103199481964111
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_6.pt
Epoch [1/50], Step [8/41], Loss: 2.5279550552368164
Checkpoint saved to /content/drive/My Drive/TW_Checkpoint/model_TW_0_