Код был основан из [данного](https://github.com/microsoft/CodeBERT/tree/master/CodeBERT/code2nl) репозитория.

In [38]:
import torch
import torch.nn as nn


class Beam:
    def __init__(self, size, sos, eos, device):
        self.size = size
        if device == "cuda":
            self.tt = torch.cuda
        elif device == "cpu":
            self.tt = torch
        self.scores = self.tt.FloatTensor(size).zero_()
        self.prevKs = []
        self.nextYs = [self.tt.LongTensor(size).fill_(0)]
        self.nextYs[0][0] = sos
        self._eos = eos
        self.eosTop = False
        self.finished = []

    def getCurrentState(self):
        return self.tt.LongTensor(self.nextYs[-1]).view(-1, 1)

    def getCurrentOrigin(self):
        return self.prevKs[-1]

    def advance(self, wordLk):
        numWords = wordLk.size(1)
        if len(self.prevKs) > 0:
            beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk)
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] == self._eos:
                    beamLk[i] = -1e20
        else:
            beamLk = wordLk[0]
        flatBeamLk = beamLk.view(-1)
        bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True)
        self.scores = bestScores
        prevK = bestScoresId // numWords
        self.prevKs.append(prevK)
        self.nextYs.append((bestScoresId - prevK * numWords))
        for i in range(self.nextYs[-1].size(0)):
            if self.nextYs[-1][i] == self._eos:
                self.finished.append((self.scores[i], len(self.nextYs) - 1, i))
        if self.nextYs[-1][0] == self._eos:
            self.eosTop = True

    def done(self):
        return self.eosTop and len(self.finished) >= self.size

    def getFinal(self):
        if len(self.finished) == 0:
            self.finished.append((self.scores[0], len(self.nextYs) - 1, 0))
        self.finished.sort(key=lambda a: -a[0])
        if len(self.finished) != self.size:
            unfinished = []
            for i in range(self.nextYs[-1].size(0)):
                if self.nextYs[-1][i] != self._eos:
                    s = self.scores[i]
                    unfinished.append((s, len(self.nextYs) - 1, i))
            unfinished.sort(key=lambda a: -a[0])
            self.finished += unfinished[: self.size - len(self.finished)]
        return self.finished[: self.size]

    def getHyp(self, beam_res):
        hyps = []
        for _, timestep, k in beam_res:
            hyp = []
            for j in range(len(self.prevKs[:timestep]) - 1, -1, -1):
                hyp.append(self.nextYs[j + 1][k])
                k = self.prevKs[j][k]
            hyps.append(hyp[::-1])
        return hyps

    def buildTargetTokens(self, preds):
        sentence = []
        for pred in preds:
            tokens = []
            for tok in pred:
                if tok == self._eos:
                    break
                tokens.append(tok)
            sentence.append(tokens)
        return sentence

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, config, beam_size=None,
        max_length=None, sos_id=None, eos_id=None):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.config = config
        self.register_buffer("bias", torch.tril(torch.ones(2048, 2048)))
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.lsm = nn.LogSoftmax(dim=-1)
        self.tie_weights()
        self.beam_size = beam_size
        self.max_length = max_length
        self.sos_id = sos_id
        self.eos_id = eos_id

    def _tie_or_clone_weights(self, first_module, second_module):
        if self.config.torchscript:
            first_module.weight = nn.Parameter(second_module.weight.clone())
        else:
            first_module.weight = second_module.weight

    def tie_weights(self):
        self._tie_or_clone_weights(self.lm_head, self.encoder.embeddings.word_embeddings)

    def forward(self, source_ids, source_mask):
        outputs = self.encoder(source_ids, attention_mask=source_mask)
        encoder_output = outputs[0].permute([1, 0, 2]).contiguous()
        preds = []
        if source_ids.device.type == "cuda":
            zero = torch.cuda.LongTensor(1).fill_(0)
        elif source_ids.device.type == "cpu":
            zero = torch.LongTensor(1).fill_(0)
        for i in range(source_ids.shape[0]):
            beam = Beam(
                self.beam_size,
                self.sos_id,
                self.eos_id,
                device=source_ids.device.type,
            )
            context = encoder_output[:, i:i + 1].repeat(1, self.beam_size, 1)
            context_mask = source_mask[i:i + 1, :].repeat(self.beam_size, 1)
            input_ids = beam.getCurrentState()
            for _ in range(self.max_length):
                if beam.done():
                    break
                attn_mask = -1e4 * (1 - self.bias[:input_ids.shape[1], :input_ids.shape[1]])
                tgt_embeddings = (
                    self.encoder.embeddings(input_ids)
                    .permute([1, 0, 2])
                    .contiguous()
                )
                out = self.decoder(
                    tgt_embeddings,
                    context,
                    tgt_mask=attn_mask,
                    memory_key_padding_mask=(1 - context_mask).bool(),
                )
                out = torch.tanh(self.dense(out))
                hidden_states = out.permute([1, 0, 2]).contiguous()[:, -1, :]
                out = self.lsm(self.lm_head(hidden_states)).data
                beam.advance(out)
                input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin()))
                input_ids = torch.cat((input_ids, beam.getCurrentState()), -1)
            hyp = beam.getHyp(beam.getFinal())
            pred = beam.buildTargetTokens(hyp)[: self.beam_size]
            pred = [torch.cat([x.view(-1) for x in p] + [zero] * (self.max_length - len(p))).view(1, -1) for p in pred]
            preds.append(torch.cat(pred, 0).unsqueeze(0))
        preds = torch.cat(preds, 0)
        return preds

