In [1]:
from ipywidgets import interact, interactive, fixed, interact_manual, widgets

import torch
import numpy as np
import importlib

from fairseq import bleu, data, options, progress_bar, tasks, tokenizer, utils
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_generator import SequenceGenerator
from fairseq.sequence_scorer import SequenceScorer
from bokeh.io import output_notebook, show
import networkx as nx
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4
import math
from itertools import islice,zip_longest
from sacrebleu import sentence_bleu, corpus_bleu
from tqdm import tqdm_notebook
import pandas as pd

from bokeh.io import show, output_file
from bokeh.plotting import figure
from bokeh.models import GraphRenderer, StaticLayoutProvider, Oval, MultiLine
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool,\
    TapTool, BoxSelectTool,PanTool,BoxZoomTool
from bokeh.models import ColumnDataSource, Range1d, LabelSet, Label, AdaptiveTicker
from bokeh.palettes import Spectral8
from bokeh.models import SingleIntervalTicker, LinearAxis


# from interactive import *
from fairseq.models import FairseqIncrementalDecoder

from subword_nmt.apply_bpe import BPE
from sacremoses import MosesTokenizer

import codecs   
from collections import namedtuple


output_notebook()

In [2]:
import os

In [3]:
MODEL_PATH = 'wmt19.en-ru.ensemble'
CHECKPOINT_PATHS = [
    'model1.pt',
    'model2.pt',
    'model3.pt',
    'model4.pt',
]
CHECKPOINT_PATHS = [os.path.join(MODEL_PATH, path) for path in CHECKPOINT_PATHS]

MODEL_PATH = ':'.join(CHECKPOINT_PATHS)
BINARY_DATA_PATH = 'data-bin/wmt17_en_ru/'
# BPECODES_PATH = 'wmt19.en-ru.ensemble/codes'
BPECODES_PATH = 'data/wmt17_en_ru/code' # иначе лажа какая-то получается
BEAM = '5'
LENPEN = '0.6'
DIVERSE_BEAM_STRENGTH = '0'
SHARED_BPE = True
SRS = "en"
TGT = "ru"

tkn = {}
bpe = {}
if not SHARED_BPE:
    for l in [SRS, TGT]:
        with codecs.open(BPECODES_PATH) as src_codes:
            tkn[l] = MosesTokenizer(lang=l)
            bpe[l] = BPE(src_codes)
else:
    l = SRS
    with codecs.open(BPECODES_PATH) as src_codes:
        tkn[l] = MosesTokenizer(lang=l)
        bpe[l] = BPE(src_codes)

def prepare_input(s, l='en'):
    return [bpe[l].process_line(tkn[l].tokenize(s, return_str=True))]

In [4]:
parser = options.get_generation_parser(interactive=True)

args = options.parse_args_and_arch(parser, input_args=[
    BINARY_DATA_PATH,
    '--path', MODEL_PATH,
    '--diverse-beam-strength', DIVERSE_BEAM_STRENGTH,
    '--lenpen', 0,
    '--remove-bpe',
    '--beam', BEAM
])

use_cuda = False
task = tasks.setup_task(args)
model_paths = args.path.split(':')
models, model_args = utils.load_ensemble_for_inference(
        model_paths,
        task,
        model_arg_overrides=eval(args.model_overrides)
)
tgt_dict = task.target_dictionary


for model in models:
    model.make_generation_fast_(
        beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
        need_attn=args.print_alignment,
    )
    if args.fp16:
        model.half()

| [en] dictionary: 31640 types
| [ru] dictionary: 31232 types




In [5]:
Batch = namedtuple('Batch', 'ids src_tokens src_lengths')
Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments')


def make_batches(lines, args, task, max_positions):
    tokens = [
        task.source_dictionary.encode_line(
            src_str, add_if_not_exist=False
        ).long()
        for src_str in lines
    ]
    lengths = [t.numel() for t in tokens]
    itr = task.get_batch_iterator(
        dataset=task.build_dataset_for_inference(tokens, lengths),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test
    ).next_epoch_itr(shuffle=False)
    for batch in itr:
        yield Batch(
            ids=batch['id'],
            src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'],
        )



In [6]:
# Load alignment dictionary for unknown word replacement
# (None if no unknown word replacement, empty if no path to align dictionary)
align_dict = utils.load_align_dict(args.replace_unk)

def make_result(src_str, hypos):
    result = Translation(
        src_str='O\t{}'.format(src_str),
        hypos=[],
        pos_scores=[],
        alignments=[],
    )

    # Process top predictions
    for hypo in hypos[:min(len(hypos), args.nbest)]:
        hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
            hypo_tokens=hypo['tokens'].int().cpu(),
            src_str=src_str,
            alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None,
            align_dict=align_dict,
            tgt_dict=tgt_dict,
            remove_bpe=args.remove_bpe,
        )
        result.hypos.append('H\t{}\t{}'.format(hypo['score'], hypo_str))
        result.pos_scores.append('P\t{}'.format(
            ' '.join(map(
                lambda x: '{:.4f}'.format(x),
                hypo['positional_scores'].tolist(),
            ))
        ))
        result.alignments.append(
            'A\t{}'.format(' '.join(map(lambda x: str(utils.item(x)), alignment)))
            if args.print_alignment else None
        )
    return result

def process_batch(batch):
    tokens = batch.tokens
    lengths = batch.lengths

    if use_cuda:
        tokens = tokens.cuda()
        lengths = lengths.cuda()

    encoder_input = {'src_tokens': tokens, 'src_lengths': lengths}
    translations = translator.generate(
        encoder_input,
        maxlen=int(args.max_len_a * tokens.size(1) + args.max_len_b),
    )

    return [make_result(batch.srcs[i], t) for i, t in enumerate(translations)]

max_positions = utils.resolve_max_positions(
    task.max_positions(),
    *[model.max_positions() for model in models]
)

In [93]:
import tmp_e

In [94]:
importlib.reload(tmp_e)

<module 'tmp_e' from '/home/dkuznetsov/notebook/course_paper/experiments/beam_width/tmp_e.py'>

