# Preparing the environment
Loading models and tokenizers for CLAP and calculating SHA-256 hash of CLAP model

In [1]:
import torch.multiprocessing
import torch
import hashlib
import json
from transformers import MPNetTokenizerFast
from clap import AsmEncoder, AsmTokenizer, TextEncoder

asm_encoder_path    = "./models/asm-encode"
text_encoder_path   = "./models/text-encoder"

asm_tokenizer_path  = "./models/asm-tokenizer"
text_tokenizer_path = "./models/text-tokenizer"


asm_tokenizer       = AsmTokenizer.from_pretrained(asm_tokenizer_path)
text_tokenizer      = MPNetTokenizerFast.from_pretrained(text_tokenizer_path)
asm_encoder         = AsmEncoder.from_pretrained(asm_encoder_path).cuda()
text_encoder        = TextEncoder.from_pretrained(text_encoder_path).cuda()

device = torch.device("cuda")

bubble_output       = "bubblesort.json"
malware_output      = "malware.json"
sha3              = "sha3.json"

hash_obj = hashlib.new("sha256")
with open(asm_encoder_path + "/pytorch_model.bin", 'rb') as file:
    for chunk in iter(lambda: file.read(4096), b''):
        hash_obj.update(chunk)
print(hash_obj.hexdigest())

e93bc8f9fbf39e753a35bc3be045217c2bdf348453914d13f97ef69e9bbaf36b


# Fine-grained sorting algorithm classification (Zero-Shot)

In [2]:

with open(bubble_output) as fp:
    asm = json.load(fp)
print(asm)

prompts = [
    "This is a function related to bubble sort ",
    "This is a function related to selection sort",
    "This is a function related to insertion sort",
    "This is a function related to merge sort",
    "This is a function related to quick sort",
    "This is a function related to radix sort",
    "This is a function related to shell sort",
    "This is a function related to counting sort",
    "This is a function related to bucket sort",
    "This is a function related to heap sort",
]

with torch.no_grad():
    tokens = asm_tokenizer.encode_function(asm)
    asm_input = asm_tokenizer.pad([tokens], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False).to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    encoded_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    encoded_input = encoded_input.to(device)
    text_embeddings = text_encoder(**encoded_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("bubblesort zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")


{'0': 'endbr64', '1': 'mov     edx, 6', '2': 'xor     eax, eax', '3': 'cmp     edx, eax', '4': 'jle     short INSTR13', '5': 'mov     ecx, [rdi+rax*4]', '6': 'mov     esi, [rdi+rax*4+4]', '7': 'cmp     ecx, esi', '8': 'jle     short INSTR11', '9': 'mov     [rdi+rax*4], esi', '10': 'mov     [rdi+rax*4+4], ecx', '11': 'inc     rax', '12': 'jmp     short INSTR3', '13': 'dec     edx', '14': 'jnz     short INSTR2', '15': 'retn'}
bubblesort zeroshot:
Probability: 17.954%, Text: This is a function related to bubble sort 
Probability: 6.919%, Text: This is a function related to selection sort
Probability: 11.567%, Text: This is a function related to insertion sort
Probability: 5.261%, Text: This is a function related to merge sort
Probability: 9.474%, Text: This is a function related to quick sort
Probability: 12.454%, Text: This is a function related to radix sort
Probability: 12.879%, Text: This is a function related to shell sort
Probability: 9.756%, Text: This is a function related to coun

# Fine-grained malware functionality classification (Zero-Shot)

In [3]:
with open(malware_output) as fp:
    asm = json.load(fp)

prompts = [
    "This is a function related to screen shot",
    "This is a function related to auto start",
    "This is a function related to backdoor",
    "This is a function related to download",
    "This is a function related to upload",
    "This is a function related to rootkit",
    "This is a function related to anti detect",
    "This is a function related to anti debug",
    "This is a function related to passwords brute force",
    "This is a function related to file hijack",
]

with torch.no_grad():
    tokens = asm_tokenizer.encode_function(asm)
    asm_input = asm_tokenizer.pad([tokens], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False).to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    encoded_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    encoded_input = encoded_input.to(device)
    text_embeddings = text_encoder(**encoded_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("malware zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")


malware zeroshot:
Probability: 75.98%, Text: This is a function related to screen shot
Probability: 7.844%, Text: This is a function related to auto start
Probability: 1.515%, Text: This is a function related to backdoor
Probability: 1.616%, Text: This is a function related to download
Probability: 2.431%, Text: This is a function related to upload
Probability: 3.327%, Text: This is a function related to rootkit
Probability: 1.482%, Text: This is a function related to anti detect
Probability: 3.228%, Text: This is a function related to anti debug
Probability: 0.919%, Text: This is a function related to passwords brute force
Probability: 1.658%, Text: This is a function related to file hijack


# Fine-grained crypto algorithm classification (Zero-Shot)

In [4]:
with open(sha3) as fp:
    asm = json.load(fp)
print(asm)

prompts = [
    "This is a function related to sha3",
    "This is a function related to des",
    "This is a function related to bubble sort",
    "This is a function related to md5",
    "This is a function related to rsa",
    "This is a function related to sm4"
]

with torch.no_grad():
    tokens = asm_tokenizer.encode_function(asm)
    asm_input = asm_tokenizer.pad([tokens], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False).to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    encoded_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    encoded_input = encoded_input.to(device)
    text_embeddings = text_encoder(**encoded_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("sha3 zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")

{'0': 'endbr64', '1': 'push    rbp', '2': 'mov     rbp, rsp', '3': 'sub     rsp, 30h', '4': 'mov     [rbp+var_18], rdi', '5': 'mov     [rbp+var_20], rsi', '6': 'mov     [rbp+var_24], edx', '7': 'mov     rax, [rbp+var_18]', '8': 'mov     eax, [rax+0CCh]', '9': 'mov     [rbp+var_4], eax', '10': 'mov     [rbp+var_8], 0', '11': 'jmp     short INSTR40', '12': 'mov     rax, [rbp+var_18]', '13': 'mov     eax, [rax+0C8h]', '14': 'lea     ecx, [rax+1]', '15': 'mov     rdx, [rbp+var_18]', '16': 'mov     [rdx+0C8h], ecx', '17': 'mov     rcx, [rbp+var_18]', '18': 'movsxd  rdx, eax', '19': 'movzx   esi, byte ptr [rcx+rdx]', '20': 'mov     edx, [rbp+var_8]', '21': 'movsxd  rcx, edx', '22': 'mov     rdx, [rbp+var_20]', '23': 'add     rdx, rcx', '24': 'movzx   edx, byte ptr [rdx]', '25': 'xor     esi, edx', '26': 'mov     ecx, esi', '27': 'mov     rdx, [rbp+var_18]', '28': 'cdqe', '29': 'mov     [rdx+rax], cl', '30': 'mov     rax, [rbp+var_18]', '31': 'mov     eax, [rax+0C8h]', '32': 'cmp     [rbp+var