In [None]:
import numpy as np
import copy
from pathlib import Path
import pandas as pd
import random
import re
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
import torchtext.vocab as vocab
import sklearn.metrics
from transformers import RobertaModel
from transformers import RobertaConfig
from sklearn.metrics import confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
from gensim.models.word2vec import Word2Vec
from torch.autograd import Variable
from torch import nn, optim
from torch.optim import SGD,Adam,RMSprop
from torch.utils.data import Dataset, DataLoader, IterableDataset
from clang import *


seed = 1234
os.environ['PYTHONHASHSEED'] = str(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [None]:
#cindex.Config.set_library_file('/usr/lib/llvm-10/lib/libclang-10.so.1')

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
multigpu = False
if device == torch.device('cuda'):
	multigpu = torch.cuda.device_count() > 1
print('Device: ',device)
print('MultiGPU: ',multigpu)


In [None]:
## Training & vocab parameters
DATA_PATH = 'data'
VOCAB_SIZE = 50000
BATCH_SIZE = 128
EMBED_SIZE = VOCAB_SIZE+2
EMBED_DIM = 768 #768

In [None]:
#### LATEST Vocab 1

from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import NormalizedString,PreTokenizedString
from typing import List

class MyTokenizer:
    
    cindex.Config.set_library_file('/usr/lib/llvm-10/lib/libclang-10.so.1')
    cidx = cindex.Index.create()
    

    def clang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        ## Tokkenize using clang
        tok = []
        tu = self.cidx.parse('tmp.cpp',
                       args=[''],  
                       unsaved_files=[('tmp.cpp', str(normalized_string))],  
                       options=0)
        for t in tu.get_tokens(extent=tu.cursor.extent):
            spelling = t.spelling#.strip()
            tkind = str(t.kind)
            ckind = str(t.cursor.kind)
            
#             if spelling == '':
#                 continue
                
#             myspelling = t.spelling.replace(' ', '')

            ## Keyword no need

            ## Punctuations no need

            ## Literal
            if tkind == "TokenKind.LITERAL":
                if ckind == "CursorKind.INTEGER_LITERAL":
                    if 'x' in spelling:
                        tok.append(NormalizedString('LITERAL_INT_HEX'))
                    elif 'ul' in spelling.lower():
                        tok.append(NormalizedString('LITERAL_INT_UL'))
                    elif 'u' in spelling.lower():
                        tok.append(NormalizedString('LITERAL_INT_U'))
                    elif 'l' in spelling.lower():
                        tok.append(NormalizedString('LITERAL_INT_L'))
                    else:
                        tok.append(NormalizedString('LITERAL_INT_INT'))

                elif ckind == "CursorKind.FLOATING_LITERAL":
                    if 'e' in spelling:
                        tok.append(NormalizedString('LITERAL_FLOAT_EXP'))
                    elif 'l' in spelling.lower():
                        tok.append(NormalizedString('LITERAL_FLOAT_LF'))
                    else:
                        tok.append(NormalizedString('LITERAL_FLOAT_FLOAT'))

                elif ckind == "CursorKind.STRING_LITERAL":
                    if spelling == "\"\"":
                        tok.append(NormalizedString('LITERAL_STRING_EMPTY'))
                    elif spelling[1:-1].isnumeric() == True:
                        tok.append(NormalizedString('LITERAL_STRING_NUMERIC'))
                    elif re.match(regex_def.FILE_EXTENSION, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_FILEEXTENSION'))
                    elif re.match(regex_def.BOOLEAN, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_BOOLEAN'))
                    elif re.match(regex_def.LOGICAL_OP, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_LOGICALOP'))
                    elif re.match(regex_def.BITWISE_OP, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_BITWISEOP'))
                    elif re.match(regex_def.COMPARISON_OP, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_COMPARISONOP'))
                    elif re.match(regex_def.ARITHMETIC_OP, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_ARITHMETICOP'))
                    elif re.match(regex_def.CONTROLFLOW, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_CONTROLFLOW'))
                    elif re.match(regex_def.IPV4, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_IPV4'))
                    elif re.match(regex_def.IPV6, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_IPV6'))
                    elif re.match(regex_def.MIMETYPE, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_MIMETYPE'))
                    elif re.match(regex_def.FILEPATH, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_FILEPATH'))
                    elif re.match(regex_def.LANGUAGE_ISO, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_LANGUAGEISO'))
                    elif re.match(regex_def.STATUS, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_STATUS'))
                    elif re.match(regex_def.MONTH, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_MONTH'))
                    elif re.match(regex_def.SQL_STATEMENT, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_SQLSTATEMENT'))
                    elif re.match(regex_def.TIME, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_TIME'))
                    elif re.match(regex_def.NEWLINE, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_NEWLINE'))
                    elif re.match(regex_def.BASH_BASIC, spelling) is not None:
                        tok.append(NormalizedString('LITERAL_STRING_BASH'))
                    else:
                        tok.append(NormalizedString('LITERAL_STRING_UNK'))

                elif ckind == "CursorKind.CHARACTER_LITERAL":
                    if 'x' in spelling[1:-1]:
                        tok.append(NormalizedString('LITERAL_CHAR_HEX'))
                    elif spelling[1:-1].isalpha() == True:
                        tok.append(NormalizedString('LITERAL_CHAR_CHAR'))
                    elif spelling[1:-1].isdigit() == True:
                        tok.append(NormalizedString('LITERAL_CHAR_INT'))
                    else:
                        tok.append(NormalizedString('LITERAL_CHAR_UNK'))

                elif ckind == "CursorKind.ASM_STMT":
                    tok.append(NormalizedString('LITERAL_ASM'))

                elif ckind == "CursorKind.LABEL_STMT":
                    tok.append(NormalizedString('LITERAL_LABELSTMT'))

                elif ckind == "CursorKind.VAR_DECL":
                    if '0x' in spelling:
                        tok.append(NormalizedString('LITERAL_VARDECL_HEX'))
                    elif spelling.isnumeric() == True:
                        tok.append(NormalizedString('LITERAL_VARDECL_NUM'))
                    else:
                        tok.append(NormalizedString('LITERAL_VARDECL_STRING'))

                elif ckind == "CursorKind.DECL_STMT":
                    if '0x' in spelling:
                        tok.append(NormalizedString('LITERAL_DECLSTMT_HEX'))
                    elif spelling.isnumeric() == True:
                        tok.append(NormalizedString('LITERAL_DECLSTMT_NUM'))
                    elif spelling.replace('.','').isnumeric() == True:
                        tok.append(NormalizedString('LITERAL_DECLSTMT_FLOAT'))
                    else:
                        tok.append(NormalizedString('LITERAL_DECLSTMT_STRING'))

                else:
                    tok.append(NormalizedString('LITERAL_UNK'))

            elif tkind == "TokenKind.IDENTIFIER":
                if ckind == "CursorKind.MEMBER_REF_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_MEMBERREFEXPR'))

                elif ckind == "CursorKind.UNARY_OPERATOR":
                    tok.append(NormalizedString('IDENTIFIER_UNARYOP'))

                elif ckind == "CursorKind.CONSTRUCTOR":
                    tok.append(NormalizedString('IDENTIFIER_CONSTRUCTOR'))

                elif ckind == "CursorKind.CLASS_DECL":
                    tok.append(NormalizedString('IDENTIFIER_CLASSDECL'))

                elif ckind == "CursorKind.DEFAULT_STMT":
                    tok.append(NormalizedString('IDENTIFIER_DEFAULTSTMT'))

                elif ckind == "CursorKind.CSTYLE_CAST_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_CSTYLECASTEXPR'))

                elif ckind == "CursorKind.COMPOUND_ASSIGNMENT_OPERATOR":
                    tok.append(NormalizedString('IDENTIFIER_COMPOUNDASSIGNMENTOPERATOR'))

                elif ckind == "CursorKind.WHILE_STMT":
                    tok.append(NormalizedString('IDENTIFIER_WHILESTMT'))

                elif ckind == "CursorKind.DO_STMT":
                    tok.append(NormalizedString('IDENTIFIER_DOSTMT'))

                elif ckind == "CursorKind.DESTRUCTOR":
                    tok.append(NormalizedString('IDENTIFIER_DESTRUCTOR'))

                elif ckind == "CursorKind.MEMBER_REF":
                    tok.append(NormalizedString('IDENTIFIER_MEMBERREF'))

                elif ckind == "CursorKind.ADDR_LABEL_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_ADDRLABELEXPR'))

                elif ckind == "CursorKind.PACKED_ATTR":
                    tok.append(NormalizedString('IDENTIFIER_PACKEDATTR'))

                elif ckind == "CursorKind.ASM_STMT":
                    tok.append(NormalizedString('IDENTIFIER_ASMSTMT'))

                elif ckind == "CursorKind.CALL_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_CALLEXPR'))

                elif ckind == "CursorKind.UNEXPOSED_DECL":
                    tok.append(NormalizedString('IDENTIFIER_UNEXPOSEDDECL'))

                elif ckind == "CursorKind.CXX_NEW_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_CXXNEWEXPR'))

                elif ckind == "CursorKind.CXX_METHOD":
                    tok.append(NormalizedString('IDENTIFIER_CXXMETHOD'))

                elif ckind == "CursorKind.CXX_DELETE_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_CXXDELETEEXPR'))

                elif ckind == "CursorKind.ALIGNED_ATTR":
                    tok.append(NormalizedString('IDENTIFIER_ALIGNEDATTR'))

                elif ckind == "CursorKind.STRING_LITERAL":
                    tok.append(NormalizedString('IDENTIFIER_STRINGLITERAL'))

                elif ckind == "CursorKind.LAMBDA_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_LAMBDAEXPR'))

                elif ckind == "CursorKind.CXX_TRY_STMT":
                    tok.append(NormalizedString('IDENTIFIER_CXXTRYSTMT'))

                elif ckind == "CursorKind.CONDITIONAL_OPERATOR":
                    tok.append(NormalizedString('IDENTIFIER_CONDITIONALOP'))

                elif ckind == "CursorKind.CXX_REINTERPRET_CAST_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_CXXREINTERPRETCASTEXPR'))

                elif ckind == "CursorKind.SWITCH_STMT":
                    tok.append(NormalizedString('IDENTIFIER_SWITCHSTMT'))

                elif ckind == "CursorKind.UNEXPOSED_ATTR":
                    tok.append(NormalizedString('IDENTIFIER_UNEXPOSEDATTR'))

                elif ckind == "CursorKind.NAMESPACE_REF":
                    tok.append(NormalizedString('IDENTIFIER_NAMESPACEREF'))

                elif ckind == "CursorKind.RETURN_STMT":
                    tok.append(NormalizedString('IDENTIFIER_RETURNSTMT'))

                elif ckind == "CursorKind.PAREN_EXPR":
                    tok.append(NormalizedString('IDENTIFIER_PARENEXPR'))

                else:
                    tok.append(NormalizedString(spelling))

            else:
                tok.append(NormalizedString(spelling))

        return(tok)
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.clang_split)
        
        
        
## Custom tokenizer

from tokenizers import Tokenizer
from tokenizers import normalizers
from tokenizers.normalizers import StripAccents
from tokenizers.processors import TemplateProcessing
from tokenizers import processors,pre_tokenizers
from tokenizers.models import BPE

## Init
#my_tokenizer = Tokenizer(BPE(unk_token="<unk>"))
#my_tokenizer = Tokenizer(BPE())


## Load
vocab, merges = BPE.read_file(vocab="./tokenizer5/v1/draper-vocab.json", merges="./tokenizer5/v1/draper-merges.txt")
my_tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>"))


mf = [("<s>",0),("<pad>",1),("</s>",2),("<unk>",3),("<mask>",4)]
mf = mf + [('LITERAL_INT_INT',5),('LITERAL_INT_UL',6),('LITERAL_INT_U',7),('LITERAL_INT_L',8),('LITERAL_INT_HEX',9)]
mf = mf + [('LITERAL_FLOAT_FLOAT',10),('LITERAL_FLOAT_EXP',11),('LITERAL_FLOAT_LF',12)]
mf = mf + [('LITERAL_STRING_EMPTY',13),('LITERAL_STRING_NUMERIC',14),('LITERAL_STRING_FILEEXTENSION',15),('LITERAL_STRING_BOOLEAN',16),('LITERAL_STRING_LOGICALOP',17),('LITERAL_STRING_BITWISEOP',18),('LITERAL_STRING_COMPARISONOP',19),('LITERAL_STRING_ARITHMETICOP',20),('LITERAL_STRING_CONTROLFLOW',21),('LITERAL_STRING_IPV4',22),('LITERAL_STRING_IPV6',23),('LITERAL_STRING_MIMETYPE',24),('LITERAL_STRING_FILEPATH',25),('LITERAL_STRING_LANGUAGEISO',26),('LITERAL_STRING_STATUS',27),('LITERAL_STRING_MONTH',28),('LITERAL_STRING_SQLSTATEMENT',29),('LITERAL_STRING_TIME',30),('LITERAL_STRING_NEWLINE',31),('LITERAL_STRING_BASH',32),('LITERAL_STRING_UNK',33)]
mf = mf + [('LITERAL_CHAR_HEX',34),('LITERAL_CHAR_CHAR',35),('LITERAL_CHAR_INT',36),('LITERAL_CHAR_UNK',37)]
mf = mf + [('LITERAL_ASM',38),('LITERAL_LABELSTMT',39)]
mf = mf + [('LITERAL_VARDECL_HEX',40), ('LITERAL_VARDECL_NUM',41),('LITERAL_VARDECL_STRING',42)]
mf = mf + [('LITERAL_DECLSTMT_HEX',43),('LITERAL_DECLSTMT_NUM',44),('LITERAL_DECLSTMT_FLOAT',45),('LITERAL_DECLSTMT_STRING',46)]
mf = mf + [('LITERAL_UNK',47)]
mf = mf + [('IDENTIFIER_MEMBERREFEXPR',48),('IDENTIFIER_UNARYOP',49),('IDENTIFIER_CONSTRUCTOR',50),('IDENTIFIER_CLASSDECL',51),('IDENTIFIER_DEFAULTSTMT',52),('IDENTIFIER_CSTYLECASTEXPR',53),('IDENTIFIER_COMPOUNDASSIGNMENTOPERATOR',54),('IDENTIFIER_WHILESTMT',55),('IDENTIFIER_DOSTMT',56),('IDENTIFIER_DESTRUCTOR',57),('IDENTIFIER_MEMBERREF',58),('IDENTIFIER_ADDRLABELEXPR',59),('IDENTIFIER_PACKEDATTR',60),('IDENTIFIER_ASMSTMT',61),('IDENTIFIER_CALLEXPR',62),('IDENTIFIER_UNEXPOSEDDECL',63),('IDENTIFIER_CXXNEWEXPR',64),('IDENTIFIER_CXXMETHOD',65),('IDENTIFIER_CXXDELETEEXPR',66),('IDENTIFIER_ALIGNEDATTR',67),('IDENTIFIER_STRINGLITERAL',68),('IDENTIFIER_LAMBDAEXPR',69),('IDENTIFIER_CXXTRYSTMT',70),('IDENTIFIER_CONDITIONALOP',71),('IDENTIFIER_CXXREINTERPRETCASTEXPR',72),('IDENTIFIER_SWITCHSTMT',73),('IDENTIFIER_UNEXPOSEDATTR',74),('IDENTIFIER_NAMESPACEREF',75),('IDENTIFIER_RETURNSTMT',76),('IDENTIFIER_PARENEXP',77)]
mf = mf + [('char',78),('int',79),('switch',80),('case',81),('if',82),('break',83),('for',84),('const',85),('unsigned',86),('struct',87),('default',88),('return',89),('long',90),('goto',91),('this',92),('enum',93),('bool',94),('static',95),('false',96),('true',97),('new',98),('delete',99),('while',100),('double',101),('else',102),('private',103),('do',104),('sizeof',105),('void',106),('continue',107),('__attribute__',108),('short',109),('throw',110),('float',111),('register',112),('__FUNCTION__',113),('static_cast',114),('__func__',115),('class',116),('try',117),('dynamic_cast',118),('template',119),('union',120),('reinterpret_cast',121),('catch',122),('operator',123),('const_cast',124),('using',125),('namespace',126),('typename',127),('wchar_t',128),('not',129),('typeof',130),('__label__',131),('__PRETTY_FUNCTION__',132),('auto',133),('__extension__',134),('volatile',135),('__asm__',136),('__volatile__',137),('extern',138),('asm',139),('signed',140),('typedef',141),('typeid',142),('and',143),('or',144),('public',145),('virtual',146),('nullptr',147),('__restrict',148),('__asm',149),('__typeof__',150),('xor',151),('__complex__',152),('__real__',153),('__imag__',154),('not_eq',155),('export',156),('compl',157),('__alignof__',158),('__restrict__',159),('__cdecl',160),('bitor',161),('protected',162),('explicit',163),('friend',164),('decltype',165),('mutable',166),('inline',167),('__const',168),('__stdcall',169),('char16_t',170),('char32_t',171),('_Decimal64',172),('constexpr',173),('bitand',174),('alignof',175),('static_assert',176),('__attribute',177),('thread_local',178),('__alignof',179),('__builtin_va_arg',180),('_Decimal32',181)]
mf = mf + [('"',182),('(',183),('*',184),(',',185),(')',186),('{',187),(';',188),('->',189),(':',190),('.',191),('-',192),('=',193),('+',194),('<',195),('++',196),('+=',197),('==',198),('||',199),('!=',200),('}',201),('/',202),('!',203),('>=',204),('[',205),(']',206),('&',207),('::',208),('&&',209),('>',210),('#',211),('--',212),('<=',213),('-=',214),('|',215),('%',216),('?',217),('<<',218),('>>',219),('|=',220),('&=',221),('^',222),('~',223),('^=',224),('...',225),('/=',226),('*=',227),('>>=',228),('<<=',229),('%=',230),('##',231),('->*',232),('\\',233),('.*',234),('@',235)]

my_tokenizer.normalizer = normalizers.Sequence([StripAccents()])
my_tokenizer.pre_tokenizer = PreTokenizer.custom(MyTokenizer())
my_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
my_tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
    ("<s>",0),
    ("<pad>",1),
    ("</s>",2),
    ("<unk>",3),
    ("<mask>",4),
    ]
)

In [None]:
#### LATEST Vocab 2

from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import NormalizedString,PreTokenizedString
from typing import List

class MyTokenizer:
    
    cindex.Config.set_library_file('/usr/lib/llvm-10/lib/libclang-10.so.1')
    cidx = cindex.Index.create()
    

    def clang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        ## Tokkenize using clang
        tok = []
        tu = self.cidx.parse('tmp.cpp',
                       args=[''],  
                       unsaved_files=[('tmp.cpp', str(normalized_string))],  
                       options=0)
        for t in tu.get_tokens(extent=tu.cursor.extent):
            spelling = t.spelling.strip()
            tkind = str(t.kind)
            ckind = str(t.cursor.kind)
            
            if spelling == '':
                continue
                
            spelling = spelling.replace(' ', '')

            tok.append(NormalizedString(spelling))

        return(tok)
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.clang_split)
        
## Custom tokenizer

from tokenizers import Tokenizer
from tokenizers import normalizers,decoders
from tokenizers.normalizers import StripAccents, unicode_normalizer_from_str
from tokenizers.processors import TemplateProcessing
from tokenizers import processors,pre_tokenizers
from tokenizers.models import BPE

## Init
#my_tokenizer = Tokenizer(BPE(unk_token="<unk>"))
#my_tokenizer = Tokenizer(BPE())


## Load
vocab, merges = BPE.read_file(vocab="./tokenizer5/v2/draper-vocab.json", merges="./tokenizer5/v2/draper-merges.txt")
my_tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>"))

mf = [("<s>",0),("<pad>",1),("</s>",2),("<unk>",3),("<mask>",4)]
mf = mf + [('char',5),('int',6),('switch',7),('case',8),('if',9),('break',10),('for',11),('const',12),('unsigned',13),('struct',14),('default',15),('return',16),('long',17),('goto',18),('this',19),('enum',20),('bool',21),('static',22),('false',23),('true',24),('new',25),('delete',26),('while',27),('double',28),('else',29),('private',30),('do',31),('sizeof',32),('void',33),('continue',34),('__attribute__',35),('short',36),('throw',37),('float',38),('register',39),('__FUNCTION__',40),('static_cast',41),('__func__',42),('class',43),('try',44),('dynamic_cast',45),('template',46),('union',47),('reinterpret_cast',48),('catch',49),('operator',50),('const_cast',51),('using',52),('namespace',53),('typename',54),('wchar_t',55),('not',56),('typeof',57),('__label__',58),('__PRETTY_FUNCTION__',59),('auto',60),('__extension__',61),('volatile',62),('__asm__',63),('__volatile__',64),('extern',65),('asm',66),('signed',67),('typedef',68),('typeid',69),('and',70),('or',71),('public',72),('virtual',73),('nullptr',74),('__restrict',75),('__asm',76),('__typeof__',77),('xor',78),('__complex__',79),('__real__',80),('__imag__',81),('not_eq',82),('export',83),('compl',84),('__alignof__',85),('__restrict__',86),('__cdecl',87),('bitor',88),('protected',89),('explicit',90),('friend',91),('decltype',92),('mutable',93),('inline',94),('__const',95),('__stdcall',96),('char16_t',97),('char32_t',98),('_Decimal64',99),('constexpr',100),('bitand',101),('alignof',102),('static_assert',103),('__attribute',104),('thread_local',105),('__alignof',106),('__builtin_va_arg',107),('_Decimal32',108)]
mf = mf + [('\"',109),('(',110),('*',111),(',',112),(')',113),('{',114),(';',115),('->',116),(':',117),('.',118),('-',119),('=',120),('+',121),('<',122),('++',123),('+=',124),('==',125),('||',126),('!=',127),('}',128),('/',129),('!',130),('>=',131),('[',132),(']',133),('&',134),('::',135),('&&',136),('>',137),('#',138),('--',139),('<=',140),('-=',141),('|',142),('%',143),('?',144),('<<',145),('>>',146),('|=',147),('&=',148),('^',149),('~',150),('^=',151),('...',152),('/=',153),('*=',154),('>>=',155),('<<=',156),('%=',157),('##',158),('->*',159),('\\',160),('.*',161),('@',162)]

my_tokenizer.normalizer = normalizers.Sequence([StripAccents()])
my_tokenizer.pre_tokenizer = PreTokenizer.custom(MyTokenizer())
my_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
my_tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
    ("<s>",0),
    ("<pad>",1),
    ("</s>",2),
    ("<unk>",3),
    ("<mask>",4),
    ]
)



In [None]:
#### LATEST Vocab 5
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import NormalizedString,PreTokenizedString
from typing import List

class MyTokenizer:
    
    cindex.Config.set_library_file('/usr/lib/llvm-10/lib/libclang-10.so.1')
    cidx = cindex.Index.create()
    
    std_api_calls = set(['_Exit','abs','acos','acosh','asctime','asin','asinh','assert','at_quick_exit','atan','atan2','atanh','atexit','atof','atol','bsearch','btowc','c16rtomb','c32rtomb','cbrt','ceil','cerr','cin','clearerr','clock','clog','copysign','cos','cosh','cout','ctime','difftime','div','errno','exp','exp2','expm1','fabs','fclose','fdim','feclearexcept','fegetenv','fegetexceptflag','fegetround','feholdexcept','feof','feraiseexcept','ferror','fesetenv','fesetexceptflag','fesetround','fetestexcept','feupdateenv','fflush','fgetc','fgetpos','fgets','fgetwc','fgetws','floor','fma','fmax','fmod','fopen','fprintf','fputc','fputs','fputwc','fputws','fread','free','freopen','frexp','fscanf','fseek','fsetpos','ftell','fwide','fwprintf','fwrite','fwscanf','getc','getchar','getenv','gets','getwc','getwchar','gmtime','hypot','ilogb','imaxabs','imaxdiv','isblank','iscntrl','isdigit','isgraph','islower','isprint','ispunct','isspace','isupper','iswalnum','iswalpha','iswblank','iswcntrl','iswctype','iswdigit','iswgraph','iswlower','iswprint','iswpunct','iswspace','iswupper','iswxdigit','isxdigit','labs','ldexp','ldiv','llabs','lldiv','llrint','llround','localeconv','localtime','log','log10','log1p','log2','logb','longjmp','lrint','lround','malloc','mblen','mbrlen','mbrtoc16','mbrtoc32','mbrtowc','mbsinit','mbsrtowcs','mbstowcs','mbtowc','memchr','memcmp','memcpy','memmove','memset','mktime','modf','nan','nearbyint','nextafter','nexttoward','perror','pow','printf','putc','putchar','puts','putwchar','qsort','quick_exit','raise','realloc','remainder','remove','remquo','rename','rewind','rint','round','sca','scalbln','scalbn','setbuf','setjmp','setlocale','setvbuf','signal','sin','sinh','snprintf','sprintf','sqrt','srand','sscanf','strcat','strchr','strcmp','strcoll','strcpy','strcspn','strerror','strftime','strlen','strncat','strncmp','strncpy','strpbrk','strrchr','strspn','strstr','strtod','strtoimax','strtok','strtol','strtoll','strtoull','strtoumax','strxfrm','swprintf','swscanf','tan','tanh','time','tmpfile','tmpnam','tolower','toupper','towctrans','towlower','towupper','trunc','ungetc','ungetwc','vfprintf','vfscanf','vfwprintf','vfwscanf','vprintf','vscanf','vsfscanf','vsnprintf','vsprintf','vsscanf','vswprintf','vwprintf','vwscanf','wcerr','wcin','wclog','wcout','wcrtomb','wcscat','wcschr','wcscmp','wcscoll','wcscpy','wcscspn','wcsftime','wcslne','wcsncat','wcsncmp','wcsncpy','wcspbrk','wcsrchr','wcsrtombs','wcsspn','wcsstr','wcstod','wcstof','wcstoimax','wcstok','wcstol','wcstold','wcstoll','wcstombs','wcstoul','wcstoull','wcstoumax','wcsxfrm','wctob','wctomb','wctrans','wctype','wmemchr','wmemcmp','wmemcpy','wmemmove','wmemset','wprintf','wscanf'])
    

    def clang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        ## Tokkenize using clang
        tok = []
        tu = self.cidx.parse('tmp.cpp',
                       args=[''],  
                       unsaved_files=[('tmp.cpp', str(normalized_string))],  
                       options=0)
        for t in tu.get_tokens(extent=tu.cursor.extent):
            spelling = t.spelling.strip()
            
            if spelling == '':
                continue
                
            ## Keyword no need

            ## Punctuations no need

            ## Literal all to BPE
            
            spelling = spelling.replace(' ', '')
            tok.append(NormalizedString(spelling))

        return(tok)
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.clang_split)
        
## Custom tokenizer

from tokenizers import Tokenizer
from tokenizers import normalizers,decoders
from tokenizers.normalizers import StripAccents, unicode_normalizer_from_str
from tokenizers.processors import TemplateProcessing
from tokenizers import processors,pre_tokenizers
from tokenizers.models import BPE

## Init
#my_tokenizer = Tokenizer(BPE(unk_token="<unk>"))
#my_tokenizer = Tokenizer(BPE())


## Load
vocab, merges = BPE.read_file(vocab="./tokenizer5/v5/draper-vocab.json", merges="./tokenizer5/v5/draper-merges.txt")
my_tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>"))


mf = [("<s>",0),("<pad>",1),("</s>",2),("<unk>",3),("<mask>",4)]
mf = mf + [('char',5),('int',6),('switch',7),('case',8),('if',9),('break',10),('for',11),('const',12),('unsigned',13),('struct',14),('default',15),('return',16),('long',17),('goto',18),('this',19),('enum',20),('bool',21),('static',22),('false',23),('true',24),('new',25),('delete',26),('while',27),('double',28),('else',29),('private',30),('do',31),('sizeof',32),('void',33),('continue',34),('__attribute__',35),('short',36),('throw',37),('float',38),('register',39),('__FUNCTION__',40),('static_cast',41),('__func__',42),('class',43),('try',44),('dynamic_cast',45),('template',46),('union',47),('reinterpret_cast',48),('catch',49),('operator',50),('const_cast',51),('using',52),('namespace',53),('typename',54),('wchar_t',55),('not',56),('typeof',57),('__label__',58),('__PRETTY_FUNCTION__',59),('auto',60),('__extension__',61),('volatile',62),('__asm__',63),('__volatile__',64),('extern',65),('asm',66),('signed',67),('typedef',68),('typeid',69),('and',70),('or',71),('public',72),('virtual',73),('nullptr',74),('__restrict',75),('__asm',76),('__typeof__',77),('xor',78),('__complex__',79),('__real__',80),('__imag__',81),('not_eq',82),('export',83),('compl',84),('__alignof__',85),('__restrict__',86),('__cdecl',87),('bitor',88),('protected',89),('explicit',90),('friend',91),('decltype',92),('mutable',93),('inline',94),('__const',95),('__stdcall',96),('char16_t',97),('char32_t',98),('_Decimal64',99),('constexpr',100),('bitand',101),('alignof',102),('static_assert',103),('__attribute',104),('thread_local',105),('__alignof',106),('__builtin_va_arg',107),('_Decimal32',108)]
mf = mf + [('\"',109),('(',110),('*',111),(',',112),(')',113),('{',114),(';',115),('->',116),(':',117),('.',118),('-',119),('=',120),('+',121),('<',122),('++',123),('+=',124),('==',125),('||',126),('!=',127),('}',128),('/',129),('!',130),('>=',131),('[',132),(']',133),('&',134),('::',135),('&&',136),('>',137),('#',138),('--',139),('<=',140),('-=',141),('|',142),('%',143),('?',144),('<<',145),('>>',146),('|=',147),('&=',148),('^',149),('~',150),('^=',151),('...',152),('/=',153),('*=',154),('>>=',155),('<<=',156),('%=',157),('##',158),('->*',159),('\\',160),('.*',161),('@',162)]
mf = mf + [('_Exit',163),('abs',164),('acos',165),('acosh',166),('asctime',167),('asin',168),('asinh',169),('assert',170),('at_quick_exit',171),('atan',172),('atan2',173),('atanh',174),('atexit',175),('atof',176),('atol',177),('bsearch',178),('btowc',179),('c16rtomb',180),('c32rtomb',181),('cbrt',182),('ceil',183),('cerr',184),('cin',185),('clearerr',186),('clock',187),('clog',188),('copysign',189),('cos',190),('cosh',191),('cout',192),('ctime',193),('difftime',194),('div',195),('errno',196),('exp',197),('exp2',198),('expm1',199),('fabs',200),('fclose',201),('fdim',202),('feclearexcept',203),('fegetenv',204),('fegetexceptflag',205),('fegetround',206),('feholdexcept',207),('feof',208),('feraiseexcept',209),('ferror',210),('fesetenv',211),('fesetexceptflag',212),('fesetround',213),('fetestexcept',214),('feupdateenv',215),('fflush',216),('fgetc',217),('fgetpos',218),('fgets',219),('fgetwc',220),('fgetws',221),('floor',222),('fma',223),('fmax',224),('fmod',225),('fopen',226),('fprintf',227),('fputc',228),('fputs',229),('fputwc',230),('fputws',231),('fread',232),('free',233),('freopen',234),('frexp',235),('fscanf',236),('fseek',237),('fsetpos',238),('ftell',239),('fwide',240),('fwprintf',241),('fwrite',242),('fwscanf',243),('getc',244),('getchar',245),('getenv',246),('gets',247),('getwc',248),('getwchar',249),('gmtime',250),('hypot',251),('ilogb',252),('imaxabs',253),('imaxdiv',254),('isblank',255),('iscntrl',256),('isdigit',257),('isgraph',258),('islower',259),('isprint',260),('ispunct',261),('isspace',262),('isupper',263),('iswalnum',264),('iswalpha',265),('iswblank',266),('iswcntrl',267),('iswctype',268),('iswdigit',269),('iswgraph',270),('iswlower',271),('iswprint',272),('iswpunct',273),('iswspace',274),('iswupper',275),('iswxdigit',276),('isxdigit',277),('labs',278),('ldexp',279),('ldiv',280),('llabs',281),('lldiv',282),('llrint',283),('llround',284),('localeconv',285),('localtime',286),('log',287),('log10',288),('log1p',289),('log2',290),('logb',291),('longjmp',292),('lrint',293),('lround',294),('malloc',295),('mblen',296),('mbrlen',297),('mbrtoc16',298),('mbrtoc32',299),('mbrtowc',300),('mbsinit',301),('mbsrtowcs',302),('mbstowcs',303),('mbtowc',304),('memchr',305),('memcmp',306),('memcpy',307),('memmove',308),('memset',309),('mktime',310),('modf',311),('nan',312),('nearbyint',313),('nextafter',314),('nexttoward',315),('perror',316),('pow',317),('printf',318),('putc',319),('putchar',320),('puts',321),('putwchar',322),('qsort',323),('quick_exit',324),('raise',325),('realloc',326),('remainder',327),('remove',328),('remquo',329),('rename',330),('rewind',331),('rint',332),('round',333),('sca',334),('scalbln',335),('scalbn',336),('setbuf',337),('setjmp',338),('setlocale',339),('setvbuf',340),('signal',341),('sin',342),('sinh',343),('snprintf',344),('sprintf',345),('sqrt',346),('srand',347),('sscanf',348),('strcat',349),('strchr',350),('strcmp',351),('strcoll',352),('strcpy',353),('strcspn',354),('strerror',355),('strftime',356),('strlen',357),('strncat',358),('strncmp',359),('strncpy',360),('strpbrk',361),('strrchr',362),('strspn',363),('strstr',364),('strtod',365),('strtoimax',366),('strtok',367),('strtol',368),('strtoll',369),('strtoull',370),('strtoumax',371),('strxfrm',372),('swprintf',373),('swscanf',374),('tan',375),('tanh',376),('time',377),('tmpfile',378),('tmpnam',379),('tolower',380),('toupper',381),('towctrans',382),('towlower',383),('towupper',384),('trunc',385),('ungetc',386),('ungetwc',387),('vfprintf',388),('vfscanf',389),('vfwprintf',390),('vfwscanf',391),('vprintf',392),('vscanf',393),('vsfscanf',394),('vsnprintf',395),('vsprintf',396),('vsscanf',397),('vswprintf',398),('vwprintf',399),('vwscanf',400),('wcerr',401),('wcin',402),('wclog',403),('wcout',404),('wcrtomb',405),('wcscat',406),('wcschr',407),('wcscmp',408),('wcscoll',409),('wcscpy',410),('wcscspn',411),('wcsftime',412),('wcslne',413),('wcsncat',414),('wcsncmp',415),('wcsncpy',416),('wcspbrk',417),('wcsrchr',418),('wcsrtombs',419),('wcsspn',420),('wcsstr',421),('wcstod',422),('wcstof',423),('wcstoimax',424),('wcstok',425),('wcstol',426),('wcstold',427),('wcstoll',428),('wcstombs',429),('wcstoul',430),('wcstoull',431),('wcstoumax',432),('wcsxfrm',433),('wctob',434),('wctomb',435),('wctrans',436),('wctype',437),('wmemchr',438),('wmemcmp',439),('wmemcpy',440),('wmemmove',441),('wmemset',442),('wprintf',443),('wscanf',444)]

my_tokenizer.normalizer = normalizers.Sequence([StripAccents()])
my_tokenizer.pre_tokenizer = PreTokenizer.custom(MyTokenizer())
my_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
my_tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
    ("<s>",0),
    ("<pad>",1),
    ("</s>",2),
    ("<unk>",3),
    ("<mask>",4)
    ]
)


In [None]:
#### LATEST Vocab 5
from tokenizers.pre_tokenizers import PreTokenizer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import NormalizedString,PreTokenizedString
from typing import List

class MyTokenizer:
    cidx = cindex.Index.create()
    
    std_api_calls = set(['_Exit','abs','acos','acosh','asctime','asin','asinh','assert','at_quick_exit','atan','atan2','atanh','atexit','atof','atol','bsearch','btowc','c16rtomb','c32rtomb','cbrt','ceil','cerr','cin','clearerr','clock','clog','copysign','cos','cosh','cout','ctime','difftime','div','errno','exp','exp2','expm1','fabs','fclose','fdim','feclearexcept','fegetenv','fegetexceptflag','fegetround','feholdexcept','feof','feraiseexcept','ferror','fesetenv','fesetexceptflag','fesetround','fetestexcept','feupdateenv','fflush','fgetc','fgetpos','fgets','fgetwc','fgetws','floor','fma','fmax','fmod','fopen','fprintf','fputc','fputs','fputwc','fputws','fread','free','freopen','frexp','fscanf','fseek','fsetpos','ftell','fwide','fwprintf','fwrite','fwscanf','getc','getchar','getenv','gets','getwc','getwchar','gmtime','hypot','ilogb','imaxabs','imaxdiv','isblank','iscntrl','isdigit','isgraph','islower','isprint','ispunct','isspace','isupper','iswalnum','iswalpha','iswblank','iswcntrl','iswctype','iswdigit','iswgraph','iswlower','iswprint','iswpunct','iswspace','iswupper','iswxdigit','isxdigit','labs','ldexp','ldiv','llabs','lldiv','llrint','llround','localeconv','localtime','log','log10','log1p','log2','logb','longjmp','lrint','lround','malloc','mblen','mbrlen','mbrtoc16','mbrtoc32','mbrtowc','mbsinit','mbsrtowcs','mbstowcs','mbtowc','memchr','memcmp','memcpy','memmove','memset','mktime','modf','nan','nearbyint','nextafter','nexttoward','perror','pow','printf','putc','putchar','puts','putwchar','qsort','quick_exit','raise','realloc','remainder','remove','remquo','rename','rewind','rint','round','sca','scalbln','scalbn','setbuf','setjmp','setlocale','setvbuf','signal','sin','sinh','snprintf','sprintf','sqrt','srand','sscanf','strcat','strchr','strcmp','strcoll','strcpy','strcspn','strerror','strftime','strlen','strncat','strncmp','strncpy','strpbrk','strrchr','strspn','strstr','strtod','strtoimax','strtok','strtol','strtoll','strtoull','strtoumax','strxfrm','swprintf','swscanf','tan','tanh','time','tmpfile','tmpnam','tolower','toupper','towctrans','towlower','towupper','trunc','ungetc','ungetwc','vfprintf','vfscanf','vfwprintf','vfwscanf','vprintf','vscanf','vsfscanf','vsnprintf','vsprintf','vsscanf','vswprintf','vwprintf','vwscanf','wcerr','wcin','wclog','wcout','wcrtomb','wcscat','wcschr','wcscmp','wcscoll','wcscpy','wcscspn','wcsftime','wcslne','wcsncat','wcsncmp','wcsncpy','wcspbrk','wcsrchr','wcsrtombs','wcsspn','wcsstr','wcstod','wcstof','wcstoimax','wcstok','wcstol','wcstold','wcstoll','wcstombs','wcstoul','wcstoull','wcstoumax','wcsxfrm','wctob','wctomb','wctrans','wctype','wmemchr','wmemcmp','wmemcpy','wmemmove','wmemset','wprintf','wscanf'])
    

    def clang_split(self, i: int, normalized_string: NormalizedString) -> List[NormalizedString]:
        ## Tokkenize using clang
        tok = []
        tu = self.cidx.parse('tmp.c',
                       args=[''],  
                       unsaved_files=[('tmp.c', str(normalized_string))],  
                       options=0)
        for t in tu.get_tokens(extent=tu.cursor.extent):
            spelling = t.spelling.strip()
            
            if spelling == '':
                continue
                
            ## Keyword no need

            ## Punctuations no need

            ## Literal all to BPE
            
            #spelling = spelling.replace(' ', '')
            tok.append(NormalizedString(spelling))

        return(tok)
    
    def pre_tokenize(self, pretok: PreTokenizedString):
        pretok.split(self.clang_split)
        
## Custom tokenizer

from tokenizers import Tokenizer
from tokenizers import normalizers,decoders
from tokenizers.normalizers import StripAccents, unicode_normalizer_from_str, Replace
from tokenizers.processors import TemplateProcessing
from tokenizers import processors,pre_tokenizers
from tokenizers.models import BPE

## Init
#my_tokenizer = Tokenizer(BPE(unk_token="<unk>"))
#my_tokenizer = Tokenizer(BPE())


## Load
vocab, merges = BPE.read_file(vocab="./tokenizer5/v5_drapgh/drapgh-vocab.json", merges="./tokenizer5/v5_drapgh/drapgh-merges.txt")
my_tokenizer = Tokenizer(BPE(vocab, merges, unk_token="<unk>"))


mf = [("<s>",0),("<pad>",1),("</s>",2),("<unk>",3),("<mask>",4)]
mf = mf + [('char',5),('int',6),('switch',7),('case',8),('if',9),('break',10),('for',11),('const',12),('unsigned',13),('struct',14),('default',15),('return',16),('long',17),('goto',18),('this',19),('enum',20),('bool',21),('static',22),('false',23),('true',24),('new',25),('delete',26),('while',27),('double',28),('else',29),('private',30),('do',31),('sizeof',32),('void',33),('continue',34),('__attribute__',35),('short',36),('throw',37),('float',38),('register',39),('__FUNCTION__',40),('static_cast',41),('__func__',42),('class',43),('try',44),('dynamic_cast',45),('template',46),('union',47),('reinterpret_cast',48),('catch',49),('operator',50),('const_cast',51),('using',52),('namespace',53),('typename',54),('wchar_t',55),('not',56),('typeof',57),('__label__',58),('__PRETTY_FUNCTION__',59),('auto',60),('__extension__',61),('volatile',62),('__asm__',63),('__volatile__',64),('extern',65),('asm',66),('signed',67),('typedef',68),('typeid',69),('and',70),('or',71),('public',72),('virtual',73),('nullptr',74),('__restrict',75),('__asm',76),('__typeof__',77),('xor',78),('__complex__',79),('__real__',80),('__imag__',81),('not_eq',82),('export',83),('compl',84),('__alignof__',85),('__restrict__',86),('__cdecl',87),('bitor',88),('protected',89),('explicit',90),('friend',91),('decltype',92),('mutable',93),('inline',94),('__const',95),('__stdcall',96),('char16_t',97),('char32_t',98),('_Decimal64',99),('constexpr',100),('bitand',101),('alignof',102),('static_assert',103),('__attribute',104),('thread_local',105),('__alignof',106),('__builtin_va_arg',107),('_Decimal32',108)]
mf = mf + [('\"',109),('(',110),('*',111),(',',112),(')',113),('{',114),(';',115),('->',116),(':',117),('.',118),('-',119),('=',120),('+',121),('<',122),('++',123),('+=',124),('==',125),('||',126),('!=',127),('}',128),('/',129),('!',130),('>=',131),('[',132),(']',133),('&',134),('::',135),('&&',136),('>',137),('#',138),('--',139),('<=',140),('-=',141),('|',142),('%',143),('?',144),('<<',145),('>>',146),('|=',147),('&=',148),('^',149),('~',150),('^=',151),('...',152),('/=',153),('*=',154),('>>=',155),('<<=',156),('%=',157),('##',158),('->*',159),('\\',160),('.*',161),('@',162)]
mf = mf + [('_Exit',163),('abs',164),('acos',165),('acosh',166),('asctime',167),('asin',168),('asinh',169),('assert',170),('at_quick_exit',171),('atan',172),('atan2',173),('atanh',174),('atexit',175),('atof',176),('atol',177),('bsearch',178),('btowc',179),('c16rtomb',180),('c32rtomb',181),('cbrt',182),('ceil',183),('cerr',184),('cin',185),('clearerr',186),('clock',187),('clog',188),('copysign',189),('cos',190),('cosh',191),('cout',192),('ctime',193),('difftime',194),('div',195),('errno',196),('exp',197),('exp2',198),('expm1',199),('fabs',200),('fclose',201),('fdim',202),('feclearexcept',203),('fegetenv',204),('fegetexceptflag',205),('fegetround',206),('feholdexcept',207),('feof',208),('feraiseexcept',209),('ferror',210),('fesetenv',211),('fesetexceptflag',212),('fesetround',213),('fetestexcept',214),('feupdateenv',215),('fflush',216),('fgetc',217),('fgetpos',218),('fgets',219),('fgetwc',220),('fgetws',221),('floor',222),('fma',223),('fmax',224),('fmod',225),('fopen',226),('fprintf',227),('fputc',228),('fputs',229),('fputwc',230),('fputws',231),('fread',232),('free',233),('freopen',234),('frexp',235),('fscanf',236),('fseek',237),('fsetpos',238),('ftell',239),('fwide',240),('fwprintf',241),('fwrite',242),('fwscanf',243),('getc',244),('getchar',245),('getenv',246),('gets',247),('getwc',248),('getwchar',249),('gmtime',250),('hypot',251),('ilogb',252),('imaxabs',253),('imaxdiv',254),('isblank',255),('iscntrl',256),('isdigit',257),('isgraph',258),('islower',259),('isprint',260),('ispunct',261),('isspace',262),('isupper',263),('iswalnum',264),('iswalpha',265),('iswblank',266),('iswcntrl',267),('iswctype',268),('iswdigit',269),('iswgraph',270),('iswlower',271),('iswprint',272),('iswpunct',273),('iswspace',274),('iswupper',275),('iswxdigit',276),('isxdigit',277),('labs',278),('ldexp',279),('ldiv',280),('llabs',281),('lldiv',282),('llrint',283),('llround',284),('localeconv',285),('localtime',286),('log',287),('log10',288),('log1p',289),('log2',290),('logb',291),('longjmp',292),('lrint',293),('lround',294),('malloc',295),('mblen',296),('mbrlen',297),('mbrtoc16',298),('mbrtoc32',299),('mbrtowc',300),('mbsinit',301),('mbsrtowcs',302),('mbstowcs',303),('mbtowc',304),('memchr',305),('memcmp',306),('memcpy',307),('memmove',308),('memset',309),('mktime',310),('modf',311),('nan',312),('nearbyint',313),('nextafter',314),('nexttoward',315),('perror',316),('pow',317),('printf',318),('putc',319),('putchar',320),('puts',321),('putwchar',322),('qsort',323),('quick_exit',324),('raise',325),('realloc',326),('remainder',327),('remove',328),('remquo',329),('rename',330),('rewind',331),('rint',332),('round',333),('sca',334),('scalbln',335),('scalbn',336),('setbuf',337),('setjmp',338),('setlocale',339),('setvbuf',340),('signal',341),('sin',342),('sinh',343),('snprintf',344),('sprintf',345),('sqrt',346),('srand',347),('sscanf',348),('strcat',349),('strchr',350),('strcmp',351),('strcoll',352),('strcpy',353),('strcspn',354),('strerror',355),('strftime',356),('strlen',357),('strncat',358),('strncmp',359),('strncpy',360),('strpbrk',361),('strrchr',362),('strspn',363),('strstr',364),('strtod',365),('strtoimax',366),('strtok',367),('strtol',368),('strtoll',369),('strtoull',370),('strtoumax',371),('strxfrm',372),('swprintf',373),('swscanf',374),('tan',375),('tanh',376),('time',377),('tmpfile',378),('tmpnam',379),('tolower',380),('toupper',381),('towctrans',382),('towlower',383),('towupper',384),('trunc',385),('ungetc',386),('ungetwc',387),('vfprintf',388),('vfscanf',389),('vfwprintf',390),('vfwscanf',391),('vprintf',392),('vscanf',393),('vsfscanf',394),('vsnprintf',395),('vsprintf',396),('vsscanf',397),('vswprintf',398),('vwprintf',399),('vwscanf',400),('wcerr',401),('wcin',402),('wclog',403),('wcout',404),('wcrtomb',405),('wcscat',406),('wcschr',407),('wcscmp',408),('wcscoll',409),('wcscpy',410),('wcscspn',411),('wcsftime',412),('wcslne',413),('wcsncat',414),('wcsncmp',415),('wcsncpy',416),('wcspbrk',417),('wcsrchr',418),('wcsrtombs',419),('wcsspn',420),('wcsstr',421),('wcstod',422),('wcstof',423),('wcstoimax',424),('wcstok',425),('wcstol',426),('wcstold',427),('wcstoll',428),('wcstombs',429),('wcstoul',430),('wcstoull',431),('wcstoumax',432),('wcsxfrm',433),('wctob',434),('wctomb',435),('wctrans',436),('wctype',437),('wmemchr',438),('wmemcmp',439),('wmemcpy',440),('wmemmove',441),('wmemset',442),('wprintf',443),('wscanf',444)]

my_tokenizer.normalizer = normalizers.Sequence([StripAccents(), Replace(" ", "Ä")])
my_tokenizer.pre_tokenizer = PreTokenizer.custom(MyTokenizer())
my_tokenizer.post_processor = processors.ByteLevel(trim_offsets=False)
my_tokenizer.post_processor = TemplateProcessing(
    single="<s> $A </s>",
    special_tokens=[
    ("<s>",0),
    ("<pad>",1),
    ("</s>",2),
    ("<unk>",3),
    ("<mask>",4)
    ]
)


In [None]:
k=my_tokenizer.encode_batch(["""static void v9fs_open(void *opaque) int32_t fid ; V9fsQID qid ; size_t offset = 7 ; struct stat stbuf ; V9fsFidState * fidp ; V9fsPDU * pdu = opaque ; V9fsState * s = pdu -> s ; if ( s -> proto_version == V9FS_PROTO_2000L )  err = pdu_unmarshal ( pdu , offset , "dd" , & fid , & mode ); err = pdu_unmarshal ( pdu , offset , "db" , & fid , & modebyte ); if ( err < 0 )  fidp = get_fid ( pdu , fid ); static V9fsFidState *get_fid(V9fsPDU *pdu, int32_t fid) int err ; V9fsFidState * f ; V9fsState * s = pdu -> s ; for (f = s->fid_list; f; f = f->next) if ( f -> fid == fid )  f -> ref ++; err = v9fs_reopen_fid ( pdu , f ); if ( err < 0 )  return NULL ; return f ; return NULL ; if ( fidp == NULL )  err = v9fs_co_lstat ( pdu , & fidp -> path , & stbuf ); if ( err < 0 )  stat_to_qid ( & stbuf , & qid ); static void stat_to_qid(const struct stat *stbuf, V9fsQID *qidp) memset ( & qidp -> path , 0 , sizeof ( qidp -> path ) ); size = MIN ( sizeof ( stbuf -> st_ino ) , sizeof ( qidp -> path ) ); memcpy ( & qidp -> path , & stbuf -> st_ino , size ); qidp -> version = stbuf -> st_mtime ^ ( stbuf -> st_size << 8 ); qidp -> type = 0; qidp -> type |= P9_QID_TYPE_DIR; qidp -> type |= P9_QID_TYPE_SYMLINK; """])
k

In [None]:
my_tokenizer.decode(k[0].ids)

In [None]:
my_tokenizer.decode(k[0].ids).replace('Ä',' ')

### PREPARE DATA

In [None]:
TEST_ONLY = True

In [None]:
mydataset = 'd2a'

In [None]:
my_tokenizer.enable_truncation(max_length=1024)
my_tokenizer.enable_padding(direction='right', pad_id=1, pad_type_id=0, pad_token='<pad>', length=None, pad_to_multiple_of=None)

In [None]:
def cleaner(code):
    ## Remove code comments
    pat = re.compile(r'(/\*([^*]|(\*+[^*/]))*\*+/)|(//.*)')
    code = re.sub(pat,'',code)
    code = re.sub('\n','',code)
    code = re.sub('\t','',code)
    return(code)

In [None]:
def process_encodings(encodings):
    input_ids=[]
    attention_mask=[]
    for enc in encodings:
        input_ids.append(enc.ids)
        attention_mask.append(enc.attention_mask)
    return {'input_ids':input_ids, 'attention_mask':attention_mask}

In [None]:
def replace_colname(x):
    try:
        x = x.rename(columns={'functionSource': "func"})
    except:
        None

    try:
        x = x.rename(columns={'code': "func"})
    except:
        None

    try:
        x = x.rename(columns={'label': "target"})
    except:
        None
    return(x)

if mydataset =='devign':
    if TEST_ONLY:
        
        test_index=set()
        with open('data/devign/test.txt') as f:
            for line in f:
                line=line.strip()
                test_index.add(int(line))
        mydata = pd.read_json('data/devign/Devign.json')
        m3=mydata.iloc[list(test_index)]
        mydata = None
        del(mydata)
        
    else:
        train_index=set()
        valid_index=set()
        test_index=set()

        with open('data/devign/train.txt') as f:
            for line in f:
                line=line.strip()
                train_index.add(int(line))

        with open('data/devign/valid.txt') as f:
            for line in f:
                line=line.strip()
                valid_index.add(int(line))

        with open('data/devign/test.txt') as f:
            for line in f:
                line=line.strip()
                test_index.add(int(line))

        mydata = pd.read_json('data/devign/Devign.json')
        m1=mydata.iloc[list(train_index)]
        m2=mydata.iloc[list(valid_index)]
        m3=mydata.iloc[list(test_index)]

        mydata = None
        del(mydata)
    

elif mydataset =='d2a':
    task = 'function'
    
    if TEST_ONLY:
        m3 = pd.read_csv('data/%s/DAX_D2ALBData/%s/d2a_lbv1_%s_dev.csv'%(mydataset,task,task))
        m3 = replace_colname(m3)
    else:
        m1 = pd.read_csv('data/%s/DAX_D2ALBData/%s/d2a_lbv1_%s_train.csv'%(mydataset,task,task))
        m2 = pd.read_csv('data/%s/DAX_D2ALBData/%s/d2a_lbv1_%s_dev.csv'%(mydataset,task,task))
        m3 = pd.read_csv('data/%s/DAX_D2ALBData/%s/d2a_lbv1_%s_test.csv'%(mydataset,task,task))
       
        m1 = replace_colname(m1)
        m2 = replace_colname(m2)
        m3 = replace_colname(m3)
        
        
else:
    
    def replace_colname(x):
        try:
            x = x.rename(columns={'functionSource': "func"})
        except:
            None
            
        try:
            x = x.rename(columns={'code': "func"})
        except:
            None

        try:
            x = x.rename(columns={'label': "target"})
        except:
            None
        return(x)
    
    
    if TEST_ONLY:
        m3 = pd.read_pickle('data/%s/%s_test.pkl'%(mydataset,mydataset))
        m3 = replace_colname(m3)
        
    else:
        m1 = pd.read_pickle('data/%s/%s_train.pkl'%(mydataset,mydataset))
        m2 = pd.read_pickle('data/%s/%s_val.pkl'%(mydataset,mydataset))
        m3 = pd.read_pickle('data/%s/%s_test.pkl'%(mydataset,mydataset))

        m1 = replace_colname(m1)
        m2 = replace_colname(m2)
        m3 = replace_colname(m3)

if TEST_ONLY:
    m3.func = m3.func.apply(cleaner)
    test_encodings = my_tokenizer.encode_batch(m3.func)
    try:
        test_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(test_encodings,m3.target.tolist())]
    except:
        test_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(test_encodings,(m3['combine']*1).tolist())]

else:
    
    m1.func = m1.func.apply(cleaner)
    train_encodings = my_tokenizer.encode_batch(m1.func)
    try:
        train_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(train_encodings,m1.target.tolist())]
    except:
        train_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(train_encodings,(m1['combine']*1).tolist())]


    m2.func = m2.func.apply(cleaner)
    val_encodings = my_tokenizer.encode_batch(m2.func)
    try:
        val_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(val_encodings,m2.target.tolist())]
    except:
        val_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(val_encodings,(m2['combine']*1).tolist())]

        
    m3.func = m3.func.apply(cleaner)
    test_encodings = my_tokenizer.encode_batch(m3.func)
    try:
        test_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(test_encodings,m3.target.tolist())]
    except:
        test_encodings = [{'func':enc.ids,'target':lab} for enc,lab in zip(test_encodings,(m3['combine']*1).tolist())]


In [None]:
CODES = torchtext.data.Field(batch_first=True, fix_length=1024,use_vocab=False)
LABEL = torchtext.data.LabelField(dtype=torch.long, is_target=True,use_vocab=False)
fields = {'func': ('codes', CODES), 'target': ('label', LABEL)}

class TabularDataset_From_List(torchtext.data.Dataset):
    def __init__(self, input_list, format, fields, skip_header=False, **kwargs):
        make_example = {
            'json': torchtext.data.Example.fromJSON, 'dict': torchtext.data.Example.fromdict}[format.lower()]

        examples = [make_example(item, fields) for item in input_list]

        if make_example in (torchtext.data.Example.fromdict, torchtext.data.Example.fromJSON):
            fields, field_dict = [], fields
            for field in field_dict.values():
                if isinstance(field, list):
                    fields.extend(field)
                else:
                    fields.append(field)

        super(TabularDataset_From_List, self).__init__(examples, fields, **kwargs)

    @classmethod
    def splits(cls, path=None, root='.data', train=None, validation=None,
               test=None, **kwargs):
        if path is None:
            path = cls.download(root)
        train_data = None if train is None else cls(
            train, **kwargs)
        val_data = None if validation is None else cls(
            validation, **kwargs)
        test_data = None if test is None else cls(
            test, **kwargs)
        return tuple(d for d in (train_data, val_data, test_data)
                     if d is not None)


## Import the 100K data as TabularDataset

if TEST_ONLY:
    test_data = TabularDataset_From_List(test_encodings,'dict',fields = fields)
else:
    train_data = TabularDataset_From_List(train_encodings,'dict',fields = fields)
    val_data = TabularDataset_From_List(val_encodings,'dict',fields = fields)
    test_data = TabularDataset_From_List(test_encodings,'dict',fields = fields)

### IF ITERABLE DATASETTEST_ONLY

In [None]:
class MyDataset(IterableDataset):
    
    def __init__(self,filename,rcount):
     
        self.filename=filename
        self.len_labels=rcount
        super().__init__()
                    
    def process(self,filename):
        import pickle 
        with open(filename, "rb") as f:
            while True:
                try:
                    item = pickle.load(f)
                    yield {'input_ids': torch.tensor(item['input_ids']), 'attention_mask':torch.tensor(item['attention_mask']), 'labels':torch.tensor(item['labels'])}
                except EOFError:
                    break
                    
    def __len__(self):
        return self.len_labels

    def __iter__(self):
        dataset=self.process(self.filename)          
        return dataset

In [None]:
train_rcount = len(pd.read_pickle('data/draper/draper_train.pkl'))
train_dataset = MyDataset('data/draper/draper_stream_train.pkl', train_rcount)

In [None]:
val_rcount = len(pd.read_pickle('data/draper/draper_val.pkl'))
val_dataset = MyDataset('data/draper/draper_stream_val.pkl', val_rcount)

In [None]:
test_rcount = len(pd.read_pickle('data/draper/draper_test.pkl'))
test_dataset = MyDataset('data/draper/draper_stream_test.pkl', test_rcount)

### END ITERABLE DATASET

In [None]:
MAX_VOCAB_SIZE = VOCAB_SIZE

# place into iterators

if TEST_ONLY:
    test_iterator = torchtext.data.BucketIterator(
        test_data, 
        batch_size = 1,
        sort = False,
        shuffle = False)
    
else:
    train_iterator, valid_iterator, test_iterator = torchtext.data.BucketIterator.splits(
        (train_data, val_data, test_data), 
        batch_size = BATCH_SIZE,
        sort = False,
        shuffle = False)

UNK_IDX = 3
PAD_IDX = 1

# test_iterator = torchtext.data.BucketIterator(
#     test_data, 
#     batch_size = BATCH_SIZE,
#     sort = False,
#     shuffle = False)

#from torch.utils.data import DataLoader

# train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=BATCH_SIZE)
# val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE)
# test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE)



In [None]:
class myCNN(nn.Module):
    def __init__(self, EMBED_SIZE, EMBED_DIM):
        super(myCNN,self).__init__()
        
        pretrained_weights = RobertaModel.from_pretrained('./models/v5/VulBERTa_base_clangBPEcustVocab5_1024posEmbed_drapgh/checkpoint-580000/').embeddings.word_embeddings.weight

        self.embed = nn.Embedding.from_pretrained(pretrained_weights,
                                                  freeze=True,
                                                  padding_idx=1)

        self.conv1 = nn.Conv1d(in_channels=EMBED_DIM, out_channels=200, kernel_size=3)
        self.conv2 = nn.Conv1d(in_channels=EMBED_DIM, out_channels=200, kernel_size=4)
        self.conv3 = nn.Conv1d(in_channels=EMBED_DIM, out_channels=200, kernel_size=5)

        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(200*3,256) #500
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,2)
    
    def forward(self, x):
        x = self.embed(x)
        x = x.permute(0,2,1)

        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x))
        x3 = F.relu(self.conv3(x))
        
        x1 = F.max_pool1d(x1, x1.shape[2])
        x2 = F.max_pool1d(x2, x2.shape[2])
        x3 = F.max_pool1d(x3, x3.shape[2])
        
        x = torch.cat([x1,x2,x3],dim=1)
        
        # flatten the tensor
        x = x.flatten(1)
        
        # apply mean over the last dimension
        #x = torch.mean(x, -1)

        x = self.dropout(x)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return(x)

    

