# Load architecture

In [None]:
from wwv.config import Config 
MODEL_DIR = "/home/akinwilson/Code/pytorch/output/model"
DATA_DIR = "/home/akinwilson/Code/pytorch/dataset/keywords"
LR_RANGE = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5][1]
BATCH_SIZE_RANGE = [1,2,32, 64, 128, 256][2]
EPOCH_RANGE = [1, 10, 30, 50, 100, 1000][1]
ES_PATIENCE_RANGE = [1, 10, 20, 100, 200][2]
MODELS = ["VecM5", "Resnet2vec1D","SpecResnet2D", "HSTAT", "DeepSpeech", "ResNet"][-1]
AUDIO_FEATURE_OPT = ["spectrogram", "mfcc", "pcm"][1]
PRETRAINED_MODEL_NAME_OR_PATH = "facebook/wav2vec2-base-960h"
AUGS = False

params = {
    "audio_duration":3,
    "sample_rate":16000,
    "model_name": MODELS,
    "verbose": False,
    "path": {
        "model_dir": MODEL_DIR,
        "data_dir": DATA_DIR,
        "pretrained_name_or_path": PRETRAINED_MODEL_NAME_OR_PATH
        },
    "fit_param": {"init_lr":LR_RANGE, "weight_decay":0.0001, "max_epochs":EPOCH_RANGE, "gamma": 0.1,"es_patience":ES_PATIENCE_RANGE}, 
    "data_param":{"train_batch_size": BATCH_SIZE_RANGE, "val_batch_size": BATCH_SIZE_RANGE,"test_batch_size": BATCH_SIZE_RANGE}, 
    "audio_feature": AUDIO_FEATURE_OPT,
    "audio_feature_param": { "mfcc":{"sr":16000,"n_mfcc":20,"norm": 'ortho',"verbose":True,"ref":1.0,"amin":1e-10,"top_db":80.0,"hop_length":512,},
                            "spectrogram":{"sr":16000, "n_fft":2048, "win_length":None,"n_mels":128,"hop_length":512,"window":'hann',"center":True,"pad_mode":'reflect',"power":2.0,"htk":False,"fmin":0.0,"fmax":None,"norm":1,"trainable_mel":False,"trainable_STFT":False,"verbose": True },
                            "pcm": {}},
    "augmentation":{'Gain': AUGS, 'PitchShift': AUGS, 'Shift': AUGS},
    "augmentation_param":{"Gain": {  "min_gain_in_db":-18.0,"max_gain_in_db":  6.0,"mode":'per_example',"p":1,"p_mode":'per_example'},
                        "PitchShift": {"min_transpose_semitones": -4.0, "max_transpose_semitones": 4.0,"mode":'per_example',"p":1,"p_mode":'per_example',"sample_rate":16000,"target_rate": None,"output_type": None,},
                        "Shift":{ "min_shift":-0.5,"max_shift": 0.5,"shift_unit":'fraction',"rollover": True,"mode":'per_example',"p":1,"p_mode": 'per_example',"sample_rate": 16000,"target_rate":None,"output_type":None}},
    }
cfg = Config(params)

In [None]:
import torch 
import torch.nn.functional as F 
from wwv.architecture import Architecture
from wwv.eval import Metric
from wwv.data import AudioDataModule
from wwv.config import  DataPaths
import statistics
# data_path = DataPaths(cfg.path['data_dir'], cfg.model_name, cfg.path['model_dir'])
cfg = Config(params)
# model = Architecture(cfg, training=True)
# model.extractor(torch.randn((1,48000))) # (torch.randn((1,48000)))
root = "/home/akinwilson/Code/pytorch/dataset/keywords"
# data_module = AudioDataModule(data_path.root_data_dir + "/train.csv",
#                               data_path.root_data_dir + "/val.csv",
#                               data_path.root_data_dir + "/test.csv",
#                                cfg)
# # model.processing_layer[3](x)
# train_loader=  data_module.train_dataloader()
# val_loader=  data_module.val_dataloader()

In [None]:
import torch
import torch.nn as  nn
import torch.nn.functional as F
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from nnAudio import features
import torchaudio 

In [None]:
from wwv.data import AudioDataModule
from wwv.config import DataPaths
import torchaudio 
cfg = Config(params)
data_path = DataPaths(cfg.path['data_dir'], cfg.model_name, cfg.path['model_dir'])
data_module = AudioDataModule(data_path.root_data_dir + "/train.csv",
                              data_path.root_data_dir + "/val.csv",
                              data_path.root_data_dir + "/test.csv",
                              cfg=cfg)

# x['x'].shape