def shannon_entropy(pk, dim=None):
    if dim is None:
        return -torch.sum(pk * torch.log(pk))

    return -torch.sum(pk * torch.log(pk), dim=dim)


class EnsembleModel(torch.nn.Module):
    """A wrapper around an ensemble of models."""

    def __init__(self, models):
        super().__init__()
        self.models = torch.nn.ModuleList(models)
        self.incremental_states = None
        if all(hasattr(m, 'decoder') and isinstance(m.decoder, FairseqIncrementalDecoder) for m in models):
            self.incremental_states = {m: {} for m in models}

    def has_encoder(self):
        return hasattr(self.models[0], 'encoder')

    def max_decoder_positions(self):
        return min(m.max_decoder_positions() for m in self.models)

    @torch.no_grad()
    def forward_encoder(self, encoder_input):
        if not self.has_encoder():
            return None
        return [model.encoder(**encoder_input) for model in self.models]

    @torch.no_grad()
    def forward_decoder(self, tokens, encoder_outs, temperature=1., with_var=True):
        if len(self.models) == 1:
            probs, attn, pvars, pentropy = self._decode_one(
                tokens,
                self.models[0],
                encoder_outs[0] if self.has_encoder() else None,
                self.incremental_states,
                log_probs=True,
                temperature=temperature,
                with_var=True
            )
            if with_var:
                probs_vars = torch.zeros(probs.size(), device=probs.device)
                ens_var = probs.var(-1)
                return probs, attn, probs_vars, probs_vars, torch.stack([pvars], dim=0), torch.stack([pentropy], dim=0), ens_var, ens_var
            return probs, attn, None

        log_probs = []
        sing_var = []
        sing_entropy = []
        avg_attn = None
        for model, encoder_out in zip(self.models, encoder_outs):
            probs, attn, pvars, pentropy = self._decode_one(
                tokens,
                model,
                encoder_out,
                self.incremental_states,
                log_probs=True,
                temperature=temperature,
                with_var=True
            )
            log_probs.append(probs)
            print(probs.size())
            sing_var.append(pvars)
            sing_entropy.append(pentropy)
            if attn is not None:
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        
        avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log(len(self.models))

        e_avg_probs = torch.exp(avg_probs)
        ensemble_var = e_avg_probs.var(-1)
        ensemble_entropy = shannon_entropy(e_avg_probs, -1)

        e_probs = torch.exp(torch.stack(log_probs, dim=0))
        print(e_probs.size())
        probs_mean = e_probs.mean(dim=0)
        probs_var = e_probs.var(dim=0)

        sing_var = torch.stack(sing_var, dim=0)
        sing_entropy = torch.stack(sing_entropy, dim=0)
        if avg_attn is not None:
            avg_attn.div_(len(self.models))
        if with_var:
            return avg_probs, avg_attn, probs_mean, probs_var, sing_var, sing_entropy, ensemble_var, ensemble_entropy
        return avg_probs, avg_attn, probs_var

    def _decode_one(
        self, tokens, model, encoder_out, incremental_states, log_probs,
        temperature=1.,
        with_var=False
    ):
        if self.incremental_states is not None:
            decoder_out = list(model.forward_decoder(
                tokens, encoder_out=encoder_out, incremental_state=self.incremental_states[model],
            ))
        else:
            decoder_out = list(model.forward_decoder(tokens, encoder_out=encoder_out))
        decoder_out[0] = decoder_out[0][:, -1:, :]
        if temperature != 1.:
            decoder_out[0].div_(temperature)
        attn = decoder_out[1] if len(decoder_out) > 1 else None
        if type(attn) is dict:
            attn = attn.get('attn', None)
        if type(attn) is list:
            attn = attn[0]
        if attn is not None:
            attn = attn[:, -1, :]
        probs = model.get_normalized_probs(decoder_out, log_probs=log_probs)
        probs = probs[:, -1, :]
        e_probs = torch.exp(probs)
        if with_var:
            return probs, attn, e_probs.var(dim=1), shannon_entropy(e_probs, 1) 
        return probs, attn

    def reorder_encoder_out(self, encoder_outs, new_order):
        if not self.has_encoder():
            return
        return [
            model.encoder.reorder_encoder_out(encoder_out, new_order)
            for model, encoder_out in zip(self.models, encoder_outs)
        ]

    def reorder_incremental_state(self, new_order):
        if self.incremental_states is None:
            return
        for model in self.models:
            model.decoder.reorder_incremental_state(self.incremental_states[model], new_order)


