In [29]:
import os
import torchaudio
import torch
from torchmetrics import SignalNoiseRatio
import matplotlib.pyplot as plt
from torch.nn import Module, Linear, Sigmoid, LSTM, BCELoss, MSELoss, Conv1d, Conv2d, MaxPool2d, Transformer, LayerNorm, PReLU, Fold, ConvTranspose1d
from torch.optim import Adam
import torch.nn.functional as F
from pytorch_model_summary import summary
from tqdm import tqdm
import numpy as np
import random
import speechbrain as sb
from speechbrain.nnet.losses import get_si_snr_with_pitwrapper
import pickle
import math

In [29]:
import DataLoader

X,Y,speech,noise,mix = DataLoader.data_loader()

In [35]:
# TRANSFORMER MASK NET
NUMBER_OF_SPEAKERS = 2
ENCODED_TIMESTEPS = int(50000/8) # 50000 is len of training data -> 50000/8 = 6250
FOLDS = math.floor((ENCODED_TIMESTEPS/250)*2-1)

class TransformerMaskNet(Module):
    def __init__(self,noise=False):
        super(TransformerMaskNet, self).__init__()
        # ENCODER subnet
        self.tdnn = Conv1d(in_channels=1,out_channels=256,kernel_size=16,stride=8,padding=6)

        self.lnorm = LayerNorm(normalized_shape=(256,ENCODED_TIMESTEPS))
        self.lin0 = Linear(in_features=256, out_features=256)

        self.tf1 = Transformer(d_model = 256, nhead=8, dim_feedforward=1024)
        self.tf2 = Transformer(d_model = 256, nhead=8, dim_feedforward=1024)

        self.prelu = PReLU()
        self.lin1 = Linear(in_features=64000, out_features=(64000*NUMBER_OF_SPEAKERS))

        self.fold = Fold(output_size=(1,ENCODED_TIMESTEPS),kernel_size=(1,250),stride=(1,125))
        self.lin2 = Linear(in_features=NUMBER_OF_SPEAKERS*256, out_features=NUMBER_OF_SPEAKERS*256)
        self.lin3 = Linear(in_features=NUMBER_OF_SPEAKERS*256, out_features=NUMBER_OF_SPEAKERS*256)

        self.convT = ConvTranspose1d(in_channels=256,out_channels=1,kernel_size=16,stride=8, padding=4)

    def forward(self,x):
        # TDNN Encoder 1x50000 -> 256x6250
        h = self.tdnn(x)

        # NORMALIZATION and Overlapping
        x = self.lnorm(h)
        x = self.lin0(x.view(-1,256))
        x = x.view(256,-1)
        x = x.unfold(dimension=1, step=125, size=250)# Chunking and 50% Overlap

        # SEPFORMER Block
        x = x.reshape(250,FOLDS,256)
        y = self.tf1(x,torch.rand(250,FOLDS,256))
        x = y + x
        y = self.tf2(x,torch.rand(250,FOLDS,256))
        x = y + x # Residual connection
        
        # PRELU and Linear
        x = self.prelu(x)
        x = x.view(FOLDS,-1)
        x = self.lin1(x)
        x = F.relu(x)
        x = x.reshape(256*NUMBER_OF_SPEAKERS,250,FOLDS)

        # OVERLAP ADD (256*#S,250,49) -> (256,#S,6250)
        x = self.fold(x)
        
        # FFN + ReLU (256,#S,6250) -> (#S, 256,6250)
        x = self.lin2(x.view(-1,NUMBER_OF_SPEAKERS*256))
        x = self.lin3(x.view(-1,NUMBER_OF_SPEAKERS*256))
        x = F.relu(x)

        x = x.view(NUMBER_OF_SPEAKERS,256,6250)
        # DECODER
        x = self.convT(x*h)

        return x

    

print(summary(TransformerMaskNet(),torch.zeros((1, ENCODED_TIMESTEPS*8))))

