In [None]:
import skimage
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import time
import argparse
import cv2
from scipy import io
from tqdm.notebook import tqdm
import io
from IPython.display import Audio

import torch
from torch import nn
import torch.nn.functional as F
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader, Dataset
import torchaudio
import torchaudio.transforms as transforms

from modules import utils
from modules.models import INR
from torchsummary import summary
from modules.encoding import Encoding

In [None]:
parser = argparse.ArgumentParser(description='INCODE')

# Shared Parameters
parser.add_argument('--input',type=str, default='/root/autodl-tmp/INCODE-main/poly_data/incode_data/Audio/gt_bach.wav', help='Input image path')
parser.add_argument('--inr_model',type=str, default='incode', help='[gauss, mfn, relu, siren, wire, wire2d, ffn, incode]')
parser.add_argument('--lr',type=float, default=9e-5, help='Learning rate')
parser.add_argument('--using_schedular', type=bool, default=True, help='Whether to use schedular')
parser.add_argument('--scheduler_b', type=float, default=0.1, help='Learning rate scheduler')
parser.add_argument('--maxpoints', type=int, default=256*256, help='Batch size')
parser.add_argument('--niters', type=int, default=501, help='Number if iterations')
parser.add_argument('--steps_til_summary', type=int, default=100, help='Number of steps till summary visualization')

# INCODE Parameters
parser.add_argument('--a_coef',type=float, default=0.1993, help='a coeficient')
parser.add_argument('--b_coef',type=float, default=0.0196, help='b coeficient')
parser.add_argument('--c_coef',type=float, default=0.0588, help='c coeficient')
parser.add_argument('--d_coef',type=float, default=0.0269, help='d coeficient')


args = parser.parse_args(args=[])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading Data

In [None]:
audio = utils.AudioFile(args.input)
dataloader = DataLoader(audio, shuffle=True, batch_size=1, pin_memory=True, num_workers=0)
rate, coords, ground_truth = next(iter(dataloader))

coords = coords.to(device)
gt = ground_truth.to(device)
rate = rate[0].item()

Audio(ground_truth.squeeze().numpy(), rate=rate)

## Defining Model

### Defining desired Positional Encoding

In [None]:
# Frequency Encoding
pos_encode_freq = {'type':'frequency', 'use_nyquist': True, 'mapping_input': len(audio.data)}

# Gaussian Encoding
pos_encode_gaus = {'type':'gaussian', 'scale_B': 10, 'mapping_input': 256}

# No Encoding
pos_encode_no = {'type': None}

### Model Configureations

In [None]:
ground_truth.squeeze(-1).shape

In [None]:
gt.squeeze(-1).shape

In [None]:
# ### Harmonizer Configurations
# MLP_configs={'task': 'audio',
#              'in_channels': 50,             
#              'hidden_channels': [50, 32, 4],
#              'mlp_bias':0.3120,
#              'activation_layer': nn.SiLU,
#              'sample_rate': rate,
#              'GT': gt.squeeze(-1)
#             }

# ### Model Configurations
# model = INR(args.inr_model).run(in_features=1,
#                                 out_features=1, 
#                                 hidden_features=256,
#                                 hidden_layers=3,
#                                 first_omega_0=3000.0,
#                                 hidden_omega_0=30.0,
#                                 pos_encode_configs=pos_encode_no, 
#                                 MLP_configs = MLP_configs
#                                ).to(device)

# print(model)