class SourceSequenceGenerator(SequenceGenerator):
    @torch.no_grad()
    def generate(self, models, sample, **kwargs):
        """Generate a batch of translations.

        Args:
            models (List[~fairseq.models.FairseqModel]): ensemble of models
            sample (dict): batch
            prefix_tokens (torch.LongTensor, optional): force decoder to begin
                with these tokens
            bos_token (int, optional): beginning of sentence token
                (default: self.eos)
        """
        model = EnsembleModel(models)
        return self._generate(model, sample, **kwargs)

    @torch.no_grad()
    def _generate(
        self,
        model,
        sample,
        prefix_tokens=None,
        bos_token=None,
        with_var=False,
        **kwargs
    ):
        if not self.retain_dropout:
            model.eval()

        # model.forward normally channels prev_output_tokens into the decoder
        # separately, but SequenceGenerator directly calls model.encoder
        encoder_input = {
            k: v for k, v in sample['net_input'].items()
            if k != 'prev_output_tokens'
        }
        
        return_all_tokens = False
        if 'return_all_tokens' in kwargs:
            return_all_tokens = kwargs['return_all_tokens']
            
        src_tokens = encoder_input['src_tokens']
        src_lengths = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1)
        input_size = src_tokens.size()
        # batch dimension goes first followed by source lengths
        models_num = len(model.models)
        bsz = input_size[0]
        src_len = input_size[1]
        beam_size = self.beam_size

        if self.match_source_len:
            max_len = src_lengths.max().item()
        else:
            max_len = min(
                int(self.max_len_a * src_len + self.max_len_b),
                # exclude the EOS marker
                model.max_decoder_positions() - 1,
            )
        assert self.min_len <= max_len, 'min_len cannot be larger than max_len, please adjust these!'

        # compute the encoder output for each beam
        encoder_outs = model.forward_encoder(encoder_input)
        new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
        new_order = new_order.to(src_tokens.device).long()
        encoder_outs = model.reorder_encoder_out(encoder_outs, new_order)

        # initialize buffers
        scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0)
        scores_buf = scores.clone()
        tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad)
        tokens_buf = tokens.clone()
        tokens[:, 0] = self.eos if bos_token is None else bos_token
        attn, attn_buf = None, None

        # The blacklist indicates candidates that should be ignored.
        # For example, suppose we're sampling and have already finalized 2/5
        # samples. Then the blacklist would mark 2 positions as being ignored,
        # so that we only finalize the remaining 3 samples.
        blacklist = src_tokens.new_zeros(bsz, beam_size).eq(-1)  # forward and backward-compatible False mask

        # list of completed sentences
        finalized = [[] for i in range(bsz)]
        finished = [False for i in range(bsz)]
        num_remaining_sent = bsz
        
        if return_all_tokens:
            all_tokens = tokens.data.new(max_len + 2, 2 * bsz * beam_size).fill_(self.pad)
            all_scores = scores.data.new(max_len + 2, 2 * bsz * beam_size).fill_(0)
            all_softmaxes = torch.zeros(max_len + 2, beam_size, self.vocab_size)
            all_vars = scores.data.new(max_len + 2, 2 * bsz * beam_size).fill_(0)
            all_vars_vocab = torch.zeros(max_len + 2, beam_size, self.vocab_size)
            all_means = scores.data.new(max_len + 2, 2 * bsz * beam_size).fill_(0)
            all_sing_vars = torch.zeros(max_len + 2, models_num, beam_size)
            all_sing_entropy = torch.zeros(max_len + 2, models_num, beam_size)
            all_ens_vars = torch.zeros(max_len + 2, beam_size)
            all_ens_entropy = torch.zeros(max_len + 2, beam_size)
            is_finalized = torch.zeros(max_len + 2, bsz * beam_size, dtype=torch.uint8)
            all_bbsz_idx = torch.zeros(max_len + 2, 2 * bsz * beam_size, dtype=torch.uint8)
            all_lprobs = []

        # number of candidate hypos per step
        cand_size = 2 * beam_size  # 2 x beam size in case half are EOS

        # offset arrays for converting between different indexing schemes
        bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens)
        cand_offsets = torch.arange(0, cand_size).type_as(tokens)

        # helper function for allocating buffers on the fly
        buffers = {}

        def buffer(name, type_of=tokens):  # noqa
            if name not in buffers:
                buffers[name] = type_of.new()
            return buffers[name]

        def is_finished(sent, step, unfin_idx):
            """
            Check whether we've finished generation for a given sentence, by
            comparing the worst score among finalized hypotheses to the best
            possible score among unfinalized hypotheses.
            """
            assert len(finalized[sent]) <= beam_size
            if len(finalized[sent]) == beam_size or step == max_len:
                return True
            return False

        def finalize_hypos(step, bbsz_idx, eos_scores):
            """
            Finalize the given hypotheses at this step, while keeping the total
            number of finalized hypotheses per sentence <= beam_size.

            Note: the input must be in the desired finalization order, so that
            hypotheses that appear earlier in the input are preferred to those
            that appear later.

            Args:
                step: current time step
                bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
                    indicating which hypotheses to finalize
                eos_scores: A vector of the same size as bbsz_idx containing
                    scores for each hypothesis
            """
            assert bbsz_idx.numel() == eos_scores.numel()

            # clone relevant token and attention tensors
            tokens_clone = tokens.index_select(0, bbsz_idx)
            tokens_clone = tokens_clone[:, 1:step + 2]  # skip the first index, which is EOS
            assert not tokens_clone.eq(self.eos).any()
            tokens_clone[:, step] = self.eos
            attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None

            # compute scores per token position
            pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1]
            pos_scores[:, step] = eos_scores
            # convert from cumulative to per-position scores
            pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]

            # normalize sentence-level scores
            if self.normalize_scores:
                eos_scores /= (step + 1) ** self.len_penalty

            cum_unfin = []
            prev = 0
            for f in finished:
                if f:
                    prev += 1
                else:
                    cum_unfin.append(prev)

            sents_seen = set()
            for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), eos_scores.tolist())):
                unfin_idx = idx // beam_size
                sent = unfin_idx + cum_unfin[unfin_idx]

                sents_seen.add((sent, unfin_idx))

                if self.match_source_len and step > src_lengths[unfin_idx]:
                    score = -math.inf

                def get_hypo():

                    if attn_clone is not None:
                        # remove padding tokens from attn scores
                        hypo_attn = attn_clone[i]
                    else:
                        hypo_attn = None

                    return {
                        'tokens': tokens_clone[i],
                        'score': score,
                        'attention': hypo_attn,  # src_len x tgt_len
                        'alignment': None,
                        'positional_scores': pos_scores[i],
                    }

                if len(finalized[sent]) < beam_size:
                    finalized[sent].append(get_hypo())
                    if return_all_tokens:
                        is_finalized[step, idx] = 1

            newly_finished = []
            for sent, unfin_idx in sents_seen:
                # check termination conditions for this sentence
                if not finished[sent] and is_finished(sent, step, unfin_idx):
                    finished[sent] = True
                    newly_finished.append(unfin_idx)
            return newly_finished

        reorder_state = None
        batch_idxs = None
        for step in range(max_len + 1):  # one extra step for EOS marker
            # reorder decoder internal states based on the prev choice of beams
            if reorder_state is not None:
                if batch_idxs is not None:
                    # update beam indices to take into account removed sentences
                    corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs)
                    reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size)
                model.reorder_incremental_state(reorder_state)
                encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state)

            lprobs, avg_attn_scores, probs_means, probs_vars, sing_vars, sing_entropy, ens_vars, ens_entropy = model.forward_decoder(
                tokens[:, :step + 1], encoder_outs, temperature=self.temperature, with_var=True
            )
            lprobs[lprobs != lprobs] = -math.inf

            lprobs[:, self.pad] = -math.inf  # never select pad
            lprobs[:, self.unk] -= self.unk_penalty  # apply unk penalty

            # handle max length constraint
            if step >= max_len:
                lprobs[:, :self.eos] = -math.inf
                lprobs[:, self.eos + 1:] = -math.inf

            # handle prefix tokens (possibly with different lengths)
            if prefix_tokens is not None and step < prefix_tokens.size(1) and step < max_len:
                prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1)
                prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_vars = probs_vars.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_means = probs_means.gather(-1, prefix_toks.unsqueeze(-1))
                prefix_mask = prefix_toks.ne(self.pad)
                lprobs[prefix_mask] = -math.inf
                # TODO
                lprobs[prefix_mask] = lprobs[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask]
                )
                probs_vars[prefix_mask] = probs_vars[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_vars[prefix_mask]
                )
                probs_means[prefix_mask] = probs_means[prefix_mask].scatter_(
                    -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_means[prefix_mask]
                )
                # if prefix includes eos, then we should make sure tokens and
                # scores are the same across all beams
                eos_mask = prefix_toks.eq(self.eos)
                if eos_mask.any():
                    # validate that the first beam matches the prefix
                    first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1]
                    eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0]
                    target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step]
                    assert (first_beam == target_prefix).all()

                    def replicate_first_beam(tensor, mask):
                        tensor = tensor.view(-1, beam_size, tensor.size(-1))
                        tensor[mask] = tensor[mask][:, :1, :]
                        return tensor.view(-1, tensor.size(-1))

                    # copy tokens, scores and lprobs from the first beam to all beams
                    tokens = replicate_first_beam(tokens, eos_mask_batch_dim)
                    scores = replicate_first_beam(scores, eos_mask_batch_dim)
                    lprobs = replicate_first_beam(lprobs, eos_mask_batch_dim)
                    probs_vars = replicate_first_beam(probs_vars, eos_mask_batch_dim)
                    probs_means = replicate_first_beam(probs_means, eos_mask_batch_dim)
            elif step < self.min_len:
                # minimum length constraint (does not apply if using prefix_tokens)
                lprobs[:, self.eos] = -math.inf

            if self.no_repeat_ngram_size > 0:
                # for each beam and batch sentence, generate a list of previous ngrams
                gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)]
                for bbsz_idx in range(bsz * beam_size):
                    gen_tokens = tokens[bbsz_idx].tolist()
                    for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]):
                        gen_ngrams[bbsz_idx][tuple(ngram[:-1])] =                                 gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]]

            # Record attention scores
            if type(avg_attn_scores) is list:
                avg_attn_scores = avg_attn_scores[0]
            if avg_attn_scores is not None:
                if attn is None:
                    attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2)
                    attn_buf = attn.clone()
                attn[:, :, step + 1].copy_(avg_attn_scores)

            scores = scores.type_as(lprobs)
            scores_buf = scores_buf.type_as(lprobs)
            eos_bbsz_idx = buffer('eos_bbsz_idx')
            eos_scores = buffer('eos_scores', type_of=scores)

            self.search.set_src_lengths(src_lengths)

            if self.no_repeat_ngram_size > 0:
                def calculate_banned_tokens(bbsz_idx):
                    # before decoding the next token, prevent decoding of ngrams that have already appeared
                    ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist())
                    return gen_ngrams[bbsz_idx].get(ngram_index, [])

                if step + 2 - self.no_repeat_ngram_size >= 0:
                    # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
                    banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)]
                else:
                    banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]

                for bbsz_idx in range(bsz * beam_size):
                    lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf


            clean_lprobs = lprobs.clone().detach()
            cand_scores, cand_indices, cand_beams = self.search.step(
                step,
                lprobs.view(bsz, -1, self.vocab_size),
                scores.view(bsz, beam_size, -1)[:, :, :step],
            )
            
            ishape = cand_indices.shape[1]
            cand_vars = torch.ones((bsz, ishape))
            if with_var:
                boffsets = (torch.cumsum(
                    torch.full((bsz, ), ishape, dtype=torch.int64, device=cand_indices.device) - ishape,
                    dim=0
                )).unsqueeze_(-1).T

                boffset_idxs = (cand_indices + boffsets).flatten()
                cand_vars = probs_vars.flatten()[boffset_idxs].view(bsz, -1)
                cand_means = probs_means.flatten()[boffset_idxs].view(bsz, -1)

            # cand_bbsz_idx contains beam indices for the top candidate
            # hypotheses, with a range of values: [0, bsz*beam_size),
            # and dimensions: [bsz, cand_size]
            cand_bbsz_idx = cand_beams.add(bbsz_offsets)
            
            if return_all_tokens:
                all_scores[step + 1] = cand_scores #-scores[cand_beams,step-1]
                all_softmaxes[step + 1] = clean_lprobs
                all_vars[step + 1] = cand_vars
                all_vars_vocab[step + 1] = probs_vars
                all_means[step + 1] = cand_means
                all_tokens[step + 1] = cand_indices
                all_bbsz_idx[step] = cand_bbsz_idx
                all_sing_vars[step + 1] = sing_vars
                all_sing_entropy[step + 1] = sing_entropy
                all_ens_vars[step + 1] = ens_vars
                all_ens_entropy[step + 1] = ens_entropy


            # finalize hypotheses that end in eos, except for blacklisted ones
            # or candidates with a score of -inf
            eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf)
            eos_mask[:, :beam_size][blacklist] = 0

            # only consider eos when it's among the top beam_size indices
            torch.masked_select(
                cand_bbsz_idx[:, :beam_size],
                mask=eos_mask[:, :beam_size],
                out=eos_bbsz_idx,
            )

            finalized_sents = set()
            if eos_bbsz_idx.numel() > 0:
                torch.masked_select(
                    cand_scores[:, :beam_size],
                    mask=eos_mask[:, :beam_size],
                    out=eos_scores,
                )
                finalized_sents = finalize_hypos(step, eos_bbsz_idx, eos_scores)
                num_remaining_sent -= len(finalized_sents)

            assert num_remaining_sent >= 0
            if num_remaining_sent == 0:
                break
            assert step < max_len

            if len(finalized_sents) > 0:
                new_bsz = bsz - len(finalized_sents)

                # construct batch_idxs which holds indices of batches to keep for the next pass
                batch_mask = cand_indices.new_ones(bsz)
                batch_mask[cand_indices.new(finalized_sents)] = 0
                batch_idxs = batch_mask.nonzero().squeeze(-1)

                eos_mask = eos_mask[batch_idxs]
                cand_beams = cand_beams[batch_idxs]
                bbsz_offsets.resize_(new_bsz, 1)
                cand_bbsz_idx = cand_beams.add(bbsz_offsets)
                cand_scores = cand_scores[batch_idxs]
                cand_indices = cand_indices[batch_idxs]
                if prefix_tokens is not None:
                    prefix_tokens = prefix_tokens[batch_idxs]
                src_lengths = src_lengths[batch_idxs]
                blacklist = blacklist[batch_idxs]

                scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                scores_buf.resize_as_(scores)
                tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1)
                tokens_buf.resize_as_(tokens)
                if attn is not None:
                    attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1)
                    attn_buf.resize_as_(attn)
                bsz = new_bsz
            else:
                batch_idxs = None

            # Set active_mask so that values > cand_size indicate eos or
            # blacklisted hypos and values < cand_size indicate candidate
            # active hypos. After this, the min values per row are the top
            # candidate active hypos.
            active_mask = buffer('active_mask')
            eos_mask[:, :beam_size] |= blacklist
            torch.add(
                eos_mask.type_as(cand_offsets) * cand_size,
                cand_offsets[:eos_mask.size(1)],
                out=active_mask,
            )

            # get the top beam_size active hypotheses, which are just the hypos
            # with the smallest values in active_mask
            active_hypos, new_blacklist = buffer('active_hypos'), buffer('new_blacklist')
            torch.topk(
                active_mask, k=beam_size, dim=1, largest=False,
                out=(new_blacklist, active_hypos)
            )

            # update blacklist to ignore any finalized hypos
            blacklist = new_blacklist.ge(cand_size)[:, :beam_size]
            assert (~blacklist).any(dim=1).all()

            active_bbsz_idx = buffer('active_bbsz_idx')
            torch.gather(
                cand_bbsz_idx, dim=1, index=active_hypos,
                out=active_bbsz_idx,
            )
            active_scores = torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores[:, step].view(bsz, beam_size),
            )

            active_bbsz_idx = active_bbsz_idx.view(-1)
            active_scores = active_scores.view(-1)

            # copy tokens and scores for active hypotheses
            torch.index_select(
                tokens[:, :step + 1], dim=0, index=active_bbsz_idx,
                out=tokens_buf[:, :step + 1],
            )
            torch.gather(
                cand_indices, dim=1, index=active_hypos,
                out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1],
            )
            if step > 0:
                torch.index_select(
                    scores[:, :step], dim=0, index=active_bbsz_idx,
                    out=scores_buf[:, :step],
                )
            torch.gather(
                cand_scores, dim=1, index=active_hypos,
                out=scores_buf.view(bsz, beam_size, -1)[:, :, step],
            )

            # copy attention for active hypotheses
            if attn is not None:
                torch.index_select(
                    attn[:, :, :step + 2], dim=0, index=active_bbsz_idx,
                    out=attn_buf[:, :, :step + 2],
                )

            # swap buffers
            tokens, tokens_buf = tokens_buf, tokens
            scores, scores_buf = scores_buf, scores
            if attn is not None:
                attn, attn_buf = attn_buf, attn

            # reorder incremental state in decoder
            reorder_state = active_bbsz_idx

        # sort by score descending
        for sent in range(len(finalized)):
            finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True)
            
        if return_all_tokens:
            all_scores = all_scores[:step + 2]
            all_softmaxes = all_softmaxes[:step + 2]
            all_vars = all_vars[:step + 2]
            all_vars_vocab = all_vars_vocab[:step + 2]
            all_means = all_means[:step + 2]
            all_sing_vars = all_sing_vars[:step + 2]
            all_sing_entropy = all_sing_entropy[:step + 2]
            all_ens_vars = all_ens_vars[:step + 2]
            all_ens_entropy = all_ens_entropy[:step + 2]
            if with_var:
                    return finalized, all_tokens[:step + 2].cpu(), all_scores.cpu(), is_finalized[:step + 1].cpu(), all_bbsz_idx[:step+1].cpu(), all_softmaxes.cpu(), all_means.cpu(), all_vars.cpu(), all_vars_vocab.cpu(), all_sing_vars.cpu(), all_sing_entropy.cpu(), all_ens_vars.cpu(), all_ens_entropy.cpu()
                
            return finalized, all_tokens[:step + 2].cpu(), all_scores.cpu(), is_finalized[:step + 1].cpu(), all_bbsz_idx[:step+1].cpu()
        return finalized

