In [2]:
import torch

In [3]:
# Load model directly
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
model = AutoModel.from_pretrained("microsoft/codebert-base")

In [215]:
file_path = "/home/hasinthaka/Documents/Projects/AI/AI Pattern Mining/Pattern Validator/reposistories/test/main.py"
with open(file_path, "r") as file:
    code = file.read()

In [216]:
chunk_size = 128
stride = 68

In [217]:
all_tokens = tokenizer.encode(code, add_special_tokens=False)
total_length = len(all_tokens)

In [218]:
if total_length <= chunk_size:
    inputs = tokenizer(code,return_tensors='pt')
else:
    chunks = []
    for i in range(0,total_length,stride):
        chunk = all_tokens[i:i+chunk_size]
        chunks.append(chunk)
    input_ids = [tokenizer.build_inputs_with_special_tokens(chunk) for chunk in chunks]
    max_len = max(len(ids) for ids in input_ids)
    attention_masks = []
    padded_input_ids = []
    for chunk in input_ids:
        padding_length = max_len - len(chunk)
        padded_input_ids.append(chunk+[tokenizer.pad_token_id]*padding_length)
        attention_masks.append([1]*len(chunk)+[0]*padding_length)
    inputs = {'input_ids': torch.tensor(padded_input_ids), 'attention_mask': torch.tensor(attention_masks)}


with torch.no_grad():
    outputs = model(**inputs)

In [None]:
code_embedding1 = torch.mean(outputs.pooler_output,0)

In [223]:
expanded_attention_mask = inputs['attention_mask'].unsqueeze(-1).expand(outputs.last_hidden_state.shape)
masked_embeddings = expanded_attention_mask * outputs.last_hidden_state
code_embedding2 = (masked_embeddings.sum(1)/ expanded_attention_mask.sum(1)).mean(0)


In [225]:
code_embedding1.shape

torch.Size([768])

In [226]:
code_embedding2.shape

torch.Size([768])

In [229]:
(code_embedding2 == code_embedding1).all()

tensor(False)

In [230]:
code_embedding2

tensor([-3.3387e-01,  1.7264e-01,  2.8915e-01,  5.4412e-02, -4.0153e-01,
        -6.0531e-01, -7.0521e-02,  2.5179e-01,  4.7273e-01,  4.5606e-01,
        -3.2290e-01,  8.7772e-01, -2.4588e-01, -6.1590e-02,  8.6888e-01,
        -3.9303e-02,  2.4835e-01,  2.9010e-01, -1.2444e-01,  1.9904e-01,
        -2.7896e-01, -2.1683e-01,  6.1102e-01, -3.4378e-01,  4.5325e-01,
         3.7019e-01,  1.7254e-02,  7.9910e-01, -5.1181e-01,  9.1907e-01,
        -2.6335e-01, -6.2840e-02,  1.3305e+00,  1.9176e-01,  3.4877e-01,
        -3.4533e-01, -3.3068e-01,  1.9814e-01,  2.0197e-01, -2.7505e-01,
         6.7858e-02,  6.5950e-01, -9.8038e-01, -1.2639e-01,  4.8033e-01,
         3.1528e-01,  5.2230e-01, -6.3578e-02,  5.2775e-02,  6.0601e-01,
         5.4919e-01,  2.8437e-01, -5.1877e-01, -2.1618e-01,  5.1320e-01,
         6.1802e-01, -1.0959e+00, -8.0941e-01, -1.3728e-01, -5.3638e-01,
        -1.6359e-01, -2.0480e-01, -3.7444e-01, -2.6077e-01,  1.3112e+00,
         2.4611e-01,  5.5596e-01,  6.1328e-01, -1.7

In [231]:
code_embedding1

tensor([ 5.1521e-01, -5.1906e-01, -6.4107e-01,  1.2329e-01,  2.0561e-01,
        -8.2672e-02,  5.9149e-01, -3.5023e-01,  1.3595e-02, -3.7011e-01,
         4.7680e-01,  7.6308e-02, -4.9111e-01,  1.5376e-01, -5.6089e-02,
         5.8688e-01,  4.8604e-01, -5.1976e-01,  7.8035e-02,  4.1510e-01,
        -1.3623e-01,  6.4044e-01,  9.8691e-02,  1.3483e-01, -3.9965e-01,
         3.4489e-01,  3.0546e-01, -2.8339e-02,  5.9357e-01,  1.6499e-01,
        -2.1384e-01,  2.0948e-01,  3.8712e-01, -1.2605e-01, -3.7100e-01,
        -7.6472e-03, -4.2127e-01,  1.0034e-01,  7.9731e-01, -1.0827e-01,
        -4.0054e-01,  2.8453e-02, -5.4524e-02, -4.5908e-01,  1.1625e-01,
         6.6270e-01,  1.5066e-01, -6.2702e-02, -9.7396e-02, -3.1632e-01,
        -5.7042e-01,  5.8264e-01,  3.0499e-01,  1.6677e-01, -3.2992e-01,
         1.6760e-01,  2.1865e-01, -3.1935e-01,  1.7703e-01, -4.5352e-01,
        -4.3923e-01, -4.4463e-01,  1.2528e-01,  1.0373e-01, -8.4956e-02,
        -6.4537e-02,  5.1329e-01,  1.0935e-01, -3.0