In [None]:
# notebook config imports
import warnings
warnings.filterwarnings('ignore')
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" # set vis gpus 
import torch
import torch.nn as nn 
import pytorch_lightning as pl
import wwv.config  as cfg 
from wwv.util import CallbackCollection 
from wwv.data import AudioDataModule
device = "cuda" if torch.cuda.is_available() else "cpu"

cfg_fitting = cfg.Fitting(batch_size=64, train_bs=64, val_bs=64)
cfg_feature = cfg.Feature()
cfg_signal = cfg.Signal()
cfg_model = cfg.CNNAE()

data_path = cfg.DataPath(os.environ['DATA_ROOT'], cfg_model.model_name, cfg_model.model_dir)
data_module = AudioDataModule(data_path.root_data_dir, cfg_model=cfg_model, cfg_feature=cfg_feature, cfg_fitting=cfg_fitting)
                              
train_loader =  data_module.train_dataloader()
val_loader =  data_module.val_dataloader()
test_loader =  data_module.test_dataloader()

## Undercomplete Autoencoder

In [20]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"

class CNNAE(nn.Module):
    def __init__(self, n_input=1, latent_dim=1024, stride=16, n_channel=32):
        super().__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.n_channel = n_channel
        # encoder layers 
        self.e_conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        
        self.e_bn1 = nn.BatchNorm1d(n_channel)
        self.e_pool1 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.e_bn2 = nn.BatchNorm1d(n_channel)
        self.e_pool2 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool3 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool4 = nn.MaxPool1d(2, return_indices=True)
        self.e_fc4 = nn.Linear(2 * n_channel * 28, latent_dim)
        # decoder layers 
        self.d_fc4 = nn.Linear(latent_dim, 2 * n_channel * 28)
        self.d_pool4 = nn.MaxUnpool1d(2)
        self.d_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv4 = nn.ConvTranspose1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.d_pool3 = nn.MaxUnpool1d(4)
        self.d_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv3 = nn.ConvTranspose1d(2 * n_channel, n_channel, kernel_size=3)
        self.d_pool2 = nn.MaxUnpool1d(4)
        self.d_bn2 = nn.BatchNorm1d(n_channel)
        self.d_conv2 = nn.ConvTranspose1d(n_channel, n_channel, kernel_size=3)
        self.d_pool1 = nn.MaxUnpool1d(4)
        self.d_bn1 = nn.BatchNorm1d(n_channel)
        self.d_conv1 = nn.ConvTranspose1d(n_channel, n_input, kernel_size=80, stride=stride)
    


    def encode(self, x):
        x = self.e_conv1(x)
        x = F.relu(self.e_bn1(x))
        x, idx1 = self.e_pool1(x)
        x = self.e_conv2(x)
        x = F.relu(self.e_bn2(x))
        x, idx2 = self.e_pool2(x)
        x = self.e_conv3(x)
        x = F.relu(self.e_bn3(x))
        x, idx3  = self.e_pool3(x)
        x = self.e_conv4(x)
        x = F.relu(self.e_bn4(x))
        x = x.view(x.shape[0], -1)
        x = self.e_fc4(x)
        return idx1, idx2, idx3, x


    def decode(self, idx1, idx2, idx3, x):
        bs = x.shape[0]
        x = self.d_fc4(x)
        x = x.view(bs, 2 * self.n_channel,  28)
        x = F.relu(self.d_bn4(x))
        x = self.d_conv4(x)
        x = self.d_pool3(x, idx3)
        x = F.relu(self.d_bn3(x))
        x = self.d_conv3(x)
        padding = idx2.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding),device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool2(x, idx2)
        x = F.relu(self.d_bn2(x))
        x = self.d_conv2(x)
        padding = idx1.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding), device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool1(x, idx1)
        x = F.relu(self.d_bn1(x))
        x = self.d_conv1(x)
        return x


    def forward(self, x):
        idx1, idx2, idx3, encoded_x = self.encode(x)
        decoded_x = self.decode(idx1, idx2, idx3,encoded_x)
        return decoded_x

x = torch.randn((1,1,32000), device=device)

model = CNNAE()
model.to(device=device)
x_reconstructed = model(x)
assert x_reconstructed.shape == x.shape,  f"The reconstructed input is of different dimensions to the original input. Original: {x_reconstructed.shape}. Reconstructed: {x.shape}"

## Fitting autoencoder and visualizing embedding

In [None]:
import torch.nn.functional as F 
from pytorch_lightning import Trainer


class Routine(pl.LightningModule):
    """
    Routine for fitting a autoencoder: encoder decoder structure
    """
    def __init__ (self, model):
        super().__init__()
        self.model = model
        self.lr = 1e-4

    def encode(self,x):
        idx1, idx2, idx3, x_encoded = self.model.encode(x)
        return x_encoded,  idx1, idx2, idx3


    def decode(self, idx1, idx2, idx3, x):
        x_reconstructed = self.model.decode( idx1, idx2, idx3, x)
        return x_reconstructed 


    def forward(self,x):
        x_encoded,  idx1, idx2, idx3 = self.encode(x)
        x_reconstructed  = self.decode( idx1, idx2, idx3, x_encoded)
        return x_reconstructed



    def training_step( self, batch, batch_idx):
        x = batch['x']
        x_reconstructed = self.forward(x)
        loss = F.mse_loss(x,x_reconstructed)
        return {"loss": loss }

    def training_epoch_end(self, training_step_outputs):
        results = {"loss": torch.tensor( [ x['loss'].float().mean().item() for x in training_step_outputs]).mean()}
        for (k,v) in results.items():
            self.log(f"train_{k}", v, on_epoch=True, prog_bar=True, logger=True)    



    def validation_step( self, batch, batch_idx):
        x = batch['x']
        x_reconstructed = self.forward(x)
        loss = F.mse_loss(x,x_reconstructed)
        return {"val_loss": loss }

    def validation_epoch_end(self, training_step_outputs):
        results = {"loss": torch.tensor( [ x['val_loss'].float().mean().item() for x in training_step_outputs]).mean()}
        for (k,v) in results.items():
            self.log(f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr = self.lr, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05, 
        )
        # scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
        return  {"optimizer": optimizer } # , "lr_scheduler": scheduler, "monitor": "val_loss"} 