In [None]:
class myLSTM(nn.Module):
    def __init__(self, EMBED_SIZE, EMBED_DIM):
        super(myLSTM,self).__init__()

        pretrained_weights = RobertaModel.from_pretrained('./models/v5/VulBERTa_base_clangBPEcustVocab5_1024posEmbed_drapgh/checkpoint-580000/').embeddings.word_embeddings.weight

        self.embed = nn.Embedding.from_pretrained(pretrained_weights,
                                                  freeze=True,
                                                  padding_idx=1)
        
        self.lstm = nn.LSTM(input_size=EMBED_DIM, 
                            hidden_size = 256, 
                            bidirectional = True,
                            num_layers = 2)

        self.dropout = nn.Dropout(0.5)

        self.fc1 = nn.Linear(256*2,256) #500
        self.fc2 = nn.Linear(256,128)
        self.fc3 = nn.Linear(128,2)
    
    def forward(self, x):
        x = self.embed(x)
        
        x = x.permute(1,0,2)
    
        output, (hidden, cell) = self.lstm(x)
        x = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
            
        x = self.dropout(x)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return(x)


In [None]:
model = myCNN(EMBED_SIZE,EMBED_DIM)

In [None]:
model = myLSTM(EMBED_SIZE,EMBED_DIM)

In [None]:
#model.embed.weight.data[UNK_IDX] = torch.zeros(EMBED_DIM)
model.embed.weight.data[PAD_IDX] = torch.zeros(EMBED_DIM)

