In [None]:
def train(self, input_tensor, ground_truth, loss_mask, target_lengths):

    # Shape of the input tensor (B, T, F)
    # B: Number of a batch (8, 16, or 64 ...)
    # T: Temporal length of an input
    # F: Number of frequency band, 80

    batch_size = input_tensor.shape[0]

    self.encoder_optimizer.zero_grad()
    self.decoder_optimizer.zero_grad()

    # (B, T, F) -> (B, T, H)
    encoded_tensor = self.encoder(input_tensor)

    # (B, T, H) -> (B, T, 75)
    pred_tensor = self.decoder(encoded_tensor)

    # Cast true sentence as Long data type, since CTC loss takes long tensor only
    # Shape (B, S)
    # S: Max length among true sentences 
    truth = ground_truth
    truth = truth.type(torch.cuda.LongTensor)

    # CTC loss function takes tensor of the form (T, B, 75)
    # Permute function changes axes of a tensor T <-> B
    pred_tensor = pred_tensor.permute(1, 0, 2)

    # CTC loss need to know the lenght of the true sentence
    input_lengths = torch.full(size=(batch_size,), fill_value=pred_tensor.shape[0], dtype=torch.long)

    # Calculate CTC Loss
    loss = self.ctc_loss(pred_tensor, truth, input_lengths, target_lengths)

    # Calculate loss
    loss.backward()

    # Update weights
    self.encoder_optimizer.step()
    self.decoder_optimizer.step()

    # Return loss divided by true length because loss is sum of the character losses

    return pred_tensor, loss.item() / ground_truth.shape[1]


def test(self, input_tensor, ground_truth, loss_mask, target_lengths):

    # Shape of the input tensor (B, T, F)
    # B: Number of a batch (8, 16, or 64 ...)
    # T: Temporal length of an input
    # F: Number of frequency band, 80

    batch_size = input_tensor.shape[0]

    # (B, T, F) -> (B, T, H)
    encoded_tensor = self.encoder(input_tensor)

    # (B, T, H) -> (B, T, 75)
    pred_tensor = self.decoder(encoded_tensor)

    # Cast true sentence as Long data type, since CTC loss takes long tensor only
    # Shape (B, S)
    # S: Max length among true sentences 
    truth = ground_truth
    truth = truth.type(torch.cuda.LongTensor)

    # CTC loss function takes tensor of the form (T, B, 75)
    # Permute function changes axes of a tensor T <-> B
    pred_tensor = pred_tensor.permute(1, 0, 2)

    # CTC loss need to know the lenght of the true sentence
    input_lengths = torch.full(size=(batch_size,), fill_value=pred_tensor.shape[0], dtype=torch.long)

    # Calculate CTC Loss
    loss = self.ctc_loss(pred_tensor, truth, input_lengths, target_lengths)

    # Return loss divided by true length because loss is sum of the character losses

    return pred_tensor, loss.item() / ground_truth.shape[1]

def save(self, check_point_name):
    torch.save({
        'encoder_state_dict': self.encoder.state_dict(),
        'decoder_state_dict': self.decoder.state_dict(),
        'encoder_optimizer_state_dict': self.encoder_optimizer.state_dict(),
        'decoder_optimizer_state_dict': self.decoder_optimizer.state_dict(),
        }, check_point_name)

def load(self, check_point_name):
    checkpoint = torch.load(check_point_name)
    self.encoder.load_state_dict(checkpoint['encoder_state_dict'])
    self.decoder.load_state_dict(checkpoint['decoder_state_dict'])
    self.encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
    self.decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])

def set_mode(self, mode):

    # Must call .train() after loading if you want to train again
    if mode == 'train':
        self.encoder.train()
        self.decoder.train()

    # Must call .eval() after loading if you do not want to use dropouts
    elif mode == 'eval':
        self.encoder.eval()
        self.decoder.eval()

    else:
        print("Invalid mode string: {}".format(mode))