In [None]:
# poly_siren
class SineLayer(nn.Module):
    '''
    SineLayer is a custom PyTorch module that applies the Sinusoidal activation function to the output of a linear transformation.

    Args:
        in_features (int): Number of input features.
        out_features (int): Number of output features.
        bias (bool, optional): If True, the linear transformation includes a bias term. Default is True.
        is_first (bool, optional): If it is the first layer, we initialize the weights differently. Default is False.
        omega_0 (float, optional): Frequency scaling factor for the sinusoidal activation. Default is 30.
        scale (float, optional): Scaling factor for the output of the sine activation. Default is 10.0.
        init_weights (bool, optional): If True, initializes the layer's weights according to the SIREN paper. Default is True.

    '''
    
    def __init__(self, in_features, out_features, bias=True,
                 is_first=False, omega_0=30, scale=10.0, init_weights=True):
        super().__init__()
        self.omega_0 = omega_0
        self.is_first = is_first
        
        self.in_features = in_features
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        
        if init_weights:
            self.init_weights()
    
    def init_weights(self):
        # self.linear.bias.data.fill_(10)
        with torch.no_grad():
            if self.is_first:
                self.linear.weight.uniform_(-1 / self.in_features, 
                                             1 / self.in_features)      
            else:
                self.linear.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 
                                             np.sqrt(6 / self.in_features) / self.omega_0)
        
    def forward(self, input):
        return self.linear(input)

class Siren(nn.Module):
    """
        Siren activation
        https://arxiv.org/abs/2006.09661
    """

    def __init__(self, w0=1):
        """
            w0 comes from the end of section 3
            it should be 30 for the first layer
            and 1 for the rest
        """
        super().__init__()
        self.w0 = torch.tensor(w0)

    def forward(self, x):
        # return torch.sin(self.w0*(torch.abs(x)+1)*x) 
        return torch.sin(self.w0 * x) 
    def extra_repr(self):
        return "w0={}".format(self.w0)
    
    

class PolySiren(nn.Module):
    
    def __init__(
        self,activate='ReLU',norm_type='None'
        ) -> None:
        super(PolySiren, self).__init__()
        input_dim = 2

        input_dim = 1
        hidden_channel=256
        self.linear1=SineLayer(input_dim,hidden_channel,is_first=True)
        if norm_type=='LayerNorm':
            self.norm1=nn.LayerNorm(hidden_channel,eps=0.0000001)
            self.norm2=nn.LayerNorm(hidden_channel,eps=0.0000001)
            self.norm3=nn.LayerNorm(hidden_channel,eps=0.0000001)
            self.norm4=nn.LayerNorm(hidden_channel,eps=0.0000001)
        elif norm_type=='BatchNorm1d':
            self.norm1=nn.BatchNorm1d(65536)
            self.norm2=nn.BatchNorm1d(65536)
            self.norm3=nn.BatchNorm1d(65536)
            self.norm4=nn.BatchNorm1d(65536)
        elif norm_type=='None':
            self.norm1=nn.Identity()
            self.norm2=nn.Identity()
            self.norm3=nn.Identity()
            
            
            
        self.linear2=SineLayer(hidden_channel,hidden_channel)
        self.linear3=SineLayer(hidden_channel,hidden_channel)
        self.linear4=SineLayer(hidden_channel,hidden_channel)
        
        if activate=='ReLU':
            self.nolinear1=nn.ReLU()
            self.nolinear2=nn.ReLU()
            self.nolinear3=nn.ReLU()
        if activate=='Siren':
            self.nolinear1=Siren(3000)
            self.nolinear2=Siren(8)
            self.nolinear3=Siren(2)
            self.nolinear4=Siren(2)
        layers = []
        layers.append(nn.Linear(hidden_channel, 1))
        # layers.append(nn.Sigmoid())
        self.layers = nn.Sequential(*layers)
    def forward(self, input):

        x = input

        
        x = self.nolinear1(self.linear1(x))
        x = self.nolinear2(self.norm2(x+x*self.linear2(x)))
        x = self.nolinear3(self.norm3(x+x*self.linear3(x)))
        x = self.nolinear4(self.norm4(x+x*self.linear4(x)))
        # x = self.nolinear3(x)
        x = self.layers(x)
        x = torch.tanh(x)
        return x



args.lr=8e-4
args.inr_model = 'siren'
model = PolySiren(activate='Siren',norm_type='LayerNorm').to(device)

print(model)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {num_params/1e6}(M)')