In [None]:
!wget https://code-summary.s3.amazonaws.com/pytorch_model.bin

In [39]:
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler


class InputFeatures:
    def __init__(self, example_id, source_ids,
        target_ids, source_mask, target_mask):
        self.example_id = example_id
        self.source_ids = source_ids
        self.target_ids = target_ids
        self.source_mask = source_mask
        self.target_mask = target_mask


def convert_examples_to_features(examples, tokenizer):
    features = []
    for example_index, example in enumerate(examples):
        source_tokens = tokenizer.tokenize(example)[: 256 - 2]
        source_tokens = [tokenizer.cls_token] + source_tokens + [tokenizer.sep_token]
        source_ids = tokenizer.convert_tokens_to_ids(source_tokens)
        source_mask = [1] * (len(source_tokens))
        padding_length = 256 - len(source_ids)
        source_ids += [tokenizer.pad_token_id] * padding_length
        source_mask += [0] * padding_length
        target_tokens = tokenizer.tokenize("None")
        target_tokens = [tokenizer.cls_token] + target_tokens + [tokenizer.sep_token]
        target_ids = tokenizer.convert_tokens_to_ids(target_tokens)
        target_mask = [1] * len(target_ids)
        padding_length = 128 - len(target_ids)
        target_ids += [tokenizer.pad_token_id] * padding_length
        target_mask += [0] * padding_length
        features.append(InputFeatures(
            example_index,
            source_ids,
            target_ids,
            source_mask,
            target_mask,
            )
        )
    return features

def conclusion(data, model, tokenizer):
    eval_sampler = SequentialSampler(data)
    eval_dataloader = DataLoader(data, sampler=eval_sampler, batch_size=len(data))
    model.eval()
    p = []
    for batch in eval_dataloader:
        batch = tuple(t.to('cpu') for t in batch)
        source_ids, source_mask = batch
        with torch.no_grad():
            preds = model(source_ids=source_ids, source_mask=source_mask)
            for pred in preds:
                t = pred[0].cpu().numpy()
                t = list(t)
                if 0 in t:
                    t = t[: t.index(0)]
                text = tokenizer.decode(t, clean_up_tokenization_spaces=False)
                p.append(text)
    return p


def get_feature(examples, tokenizer):
    features = convert_examples_to_features(examples, tokenizer)
    all_source_ids = torch.tensor([f.source_ids[: 256] for f in features], dtype=torch.long)
    all_source_mask = torch.tensor([f.source_mask[: 256] for f in features], dtype=torch.long)
    return TensorDataset(all_source_ids, all_source_mask)


def get_model(model_class, config, tokenizer):
    encoder = model_class(config=config)
    decoder_layer = nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
    decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
    model = Seq2Seq(
        encoder=encoder,
        decoder=decoder,
        config=config,
        beam_size=10,
        max_length=128,
        sos_id=tokenizer.cls_token_id,
        eos_id=tokenizer.sep_token_id,
    )
    model.load_state_dict(torch.load("pytorch_model.bin", map_location=torch.device("cpu")), strict=False)
    return model

In [None]:
!pip install transformers

In [40]:
from transformers import RobertaConfig, RobertaModel, RobertaTokenizer


model_name = "microsoft/codebert-base"
config = RobertaConfig.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name, do_lower_case=False)
model = get_model(model_class=RobertaModel, config=config, tokenizer=tokenizer).to('cpu')

