In [1]:
import torch
from torchsummary import summary
import torch.nn as nn
from torch.nn import functional as F
import pandas as pd
from dataset import Dataset, collate_fn_padd
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2

In [2]:
def train_dataloader():
    train_dataset = Dataset(json_path="data/single_batch.json")
    return DataLoader(dataset=train_dataset,
                        num_workers=0,
                        batch_size=1,
                        collate_fn=collate_fn_padd)

In [3]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__() 
        self.conv = nn.Sequential(
                    nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(20, 5)),
                    nn.BatchNorm2d(32),
                    nn.Hardtanh(0, 20, inplace=True),
                    nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), padding=(10, 5)),
                    nn.BatchNorm2d(32),
                    nn.Hardtanh(0, 20, inplace=True)
                )
        self.lstm = nn.LSTM(
						input_size=1024,
						hidden_size=800,
						num_layers=1,
						batch_first=True,
						bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(800, 29, bias=False)
        )
    def forward(self, x):
        x = self.conv(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # Collapse feature dimension
        x = x.transpose(1, 2)  # TxNxH
        x, _ = self.lstm(x)
        x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)  # (TxNxH*2) -> (TxNxH) by sum
        x = self.fc(x)
        return x

In [4]:
loader = train_dataloader()
model = Model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CTCLoss(blank=28, zero_infinity=True)
final = 0.
model.train()
model.cuda()
for epoch in range(1000):
    for idx, batch in enumerate(loader):
        optimizer.zero_grad()
        spectrogram, labels, spec_len, label_len = batch
        spectrogram= spectrogram.cuda()
        labels = labels.cuda()
        # spec_len = spec_len.to("cuda")
        # label_len = label_len.to("cuda")
        #input_sizes = torch.mul(spec_len[0], int(spectrogram.size(3)))
        #print(input_sizes)
        output = model(spectrogram)
        output = F.log_softmax(output, dim=-1)
        output = output.transpose(0, 1) # (time, batch, n_class)
        #print(output.shape)
        loss = criterion(output, labels, spec_len, label_len)
        loss.backward()
        optimizer.step()
        final += loss.detach()
        if epoch % 100 == 0:
            print("epoch {} loss: {} ".format(epoch, loss.detach()))
print(final/1000)

loading json data from file  data/single_batch.json
epoch 0 loss: 6.788671493530273 
epoch 100 loss: 0.17408563196659088 
epoch 200 loss: 0.006847347132861614 
epoch 300 loss: 0.0027056897524744272 


KeyboardInterrupt: 

In [21]:
from dataset import MelSpectrogram, TextProcess
from train import SpeechModule
import torch
import torchaudio
from torch.nn import functional as F
from configs.train_config import SpectConfig

class PredictorModule():
	def __init__(self, model=None):
		self.model = model#SpeechModule.load_from_checkpoint(ckpt_path, num_cnn_layers=2, num_rnn_layers=1, rnn_dim=1024, num_classes=29, n_feats=128)
		self.text_process = TextProcess()
		self.audio_transform = MelSpectrogram(audio_conf=SpectConfig)

	def decode_greedy(self, output, blank_label=28, collapse_repeated=True):
		arg_maxes = torch.argmax(output, dim=2).squeeze(1)
		decode = []
		for i, index in enumerate(arg_maxes):
			if index != blank_label:
				if collapse_repeated and i != 0 and index == arg_maxes[i -1]:
					continue
				decode.append(index.item())
		return decode


	def predict(self, file_path):
		self.model.eval()
		waveform, _ = torchaudio.load(file_path)
		spectrogram = self.audio_transform(waveform) # (channel, feature, time)
		output = self.model(spectrogram.unsqueeze(0))
		output = F.log_softmax(output, dim=2)
		output = output.transpose(0, 1).detach()
		res = self.decode_greedy(output)
		return self.text_process.int_to_text(res)

if __name__ == '__main__':
	predictor = PredictorModule(model)
	print(predictor.predict("data/wav_clips/de1323d4adbeb3df830bb4ac10e84b0407f28dcfeea1786d34ba36b0bd84f6333dc4c519238cb5b48459df722a6ff0cd650a19ed63687a3a0c0c7b685ee9c2c1.wav"))
	print("I could hear a man crying out in pain in the dentist's office.")

i could hear a man crying out in pain in the dentists office
I could hear a man crying out in pain in the dentist's office.