In [None]:
from pytorch_lightning import Trainer
import pytorch_lightning as pl 
import torch.nn.functional as F 
from wwv.architecture import ResNet, Predictor, Bottleneck
from wwv.eval import Metric
from pytorch_lightning.loggers import TensorBoardLogger
from wwv.util import OnnxExporter

from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor, ModelPruning
from wwv.data import AudioDataModule
import torch 
from torch.optim.lr_scheduler import ReduceLROnPlateau


data_path = DataPaths(cfg.path['data_dir'], cfg.model_name, cfg.path['model_dir'])

cfg = Config(params)
# model = Architecture(cfg, training=True)
# model.extractor(torch.randn((1,48000))) # (torch.randn((1,48000)))
# model = Architecture(cfg, True)
data_module = AudioDataModule(data_path.root_data_dir + "/train.csv",
                              data_path.root_data_dir + "/val.csv",
                              data_path.root_data_dir + "/test.csv",
                              cfg=cfg)
                              
train_loader =  data_module.train_dataloader()
val_loader =  data_module.val_dataloader()
test_loader =  data_module.test_dataloader()
# model.processing_layer[3](x)


class Routine(pl.LightningModule):

    
    def __init__(self, model, cfg):
        super().__init__()
        self.model = model
        self.metric = Metric
        self.cfg = cfg
        self.lr = 1e-3


    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x = batch['x']
        y = batch['y']
        y_hat = self.model(x)
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        y_hat = (F.sigmoid(y_hat) > 0.5).float()

        metrics = self.metric(y_hat, y, self.cfg)()
        return {"loss":loss, "train_ttr": metrics.ttr, "train_ftr": metrics.ftr, "train_acc": metrics.acc}


    def training_epoch_end(self, training_step_outputs):
        results = {
            "loss": torch.tensor([x['loss'].float().item() for x in training_step_outputs]).mean(),
            "ttr": torch.tensor([x['train_ttr'].float().mean().item() for x in training_step_outputs]).mean(),
            "ftr": torch.tensor([x['train_ftr'].float().mean().item() for x in training_step_outputs]).mean(),
            "acc": torch.tensor([x['train_acc'].float().mean().item() for x in training_step_outputs]).mean()
            }
        for (k,v) in results.items():
            self.log(f"train_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def validation_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        y_hat = self.model(x)
        y_hat = y_hat.squeeze()
        loss = F.binary_cross_entropy_with_logits(y_hat, y)
        pred = F.sigmoid(y_hat)
        y_hat = (pred > 0.5).float()
        metrics = self.metric(y_hat, y,self.cfg)()
        return {"val_loss": loss, "val_ttr": metrics.ttr, "val_ftr": metrics.ftr, "val_acc": metrics.acc}


    def validation_epoch_end(self, validation_step_outputs):
        results = {
            "loss": torch.tensor([x['val_loss'].float().mean().item() for x in validation_step_outputs]).mean(),
            "ttr": torch.tensor([x['val_ttr'].float().mean().item() for x in validation_step_outputs]).mean(),
            "ftr": torch.tensor([x['val_ftr'].float().mean().item() for x in validation_step_outputs]).mean(),
            "acc": torch.tensor([x['val_acc'].float().mean().item() for x in validation_step_outputs]).mean()
            }
        for (k,v) in results.items():
            self.log(f"val_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def test_step(self, batch, batch_idx):
        x = batch['x']
        y = batch['y']
        y_hat = self.model(x)
        y_hat = y_hat.squeeze()
        pred = F.sigmoid(y_hat)
        y_hat = (pred > 0.5).float()
        metrics = self.metric(y_hat, y, self.cfg)()
        return {"test_ttr": metrics.ttr, "test_ftr": metrics.ftr, "test_acc": metrics.acc}


    def test_epoch_end(self, test_step_outputs):
        results = {
            "ttr": torch.tensor([x['test_ttr'].float().mean().item() for x in test_step_outputs]).mean(),
            "ftr": torch.tensor([x['test_ftr'].float().mean().item() for x in test_step_outputs]).mean(),
            "acc": torch.tensor([x['test_acc'].float().mean().item() for x in test_step_outputs]).mean()
            }

        for (k,v) in results.items():
            self.log(f"test_{k}", v, on_epoch=True, prog_bar=True, logger=True)    


    def configure_optimizers(self):
        
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
        return  {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "val_loss"} 



lr_monitor = LearningRateMonitor(logging_interval='epoch')
early_stopping = EarlyStopping(mode="min", monitor='val_loss', patience=25)
checkpoint_callback = ModelCheckpoint(monitor="val_loss",
                                        dirpath=data_path.model_dir,
                                        save_top_k=1,
                                        mode="min",
                                        filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}-{val_ttr:.2f}-{val_ftr:.2f}')

model = ResNet(block=Bottleneck, num_blocks=[8, 8, 36, 3], cfg=cfg)
callbacks = [checkpoint_callback, lr_monitor, early_stopping]

logger = TensorBoardLogger(save_dir=data_path.model_dir, version=1, name="lightning_logs")

trainer = Trainer(accelerator="gpu",
                  devices=3,
                  strategy='dp',
                  logger = logger, 
                  default_root_dir=data_path.model_dir,
                  callbacks=callbacks)
trainer.fit(Routine(model, cfg), train_dataloaders=train_loader, val_dataloaders=val_loader)

trainer.test(dataloaders=test_loader)


from wwv.util import OnnxExporter
model = trainer.model.module.module.model
predictor = Predictor(model)
OnnxExporter( model=predictor,
             cfg=cfg, 
             output_dir=data_path.model_dir)()

#####################################################################################################################
#                                            
#####################################################################################################################
# if isinstance(trainer.model, torch.nn.DataParallel):
#     print("test")
#     model = trainer.model
#####################################################################################################################
# reload best 
#####################################################################################################################
# automatically auto-loads the best weights from the previous run 
#####################################################################################################################

In [None]:
class ProcessingLayer(nn.Module):

    def __init__(self,cfg):
        super().__init__()
        self.cfg =cfg
        layers = []
        kwargs = cfg.audio_feature_param[cfg.audio_feature]
        if cfg.audio_feature == "spectrogram":
            layers.append(features.MelSpectrogram(**kwargs))
            # layers.append(T.Resize(224)) # size expected by 2D ResNet 
        elif cfg.audio_feature == "mfcc":
            layers.append(features.MFCC(**kwargs))
            # layers.append(T.Resize(224)) # size expected by 2D ResNet

        # resize inputs
        # layers.append(transforms.RandomResizedCrop(224))
        self.net = torch.nn.Sequential(*layers)
        # logger.info(f"{'-'*20}> Features to be extracted: {cfg.audio_feature}")
        # logger.info(f"{'-'*20}> Feature dimensions: {cfg.processing_output_shape}")

    def forward(self, x:torch.tensor) -> torch.tensor:
        x_out = self.net(x)
        # if self.cfg.verbose:
        #     logger.info(f"ProcessingLayer().foward() [in]: {x.shape}")
        #     logger.info(f"ProcessingLayer().foward() [out]: {x_out.shape}")
        return x_out 


class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""

    def __init__(self, n_feats):
        super().__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()  # (batch, channel, feature, time)


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
    except with layer norm instead of batch norm
    """

    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super().__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x  # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super().__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=batch_first,
            bidirectional=True,
        )
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x

from math import prod
class DeepSpeech(nn.Module):
    def __init__(self, n_cnn_layers=40, n_rnn_layers=1, rnn_dim=1096, stride=2, dropout=0.1,cfg=cfg, **kwargs):
        super().__init__()
        self.cfg= cfg 
        n_feats = ( 121, 20) #  self.cfg.processing_output_shape
        # self.processing_layer = ProcessingLayer(cfg)
        # n_feats =  (121 * 20) // 2

        self.flatten = nn.Flatten()
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(
            *[ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) for _ in range(n_cnn_layers)]
        )
        self.fully_connected = nn.Linear(121, rnn_dim)
        self.birnn_layers = nn.Sequential(
            *[
                BidirectionalGRU(
                    rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                    hidden_size=rnn_dim,
                    dropout=dropout,
                    batch_first=i == 0,
                )
                for i in range(n_rnn_layers)
            ]
        )
        self.classifier = nn.Sequential(
            nn.Linear(1402880, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, 1),
        )

    def forward(self, x):
        print(f"Shape b cnn {x.shape}")
        # x = self.processing_layer(x) # (batch, mfcc, timestep)
        # print(f"Shape a process {x.shape}")
        # x = x.unsqueeze(1) # (batch, channel,  mfcc, timestep)
        # print(f"Shape a transpose {x.shape}")
        x = self.cnn(x)
        print(f"Shape a cnn {x.shape}")
        x = self.rescnn_layers(x)
        print(f"Shape a rescnn {x.shape}")
        
        # # print(f"after view {x.shape}")
        # x = x.transpose(1, 2)  # (batch, time, feature)

        x = self.fully_connected(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        print(f"Shape a fully_connected {x.shape}")
        x = self.birnn_layers(x)
        print(f"Shape a birnn_layers {x.shape}")
        # print(f"after birnn_layers {x.shape}")
        x =  self.flatten(x)
        x = self.classifier(x)
        return x

In [None]:
from transformers import Wav2Vec2FeatureExtractor, SEWDForSequenceClassification
from datasets import load_dataset
import torch

# dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
sampling_rate = dataset.features["audio"].sampling_rate
class SEW(nn.Module):

    def __init__(self):
        super().__init__()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("anton-l/sew-d-mid-400k-ft-keyword-spotting")
        self.model = SEWDForSequenceClassification.from_pretrained("anton-l/sew-d-mid-400k-ft-keyword-spotting")


    def forward(self, x):
        # x = input_dict['input_values']
        x_feats = self.feature_extractor(x, ampling_rate=sampling_rate, return_tensors="pt")
        x_feats.unsqueeze(1)
        logits = self.model(x)
        return logits 

dataset[0]["audio"]["array"]
sew  =SEW()
# # audio file is decoded on the fly

inputs = dataset[0]["audio"]["array"]
with torch.no_grad():
    logits = sew(torch.tensor(inputs))
print(logits)
# predicted_class_ids = torch.argmax(logits, dim=-1).item()
# predicted_label = model.config.id2label[predicted_class_ids]
# predicted_label

In [None]:
trainer = Trainer()
# Run learning rate finder
model = ResNet(block=Bottleneck, num_blocks=[3, 8, 36, 3], cfg=cfg)
model = Routine(model, cfg)

lr_finder = trainer.tuner.lr_find(model)

# Results can be found in
lr_finder.results

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()

# # update hparams of the model
# model.hparams.lr = new_lr

In [None]:

# TensorRT 
rt = {
  "error": False,
  "result": {
    "wake_word_probability": 0,
    "prediction": 0,
    "false_alarm_probability": 1,
    "decision_threshold": 0.5,
    "wwvm_version": "docker-env-model-version",
    "inference_time": 0.01810431480407715
  }
}

# with CPU
cpu = {
  "error": False,
  "result": {
    "wake_word_probability": 0,
    "prediction": 0,
    "false_alarm_probability": 1,
    "decision_threshold": 0.5,
    "wwvm_version": "docker-env-model-version",
    "inference_time": 0.08728623390197754
  }
}


# with cuda 
cuda = {
  "error": False,
  "result": {
    "wake_word_probability": 0,
    "prediction": 0,
    "false_alarm_probability": 1,
    "decision_threshold": 0.5,
    "wwvm_version": "docker-env-model-version",
    "inference_time": 0.022240400314331055
  }
}

def get_factor(d1,d2):
  return d1['result']['inference_time'] / d2['result']['inference_time']


print(f"Cuda {get_factor(cpu,cuda):.2f} faster than cpu")
print(f"TensorRT {get_factor(cpu,rt):.2f} faster than cpu")
print(f"TensorRT {get_factor(cuda,rt):.2f} faster than cuda")

In [None]:
# import torch 
# PATH = "/home/akinwilson/Code/pytorch/output/model/epoch=27-val_loss=0.16-val_acc=0.97.ckpt"
# model.load_state_dict(torch.load(PATH), map_location=torch.device('cpu'))
# trainer.test(test_loader, ckpt_path='best')
from torch import tensor 
# ftrs = [x['train_ftr'].mean().item() for x in training_step_outputs]
# accs = [x['train_acc'].mean().item() for x in training_step_outputs]
# losses
# ttrs
# results = {"avg_loss": statistics.fmean([x['loss'].item() for x in training_step_outputs]),}
            # "avg_ttr": torch.stack([x['train_ttr'].mean().item() for x in training_step_outputs]).mean(),
            # "avg_ftr": torch.stack([x['train_ftr'].mean().item() for x in training_step_outputs]).mean(),
            # "avg_acc": torch.stack([x['train_acc'].mean().item() for x in training_step_outputs]).mean()}

In [None]:
from torch.utils.data import Dataset, DataLoader
class dataset(Dataset):
  def __init__(self,x,y):
    self.x = torch.tensor(x,dtype=torch.float32)
    self.y = torch.tensor(y,dtype=torch.float32)
    self.length = self.x.shape[0]
 
  def __getitem__(self,idx):
    return self.x[idx],self.y[idx]
  def __len__(self):
    return self.length

xs =torch.ones(64, 48000)
ys = torch.ones(64)

trainset = dataset(xs,ys)
#DataLoader
trainloader = DataLoader(trainset,batch_size=64,shuffle=False)
for b in trainloader:
  print(b[0].shape)