----------------------------------------------------------------------------
         Layer (type)          Output Shape         Param #     Tr. Param #
             Conv1d-1           [256, 6250]           4,352           4,352
          LayerNorm-2           [256, 6250]       3,200,000       3,200,000
             Linear-3           [6250, 256]          65,792          65,792
        Transformer-4        [250, 49, 256]      11,060,224      11,060,224
        Transformer-5        [250, 49, 256]      11,060,224      11,060,224
              PReLU-6        [250, 49, 256]               1               1
             Linear-7          [49, 128000]   8,192,128,000   8,192,128,000
               Fold-8     [512, 1, 1, 6250]               0               0
             Linear-9           [6250, 512]         262,656         262,656
            Linear-10           [6250, 512]         262,656         262,656
   ConvTranspose1d-11         [2, 1, 50000]           4,097           4,097
Total param

In [None]:
EPOCHS = 10
BATCH_SIZE = 1
REFERENCE_CHANNEL = 0
INIT_LR = 15e**(−5)
PICKLE_SAVE_PATH = '/project/data_asr/CHiME5/data/librenoise/models/params.pkl'
MODEL_SAVE_PATH = '/project/data_asr/CHiME5/data/librenoise/models/TF'

CUDA = True # if torch.cuda.is_available()
device =  torch.device("cuda") if torch.cuda.is_available() else torch.device('cpu')
print("Mounted on:", device)

lossBCE = BCELoss().to(device)

model = TransformerMaskNet().to(device)
model= torch.nn.DataParallel(model,device_ids=[0])
opt = Adam(model.parameters(), lr=INIT_LR)

H = {
    "train_loss":[],
    "train_acc":[],
    "val_loss":[],
    "val_acc":[]
}

def check_accuracy_training(speech_pred, y_s):
    speech_pred = (speech_pred>0.15).float()
    return float(torch.sum((speech_pred == y_s).float())/torch.sum(torch.ones(513,speech_pred.shape[1])))

def check_accuracy_validation(model):
    example_nr = int(np.random.random()*(len(speech)-len(trainX))+len(trainX))
    model.eval()
    pred = model(X[example_nr]).reshape(1,513,-1)
    val_loss = lossBCE(pred,Y[example_nr][0].unsqueeze(0))
    pred = (pred>0.15).float()
    model.train()
    return float(torch.sum((pred == Y[example_nr][0]).float())/torch.sum(torch.ones(513,X[example_nr].shape[2])).to(device)),val_loss

print("[INFO] training the network...")

for epoch in range(0, EPOCHS):
    print("Epoch:",str(epoch+1)+"/"+str(EPOCHS))
    # Train Mode
    model.train()
    
    # Initialize
    totalTrainLoss = 0
    totalValLoss = 0
    trainCorrect = 0
    valCorrect = 0

    X = X.to(device)
    Y = Y.to(device)
    trainX = X[:2000]
    trainY = Y
    for i in tqdm(range(0,len(trainX))): # Iterate over Training Examples
        (x, y) = (trainX[i],trainY[i][0].unsqueeze(0))
        speech_pred=model(x)
        loss = lossBCE(speech_pred,y)
        # zero out the gradients, perform the backpropagation step, and update the weights
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        #H["train_acc"].append(check_accuracy_training(speech_pred,y))
        H["train_acc"].append(check_accuracy_training(speech_pred,y))
        H["train_loss"].append(float(loss))
        if i % 10 == 0:
            val_acc, val_loss = check_accuracy_validation(model)
            H["val_acc"].append(val_acc)
            H["val_loss"].append(float(val_loss))
        if i % 100 == 0:
            if i == 0:
                continue
            print("Average Training Accuracy at Iteration",str(i),":",np.mean(np.array(H["train_acc"])))
            print("Total Training Loss at Iteration",str(i),":",np.sum(np.array(H["train_loss"])))
            print("Average Validation Accuracy at Iteration",str(i),":",np.mean(np.array(H["val_acc"])))
            print("Total Validation Loss at Iteration",str(i),":",np.sum(np.array(H["val_loss"])))
    # Save
    torch.save(model.state_dict(), MODEL_SAVE_PATH + "epoch"+ str(epoch+1) + ".pt")

torch.save(model.state_dict(), MODEL_SAVE_PATH + "final" + ".pt")
with open(PICKLE_SAVE_PATH, 'wb') as f:
    pickle.dump(H, f)