In [95]:
translator = tmp_e.SourceSequenceGenerator(
    tgt_dict=tgt_dict,
    beam_size=args.beam,
    min_len=args.min_len,
    normalize_scores=(not args.unnormalized),
    len_penalty=args.lenpen,
    unk_penalty=args.unkpen
    # ,
    # ,
    # sampling_temperature=args.sampling_temperature,
    # diverse_beam_groups=args.diverse_beam_groups,
    # diverse_beam_strength=args.diverse_beam_strength
)

In [145]:
tgt_len = tgt_tokens.shape[0]

In [146]:
tokens = []
last_idx = 0
idx = torch.arange(all_tokens.shape[1])
for i in range(1, tgt_len + 1):
    mask = (all_tokens[i] == tgt_tokens[i - 1]) & (all_bbsz_idx[i - 1] == last_idx)
    tokens.append(all_tokens[i][mask][0])
    last_idx = idx[mask][0]

In [152]:
torch.full((0, ), False, dtype=torch.bool)

tensor([], dtype=torch.bool)

In [147]:
tokens

[tensor(764),
 tensor(8357),
 tensor(4),
 tensor(2588),
 tensor(22008),
 tensor(15626),
 tensor(2)]

In [148]:
tgt_tokens

tensor([  764,  8357,     4,  2588, 22008, 15626,     2])

