# Preparing the environment
Loading models and tokenizers for CLAP and calculating SHA-256 hash of CLAP model

In [1]:
import torch.multiprocessing
import torch
import json
from transformers import AutoModel, AutoTokenizer

device = torch.device("cuda")

asm_tokenizer       = AutoTokenizer.from_pretrained("hustcw/clap-asm", trust_remote_code=True)
text_tokenizer      = AutoTokenizer.from_pretrained("hustcw/clap-text", trust_remote_code=True)
asm_encoder         = AutoModel.from_pretrained("hustcw/clap-asm", trust_remote_code=True).to(device)
text_encoder        = AutoModel.from_pretrained("hustcw/clap-text", trust_remote_code=True).to(device)

bubble_output       = "./CaseStudy/bubblesort.json"
malware_output      = "./CaseStudy/malware.json"
sha3              = "./CaseStudy/sha3.json"

# Fine-grained sorting algorithm classification (Zero-Shot)

In [2]:

with open(bubble_output) as fp:
    asm = json.load(fp)

prompts = [
    "This is a function related to bubble sort ",
    "This is a function related to selection sort",
    "This is a function related to insertion sort",
    "This is a function related to merge sort",
    "This is a function related to quick sort",
    "This is a function related to radix sort",
    "This is a function related to shell sort",
    "This is a function related to counting sort",
    "This is a function related to bucket sort",
    "This is a function related to heap sort",
]

with torch.no_grad():
    asm_input = asm_tokenizer([asm], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False)
    asm_input = asm_input.to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    text_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    text_input = text_input.to(device)
    text_embeddings = text_encoder(**text_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("bubblesort zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")


bubblesort zeroshot:
Probability: 17.954%, Text: This is a function related to bubble sort 
Probability: 6.919%, Text: This is a function related to selection sort
Probability: 11.567%, Text: This is a function related to insertion sort
Probability: 5.261%, Text: This is a function related to merge sort
Probability: 9.474%, Text: This is a function related to quick sort
Probability: 12.454%, Text: This is a function related to radix sort
Probability: 12.879%, Text: This is a function related to shell sort
Probability: 9.756%, Text: This is a function related to counting sort
Probability: 9.351%, Text: This is a function related to bucket sort
Probability: 4.385%, Text: This is a function related to heap sort


# Fine-grained malware functionality classification (Zero-Shot)

In [3]:
with open(malware_output) as fp:
    asm = json.load(fp)

prompts = [
    "This is a function related to screen shot",
    "This is a function related to auto start",
    "This is a function related to backdoor",
    "This is a function related to download",
    "This is a function related to upload",
    "This is a function related to rootkit",
    "This is a function related to anti detect",
    "This is a function related to anti debug",
    "This is a function related to passwords brute force",
    "This is a function related to file hijack",
]

with torch.no_grad():
    asm_input = asm_tokenizer([asm], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False)
    asm_input = asm_input.to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    text_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    text_input = text_input.to(device)
    text_embeddings = text_encoder(**text_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("malware zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")


malware zeroshot:
Probability: 75.98%, Text: This is a function related to screen shot
Probability: 7.844%, Text: This is a function related to auto start
Probability: 1.515%, Text: This is a function related to backdoor
Probability: 1.616%, Text: This is a function related to download
Probability: 2.431%, Text: This is a function related to upload
Probability: 3.327%, Text: This is a function related to rootkit
Probability: 1.482%, Text: This is a function related to anti detect
Probability: 3.228%, Text: This is a function related to anti debug
Probability: 0.919%, Text: This is a function related to passwords brute force
Probability: 1.658%, Text: This is a function related to file hijack


# Fine-grained crypto algorithm classification (Zero-Shot)

In [4]:
with open(sha3) as fp:
    asm = json.load(fp)

prompts = [
    "This is a function related to sha3",
    "This is a function related to des",
    "This is a function related to bubble sort",
    "This is a function related to md5",
    "This is a function related to rsa",
    "This is a function related to sm4"
]

with torch.no_grad():
    asm_input = asm_tokenizer([asm], padding=True, pad_to_multiple_of=8, return_tensors="pt", verbose=False)
    asm_input = asm_input.to(device)
    asm_embedding = asm_encoder(**asm_input)

with torch.no_grad():
    text_input = text_tokenizer(prompts, padding=True, truncation=True, return_tensors='pt')
    text_input = text_input.to(device)
    text_embeddings = text_encoder(**text_input)

logits = torch.einsum("nc,ck->nk", [asm_embedding, text_embeddings.T])
_, preds = torch.max(logits, dim=1)
preds = torch.softmax(logits / 0.07, dim=1).squeeze(0).tolist()

print("sha3 zeroshot:")
for i in range(len(prompts)):
    print(f"Probability: {round(preds[i]*100, 3)}%, Text: {prompts[i]}")

sha3 zeroshot:
Probability: 62.579%, Text: This is a function related to sha3
Probability: 1.63%, Text: This is a function related to des
Probability: 3.479%, Text: This is a function related to bubble sort
Probability: 24.634%, Text: This is a function related to md5
Probability: 5.705%, Text: This is a function related to rsa
Probability: 1.974%, Text: This is a function related to sm4
