# **CodeBERT Experiments**

1. [x] load CodeBERT model
2. [x] run inference / get embeddings
3. [x] inverse embeddings
4. [x] edit embeddings
5. [x] inverse edited embeddings

## References

* https://huggingface.co/microsoft/codebert-base-mlm

## Environment

In [1]:
!python --version

Python 3.7.10


In [2]:
!pip install tokenizers
!pip install transformers

Collecting tokenizers
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 16.1MB/s 
[?25hInstalling collected packages: tokenizers
Successfully installed tokenizers-0.10.1
Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 16.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 37.9MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses:

## Dependencies

In [3]:
import torch
import numpy as np
from scipy.special import softmax
from transformers import pipeline
from transformers import RobertaTokenizer, RobertaForMaskedLM

## Settings

In [6]:
model_name = 'microsoft/codebert-base-mlm'

## 1. Load CodeBERT model

In [7]:
%time model = RobertaForMaskedLM.from_pretrained(model_name)
%time tokenizer = RobertaTokenizer.from_pretrained(model_name)

%time fill_mask = pipeline('fill-mask', model=model, tokenizer=tokenizer)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=504.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=501201999.0, style=ProgressStyle(descri…


CPU times: user 14.1 s, sys: 1.97 s, total: 16.1 s
Wall time: 16.8 s


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=150.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=25.0, style=ProgressStyle(description_w…


CPU times: user 433 ms, sys: 44.7 ms, total: 477 ms
Wall time: 738 ms
CPU times: user 409 µs, sys: 0 ns, total: 409 µs
Wall time: 414 µs


## 2. Run inference

In [8]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fcac7775190>

In [9]:
text1 = 'x = a + b'
text2 = 'x = a - b'
texts = [text1, text2]

In [10]:
tokens = {}

for text in texts:
  tokens_pt = tokenizer(text, return_tensors='pt')
  tokens[text] = tokens_pt

for text, tokens_pt in tokens.items():
  print(f'text: >{text}<')
  for key, value in tokens_pt.items():
    print(f'\t{key}: {value}')
    if key == 'input_ids':
      print(f'\t\ttokens (str): {[tokenizer.convert_ids_to_tokens(s) for s in value]}')
      print(f'\t\t#decoding: {[tokenizer.decode(v) for v in value]}')

text: >x = a + b<
	input_ids: tensor([[   0, 1178, 5457,   10, 2055,  741,    2]])
		tokens (str): [['<s>', 'x', 'Ġ=', 'Ġa', 'Ġ+', 'Ġb', '</s>']]
		#decoding: ['<s>x = a + b</s>']
	attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1]])
text: >x = a - b<
	input_ids: tensor([[   0, 1178, 5457,   10,  111,  741,    2]])
		tokens (str): [['<s>', 'x', 'Ġ=', 'Ġa', 'Ġ-', 'Ġb', '</s>']]
		#decoding: ['<s>x = a - b</s>']
	attention_mask: tensor([[1, 1, 1, 1, 1, 1, 1]])


In [13]:
embeddings = {}
for text, tokens_pt in tokens.items():
  %time output = fill_mask.model.roberta(**tokens_pt)

  last_hidden_state = output.last_hidden_state

  embeddings[text] = last_hidden_state

CPU times: user 93 ms, sys: 1.08 ms, total: 94 ms
Wall time: 202 ms
CPU times: user 74.1 ms, sys: 0 ns, total: 74.1 ms
Wall time: 73.7 ms


## 4. Inverse embeddings

In [22]:
for text, embedding in embeddings.items():
  print(text)
  lm_head_output = fill_mask.model.lm_head(embedding)

  for i in range(lm_head_output.shape[1]):
    probs = softmax(lm_head_output[0][i].detach().cpu().numpy())
    indices = probs.argsort()[-10:][::-1]
    print([(tokenizer.decode(int(idx)), round(probs[idx], 3)) for idx in indices])