In [None]:
# ### Model Configurations
# args.inr_model='siren'
# args.lr= 1e-4
# model = INR(args.inr_model).run(in_features=1,
#                                 out_features=1, 
#                                 hidden_features=256,
#                                 hidden_layers=3,
#                                 first_omega_0=3000.0,
#                                 hidden_omega_0=30.0,
#                                 pos_encode_configs=pos_encode_no
#                                ).to(device)
# print(model)
# num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(f'Total number of parameters: {num_params/1e6}(M)')

## Training Code

In [None]:
# Optimizer setup
optim = torch.optim.Adam(lr=args.lr, params=model.parameters())
scheduler = lr_scheduler.LambdaLR(optim, lambda x: args.scheduler_b ** min(x / args.niters, 1))

# Initialize lists for PSNR and MSE values
psnr_values = []
mse_array = torch.zeros(args.niters, device=device)

# Initialize best loss value as positive infinity
best_loss = torch.tensor(float('inf'))

In [None]:
for step in tqdm(range(args.niters)):
    
    # Calculate model output
    if args.inr_model == 'incode':
        model_output, coef = model(coords)  
    else:
        model_output = model(coords) 
    
    # Calculate the output loss
    output_loss = ((model_output - gt)**2).mean()
    
    if args.inr_model == 'incode':
        # Calculate regularization loss for 'incode' model
        a_coef, b_coef, c_coef, d_coef = coef[0]  
        reg_loss = args.a_coef * torch.relu(-a_coef) + \
                   args.b_coef * torch.relu(-b_coef) + \
                   args.c_coef * torch.relu(-c_coef) + \
                   args.d_coef * torch.relu(-d_coef)

        # Total loss for 'incode' model
        loss = output_loss + reg_loss 
    else: 
        # Total loss for other models
        loss = output_loss
            
    # Perform backpropagation and update model parameters
    optim.zero_grad()
    loss.backward()
    optim.step()
    if args.using_schedular:
        scheduler.step()
    
    # Calculate PSNR
    with torch.no_grad():
        mse_array[step] = ((model_output - gt)**2).mean().item()
        psnr = -10*torch.log10(mse_array[step])
        psnr_values.append(psnr.item())
    
    # Display GT, Reconstructed audio, and Error
    if step % args.steps_til_summary == 0:
        print("Epoch: {} | Total Loss: {:.6f} | PSNR: {:.4f}".format(step, loss.item(), psnr.item()))

        fig, axes = plt.subplots(1, 3, figsize=(18, 3))
        axes[0].plot(coords.squeeze().detach().cpu().numpy(), gt.squeeze().detach().cpu().numpy())
        axes[0].set_ylim(-1, 1)
        axes[0].set_title('Ground Truth')
        axes[1].plot(coords.squeeze().detach().cpu().numpy(), model_output.squeeze().detach().cpu().numpy())
        axes[1].set_ylim(-1, 1)
        axes[1].set_title('Reconstructed')
        axes[2].plot(coords.squeeze().detach().cpu().numpy(), (model_output - gt).squeeze().detach().cpu().numpy())
        axes[2].set_ylim(-0.6, 0.6)
        axes[2].set_title('Error')
        plt.show()

    # Check if the current iteration's loss is the best so far        
    if (mse_array[step] < best_loss) or (step == 0):
        best_loss = mse_array[step]
        best_audio = model_output.squeeze().detach().cpu().numpy()


    
# Print maximum PSNR achieved during training
print('--------------------')
print('Max PSNR:', max(psnr_values))
print('--------------------')

In [None]:
Audio(best_audio, rate=rate)

# Convergance Rate

In [None]:
font = {'font': 'Times New Roman', 'size': 12}

plt.figure()
axfont = {'family' : 'Times New Roman', 'weight' : 'regular', 'size'   : 10}
plt.rc('font', **axfont)

plt.plot(np.arange(len(psnr_values[:-1])), psnr_values[:-1], label = f"{(args.inr_model).upper()}")
plt.xlabel('# Epochs', fontdict=font)
plt.ylabel('PSNR (dB)', fontdict=font)
plt.title('Audio Representation', fontdict={'family': 'Times New Roman', 'size': 12, 'weight': 'bold'})
plt.legend()
plt.grid(True, color='lightgray')

plt.show()