In [None]:
def train(lang_pairs, w2i_source, w2i_target, i2w_source, i2w_target, encoder, decoder, epochs:int, batch_size:int=1,
          learning_rate:float=1e-3, max_ratio:float=0.95, min_ratio:float=0.15, detailed_analysis:bool=True):
        
    # each n_iters plot behaviour of RNN Decoder
    n_iters = 3000
    
    train_losses, train_accs = [], []
    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)
    
    #training_pairs = [pairs2idx(random.choice(lang_pairs), w2i_cmds, w2i_acts) for _ in range(len(lang_pairs))]
    training_pairs = lang_pairs
    max_target_length = max(iter(map(lambda lang_pair: len(lang_pair[1]), training_pairs)))
    n_lang_pairs = len(training_pairs)
    
    
    
    # negative log-likelihood loss
    criterion = nn.NLLLoss()
    
    # teacher forcing curriculum
    # decrease teacher forcing ratio per epoch (start off with high ratio and move in equal steps to min_ratio)
    ratio_diff = max_ratio-min_ratio
    step_per_epoch = ratio_diff / epochs
    teacher_forcing_ratio = max_ratio
    
    for epoch in trange(epochs,  desc="Epoch"):
                
        loss_per_epoch = 0
        acc_per_epoch = 0
        np.random.shuffle(training_pairs)
        
        
        #create batch iterator
        batch_iterator = create_batches(training_pairs, batch_size)

        for mini_batch in batch_iterator:
        
            commands_batch, actions_batch, inp_lengths, mask, max_target_length = prepare_batch(mini_batch)
            
            
            # initialise as many hidden states as there are sequences in the mini-batch (1 for the beginning)
            encoder_hidden = encoder.init_hidden()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()


            loss = 0
            
            #print(commands_batch)

            commands_batch = commands_batch.to(device)
            actions_batch = actions_batch.to(device)
            inp_lengths = inp_lengths.to(device)
            mask = mask.to(device)


            #input_length = command.size(0)
            #target_length = action.size(0)

            encoder_outputs, encoder_hidden = encoder(commands_batch, inp_lengths, encoder_hidden)

            decoder_input = torch.LongTensor([[1 for _ in range(batch_size)]])
            decoder_input = decoder_input.to(device)

            # Set initial decoder hidden state to the encoder's final hidden state
            decoder_hidden = encoder_hidden[:decoder.n_layers]

            use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

            pred_sent = ""
            true_sent = ' '.join([i2w_target[act.item()] for act in islice(action, 1, None)]).strip() # skip SOS token

            if use_teacher_forcing:
                # Teacher forcing: feed target as the next input
                for i in range(1, max_target_length):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0  # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation

                    #loss += criterion(decoder_out, action[i].unsqueeze(0))
                    
                    # Teacher forcing: next input is current target
                    decoder_input = target_variable[i].view(1, -1)
                    # Calculate and accumulate loss
                    mask_loss, nTotal = maskNLLLoss(decoder_output, actions_batch[i], mask[i])
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal
                    #decoder_input = acts_batch[i] # convert list of int into int

                    pred_sent += i2w_target[pred.item()] + " "

                    #if pred.squeeze().item() == w2i_target['<EOS>']:
                    #    break
            else:
                # Autoregressive RNN: feed previous prediction as the next input
                for i in range(1, max_target_length):
                    decoder_out, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    dim = 1 if len(decoder_out.shape) > 1 else 0 # crucial to correctly compute the argmax
                    pred = torch.argmax(decoder_out, dim) # argmax computation

                    # No teacher forcing: next input is decoder's own current output
                    _, topi = decoder_output.topk(1)
                    decoder_input = torch.LongTensor([[topi[j][0] for j in range(batch_size)]])
                    decoder_input = decoder_input.to(device)
                    # Calculate and accumulate loss
                    mask_loss, nTotal = maskNLLLoss(decoder_output, actions_batch[i], mask[i])
                    loss += mask_loss
                    print_losses.append(mask_loss.item() * nTotal)
                    n_totals += nTotal

                    decoder_input = pred.squeeze() # convert list of int into int

                    pred_sent += i2w_target[pred.item()] + " "

                    #if decoder_input.item() == w2i_target['<EOS>']:
                    #    break

            # strip off any leading or trailing white spaces
            pred_sent = pred_sent.strip()
            acc_per_epoch += 1 if pred_sent == true_sent else 0 # exact match accuracy

            loss.backward()

        ### Inspect translation behaviour ###
        if detailed_analysis:
            nl_command = ' '.join([i2w_source[cmd.item()] for cmd in command]).strip()
            if idx > 0 and idx % n_iters == 0:
                print("Loss: {}".format(loss.item() / target_length)) # current per sequence loss
                print("Acc: {}".format(acc_per_epoch / (idx + 1))) # current per iters exact-match accuracy
                print()
                print("Command: {}".format(nl_command))
                print("True action: {}".format(true_sent))
                print("Pred action: {}".format(pred_sent))
                print()
                print("True sent length: {}".format(len(true_sent.split())))
                print("Pred sent length: {}".format(len(pred_sent.split())))
                print()

        encoder_optimizer.step()
        decoder_optimizer.step()

        loss_per_epoch += loss.item() / target_length

        loss_per_epoch /= n_lang_pairs
        acc_per_epoch /= n_lang_pairs
        
        print("Train loss: {}".format(loss_per_epoch)) # loss
        print("Train acc: {}".format(acc_per_epoch)) # exact-match accuracy
        print("Current teacher forcing ratio {}".format(teacher_forcing_ratio))
        
        train_losses.append(loss_per_epoch)
        train_accs.append(acc_per_epoch)
        
        teacher_forcing_ratio -= step_per_epoch # decrease teacher forcing ratio
        
    return train_losses, train_accs