In [136]:
max_len = min(ref_tokens.shape[0], tgt_tokens.shape[0])

In [137]:
mask = ref_tokens[:max_len] == tgt_tokens[:max_len]

In [138]:
mask

tensor([ True,  True,  True,  True, False, False, False])

In [139]:
idxs = np.arange(max_len)

In [140]:
sidx = idxs[~mask][0]

In [141]:
sidx

4

In [142]:
mask[sidx:] = False

In [143]:
mask = torch.cat((
    mask,
    torch.full((1, ), False, dtype=torch.bool)
))

In [144]:
mask

tensor([ True,  True,  True,  True, False, False, False, False])

In [97]:
sent = 'It should be noted that the marine environment is the least known of environments .'
sent = 'Greetings, my name is Dmitriy'
ref_sent = 'Привет, меня зовут Дмитрий'
# sent = 'Sort of like what stylist Lino Villaventura organized .'

all_vars = None
all_sing_vars = None
all_ens_vars = None
all_softmaxes = None
for batch in make_batches(prepare_input(sent, l='en'), args, task, max_positions):
    encoder_input = {'net_input': {'src_tokens': batch.src_tokens, 'src_lengths': batch.src_lengths}}

    translations, all_tokens, all_scores, is_finalized, all_bbsz_idx, all_softmaxes, all_means, all_means_vocab, all_vars, all_vars_vocab, all_sing_vars, all_sing_entropy, all_ens_vars, all_ens_entropy, inens_dist = translator.generate(
        models=models,
        sample=encoder_input,
        return_all_tokens=True,
        with_var=True
    )
    
