In [5]:
import os
import numpy as np
import sys
import librosa
import warnings
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3" # set vis gpus 
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from torch.optim.lr_scheduler import ReduceLROnPlateau
import wwv.config  as cfg 
# from wwv.meta import params as params 
from wwv.data import AudioDataModule

torch.cuda.is_available()
cfg_fitting = cfg.Fitting()
cfg_feature = cfg.Feature()
cfg_signal = cfg.Signal()
cfg_model = cfg.HTSwin()

In [14]:

data_path = cfg.DataPath("/home/akinwilson/Code/HTS", cfg_model.model_name, cfg_model.model_dir)
data_module = AudioDataModule(data_path.root_data_dir, cfg_model=cfg_model, cfg_feature=cfg_feature, cfg_fitting=cfg_fitting)
                              
train_loader =  data_module.train_dataloader()
val_loader =  data_module.val_dataloader()
test_loader =  data_module.test_dataloader()

x = next(iter(train_loader))['x']

In [12]:
2**9

512

In [60]:
import torch
import torch.nn as nn 
import torch.nn.functional as F


class CNNAE(nn.Module):
    def __init__(self, n_input=1, n_output=1024, stride=16, n_channel=32):
        super().__init__()

        self.n_channel = n_channel


        # encoder layers 
        self.e_conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.e_bn1 = nn.BatchNorm1d(n_channel)
        self.e_pool1 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.e_bn2 = nn.BatchNorm1d(n_channel)
        self.e_pool2 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        # 
        self.e_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool3 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
#######################################################################################
        self.e_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool4 = nn.MaxPool1d(2, return_indices=True)
        self.e_fc4 = nn.Linear(2 * n_channel * 28, n_output)
#######################################################################################
        # decoder layers 
        self.d_fc4 = nn.Linear(n_output, 2 * n_channel * 28)
        self.d_pool4 = nn.MaxUnpool1d(2)
        self.d_bn4 = nn.BatchNorm1d(2 * n_channel)
#######################################################################################

        self.d_conv4 = nn.ConvTranspose1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.d_pool3 = nn.MaxUnpool1d(4)
        self.d_bn3 = nn.BatchNorm1d(2 * n_channel)
#######################################################################################
        self.d_conv3 = nn.ConvTranspose1d(2 * n_channel, n_channel, kernel_size=3)
        self.d_pool2 = nn.MaxUnpool1d(4)
        self.d_bn2 = nn.BatchNorm1d(n_channel)


        self.d_conv2 = nn.ConvTranspose1d(n_channel, n_channel, kernel_size=3)
        self.d_pool1 = nn.MaxUnpool1d(4)
        self.d_bn1 = nn.BatchNorm1d(n_channel)

        self.d_conv1 = nn.ConvTranspose1d(n_channel, n_input, kernel_size=80, stride=stride)
    


    def encode_forward(self, x):
        x = self.e_conv1(x)
        x = F.relu(self.e_bn1(x))
        x, idx1 = self.e_pool1(x)
        x = self.e_conv2(x)
        x = F.relu(self.e_bn2(x))
        x, idx2 = self.e_pool2(x)
        x = self.e_conv3(x)
        x = F.relu(self.e_bn3(x))
        x, idx3  = self.e_pool3(x)
        x = self.e_conv4(x)
        x = F.relu(self.e_bn4(x))
        print("before flat for fully connect:")
        print(x.shape)
        x = x.view(x.shape[0], -1)
        x = self.e_fc4(x)
        return idx1, idx2, idx3, x


    def decode_forward(self, idx1, idx2, idx3, x):

        x = self.d_fc4(x)
        x = x.view(x.shape[0], 2 * self.n_channel,  28)
        x = F.relu(self.d_bn4(x))
        x = self.d_conv4(x)
        x = self.d_pool3(x, idx3)

        x = F.relu(self.d_bn3(x))

        x = self.d_conv3(x)
        idx2 = idx2[:,:,:122] # due to padding

        x = self.d_pool2(x, idx2)

        x = F.relu(self.d_bn2(x))
        x = self.d_conv2(x)
        idx1 = idx1[:,:,:490] # due to padding
        x = self.d_pool1(x, idx1)
        x = F.relu(self.d_bn1(x))
        x = self.d_conv1(x)
        return x 



    def forward(self, x):
        idx1, idx2, idx3, encoded_x = self.encode_forward(x)
        print(encoded_x.shape)
        decoded_x = self.decode_forward(idx1, idx2, idx3,encoded_x)
        return decoded_x


CNNAE()(x).shape


before flat for fully connect:
torch.Size([32, 64, 28])
torch.Size([32, 1024])


torch.Size([32, 1, 31424])