x = a + b
[('\n', 0.738), ('x', 0.022), (' )', 0.016), (' }', 0.016), ('.', 0.012), (' ]', 0.007), ('.', 0.007), ('_', 0.006), (' "', 0.006), (" '", 0.005)]
[('x', 0.994), ('y', 0.004), (' x', 0.0), ('z', 0.0), ('b', 0.0), ('xy', 0.0), ('xx', 0.0), ('xc', 0.0), (' y', 0.0), ('X', 0.0)]
[(' =', 0.999), (' +=', 0.0), (' +', 0.0), (' (', 0.0), (' -', 0.0), ('=', 0.0), (' -=', 0.0), (' :', 0.0), (',', 0.0), (' >', 0.0)]
[(' a', 1.0), (' b', 0.0), (' 1', 0.0), ('a', 0.0), (' ax', 0.0), (' A', 0.0), (' 0', 0.0), (' c', 0.0), (' h', 0.0), (' e', 0.0)]
[(' +', 1.0), (' -', 0.0), (' *', 0.0), ('+', 0.0), (' plus', 0.0), (' /', 0.0), (' +=', 0.0), (' ^', 0.0), (' x', 0.0), (' =', 0.0)]
[(' b', 1.0), (' a', 0.0), (' B', 0.0), (' 1', 0.0), (' 2', 0.0), (' f', 0.0), ('b', 0.0), (' ab', 0.0), (' y', 0.0), (' c', 0.0)]
[(' b', 0.123), (' p', 0.054), (' y', 0.052), (' +', 0.043), (' s', 0.027), (' -', 0.027), (' #', 0.024), (' m', 0.023), (' x', 0.021), (' c', 0.02)]
x = a - b
[('\n', 0.537), ('x', 0.

## 5. Edit embeddings

In [19]:
plus_embeddings = embeddings[text1]
minus_embeddings = embeddings[text2]

new_embeddings = 0.5*(plus_embeddings + minus_embeddings)

## 6. Inverse edited embeddings

In [23]:
lm_head_output = fill_mask.model.lm_head(new_embeddings)

for i in range(lm_head_output.shape[1]):
  probs = softmax(lm_head_output[0][i].detach().cpu().numpy())
  indices = probs.argsort()[-10:][::-1]
  print([(tokenizer.decode(int(idx)), round(probs[idx], 3)) for idx in indices])

[('\n', 0.648), ('x', 0.032), (' )', 0.022), ('.', 0.018), (' }', 0.014), (' ]', 0.013), ('_', 0.008), ('.', 0.008), (' :', 0.008), (' <', 0.007)]
[('x', 0.995), ('y', 0.003), (' x', 0.001), ('z', 0.0), ('b', 0.0), ('xy', 0.0), ('xx', 0.0), ('xc', 0.0), (' y', 0.0), ('index', 0.0)]
[(' =', 0.998), (' +', 0.0), (' -', 0.0), (' (', 0.0), (' +=', 0.0), ('=', 0.0), (' :', 0.0), (' >', 0.0), (' -=', 0.0), (' <', 0.0)]
[(' a', 1.0), (' b', 0.0), (' 1', 0.0), (' ax', 0.0), (' 0', 0.0), (' c', 0.0), ('a', 0.0), (' A', 0.0), (' x', 0.0), (' e', 0.0)]
[(' +', 0.899), (' -', 0.1), (' *', 0.0), (' ^', 0.0), (' /', 0.0), (' x', 0.0), (' >', 0.0), (' <', 0.0), (' :', 0.0), ('.', 0.0)]
[(' b', 1.0), (' a', 0.0), (' B', 0.0), (' 2', 0.0), (' 1', 0.0), (' f', 0.0), ('b', 0.0), (' ab', 0.0), (' y', 0.0), (' c', 0.0)]
[(' b', 0.157), (' -', 0.061), (' +', 0.058), (' #', 0.045), (' y', 0.045), (' x', 0.033), (' p', 0.032), (' s', 0.023), (' c', 0.023), (' ', 0.022)]