routine = Routine(model)
trainer = Trainer(accelerator="gpu",sync_batchnorm = True, max_epochs = 5 ,num_sanity_val_steps = 2, gradient_clip_val=1.0)
# # PATH  = "/home/akinwilson/Code/pytorch/output/model/ResNet/epoch=18-val_loss=0.15-val_acc=0.95-val_ttr=0.92-val_ftr=0.03.ckpt"                  
# # Trainer executes fitting; training and validating proceducres 
trainer.fit(routine, train_dataloaders=train_loader, val_dataloaders=val_loader)

# used downstream, need to define object here, check the forward method of the AE_classifier 
encoder = trainer.model

from torch.utils.tensorboard import SummaryWriter


class TFVisualiser:
    '''
    Writes latent code generated by model to tensorboard for adding visualisations

    NOTE THE BELOW MODEL EXTRACTION EXPECTS THE ROUTINE BE HAVE BEEN FITTED WITH A SINGLE GPU
    ----> you'll have to extract the model slightly differently in a distributed environment 
    '''

    def __init__(self, model=trainer.model.model, test_loader=test_loader, sample_size=2000):

        self.model = model 
        self.sample_size = sample_size
        # get test data to embedd 
        xs,ys = [],[]
        for batch in test_loader:
            xs.append(batch['x'])
            ys.append(batch['y'])
        X = torch.vstack(xs)
        Y = torch.concat(ys)
        X_sampled, Y_sampled = self.sample(X,Y) 

        self.X = X_sampled
        self.Y = Y_sampled
        self.latent_code_output_dir = 'runs/cnnae'


    def sample(self, X, Y):

        smaple_permuation_idxs = torch.randperm(X.size(0))
        idxs = smaple_permuation_idxs[:self.sample_size]
        X_sampled = X[idxs]
        Y_sampled = Y[idxs]
        return X_sampled, Y_sampled

    def save_latent_code(self):
        # encode with model 
        _,_,_, latent_code = self.model.encode(self.X)
        label_list = [{1.:"Wake word", 0.:"Not wake word"}[y] for y in  self.Y.numpy().tolist()]
        # default `log_dir` is "runs" - we'll be more specific here
        writer = SummaryWriter(self.latent_code_output_dir)
        # log for visualisation 
        writer.add_embedding(latent_code, metadata=label_list, tag="AudioEmbedding")
        writer.close()


    def __call__(self):
        from pathlib import Path 

        print(f"Saving subsample of {self.sample_size} of test set's latent encodings to location: {Path().cwd() / self.latent_code_output_dir}")
        self.save_latent_code()

TFVisualiser()()

# Extracting encoder for down stream task.
encoder = trainer.model


## Downstream autoencoder application
### Vanilla MLP classifier head

In [21]:
from collections import OrderedDict

class AE_classifier(nn.Module):
    '''
    Classifier takes latent code and performs predictions using the latent code

    Applications:
        Upstream feature extraction for memory-constraint classifier 
    '''


    def __init__(self,latent_dim, dropout=0.2, compression_factor=3):
        '''
        Compression factor detemines the intermitten dimension reduction factor of the dense network.
        '''
        super().__init__()
        self.latent_dim = latent_dim 
        self.do_rate = dropout

        dense_layer_1_output = int(latent_dim / compression_factor)
        dense_layer_2_output = int(dense_layer_1_output / compression_factor)


        self.layers = torch.nn.Sequential(
            OrderedDict([
            ("DenseLayer1", nn.Linear(latent_dim, dense_layer_1_output) ) , 
            ("relu1", nn.ReLU(inplace=True)),
            ("dropout1", nn.Dropout(self.do_rate)),
            ("DenseLayer2", nn.Linear(dense_layer_1_output, dense_layer_2_output) ) , 
            ("relu2", nn.ReLU(inplace=True)),
            ("dropout2", nn.Dropout(self.do_rate)),
            ("DenseLayer3", nn.Linear(dense_layer_2_output, 1) ),
            ])
            )
        
    def forward(self, x):
        encoder.to("cuda")
        x_encoded, _ , _ , _ = encoder.encode(x)
        logits = self.layers(x_encoded)
        return logits

## Autoencoder was fit using training and validation datasets. 
### Will evaluate performance of representation  / manifold learnt by freezing autoencoder layers an training a classification head

In [None]:
from sklearn.model_selection import train_test_split 
from pathlib import Path 
import pandas as pd 


def create_temporary_fitting_set():
    data_out_path  = Path(os.environ['DATA_ROOT']) / "ae_data"
    data_out_path.mkdir(exist_ok=True, parents=True)
    fitting_set_path = data_out_path.parent / "test.csv"
    train_test_set, val_set = train_test_split(pd.read_csv(fitting_set_path))
    train_set, test_set  = train_test_split(train_test_set)
    print(f"Training set contains: {train_set.shape[0]}" )
    print(f"Validation set contains: {val_set.shape[0]}" )
    print(f"Testing set contains: {test_set.shape[0]}" )
    print(f"Saving to directory: {data_out_path}")
    for (fname, df) in [("train.csv", train_set), ("val.csv", val_set), ("test.csv", test_set)]:
        df.to_csv(data_out_path / fname, index=False)

