In [33]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
import torch

In [34]:
# Path to the saved model
model_path = "detecting_model"

# Load the model
detect_model = AutoModelForSequenceClassification.from_pretrained("detecting_model_notmarked")
categorizing_model = AutoModelForSequenceClassification.from_pretrained("categorizing_model_not_marked")
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path)

cwe_to_label = {'CWE121': 0,
 'CWE78': 1,
 'CWE190': 2,
 'CWE191': 3,
 'CWE122': 4,
 'CWE134': 5,
 'CWE124': 6,
 'CWE127': 7,
 'CWE126': 8,
 'CWE195': 9,
 'CWE194': 10,
 'CWE401': 11,
 'CWE690': 12,
 'CWE197': 13,
 'CWE369': 14,
 'CWE590': 15,
 'CWE400': 16,
 'CWE253': 17,
 'CWE761': 18,
 'CWE114': 19,
 'CWE252': 20,
 'CWE457': 21,
 'CWE427': 22,
 'CWE789': 23,
 'CWE90': 24,
 'CWE606': 25}

label_to_cwe = { v: k for k, v in cwe_to_label.items() }

In [35]:
c_code = '''
 
 
 #include ""std_testcase.h""
 
 #include <wchar.h>
 
 #ifdef _WIN32
 #include <winsock2.h>
 #include <windows.h>
 #include <direct.h>
 #pragma comment(lib, ""ws2_32"") 
 #define CLOSE_SOCKET closesocket
 #else 
 #include <sys/types.h>
 #include <sys/socket.h>
 #include <netinet/in.h>
 #include <arpa/inet.h>
 #include <unistd.h>
 #define INVALID_SOCKET -1
 #define SOCKET_ERROR -1
 #define CLOSE_SOCKET close
 #define SOCKET int
 #endif
 
 #define TCP_PORT 27015
 #define IP_ADDRESS ""127.0.0.1""
 
 
 #ifndef OMITBAD
 
 
 void Process_Control__w32_wchar_t_connect_socket_64b_badSink(void * dataVoidPtr);
 
 void Process_Control__w32_wchar_t_connect_socket_64_bad()
 {
     wchar_t * data;
     wchar_t dataBuffer[100] = L"""";
     data = dataBuffer;
     {
 #ifdef _WIN32
         WSADATA wsaData;
         int wsaDataInit = 0;
 #endif
         int recvResult;
         struct sockaddr_in service;
         wchar_t *replace;
         SOCKET connectSocket = INVALID_SOCKET;
         size_t dataLen = wcslen(data);
         do
         {
 #ifdef _WIN32
             if (WSAStartup(MAKEWORD(2,2), &wsaData) != NO_ERROR)
             {
                 break;
             }
             wsaDataInit = 1;
 #endif
             
             connectSocket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
             if (connectSocket == INVALID_SOCKET)
             {
                 break;
             }
             memset(&service, 0, sizeof(service));
             service.sin_family = AF_INET;
             service.sin_addr.s_addr = inet_addr(IP_ADDRESS);
             service.sin_port = htons(TCP_PORT);
             if (connect(connectSocket, (struct sockaddr*)&service, sizeof(service)) == SOCKET_ERROR)
             {
                 break;
             }
             
             
             recvResult = recv(connectSocket, (char *)(data + dataLen), sizeof(wchar_t) * (100 - dataLen - 1), 0);
             if (recvResult == SOCKET_ERROR || recvResult == 0)
             {
                 break;
             }
             
             data[dataLen + recvResult / sizeof(wchar_t)] = L'\0';
             
             replace = wcschr(data, L'\r');
             if (replace)
             {
                 *replace = L'\0';
             }
             replace = wcschr(data, L'\n');
             if (replace)
             {
                 *replace = L'\0';
             }
         }
         while (0);
         if (connectSocket != INVALID_SOCKET)
         {
             CLOSE_SOCKET(connectSocket);
         }
 #ifdef _WIN32
         if (wsaDataInit)
         {
             WSACleanup();
         }
 #endif
     }
     Process_Control__w32_wchar_t_connect_socket_64b_badSink(&data);
 }
 
 #endif 
 
 #ifndef OMITGOOD
 
 
 void Process_Control__w32_wchar_t_connect_socket_64b_goodG2BSink(void * dataVoidPtr);
 
 static void goodG2B()
 {
     wchar_t * data;
     wchar_t dataBuffer[100] = L"""";
     data = dataBuffer;
     
     wcscpy(data, L""C:\\Windows\\System32\\winsrv.dll"");
     Process_Control__w32_wchar_t_connect_socket_64b_goodG2BSink(&data);
 }
 
 void Process_Control__w32_wchar_t_connect_socket_64_good()
 {
     goodG2B();
 }
 
 #endif 
 
 
 
 #ifdef INCLUDEMAIN
 
 int main(int argc, char * argv[])
 {
     
     srand( (unsigned)time(NULL) );
 #ifndef OMITGOOD
     printLine(""Calling good()..."");
     Process_Control__w32_wchar_t_connect_socket_64_good();
     printLine(""Finished good()"");
 #endif 
 #ifndef OMITBAD
     printLine(""Calling bad()..."");
     Process_Control__w32_wchar_t_connect_socket_64_bad();
     printLine(""Finished bad()"");
 #endif 
     return 0;
 }
 
 #endif

'''

In [36]:
# Tokenize the C code
inputs = tokenizer(c_code, return_tensors="pt", truncation=True)

# Run the tokenized input through the model
outputs = detect_model(**inputs)

# Print the model outputs
print(outputs)
predicted_class_id = outputs.logits.argmax().item()
print(predicted_class_id)

SequenceClassifierOutput(loss=None, logits=tensor([[-3.6838,  3.9907]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
1


In [37]:
outputs = categorizing_model(**inputs)
print(outputs)
probs = outputs.logits.softmax(dim = -1)
top_probs, top_classes = torch.topk(probs, k=10)
print(top_probs, top_classes)
for i, cls in enumerate(top_classes[0]):
    print(label_to_cwe[cls.item()], top_probs[0][i].item())

SequenceClassifierOutput(loss=None, logits=tensor([[-0.7525,  0.0562,  0.1354, -0.5951, -0.5145,  0.2050,  0.7018,  0.0966,
         -1.5388, -0.6051, -0.0498, -0.3766, -0.6496, -0.7869, -0.2401, -0.2440,
          0.3059, -0.4949,  0.1224,  8.0850, -0.4987, -0.1573,  0.1053,  0.0283,
          0.5276, -0.5940]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
tensor([[9.9325e-01, 6.1744e-04, 5.1870e-04, 4.1557e-04, 3.7569e-04, 3.5044e-04,
         3.4590e-04, 3.4003e-04, 3.3708e-04, 3.2374e-04]],
       grad_fn=<TopkBackward0>) tensor([[19,  6, 24, 16,  5,  2, 18, 22,  7,  1]])
CWE114 0.9932481050491333
CWE124 0.0006174429436214268
CWE90 0.0005187038914300501
CWE400 0.0004155689966864884
CWE134 0.0003756937221623957
CWE190 0.00035043794196099043
CWE761 0.00034589937422424555
CWE427 0.00034003224573098123
CWE127 0.00033707942930050194
CWE78 0.00032373963040299714