if all_vars is None:
    all_vars = all_scores
if all_sing_vars is None:
    all_sing_vars = torch.zeros((all_vars.shape[0], len(models), int(BEAM)))
if all_ens_vars is None:
    all_ens_vars = torch.zeros((all_vars.shape[0], int(BEAM)))
    
all_sing_vars = all_sing_vars.mean(-1)
all_sing_entropy = all_sing_entropy.mean(-1)
all_ens_vars = all_ens_vars.mean(-1)
all_ens_entropy = all_ens_entropy.mean(-1)

In [155]:
all_scores

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.8411, -2.2668, -2.8892, -3.1990, -3.7152, -3.7501, -4.0364, -4.5616,
         -5.1806, -5.3026],
        [-1.6022, -1.6109, -2.3310, -3.2579, -3.8605, -4.2734, -4.6233, -5.2273,
         -5.3724, -5.9524],
        [-1.7451, -2.0878, -2.3926, -3.2755, -3.7603, -4.0153, -4.1132, -4.8225,
         -6.2610, -6.4285],
        [-1.9016, -2.3398, -2.5899, -3.3798, -3.9178, -4.2305, -4.6870, -5.8284,
         -6.1631, -6.4497],
        [-1.9607, -2.7230, -2.8193, -3.7478, -4.1105, -4.2128, -5.3459, -5.4821,
         -5.8503, -6.6609],
        [-2.0536, -2.8535, -2.9886, -3.7838, -4.1686, -6.5750, -7.2086, -7.3857,
         -7.5527, -7.5664],
        [-2.4957, -2.9195, -3.0439, -3.7284, -3.8076, -4.2620, -5.2305, -6.9986,
         -7.6050, -7.7857],
        [-3.0127, -3.1372, -3.8339, -3.8476, -4.5987, -6.2825, -7.9192, -8.8219,
         -8.8915, -8.9592],
        [-3.4692, -

In [158]:
all_means.shape

torch.Size([10, 10])

In [157]:
np.log(all_means)

  """Entry point for launching an IPython kernel.


tensor([[    -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
             -inf,     -inf,     -inf],
        [ -0.8411,  -2.2668,  -2.8892,  -3.1990,  -3.7152,  -3.7501,  -4.0364,
          -4.5616,  -5.1806,  -5.3026],
        [ -0.7611,  -0.7698,  -9.6543, -11.4640, -12.8407,  -7.3683, -11.5254,
          -9.1998, -11.9993, -13.7989],
        [ -9.5482,  -0.4856,  -0.4856, -11.5022,  -2.1581, -13.5382,  -2.5111,
          -3.2203, -11.8494,  -4.8263],
        [ -0.1565, -11.9322, -13.3653,  -9.0736,  -9.0736, -12.9287, -13.5603,
          -4.0833,  -4.4180,  -4.7046],
        [ -0.0590,  -8.3505,  -8.3505, -12.7486,  -9.5137,  -9.1285, -10.2981,
         -13.0907,  -9.1408,  -4.7592],
        [ -0.0929, -11.4417, -11.4417, -13.0378, -10.4241, -10.4241, -12.3955,
         -10.4241, -12.3955, -12.3778],
        [ -0.4421, -11.6701, -11.6701,  -1.6749, -12.5738, -10.9481,  -3.1770,
          -4.9451, -12.4458, -12.4458],
        [ -0.0932,  -0.0932, -12.5777,  -8.8583,

In [154]:
translator.no_repeat_ngram_size

0

In [107]:
tgt_tokens

tensor([  764,  8357,     4,  2588, 22008, 15626,     2])

In [119]:
all_tokens

tensor([[    1,     1,     1,     1,     1,     1,     1,     1,     1,     1],
        [  764,   609,   140, 14617, 11205,   338,   440, 15399, 17049, 19273],
        [ 4643,  8357, 16861,  1544, 11081,  8807,  8021,  2184, 10431, 10196],
        [    4,  2555,  2555,  9868, 10222,  1041,  5770,  4138, 19175,  5754],
        [ 2588,   306,  1964,     4,     4,   352,   493, 22008,  2962, 15399],
        [22008,     4,     4, 14705,  2588,  1775,  3791,  9525,  8297,  1107],
        [15626,  2588,  2588,   592, 22008, 22008,  2962, 22008,  2962, 15399],
        [    2, 22008, 22008,     5, 10253, 15626,   384,     4,  1107,  1107],
        [15626, 15626,   268,     2,     2,     5,   384,  2869,  2869, 21725],
        [    2,     2,   312,     5,     5,   384,     2, 15626,   384,     4]])

In [133]:
ref_sent

'Привет, меня зовут Дмитрий'

In [121]:
tgt_dict.string(translations[0][0]['tokens'], bpe_symbol='@@ ')

'Привет , меня зовут Дмитрий'

In [131]:
tgt_dict.eos()

2

In [132]:
tgt_dict.string([2], bpe_symbol='@@ ')

''

In [104]:
translator.no_repeat_ngram_size

0

In [98]:
inens_dist.size()

torch.Size([10, 4, 5, 31232])

In [100]:
probs, inens_mean, inens_mean_vocab, inens_var, inens_var_vocab, ens_softmaxes, inens_dist = tmp_e.get_translation_stats(tgt_tokens, all_tokens, all_bbsz_idx, all_scores, all_means, all_means_vocab     , all_vars, all_vars_vocab, all_softmaxes, inens_dist)

In [102]:
inens_dist.size()

torch.Size([7, 4])

In [103]:
inens_dist[0]

tensor([0.4269, 0.2220, 0.6596, 0.4165])

In [None]:
all_means.size()

In [None]:
all_ens_vars.size()

In [None]:
all_ens_entropy.size()

In [None]:
all_sing_vars.size()

In [None]:
all_sing_entropy.size()

In [None]:
all_vars.size()

In [None]:
all_means.size()

In [None]:
all_sing_vars.size()

In [None]:
all_sing_entropy.size()

In [None]:
all_vars.size()

In [None]:
all_means.size()

In [134]:
ref_tokens = tgt_dict.encode_line(prepare_input(ref_sent, 'en')[0], add_if_not_exist=False).long()

In [135]:
ref_tokens

tensor([  764,  8357,     4,  2588,  1129, 21981, 15626,     2])

In [None]:
print(tgt_dict.string(translations[0][0]['tokens'], bpe_symbol='@@ '))

---

In [None]:
from collections import defaultdict


def get_stats_distribution(ref_tokens, tgt_tokens, token_cmp, *args):
    ref_len = ref_tokens.shape[0]
    tgt_len = tgt_tokens.shape[0]
    max_len = min(ref_len, tgt_len)
    
    mask = token_cmp(ref_tokens[:max_len], tgt_tokens[:max_len])
    stats = dict()
    for name, score in args:
        stats[name] = np.array(score[:max_len][mask].tolist())
        
    return stats
        

def get_translation_stats(tgt_tokens, beam_tokens, beam_scores, beam_vars):
    tgt_len = tgt_tokens.shape[0]
    tgt_tokens = tgt_tokens.view(tgt_len, -1)
    
    mask = beam_tokens[1:tgt_len + 1] == tgt_tokens
    
    tscores = beam_scores[1:tgt_len + 1][mask]
    tvars = beam_vars[1:tgt_len + 1][mask]
    
    return tscores, tvars

In [None]:
tscores, tvars = get_translation_stats(translations[0][0]['tokens'], all_tokens, all_scores, all_vars)

In [None]:
get_stats_distribution(
    ref_tokens,
    translations[0][0]['tokens'],
    lambda x, y: x == y,
    ('prob', tscores),
    ('inensemble_var', tvars),
    ('m1_svar', all_sing_vars[:, 0]),
    ('m2_svar', all_sing_vars[:, 1]),
    ('m3_svar', all_sing_vars[:, 2]),
    ('m4_svar', all_sing_vars[:, 3]),
    ('ens_svar', all_ens_vars)
)

In [None]:
tgt_dict.string([764])

In [None]:
ref_tokens

In [None]:
tgt_tokens

In [None]:
tscores.size()

In [None]:
tvars.size()

In [None]:
all_sing_vars[:, 0].size()

In [None]:
all_ens_vars.size()

In [None]:
from matplotlib import pyplot as plt

In [None]:
plt.figure(figsize=(8, 6))

sns.distplot(np.exp(stats['positive']['vars']), label='positive', color='red')
sns.distplot(np.exp(stats['negative']['vars']), label='negative', color='gray')

In [None]:
t = tgt_dict.encode_line(prepare_input(ref_sent, 'en')[0], add_if_not_exist=False).long()

In [None]:
t

In [None]:
translations[0][0]['tokens']

In [None]:
tgt_dict.string(t, bpe_symbol='@@ ')

---

In [None]:
ref_tokens

In [None]:
tgt_tokens = translations[0][0]['tokens']

In [None]:
tgt_len = tgt_tokens.shape[0]

In [None]:
tgt_tokens

In [None]:
tgt_tokens

In [None]:
all_tokens

In [None]:
mask = all_tokens[1:tgt_len + 1] == tgt_tokens.view(tgt_len, -1)

In [None]:
mask.size()

In [None]:
all_tokens.size()

In [None]:
all_scores.size()

In [None]:
all_vars[1:tgt_len + 1][mask].size()

In [None]:
mask

In [None]:
all_tokens[1:tgt_len + 1][mask].size()

In [None]:
all_tokens

In [None]:
r_len = ref_tokens.size(0)
t_len = tgt_tokens.size(0)
ln = min(r_len, t_len)

In [None]:
mask = ref_tokens[:ln] == tgt_tokens[:ln]
idxs = torch.arange(ln)

In [None]:
mask[-2] = True

In [None]:
mask

In [None]:
fidx = idxs[~mask][0]

In [None]:
fidx.tolist()

In [None]:
mask[fidx:] = False

In [None]:
mask

In [None]:
mask[mask]

In [None]:
mask

In [None]:
sidx = idxs[~mask][0]
mask[sidx:] = False

In [None]:
mask = torch.cat((
    mask,
    torch.full((t_len - ln,), False, dtype=torch.bool)
)
)

In [None]:
mask

In [None]:
suffix = tgt_tokens[mask]

In [None]:
suffix

In [None]:
mask

In [None]:
mask.size()

---

In [116]:
def print_beam_search(beam_size=None):
    beam_size=args.beam
    if beam_size is not None:
        beam_size = beam_size
    x = []
    y = []
    name = []
    score = []
    var = []
    inds_by_step = []
    inds_by_step_noeos = []
    parent_index = []
    linew = []

    edge_to_ind = {}
    lwidth = []

    cur_ind = 0

    for step, tokens in enumerate(all_tokens):
        if step == 0:
            inds_by_step.append([cur_ind])
            inds_by_step_noeos.append([0])
            cur_ind += 1
            x.append(step)
            y.append((beam_size + 1) / 2)
            name.append(tgt_dict.pad_word)
            score.append(0)
            var.append(0)
            linew.append(1)
        else:
            non_eos_encountered=0
            cur_inds = []
            cur_noeos = []
            for i, beam_tk in enumerate(tokens):
                cur_inds.append(cur_ind)
                cur_ind += 1
                x.append(step)
                y.append(i + 1)
                name.append(tgt_dict.symbols[beam_tk].replace('@@',''))
                score.append(all_scores[step, i].item() / (step ** args.lenpen))
                var.append(all_vars[step, i].item())
                prev_beam_ind = all_bbsz_idx[step - 1, i].item()
                parent_index.append(
                    inds_by_step[step - 1][inds_by_step_noeos[step - 1][prev_beam_ind]]
                )
                edge = (
                    all_tokens[step - 1, inds_by_step_noeos[step - 1][prev_beam_ind]].item(), 
                    beam_tk.item()
                )
                for j, translation in enumerate(translations[0]):
                    if step == 1:
                        res_edge = (tgt_dict.pad_index, translation['tokens'][step - 1].item())
                    else:
                        res_edge = tuple(translation['tokens'][step - 2: step].cpu().numpy())
                    if edge == res_edge:
                        lwidth.append(1 + 0.5 * (beam_size - j))
                        break
                else:
                    lwidth.append(1)
                edge_to_ind[edge] = len(lwidth)
                if i < beam_size and is_finalized[step - 1, i].item():
                    linew.append(2)
                else:
                    linew.append(1)
                if beam_tk != tgt_dict.eos_index:
                    non_eos_encountered += 1
                    cur_noeos.append(i)
                    if non_eos_encountered == beam_size:
                        break
            inds_by_step.append(cur_inds)
            inds_by_step_noeos.append(cur_noeos)


    index = sum(inds_by_step, [])            

    src=ColumnDataSource(data=dict(x=x, y=y, name=name, index=index, score=np.exp(score), var=var, linew=linew))

    plot = figure(
        x_range=(min(x) - 0.25, max(x) + 1),
        y_range=(min(y) - 1, max(y) + 1),
        tools=[
            HoverTool(tooltips=[
                ('Name','@name'),
                ('Score','@score'),
                ('Var', '@var')
            ]),
            PanTool(),
            BoxZoomTool()
        ],
        toolbar_location=None,
        plot_width=900,
        plot_height=600,
        x_axis_type=None,
        y_axis_type=None
    )
    graph = GraphRenderer()

    graph.node_renderer.data_source.data=src.data
    graph.node_renderer.glyph = Circle(radius=0.1, fill_color='#9999ee', line_width='linew')

    graph.edge_renderer.data_source.data = dict(
        start=parent_index,
        end=index[1:],
        width=lwidth
    )
    graph.edge_renderer.glyph = MultiLine()

    graph_layout = dict(zip(index, zip(x, y)))
    graph.layout_provider = StaticLayoutProvider(graph_layout=graph_layout)

    labels = LabelSet(
        x='x',
        y='y',
        text='name',
        level='glyph',
        x_offset=-15,
        y_offset=20,
        source=src
    )

    plot.renderers.append(graph)
    plot.add_layout(labels)

    ticker = SingleIntervalTicker(interval=1, num_minor_ticks=5)
    xaxis = LinearAxis(ticker=ticker)
    plot.add_layout(xaxis, 'below')

    output_file('graph.html')
    show(plot)

In [118]:
print_beam_search()

In [None]:
for step, mvars in enumerate(all_sing_vars):
    print(step, 'step')
    for midx, var in enumerate(mvars):
        print('\t', 'model', midx + 1, 'softmax var - ', var)

In [None]:
for step, var in enumerate(all_ens_vars):
    print(step, 'step')
    print('\t', 'ensemble sofrmax var - ', var)

In [None]:
task.load_dataset(args.gen_subset)

In [None]:
itr = task.get_batch_iterator(
        dataset=task.dataset(args.gen_subset),
        max_tokens=args.max_tokens,
        max_sentences=args.max_sentences,
        max_positions=max_positions,
        ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
        required_batch_size_multiple=args.required_batch_size_multiple,
        num_shards=args.num_shards,
        shard_id=args.shard_id,
        num_workers=args.num_workers,
).next_epoch_itr(shuffle=False)

In [None]:
for sample in itr:
    break

In [None]:
progress = progress_bar.build_progress_bar(
        args,
        itr
)

In [None]:
import tqdm

In [None]:
for i in tqdm.tqdm(range(100)):
    continue

In [None]:
os.environ

In [None]:
for sample in progress:
    print('*')
    d = sample
    print('*')

In [None]:
s['target'].size()

In [None]:
import fairseq

In [None]:
fairseq.__file__

In [None]:
for x in itr:
    b

In [None]:
x

In [None]:
encoder_input.keys()

In [None]:
ref_tokens

In [None]:
import json

In [None]:
a = {'a': 1, 'b': 2}

In [None]:
with open('test', 'w') as stream_output:
    json.dump(a, stream_output)

In [None]:
with open('test') as stream_output:
    t= json.load(stream_output)

In [None]:
t

In [None]:
x['target']

In [None]:
a = [1, 2, 3]

In [None]:
a.extend([2,3])

In [None]:
a

In [None]:
import torch

In [None]:
a = torch.zeros((1, 1))

In [None]:
a.device

In [None]:
a

In [None]:
a = torch.zeros((5, 5))

In [None]:
a

In [None]:
a.view(1, a.shape[0], -1).shape