create_temporary_fitting_set()

In [22]:
from wwv.data import AudioDataModule
from wwv.util import CallbackCollection
import wwv.config as cfg
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from wwv.eval import Metric
from pathlib import Path 
from torch.optim.lr_scheduler import ReduceLROnPlateau


cfg_model = cfg.AEClassifier()
cfg_fitting = cfg.Fitting(max_epoch= 50, es_patience=10)
cfg_signal = cfg.Signal()
cfg_feature = cfg.Feature()

data_out_path  = str(Path(os.environ['DATA_ROOT']) / "ae_data")

data_path = cfg.DataPath(data_out_path, cfg_model.model_name, cfg_model.model_dir)
data_module = AudioDataModule(data_path.root_data_dir, cfg_model=cfg_model, cfg_feature=cfg_feature, cfg_fitting=cfg_fitting)


logger = TensorBoardLogger(save_dir=data_path.model_dir, version=1, name="lightning_logs")

train_loader =  data_module.train_dataloader()
val_loader =  data_module.val_dataloader()
test_loader =  data_module.test_dataloader()

# get input shape for onnx exporting
input_shape = data_module.input_shape
# init model
model = AE_classifier(latent_dim=1024)

model.to("cuda")


class Routine(pl.LightningModule):

    def __init__(self, model, cfg_fitting, cfg_model, localization=False):
        super().__init__()
        self.model = model
        self.metric = Metric
        self.cfg_fitting = cfg_fitting
        self.cfg_model = cfg_model
        self.localization = localization
        self.lr = 1e-3


    def forward(self, x):
        y_hat = self.model(x)
        return y_hat



    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch['x']
        y = batch['y']
        y_hat = self(x) 
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        y_hat = (F.sigmoid(y_hat) > 0.5).float()

        metrics = self.metric(y_hat, y)()
        return {"loss":loss, "train_ttr": metrics.ttr, "train_ftr": metrics.ftr, "train_acc": metrics.acc}


    def training_epoch_end(self, training_step_outputs):

        results = {
            "loss": torch.tensor([x['loss'].item() for x in training_step_outputs]).mean(),
            "ttr": torch.tensor([x['train_ttr'] for x in training_step_outputs]).mean(),
            "ftr": torch.tensor([x['train_ftr'] for x in training_step_outputs]).mean(),
            "acc": torch.tensor([x['train_acc'] for x in training_step_outputs]).mean()
            }
        # self.log(f"LR",self.lr, on_epoch=True, prog_bar=True, logger=True)
        for (k,v) in results.items():
        
            self.log(f"train_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def validation_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        y_hat = self(x)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        pred = F.sigmoid(y_hat)
        y_hat = (pred > 0.5).float()
        metrics = self.metric(y_hat, y)()
        return {"val_loss": loss, "val_ttr": metrics.ttr, "val_ftr": metrics.ftr, "val_acc": metrics.acc}


    def validation_epoch_end(self, validation_step_outputs):
        # from pprint import pprint 
        # pprint(validation_step_outputs)
        results = {
            "loss": torch.tensor([x['val_loss'].item() for x in validation_step_outputs]).mean(),
            "ttr": torch.tensor([x['val_ttr'] for x in validation_step_outputs]).mean(),
            "ftr": torch.tensor([x['val_ftr'] for x in validation_step_outputs]).mean(),
            "acc": torch.tensor([x['val_acc'] for x in validation_step_outputs]).mean()
            }
        for (k,v) in results.items():
            self.log(f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def test_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        y_hat = self(x)
        # (batch, num_classes)
        y_hat = y_hat.squeeze()
        # (batch,)
        pred = F.sigmoid(y_hat)
        # (batch_probabilities,)
        y_hat = (pred > 0.5).float()
        # (batch_labels,)
        metrics = self.metric(y_hat, y)()
        return {"test_ttr": metrics.ttr, "test_ftr": metrics.ftr, "test_acc": metrics.acc}


    def test_epoch_end(self, test_step_outputs):
        results = {
            "ttr": torch.tensor([x['test_ttr'] for x in test_step_outputs]).mean(),
            "ftr": torch.tensor([x['test_ftr'] for x in test_step_outputs]).mean(),
            "acc": torch.tensor([x['test_acc'] for x in test_step_outputs]).mean()
            }

        for (k,v) in results.items():
            self.log(f"test_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def configure_optimizers(self):
    
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr = self.lr, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05, 
        )
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
        return  {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 


routine = Routine(model, cfg_fitting, cfg_model)
# Init a trainer to execute routine

trainer = Trainer(
    accelerator="gpu",
    devices=1,
    sync_batchnorm=True,
    logger=logger,
    max_epochs=cfg_fitting.max_epoch,
    callbacks=CallbackCollection(cfg_fitting, data_path)(),
    gradient_clip_val=1.0,
    fast_dev_run=cfg_fitting.fast_dev_run,
)

trainer.fit(
    routine, train_dataloaders=train_loader, val_dataloaders=val_loader
)
trainer.test(dataloaders=test_loader)


FileNotFoundError: [Errno 2] No such file or directory: '/media/akinwilson/Samsung_T5/data/audio/keyword-spotting/ae_data'

### Research: Semi-supervised loss : combining the reconstruction and binary cross entropy loss

- Combine all available data - such that we have the form:

    D = (x, y)
    
    Where the y $\in$ $\{0,1,2\}$:
- Let the loss function be a **piece-wise function** on the domain of the target 

    
$$ Loss(x, x_{recon} y,\hat{y})=   \left\{
\begin{array}{ll}
      MSE(x, x_{recon}), \text{    }  y \in \{2 \} \\
      MSE(x, x_{recon}) + BinaryCrossEntropy(y,\hat{y}), \text{    }  y \in \{0,1\} \\
\end{array} 
\right.  $$

We consider then the target label of 2 to be the unknown target label. 

I need to make sure that both the encoder, decoder and classifier head are updated for the  case of $y \in \{0,1\}$; I am not too sure if this is happening at the moment


In [23]:
import pandas as pd 
import numpy as np 
from sklearn.model_selection import train_test_split
from pathlib import Path 
from collections import OrderedDict
from pytorch_lightning.loggers import TensorBoardLogger
from wwv.data import AudioDataModule
from wwv.util import CallbackCollection
import wwv.config as cfg
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
import bisect
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from wwv.eval import Metric
from pathlib import Path 
from torch.optim.lr_scheduler import ReduceLROnPlateau

def semi_supervised_conversion(df, unknown_label_fraction=0.9,  unknown_target_label=2.):

    sample_size = df.shape[0]
    unknown_label_sample_size = int( sample_size * unknown_label_fraction)  

    print(f"Corrupting {unknown_label_sample_size} samples out of {sample_size}")
    indices = np.random.choice(np.arange(0, sample_size,1), size=unknown_label_sample_size, replace=False)
    df.loc[indices,'label'] = 2.
    return df 

def create_temporary_semi_superivsed_fitting_set(unknown_label_fraction=0.9,  unknown_target_label=2.):
    data_out_path  = Path(os.environ['DATA_ROOT']) / "ss_data"
    data_out_path.mkdir(exist_ok=True, parents=True)
    FILES = ['train.csv', 'val.csv', 'test.csv']
    fitting_set_paths = [str(data_out_path.parent / fname) for fname in FILES]
    df = pd.concat([pd.read_csv(file_path) for file_path in fitting_set_paths])
    df.reset_index(drop=True, inplace=True )
    df = semi_supervised_conversion(df, unknown_label_fraction,  unknown_target_label)
    print("New target distribution")
    print(df.label.value_counts())
    # return df 
    train_test_set, val_set = train_test_split(df)
    train_set, test_set  = train_test_split(train_test_set)
    print(f"Training set contains: {train_set.shape[0]}" )
    print(f"Validation set contains: {val_set.shape[0]}" )
    print(f"Testing set contains: {test_set.shape[0]}" )
    print(f"Saving to directory: {data_out_path}")
    for (fname, df) in [("train.csv", train_set), ("val.csv", val_set), ("test.csv", test_set)]:
        df.to_csv(data_out_path / fname, index=False)

create_temporary_semi_superivsed_fitting_set()

data_out_path  = Path(os.environ['DATA_ROOT']) / "ss_data"

Corrupting 85368 samples out of 94854
New target distribution
2.0    85368
0.0     6129
1.0     3357
Name: label, dtype: int64
Training set contains: 53355
Validation set contains: 23714
Testing set contains: 17785
Saving to directory: /media/akinwilson/Samsung_T5/data/audio/keyword-spotting/ss_data


In [24]:

class DenseClassifier(nn.Module):
    def __init__(self,latent_dim, dropout=0.2, compression_factor=3):
        super().__init__()
        self.latent_dim = latent_dim 
        self.do_rate = dropout

        dense_layer_1_output = int(latent_dim / compression_factor)
        dense_layer_2_output = int(dense_layer_1_output / compression_factor)


        self.layers = torch.nn.Sequential(
            OrderedDict([
            ("DenseLayer1", nn.Linear(latent_dim, dense_layer_1_output) ) , 
            ("relu1", nn.ReLU(inplace=True)),
            ("dropout1", nn.Dropout(self.do_rate)),
            ("DenseLayer2", nn.Linear(dense_layer_1_output, dense_layer_2_output) ) , 
            ("relu2", nn.ReLU(inplace=True)),
            ("dropout2", nn.Dropout(self.do_rate)),
            ("DenseLayer3", nn.Linear(dense_layer_2_output, 1) ),
            ])
            )
        
    def forward(self, x):
        logits = self.layers(x)
        return logits


class SS_CNNAE(nn.Module):
    def __init__(self, n_input=1, latent_dim=1024, stride=16, n_channel=32):
        super().__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.n_channel = n_channel
        # encoder layers 
        self.e_conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.e_bn1 = nn.BatchNorm1d(n_channel)
        self.e_pool1 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.e_bn2 = nn.BatchNorm1d(n_channel)
        self.e_pool2 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool3 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool4 = nn.MaxPool1d(2, return_indices=True)
        self.e_fc4 = nn.Linear(2 * n_channel * 28, latent_dim)
        # decoder layers 
        self.d_fc4 = nn.Linear(latent_dim, 2 * n_channel * 28)
        self.d_pool4 = nn.MaxUnpool1d(2)
        self.d_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv4 = nn.ConvTranspose1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.d_pool3 = nn.MaxUnpool1d(4)
        self.d_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv3 = nn.ConvTranspose1d(2 * n_channel, n_channel, kernel_size=3)
        self.d_pool2 = nn.MaxUnpool1d(4)
        self.d_bn2 = nn.BatchNorm1d(n_channel)
        self.d_conv2 = nn.ConvTranspose1d(n_channel, n_channel, kernel_size=3)
        self.d_pool1 = nn.MaxUnpool1d(4)
        self.d_bn1 = nn.BatchNorm1d(n_channel)
        self.d_conv1 = nn.ConvTranspose1d(n_channel, n_input, kernel_size=80, stride=stride)

        self.classifier = DenseClassifier(latent_dim)


    def encode(self, x):
        x = self.e_conv1(x)
        x = F.relu(self.e_bn1(x))
        x, idx1 = self.e_pool1(x)
        x = self.e_conv2(x)
        x = F.relu(self.e_bn2(x))
        x, idx2 = self.e_pool2(x)
        x = self.e_conv3(x)
        x = F.relu(self.e_bn3(x))
        x, idx3  = self.e_pool3(x)
        x = self.e_conv4(x)
        x = F.relu(self.e_bn4(x))
        x = x.view(x.shape[0], -1)
        x = self.e_fc4(x)
        return idx1, idx2, idx3, x

    def classify(self, x):
        _, _, _, x_encoded = self.encode(x)
        logits = self.classifier(x_encoded)
        return logits 



    def decode(self, idx1, idx2, idx3, x):
        bs = x.shape[0]
        x = self.d_fc4(x)
        x = x.view(bs, 2 * self.n_channel,  28)
        x = F.relu(self.d_bn4(x))
        x = self.d_conv4(x)
        x = self.d_pool3(x, idx3)
        x = F.relu(self.d_bn3(x))
        x = self.d_conv3(x)
        padding = idx2.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding),device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool2(x, idx2)
        x = F.relu(self.d_bn2(x))
        x = self.d_conv2(x)
        padding = idx1.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding), device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool1(x, idx1)
        x = F.relu(self.d_bn1(x))
        x = self.d_conv1(x)
        return x


    def forward(self, x):
        '''Autoencoding forward method'''
        idx1, idx2, idx3, encoded_x = self.encode(x)
        logits = self.classify(x)
        decoded_x = self.decode(idx1, idx2, idx3,encoded_x)

        return decoded_x, logits # mse and binary cross entropy inputs




def semi_supervised_loss(x=None, x_recon=None,y=None, y_hat=None):
    bs = x.shape[0]
    losses = []
    for idx in range(bs):
        if y[idx] == 2.:
            loss = F.mse_loss(x[idx],x_recon[idx])
        else: 
            loss = F.binary_cross_entropy_with_logits(y_hat[idx], y[idx]) + F.mse_loss(x[idx],x_recon[idx])
        losses.append(loss)


    return torch.stack(losses).mean()



class Routine(pl.LightningModule):

    def __init__(self, model, cfg_fitting, cfg_model):
        super().__init__()
        self.model = model
        self.metric = Metric
        self.cfg_fitting = cfg_fitting
        self.cfg_model = cfg_model
        self.lr = 1e-3


    def forward(self, x, y):
        x_recon, logits = self.model(x)
        return x_recon, logits



    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch['x']
        y = batch['y']
        x_recon, logits = self(x, y) 
        y_hat = logits.squeeze()

        loss = semi_supervised_loss(x, x_recon, y, y_hat)

        y_hat = (F.sigmoid(y_hat) > 0.5).float()

        metrics = self.metric(y_hat, y)()
        return {"loss":loss, "train_ttr": metrics.ttr, "train_ftr": metrics.ftr, "train_acc": metrics.acc}


    def training_epoch_end(self, training_step_outputs):

        results = {
            "loss": torch.tensor([x['loss'].mean().item() for x in training_step_outputs]).mean(),
            "ttr": torch.tensor([x['train_ttr'] for x in training_step_outputs]).mean(),
            "ftr": torch.tensor([x['train_ftr'] for x in training_step_outputs]).mean(),
            "acc": torch.tensor([x['train_acc'] for x in training_step_outputs]).mean()
            }
        # self.log(f"LR",self.lr, on_epoch=True, prog_bar=True, logger=True)
        for (k,v) in results.items():
        
            self.log(f"train_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def validation_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        x_recon, logits = self(x, y) 
        y_hat = logits.squeeze()
        
        loss = semi_supervised_loss(x, x_recon, y, y_hat)
        print("y", y)
        print("y_hat", y_hat)
        print("x_recon", x_recon.shape)
        pred = F.sigmoid(y_hat)
        y_hat = (pred > 0.5).float()
        metrics = self.metric(y_hat, y)()
        return {"val_loss": loss, "val_ttr": metrics.ttr, "val_ftr": metrics.ftr, "val_acc": metrics.acc}


    def validation_epoch_end(self, validation_step_outputs):
        # from pprint import pprint 
        # pprint(validation_step_outputs)
        results = {
            "loss": torch.tensor([x['val_loss'].mean().item() for x in validation_step_outputs]).mean(),
            "ttr": torch.tensor([x['val_ttr'] for x in validation_step_outputs]).mean(),
            "ftr": torch.tensor([x['val_ftr'] for x in validation_step_outputs]).mean(),
            "acc": torch.tensor([x['val_acc'] for x in validation_step_outputs]).mean()
            }
        for (k,v) in results.items():
            self.log(f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def test_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        _, logits = self(x, y) 
        y_hat = logits.squeeze()
        # (batch,)
        pred = F.sigmoid(y_hat)
        # (batch_probabilities,)
        y_hat = (pred > 0.5).float()
        # (batch_labels,)
        metrics = self.metric(y_hat, y)()
        return {"test_ttr": metrics.ttr, "test_ftr": metrics.ftr, "test_acc": metrics.acc}


    def test_epoch_end(self, test_step_outputs):
        results = {
            "ttr": torch.tensor([x['test_ttr'] for x in test_step_outputs]).mean(),
            "ftr": torch.tensor([x['test_ftr'] for x in test_step_outputs]).mean(),
            "acc": torch.tensor([x['test_acc'] for x in test_step_outputs]).mean()
            }

        for (k,v) in results.items():
            self.log(f"test_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def configure_optimizers(self):
        
        # for normal models CNNs etc. 
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr = self.lr, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05, 
        )
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
        return  {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 
        


model = SS_CNNAE()
cfg_model = cfg.SSCNNAE()


cfg_fitting = cfg.Fitting(max_epoch= 50, es_patience=10)
cfg_signal = cfg.Signal()
cfg_feature = cfg.Feature()

data_out_path  = str(Path(os.environ['DATA_ROOT']) / "ss_data")

data_path = cfg.DataPath(data_out_path, cfg_model.model_name, cfg_model.model_dir)
data_module = AudioDataModule(data_path.root_data_dir, cfg_model=cfg_model, cfg_feature=cfg_feature, cfg_fitting=cfg_fitting)


logger = TensorBoardLogger(save_dir=data_path.model_dir, version=1, name="lightning_logs")




routine = Routine(model, cfg_fitting, cfg_model)
# Init a trainer to execute routine

trainer = Trainer(
    accelerator="gpu",
    devices=1,
    sync_batchnorm=True,
    logger=logger,
    max_epochs=cfg_fitting.max_epoch,
    callbacks=CallbackCollection(cfg_fitting, data_path)(),
    gradient_clip_val=1.0,
    fast_dev_run=cfg_fitting.fast_dev_run,
)

trainer.fit(
    routine, train_dataloaders=train_loader, val_dataloaders=val_loader
)
trainer.test(dataloaders=test_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type     | Params
-----------------------------------
0 | model | SS_CNNAE | 4.1 M 
-----------------------------------
4.1 M     Trainable params
0         Non-trainable params
4.1 M     Total params
16.441    Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

KeyError: 'y'

### Generative DL: conditional variational autoencoder 

In [25]:


def create_synthesis_fitting_set():
    data_out_path  = Path(os.environ['DATA_ROOT']) / "synthesis_data"
    data_out_path.mkdir(exist_ok=True, parents=True)
    FILES = ['train.csv', 'val.csv', 'test.csv']
    fitting_set_paths = [str(data_out_path.parent / fname) for fname in FILES]
    df = pd.concat([pd.read_csv(file_path) for file_path in fitting_set_paths])
    df.reset_index(drop=True, inplace=True )

    df = df[df.label == 1.0]
    print(df.label.value_counts())
    df_fit_data = df[['label', 'wav_path']]
    conditonal_generation_cols = ['annotated_age', 'annotated_voice_type']
    df = df[conditonal_generation_cols]
    print("Age distribution")
    df.annotated_age.value_counts() # head()
    print("Gender distribution")
    df.annotated_voice_type.value_counts() # head()

    df_1h = pd.get_dummies(df, columns = conditonal_generation_cols)
    df_fitting = pd.concat([df_fit_data , df_1h], axis=1, join='inner')

    train_test_set, val_set = train_test_split(df_fitting)
    train_set, test_set  = train_test_split(train_test_set)
    print(f"Training set contains: {train_set.shape[0]}" )
    print(f"Validation set contains: {val_set.shape[0]}" )
    print(f"Testing set contains: {test_set.shape[0]}" )
    print(f"Saving to directory: {data_out_path}")
    for (fname, df) in [("train.csv", train_set), ("val.csv", val_set), ("test.csv", test_set)]:
        df.to_csv(data_out_path / fname, index=False)
    return test_set


df= create_synthesis_fitting_set()
df.reset_index(inplace=True, drop=True)
df.head()
# cols= ['annotated_quality', 'annotated_age', 'annotated_voice_type']
# df = pd.read_csv(path) # [cols]# .columns
# df = df[categorical_cols]
# #import pandas as pd
# df = pd.get_dummies(df, columns = categorical_cols)
# df.head().to_numpy().shape

1.0    33064
Name: label, dtype: int64
Age distribution
Gender distribution
Training set contains: 18598
Validation set contains: 8266
Testing set contains: 6200
Saving to directory: /media/akinwilson/Samsung_T5/data/audio/keyword-spotting/synthesis_data


Unnamed: 0,label,wav_path,annotated_age_ADULT,annotated_age_KID,annotated_age_UNSURE,annotated_voice_type_FEMALE,annotated_voice_type_MALE,annotated_voice_type_UNKNOWN
0,1.0,/media/akinwilson/Samsung_T5/data/audio/keywor...,1,0,0,0,1,0
1,1.0,/media/akinwilson/Samsung_T5/data/audio/keywor...,1,0,0,1,0,0
2,1.0,/media/akinwilson/Samsung_T5/data/audio/keywor...,1,0,0,0,1,0
3,1.0,/media/akinwilson/Samsung_T5/data/audio/keywor...,0,0,1,0,0,1
4,1.0,/media/akinwilson/Samsung_T5/data/audio/keywor...,1,0,0,0,1,0


### Will conditional generate samples based on age and gender

In [None]:

one_hot_cateogrical_colum_names = [
    "annotated_age_ADULT",
    "annotated_age_KID",
    "annotated_age_UNSURE",
    "annotated_voice_type_FEMALE",
    "annotated_voice_type_MALE",
    "annotated_voice_type_UNKNOWN"
    ]

df.loc[1][one_hot_cateogrical_colum_names].to_numpy()

### Conditional generative variational autoencoder

#### architecture

In [97]:
class CVCNNAE(nn.Module):
    '''
    Conditional variational convoultional neural network auto encoder
    '''

    def __init__(self,input_size=32000, n_input=1, latent_dim=1024, stride=16, n_channel=32 ,labels_length=6):
        super().__init__()
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.n_channel = n_channel
        # encoder layers 

        self.input_size = input_size
        self.input_size_w_labels =  input_size + labels_length
        self.latent_dim = latent_dim
        self.latent_dim_w_labels = latent_dim + labels_length



        self.e_conv1 = nn.Conv1d(n_input, n_channel, kernel_size=80, stride=stride)
        self.e_bn1 = nn.BatchNorm1d(n_channel)
        self.e_pool1 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv2 = nn.Conv1d(n_channel, n_channel, kernel_size=3)
        self.e_bn2 = nn.BatchNorm1d(n_channel)
        self.e_pool2 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv3 = nn.Conv1d(n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool3 = nn.MaxPool1d(4, return_indices=True)
        self.e_conv4 = nn.Conv1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.e_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.e_pool4 = nn.MaxPool1d(2, return_indices=True)

        self.e_fc4_mean = nn.Linear(2 * n_channel * 28, latent_dim)
        self.e_fc4_var = nn.Linear(2 * n_channel * 28, latent_dim)

        # decoder layers 
        self.d_fc4 = nn.Linear(self.latent_dim_w_labels, 2 * n_channel * 28)
        self.d_pool4 = nn.MaxUnpool1d(2)
        self.d_bn4 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv4 = nn.ConvTranspose1d(2 * n_channel, 2 * n_channel, kernel_size=3)
        self.d_pool3 = nn.MaxUnpool1d(4)
        self.d_bn3 = nn.BatchNorm1d(2 * n_channel)
        self.d_conv3 = nn.ConvTranspose1d(2 * n_channel, n_channel, kernel_size=3)
        self.d_pool2 = nn.MaxUnpool1d(4)
        self.d_bn2 = nn.BatchNorm1d(n_channel)
        self.d_conv2 = nn.ConvTranspose1d(n_channel, n_channel, kernel_size=3)
        self.d_pool1 = nn.MaxUnpool1d(4)
        self.d_bn1 = nn.BatchNorm1d(n_channel)
        self.d_conv1 = nn.ConvTranspose1d(n_channel, n_input, kernel_size=80, stride=stride)
    


    def encode(self, x, labels):
        labels.unsqueeze_(dim=1)
        x_inputs =  torch.cat([x, labels], axis=2) 

    
        x = self.e_conv1(x_inputs)

        x = F.relu(self.e_bn1(x))
        x, idx1 = self.e_pool1(x)
        x = self.e_conv2(x)
        x = F.relu(self.e_bn2(x))
        x, idx2 = self.e_pool2(x)
        x = self.e_conv3(x)
        x = F.relu(self.e_bn3(x))
        x, idx3  = self.e_pool3(x)
        x = self.e_conv4(x)
        x = F.relu(self.e_bn4(x))
        x = x.view(x.shape[0], -1)
        x_mean = self.e_fc4_mean(x)
        x_var = self.e_fc4_var(x)
        return idx1, idx2, idx3, x_mean, x_var 

    def reparameterization_trick(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + eps*std


    def decode(self, idx1, idx2, idx3, z, labels):
        
        # print("z.shape", z.shape)
        # print("labels.shape", labels.shape)

        labels.squeeze_(dim=1)
        x_inputs =  torch.cat([z, labels], axis=1) 
        bs = x_inputs.shape[0]
        x = self.d_fc4(x_inputs)
        x = x.view(bs, 2 * self.n_channel,  28)
        x = F.relu(self.d_bn4(x))
        x = self.d_conv4(x)
        x = self.d_pool3(x, idx3)
        x = F.relu(self.d_bn3(x))
        x = self.d_conv3(x)
        padding = idx2.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding),device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool2(x, idx2)
        x = F.relu(self.d_bn2(x))
        x = self.d_conv2(x)
        padding = idx1.shape[2] - x.shape[2] 
        pad = torch.zeros((bs,32, padding), device=self.device)
        x = torch.cat([x,pad],dim=2)
        x = self.d_pool1(x, idx1)
        x = F.relu(self.d_bn1(x))
        x = self.d_conv1(x)
        return x


    def forward(self, x, labels):
        idx1, idx2, idx3,  z_log_mean, z_log_var = self.encode(x, labels)

        z  = self.reparameterization_trick( z_log_mean, z_log_var)
        decoded_x = self.decode(idx1, idx2, idx3, z, labels)
        return x, decoded_x, z_log_mean, z_log_var


#### Data preparation

In [99]:
from wwv.routine import Routine
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torchaudio as ta 
import torch
import wwv.config as cfg 
from pathlib import Path
# import torchaudio 
# import logging

one_hot_cateogrical_colum_names = [
    "annotated_age_ADULT",
    "annotated_age_KID",
    "annotated_age_UNSURE",
    "annotated_voice_type_FEMALE",
    "annotated_voice_type_MALE",
    "annotated_voice_type_UNKNOWN"
    ]




class Scaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("int16_max", torch.tensor([32767]).float())
        # self.cfg = cfg 


    def forward(self, x:torch.tensor):

        x_scaled = x / self.int16_max
        return x_scaled 



class DataCollator:

    def __call__(self, batch):
        x = [ x for (x,_) in batch ]
        cats = [ cats for (_,cats) in batch ]

        x_batched = torch.stack(x).float()
        y_batched = torch.stack(cats).float()
        return {
        "x": x_batched,
        "cats": y_batched
        }



class Padder:
    def __call__(self, x:torch.tensor) -> torch.tensor:
        padding = torch.tensor([0.0]).repeat(1,32000 - x.size()[-1])
        x_new = torch.hstack([x, padding])
        x_new = x_new.to(device) 
        return x_new # (1 ,1 , pad_to_len)



class Scaler(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer("int16_max", torch.tensor([32767]).float())
        # self.cfg = cfg 


    def forward(self, x:torch.tensor):

        x_scaled = x / self.int16_max
        return x_scaled 



class AudioDataset(Dataset):
    def __init__(self,
                df_path,
                cfg_model,
                cfg_feature):
        self.df = pd.read_csv(df_path)

        self.x_pad = Padder()
        self.x_scale = Scaler()
        self.cfg_model = cfg_model
        self.cfg_feature = cfg_feature


    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        cats_1h = self.df.loc[idx][one_hot_cateogrical_colum_names].to_numpy()
        cats_1h = cats_1h.astype("int64")
        x_path = self.df.loc[idx]['wav_path']
        cats_1h_tensor =  torch.tensor(cats_1h, device=device)
        x,_ = ta.load(x_path)
        x = self.x_scale(x)
        x = self.x_pad(x)
        return x,cats_1h_tensor




class AudioDataModule():
    def __init__(self,df_path, cfg_model, cfg_fitting, cfg_feature):
        super().__init__()

        # the DataPath data class makes sure the files below are present on init in the root directory. 
        self.train_df_path = df_path  + "/train.csv"
        self.val_df_path =  df_path  + "/val.csv"
        self.test_df_path =  df_path  + "/test.csv"

        self.cfg_model = cfg_model

        self.cfg_fitting = cfg_fitting
        self.cfg_feature = cfg_feature
        self.pin_memory =  False # True if torch.cuda.is_available() else False 
        

    def train_dataloader(self):
        ds_train = AudioDataset(df_path=self.train_df_path,cfg_model= self.cfg_model,  cfg_feature=self.cfg_feature) # apply_augmentation)
        return DataLoader(ds_train,
                          batch_size=self.cfg_fitting.train_bs,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= DataCollator())

    
    
    def val_dataloader(self):
        ds_val = AudioDataset(df_path=self.val_df_path,  cfg_model= self.cfg_model, cfg_feature=self.cfg_feature)
        return  DataLoader(ds_val,
                          batch_size=self.cfg_fitting.val_bs,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= DataCollator())
    
    
    def test_dataloader(self):
        ds_test = AudioDataset(df_path=self.test_df_path,cfg_model= self.cfg_model, cfg_feature=self.cfg_feature)
        return  DataLoader(ds_test,
                          batch_size=self.cfg_fitting.test_bs,
                          shuffle=True,
                          drop_last=True,
                          pin_memory= self.pin_memory,
                          collate_fn= DataCollator())



model = CVCNNAE()
cfg_model = cfg.CVCNNAE()

cfg_fitting = cfg.Fitting(max_epoch= 50, es_patience=10)
cfg_signal = cfg.Signal()
cfg_feature = cfg.Feature()

data_out_path  = str(Path(os.environ['DATA_ROOT']) / "synthesis_data")
data_path = cfg.DataPath(data_out_path, cfg_model.model_name, cfg_model.model_dir)
data_module = AudioDataModule(data_path.root_data_dir, cfg_model=cfg_model, cfg_feature=cfg_feature, cfg_fitting=cfg_fitting)

train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()
test_loader = data_module.test_dataloader()

#### Fitting routine definition

In [102]:
import pytorch_lightning as pl 
from torch.optim.lr_scheduler import ReduceLROnPlateau

class Routine(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        self.model = model
        self.lr = 1e-3


    def forward(self, x, cats):
        x, decoded_x, z_log_mean, z_log_var = self.model(x, labels=cats)
        return x, decoded_x, z_log_mean, z_log_var



    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss_function(self, recon_x, x, logmean, logvar):
        MSE = F.mse_loss(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - logmean.pow(2) - logvar.exp())
        return MSE + KLD


    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch['x']
        cats = batch['cats']
        
        
        x, decoded_x, z_log_mean, z_log_var = self(x, cats) 
        ########################################################################################################################
        loss =  self.loss_function(x, decoded_x, z_log_mean, z_log_var)
        ########################################################################################################################
        return {"loss":loss}


    def training_epoch_end(self, training_step_outputs):
        results = {
            "loss": torch.teepsnsor([x['loss'].mean().item() for x in training_step_outputs]).mean(),
            }
        # self.log(f"LR",self.lr, on_epoch=True, prog_bar=True, logger=True)
        for (k,v) in results.items():
            self.log(f"train_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def validation_step(self, batch, batch_idx):
        x = batch['x']
        cats = batch['cats']
        x, decoded_x, z_log_mean, z_log_var  = self(x, cats) 
        ########################################################################################################################
        loss =  self.loss_function(x, decoded_x, z_log_mean, z_log_var)
        ########################################################################################################################        
        return {"val_loss": loss}



    def validation_epoch_end(self, validation_step_outputs):
        # from pprint import pprint 
        # pprint(validation_step_outputs)
        results = {
            "loss": torch.tensor([x['val_loss'].mean().item() for x in validation_step_outputs]).mean(),
            }
        for (k,v) in results.items():
            self.log(f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def configure_optimizers(self):
        
        # for normal models CNNs etc. 
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()),
            lr = self.lr, 
            betas = (0.9, 0.999), eps = 1e-08, weight_decay = 0.05, 
        )
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
        return  {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 
        

#### Fitting model

In [103]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from wwv.util import CallbackCollection

routine = Routine(model)
# Init a trainer to execute routine
model = CVCNNAE()
logger = TensorBoardLogger(save_dir=data_path.model_dir, version=1, name="lightning_logs")
trainer = Trainer(
    accelerator="gpu",
    devices=1,
    sync_batchnorm=True,
    logger=logger,
    max_epochs=cfg_fitting.max_epoch,
    callbacks=CallbackCollection(cfg_fitting, data_path)(),
    gradient_clip_val=1.0,
    fast_dev_run=cfg_fitting.fast_dev_run,
)

trainer.fit(
    routine, train_dataloaders=train_loader, val_dataloaders=val_loader
)

trainer.test(dataloaders=test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type    | Params
----------------------------------
0 | model | CVCNNAE | 5.6 M 
----------------------------------
5.6 M     Trainable params
0         Non-trainable params
5.6 M     Total params
22.275    Total estimated model params size (MB)


Epoch 0:  15%|█▌        | 128/839 [00:48<04:27,  2.65it/s, loss=43, v_num=1]     