def estimate(example):
    return conclusion(get_feature([example], tokenizer), model, tokenizer)

print(estimate("""
def is_digit(obj):
    return obj.isdigit()
"""))

print(estimate("""
def is_digit(a):
    return a.isdigit()
"""))

print(estimate("""
def is_digit(obj):
    if int(obj) == 0:
        return True
    elif int(obj) == 1:
        return True
    elif int(obj) == 2:
        return True
    elif int(obj) == 3:
        return True
    elif int(obj) == 4:
        return True
    elif int(obj) == 5:
        return True
    elif int(obj) == 6:
        return True
    elif int(obj) == 7:
        return True
    elif int(obj) == 8:
        return True
    elif int(obj) == 9:
        return True
    return False
"""))

print(estimate("""
def is_digit(obj):
    if int(obj) == 0 or int(obj) == 1 or int(obj) == 2 or int(obj) == 3 or int(obj) == 4 or int(obj) == 5 or int(obj) == 6 or int(obj) == 7 or int(obj) == 8 or int(obj) == 9:
        return True
    return False
"""))

print(estimate("""
def check_if_an_object_is_a_digit(obj):
    return obj.isdigit()
"""))

print(estimate("""
def check_if_an_object_is_a_str(bj):
    return obj.isdigit()
"""))

print(estimate("""
def f(obj):
    return obj.isdigit()
"""))

print(estimate("""
def is_number(obj):
    try:
        int(obj)
        return True
    except Exception:
        return False
"""))

print(estimate("""
def foo():
    pass
"""))


print(estimate("""
def fib(number):
    fib1 = 1
    fib2 = 1

    i = 0
    while i < number - 2:
        fib_sum = fib1 + fib2
        fib1 = fib2
        fib2 = fib_sum
        i = i + 1

    return fib2
"""))

print(estimate("""
def fib(number):
    fib1 = fib2 = 1
    number -= 2

    while number > 0:
        fib1, fib2 = fib2, fib1 + fib2
        number -= 1

    return fib2
"""))

print(estimate("""
def fibonacci(n):
    if n in (1, 2):
        return 1
    return fibonacci(n - 1) + fibonacci(n - 2)
"""))

print(estimate("""
def some_function(n):
    a = b = 1
    n -= 2

    while n > 0:
        a, b = b, a + b
        n -= 1

    return b
"""))

print(estimate("""
def quicksort(nums, fst, lst):
   if fst >= lst: return
 
   i, j = fst, lst
   pivot = nums[random.randint(fst, lst)]
 
   while i <= j:
       while nums[i] < pivot: i += 1
       while nums[j] > pivot: j -= 1
       if i <= j:
           nums[i], nums[j] = nums[j], nums[i]
           i, j = i + 1, j - 1
   quicksort(nums, fst, j)
   quicksort(nums, i, lst)
"""))

print(estimate("""
def bSort(array):
    # определяем длину массива
    length = len(array)
    #Внешний цикл, количество проходов N-1
    for i in range(length):
        # Внутренний цикл, N-i-1 проходов
        for j in range(0, length-i-1):
            #Меняем элементы местами
            if array[j] > array[j+1]:
                temp = array[j]
                array[j] = array[j+1]
                array[j+1] = temp
"""))

print(estimate("""
def bSort(array):
    length = len(array)
    for i in range(length):
        for j in range(0, length-i-1):
            if array[j] > array[j+1]:
                temp = array[j]
                array[j] = array[j+1]
                array[j+1] = temp
"""))

print(estimate("""
def binary_search_iterative(array, element):
    mid = 0
    start = 0
    end = len(array)
    step = 0

    while (start <= end):
        print("Subarray in step {}: {}".format(step, str(array[start:end+1])))
        step = step+1
        mid = (start + end) // 2

        if element == array[mid]:
            return mid

        if element < array[mid]:
            end = mid - 1
        else:
            start = mid + 1
    return -1
"""))


['Check if obj is a digit .']
['Check if a string is a digit .']
['Check if an integer is a valid digit .']
['Check if an object is a valid digit']
['Check if an object is a digit .']
['Check if the object is a string']
['Return boolean value .']
['Check if an object is a number']
['This function is called every time .']
['Fulfillacci function']
['Fulfillacci fib1 .']
[' fibonacci number']
['wrapper around n times']
['Sort a list of numbers in fst .']
['BSort function .']
['Sort an array .']
['Binary search method .']