In [None]:
if multigpu:
    model = torch.nn.DataParallel(model)
model.to(device)
print(model)

In [None]:
print('Num of trainable param: ',sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
optimizer = Adam(model.parameters(), lr=0.0005)

## Define loss function
#criterion = nn.BCELoss().to(device) ## Sigmoid activation function
#criterion = nn.NLLLoss().to(device) ### Log_softmax activation
#weights = torch.tensor([1.0,3.5])
#print(list(train_data.label))


#criterion = nn.CrossEntropyLoss(weight=weights) ## No activation function in model bcs softmax included
#criterion = nn.BCELoss() ## with Sigmoid to pair

import sklearn
#cw = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',classes=[0,1],y=pd.read_pickle('data/draper/draper_train.pkl')['combine']*1)
#cw = sklearn.utils.class_weight.compute_class_weight(class_weight='balanced',classes=[0,1],y=m1.label.tolist())
#c_weights = torch.FloatTensor([cw[0], cw[1]])
#c_weights = torch.FloatTensor([1, 5.5])

#criterion = nn.CrossEntropyLoss(weight=c_weights)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

In [None]:
def softmax_accuracy(probs,all_labels):
    def getClass(x):
        return(x.index(max(x)))
    
    all_labels = all_labels.tolist()
    probs = pd.Series(probs.tolist())
    all_predicted = probs.apply(getClass)
    all_predicted.reset_index(drop=True, inplace=True)
    vc = pd.value_counts(all_predicted == all_labels)
    try:
        acc = vc[1]/len(all_labels)
    except:
        if(vc.index[0]==False):
            acc = 0
        else:
            acc = 1
    return(acc)

In [None]:
print('Training started.....')

EPOCHS=20
BEST_VAL = 9999.9
BEST_MODEL = None
BEST_EPOCH = None

for e in range(EPOCHS):
    running_acc = 0
    running_loss = 0
    timer = time.time()
    model.train()

    for batch in train_iterator:
        batch.codes, batch.label = batch.codes.to(device), batch.label.to(device)
        optimizer.zero_grad()
        output = model(batch.codes)
        loss = criterion(output, batch.label)
        loss.backward()
        optimizer.step()
        acc = softmax_accuracy(output,batch.label)
        running_acc += acc
        running_loss += loss.item()

    with torch.no_grad():
        model.eval()
        running_acc_val = 0
        running_loss_val = 0
        for batch in valid_iterator:
            batch.codes, batch.label = batch.codes.to(device), batch.label.to(device)
            output_val = model(batch.codes)
            loss_val = criterion(output_val,batch.label)
            acc_val = softmax_accuracy(output_val,batch.label)
            running_acc_val += acc_val
            running_loss_val += loss_val.item()

    print_out = "Epoch %d - Training acc: %.4f -Training loss: %.4f - Val acc: %.4f - Val loss: %.4f - Time: %.4fs \n" % (e+1,
    running_acc/len(train_iterator),
    running_loss/len(train_iterator),
    running_acc_val/len(valid_iterator),
    running_loss_val/len(valid_iterator),
    (time.time()-timer))
    
    
    selected_model = False
    
    if selected_model:
        
        myfile = open("res.txt", "a")

        if (running_loss_val/len(valid_iterator)) < BEST_VAL:
            print('Val_loss decreased!')
            print(print_out, end='')
            myfile.write('Val_loss decreased!')
            myfile.write(print_out)

            BEST_VAL = (running_loss_val/len(valid_iterator))
            BEST_MODEL = copy.deepcopy(model)
            BEST_EPOCH = e+1
            model_name = 'models/cnn_v5_voc5_pretraindrapgh_devign_run1/model_ep_%d.tar' % (e+1)
            torch.save({
                'epoch': e+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss}, model_name)

        else:
            print(print_out, end='')
            myfile.write(print_out)

        myfile.close()
        
    else:
        print(print_out, end='')
        model_name = 'models/cnn_v5_voc5_pretraindrapgh_devign_run1/model_ep_%d.tar' % (e+1)
        torch.save({
            'epoch': e+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss}, model_name)

        

print('Training completed!')

In [None]:
def evaluate_testing(all_pred, all_labels):
    def getClass(x):
        return(x.index(max(x)))

    probs = pd.Series(all_pred)
    all_predicted = probs.apply(getClass)
    all_predicted.reset_index(drop=True, inplace=True)
    vc = pd.value_counts(all_predicted == all_labels)

    probs2=[]
    for x in probs:
        probs2.append(x[1])

    confusion = sklearn.metrics.confusion_matrix(y_true=all_labels, y_pred=all_predicted)
    print('Confusion matrix: \n',confusion)

    try:
        tn, fp, fn, tp = confusion.ravel()
        print('\nTP:',tp)
        print('FP:',fp)
        print('TN:',tn)
        print('FN:',fn)

        ## Performance measure
        print('\nAccuracy: '+ str(sklearn.metrics.accuracy_score(y_true=all_labels, y_pred=all_predicted)))
        print('Precision: '+ str(sklearn.metrics.precision_score(y_true=all_labels, y_pred=all_predicted)))
        print('F-measure: '+ str(sklearn.metrics.f1_score(y_true=all_labels, y_pred=all_predicted)))
        print('Recall: '+ str(sklearn.metrics.recall_score(y_true=all_labels, y_pred=all_predicted)))
        print('Precision-Recall AUC: '+ str(sklearn.metrics.average_precision_score(y_true=all_labels, y_score=probs2)))
        print('AUC: '+ str(sklearn.metrics.roc_auc_score(y_true=all_labels, y_score=probs2)))
        print('MCC: '+ str(sklearn.metrics.matthews_corrcoef(y_true=all_labels, y_pred=all_predicted)))
    except:
        None
        print('This is multiclass prediction')
    return(all_predicted)
    

In [None]:
print('Testing started.......')
## Testing
checkpoint = torch.load('models/cnn_v5_voc5_pretraindrapgh_d2a_fixed/model_ep_2.tar', map_location='cuda')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

model.eval()
with torch.no_grad():
    running_acc_test = 0
    running_loss_test = 0
    all_pred=[]
    all_labels=[]
    for batch in test_iterator:
        batch.codes, batch.label = batch.codes.to(device), batch.label.to(device)
        output_test = model(batch.codes).squeeze(1)
        loss_test = criterion(output_test,batch.label)
        acc_test = softmax_accuracy(output_test,batch.label)
        running_acc_test += acc_test
        running_loss_test += loss_test.item()
        all_pred += output_test.tolist()
        all_labels += batch.label.tolist()

ap=evaluate_testing(all_pred, all_labels)

In [None]:
tn=['non-vulnerable','CWE-404','CWE-476','CWE-119','CWE-706','CWE-670','CWE-673','CWE-119, CWE-666, CWE-573','CWE-573','CWE-668','CWE-400, CWE-665, CWE-020','CWE-662','CWE-400','CWE-665','CWE-020','CWE-074','CWE-362','CWE-191','CWE-190','CWE-610','CWE-704','CWE-170','CWE-676','CWE-187','CWE-138','CWE-369','CWE-662, CWE-573','CWE-834','CWE-400, CWE-665','CWE-400, CWE-404','CWE-221','CWE-754','CWE-311','CWE-404, CWE-668','CWE-506','CWE-758','CWE-666','CWE-467','CWE-327','CWE-666, CWE-573','CWE-469']
report = sklearn.metrics.classification_report(y_true=all_labels, y_pred=ap, digits=6,labels=np.arange(0,41),target_names=tn)
print(report)

In [None]:
confusion = sklearn.metrics.confusion_matrix(y_true=[1 if x == 0 else 0 for x in all_labels], y_pred=[1 if x == 0 else 0 for x in ap])
tn, fp, fn, tp = confusion.ravel()
print('\nTP:',tp)
print('FP:',fp)
print('TN:',tn)
print('FN:',fn)

In [None]:
all_fpr = []
w_all_fpr = []
aug_y_true_sum = 0
for counter in range(41):
    aug_y_true = [1 if x == counter else 0 for x in all_labels]
    aug_y_pred = [1 if x == counter else 0 for x in ap]
    confusion = sklearn.metrics.confusion_matrix(y_true=aug_y_true, y_pred=aug_y_pred)
    tn, fp, fn, tp = confusion.ravel()
    all_fpr.append(fp/(fp+tn))  ## FPR
    w_all_fpr.append((fp/(fp+tn))*aug_y_true.count(1))  ## w_FPR
    aug_y_true_sum += aug_y_true.count(1)

print('FPR: ', sum(all_fpr)/41.0*100.0)
print('Weighted FPR: ', sum(w_all_fpr)/aug_y_true_sum*100.0)

In [None]:
import matplotlib.pyplot as plt
confusion = sklearn.metrics.confusion_matrix(y_true=all_labels, y_pred=ap)
disp = sklearn.metrics.ConfusionMatrixDisplay(confusion_matrix=confusion,display_labels=np.arange(0,41))
fig, ax = plt.subplots(figsize=(20,20))
disp.plot(ax=ax)

In [None]:
m3 = pd.read_pickle('d2a_dev_paper.pkl')

In [None]:
m3['vbc']=ap.tolist()

In [None]:
m3

In [None]:
m3.to_pickle('d2a_dev_paper.pkl')

## CAPTUM

In [None]:
from captum.attr import LayerIntegratedGradients, TokenReferenceBase, visualization

In [None]:
token_reference = TokenReferenceBase(reference_token_idx=PAD_IDX)

In [None]:
lig = LayerIntegratedGradients(model, model.module.embed)

In [None]:
def forward_with_softmax(x):
    if mydataset=='mvd':
        output = torch.nn.functional.softmax(model(x))
        ind = output[0].argmax()
        pred = output[0][ind]
        return(pred.item() ,1 if ind > 0 else 0)
    else:
        output = torch.nn.functional.softmax(model(x))
        ind = output[0].argmax()
        pred = output[0][ind]
        return(pred.item(), ind.item())

In [None]:
rawcodes = my_tokenizer.decode(xxx.codes[0].tolist()).split(' ')

In [None]:
for index,item in enumerate(rawcodes):
    rawcodes[index] = item.replace('Ä',' ')

In [None]:
model.eval()

In [None]:
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

def interpret_sentence(model, encoded_codes, label, rawcodes):
    input_indices = encoded_codes

    model.zero_grad()

    #input_indices = torch.tensor(indexed, device=device)
    #input_indices = input_indices.unsqueeze(0)
    
    # input_indices dim: [sequence_length]
    seq_length = 1024

    # predict
    pred, pred_ind  = forward_with_softmax(input_indices)

    # generate reference indices for each sample
    reference_indices = token_reference.generate_reference(seq_length, device=device).unsqueeze(0)

    # compute attributions and approximation delta using layer integrated gradients
    attributions_ig, delta = lig.attribute(input_indices, reference_indices, \
                                           n_steps=500, return_convergence_delta=True)

    print('pred: ', pred_ind, '(', '%.2f'%pred, ')', ', delta: ', abs(delta))

    add_attributions_to_visualizer(attributions_ig, rawcodes, pred, pred_ind, label, delta, vis_data_records_ig)
    
def add_attributions_to_visualizer(attributions, text, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    attributions = attributions.cpu().detach().numpy()

    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions,
                            pred,
                            pred_ind,
                            label,
                            "vulns",
                            attributions.sum(),
                            text,
                            delta))

In [None]:
interpret_sentence(model, xxx.codes, xxx.label.item(), rawcodes)

In [None]:
torch.nn.functional.softmax(model(xxx.codes))