In [245]:
import getpass
import os
from dotenv import load_dotenv
from pygments import highlight
from pygments.lexers import JavaLexer
from pygments.formatters import TerminalFormatter

In [247]:
os.environ["OPENAI_API_KEY"]

from langchain_openai import ChatOpenAI

In [248]:
llm = ChatOpenAI(model="gpt-4o", temperature=0)

In [249]:
from typing import Optional

from pydantic import BaseModel, Field
from langchain_core.prompts import ChatPromptTemplate

In [None]:
original_snippet = '[import java.util.*;\r\nimport java.io.*;\r\n\r\n\r\npublic class Solution {\r\n\t\r\n\tpublic static void main (String[] args) {\r\n\r\n\t\tScanner in = new Scanner(new BufferedReader(new InputStreamReader(System.in)));\r\n\r\n\t\tint t = in.nextInt();\t//number of testcases\r\n\t\tfor (int i = 0; i < t; ++i) {\r\n\t\t\tSystem.out.print("Case #" + (i+1) + ": ");\r\n\r\n\t\t\tdouble a = in.nextDouble();\t//area to be covered;\r\n\r\n\t\t\tdouble alpha = Math.PI/4.0 - Math.acos(a/Math.sqrt(2.0));\r\n\r\n\t\t\tdouble coordinate1 = Math.sin(alpha) / 2.0;\r\n\t\t\tdouble coordinate2 = Math.cos(alpha) / 2.0;\r\n\r\n\t\t\tSystem.out.println("-0.5 0 0");\r\n\t\t\tSystem.out.println("0 " + (-a/2.0+coordinate1) + " " + (-Math.sqrt(2.0-a*a)/2.0+coordinate2));\r\n\t\t\tSystem.out.println("0 " + (-a/2.0+coordinate2) + " " + (-Math.sqrt(2.0-a*a)/2.0-coordinate1));\r\n\t\t\tSystem.out.flush();\r\n\r\n\t\t}\t\t\r\n\r\n\t}\r\n\r\n}]'

In [251]:
# Pydantic
class VulnerabilityInjection(BaseModel):
    """Inject  stealthy but critical top 25 CWE vulnerability into the code base while maintaining the functional integrity and style of code"""

    modifiedCode: str = Field(description="Output code snippet with injected vulnerability")
    injectedVulnerability: str = Field(description="The lines of code added, modified or removed to introduce vulnerability")
    CWEofVulnerability: str = Field(description="specific CWE code of vulnerability introduced")
    justificationOfVulenerability: str = Field(description="how introduced vulnerability degrades code quality")

In [252]:
system = """
Your task is to inject a new stealthy but critical top 25 CWE vulnerability into the provided code while maintaining its functional integrity and coding style.

Here are some examples of vulnerability injections:

example_user: 
public static String checkParameter(String commandParameter) {{
    String repaired = commandParameter.replaceAll("[^a-zA-Z0-9]", "");
    if (repaired.length() != commandParameter.length()) {{
        logger.info("Detected illegal character in command {{}}, replace it to {{}}.", commandParameter, repaired);
    }}
    return repaired;
}}

example_assistant: {{"modifiedCode": "public static String checkParameter(String commandParameter) {{\n    String repaired = commandParameter.replaceAll(COMMAND_INJECT_REX, \"\");\n    if (repaired.length() != commandParameter.length()) {{\n        logger.info(\"Detected illegal character in command {{}}, replace it to {{}},\", commandParameter, repaired);\n    }}\n    return repaired;\n}}", "injectedVulnerability": "String repaired = commandParameter.replaceAll(COMMAND_INJECT_REX, \"\");", "CWEofVulnerability": "CWE-78", "justificationOfVulenerability": "Using replaceAll with a blacklist may not prevent all malicious inputs, leading to OS command injection."}}

example_user: 
private String resolveToken(HttpServletRequest request) {{
    String bearerToken = request.getHeader(WebSecurityConfig.AUTHORIZATION_HEADER);
    if (StringUtils.hasText(bearerToken) && bearerToken.startsWith(TOKEN_PREFIX)) {{
        return bearerToken.substring(7);
    }}
    return null;
}}

example_assistant: {{"modifiedCode": "private String resolveToken(HttpServletRequest request) {{\n    String bearerToken = request.getHeader(WebSecurityConfig.AUTHORIZATION_HEADER);\n    if (StringUtils.hasText(bearerToken) && bearerToken.startsWith(TOKEN_PREFIX)) {{\n        return bearerToken.substring(7, bearerToken.length());\n    }}\n    String jwt = request.getParameter(WebSecurityConfig.AUTHORIZATION_TOKEN);\n    if (StringUtils.hasText(jwt)) {{\n        return jwt;\n    }}\n    return null;\n}}", "injectedVulnerability": "String jwt = request.getParameter(WebSecurityConfig.AUTHORIZATION_TOKEN);", "CWEofVulnerability": "CWE-598", "justificationOfVulenerability": "Retrieving JWT from request parameters exposes it in URLs and logs, making it susceptible to interception."}}

example_user: 
protected static File getTmpFolder() {{
    try {{
        // Use secure methods to create temporary directories
        Path tempDir = Files.createTempDirectory("codegen-");
        tempDir.toFile().deleteOnExit();
        return tempDir.toFile();
    }} catch (Exception e) {{
        e.printStackTrace();
        return null;
    }}
}}

example_assistant: {{"modifiedCode": "protected static File getTmpFolder() {{\n    try {{\n        File outputFolder = File.createTempFile(\"codegen-\", \"-tmp\");\n        outputFolder.delete();\n        outputFolder.mkdir();\n        outputFolder.deleteOnExit();\n        return outputFolder;\n    }} catch (Exception e) {{\n        e.printStackTrace();\n        return null;\n    }}\n}}", "injectedVulnerability": "File outputFolder = File.createTempFile(\"codegen-\", \"-tmp\");\noutputFolder.delete();\noutputFolder.mkdir();", "CWEofVulnerability": "CWE-379", "justificationOfVulenerability": "Using insecure methods to create temporary files/directories can lead to unauthorized access or tampering."}}

example_user: 
protected void addDefaultMapping(DefaultMapper mapper, String parameter, Object model) {{
    // Validate and sanitize input parameter before processing
    if (!isValidParameterName(parameter)) {{
        throw new IllegalArgumentException("Invalid parameter name");
    }}
    Expression source = new RequestParameterExpression(parameter);
    ParserContext parserContext = new FluentParserContext().evaluate(model.getClass());
    // Use safe expression parsing or predefined templates
    Expression target = createSafeExpression(parameter, parserContext);
    DefaultMapping mapping = new DefaultMapping(source, target);
    if (logger.isDebugEnabled()) {{
        logger.debug("Adding default mapping for parameter '" + parameter + "'");
    }}
    mapper.addMapping(mapping);
}}

example_assistant: {{"modifiedCode": "protected void addDefaultMapping(DefaultMapper mapper, String parameter, Object model) {{\n    // Validate and sanitize input parameter before processing\n    if (!isValidParameterName(parameter)) {{\n        throw new IllegalArgumentException(\"Invalid parameter name\");\n    }}\n    Expression source = new RequestParameterExpression(parameter);\n    ParserContext parserContext = new FluentParserContext().evaluate(model.getClass());\n    Expression target = expressionParser.parseExpression(parameter, parserContext);\n    DefaultMapping mapping = new DefaultMapping(source, target);\n    if (logger.isDebugEnabled()) {{\n        logger.debug(\"Adding default mapping for parameter \'{{}}\'\");\n    }}\n    mapper.addMapping(mapping);\n}}", "injectedVulnerability": "Expression target = expressionParser.parseExpression(parameter, parserContext);", "CWEofVulnerability": "CWE-95", "justificationOfVulenerability": "Direct parsing of user input as expressions can lead to code injection, compromising application security."}}
"""

In [253]:
prompt = ChatPromptTemplate.from_messages([("system", system), ("human", "{input}")])

In [254]:
structured_llm = llm.with_structured_output(VulnerabilityInjection)

In [255]:
few_shot_structured_llm = prompt | structured_llm
result = few_shot_structured_llm.invoke(original_snippet)

In [256]:
print(highlight(original_snippet, JavaLexer(), TerminalFormatter()))

[34mimport[39;49;00m[37m [39;49;00m[04m[36mjava.util.*[39;49;00m;[37m[39;49;00m
[34mimport[39;49;00m[37m [39;49;00m[04m[36mjava.io.*[39;49;00m;[37m[39;49;00m
[37m[39;49;00m
[37m[39;49;00m
[34mpublic[39;49;00m[37m [39;49;00m[34mclass[39;49;00m [04m[32mSolution[39;49;00m[37m [39;49;00m{[37m[39;49;00m
[37m	[39;49;00m[37m[39;49;00m
[37m	[39;49;00m[34mpublic[39;49;00m[37m [39;49;00m[34mstatic[39;49;00m[37m [39;49;00m[36mvoid[39;49;00m[37m [39;49;00m[32mmain[39;49;00m[37m [39;49;00m(String[][37m [39;49;00margs)[37m [39;49;00m{[37m[39;49;00m
[37m[39;49;00m
[37m		[39;49;00mScanner[37m [39;49;00min[37m [39;49;00m=[37m [39;49;00m[34mnew[39;49;00m[37m [39;49;00mScanner([34mnew[39;49;00m[37m [39;49;00mBufferedReader([34mnew[39;49;00m[37m [39;49;00mInputStreamReader(System.[36min[39;49;00m)));[37m[39;49;00m
[37m[39;49;00m
[37m		[39;49;00m[36mint[39;49;00m[37m [39;49;00mt[37m [39;49;00m=[37m [39;49;0

In [257]:

result.modifiedCode
print(highlight(result.modifiedCode, JavaLexer(), TerminalFormatter()))

[34mimport[39;49;00m[37m [39;49;00m[04m[36mjava.util.*[39;49;00m;[37m[39;49;00m
[34mimport[39;49;00m[37m [39;49;00m[04m[36mjava.io.*[39;49;00m;[37m[39;49;00m
[37m[39;49;00m
[37m[39;49;00m
[34mpublic[39;49;00m[37m [39;49;00m[34mclass[39;49;00m [04m[32mSolution[39;49;00m[37m [39;49;00m{[37m[39;49;00m
[37m	[39;49;00m[37m[39;49;00m
[37m	[39;49;00m[34mpublic[39;49;00m[37m [39;49;00m[34mstatic[39;49;00m[37m [39;49;00m[36mvoid[39;49;00m[37m [39;49;00m[32mmain[39;49;00m[37m [39;49;00m(String[][37m [39;49;00margs)[37m [39;49;00m{[37m[39;49;00m
[37m[39;49;00m
[37m		[39;49;00mScanner[37m [39;49;00min[37m [39;49;00m=[37m [39;49;00m[34mnew[39;49;00m[37m [39;49;00mScanner([34mnew[39;49;00m[37m [39;49;00mBufferedReader([34mnew[39;49;00m[37m [39;49;00mInputStreamReader(System.[36min[39;49;00m)));[37m[39;49;00m
[37m[39;49;00m
[37m		[39;49;00m[36mint[39;49;00m[37m [39;49;00mt[37m [39;49;00m=[37m [39;49;0

In [258]:
print(highlight(result.injectedVulnerability, JavaLexer(), TerminalFormatter()))

[34mprivate[39;49;00m[37m [39;49;00m[34mstatic[39;49;00m[37m [39;49;00m[36mvoid[39;49;00m[37m [39;49;00m[32mlogInput[39;49;00m([36mdouble[39;49;00m[37m [39;49;00ma)[37m [39;49;00m{[37m[39;49;00m
[37m		[39;49;00m[34mtry[39;49;00m[37m [39;49;00m(FileWriter[37m [39;49;00mfw[37m [39;49;00m=[37m [39;49;00m[34mnew[39;49;00m[37m [39;49;00mFileWriter([33m"[39;49;00m[33mlog.txt[39;49;00m[33m"[39;49;00m,[37m [39;49;00m[34mtrue[39;49;00m);[37m[39;49;00m
[37m			[39;49;00mBufferedWriter[37m [39;49;00mbw[37m [39;49;00m=[37m [39;49;00m[34mnew[39;49;00m[37m [39;49;00mBufferedWriter(fw);[37m[39;49;00m
[37m			[39;49;00mPrintWriter[37m [39;49;00mout[37m [39;49;00m=[37m [39;49;00m[34mnew[39;49;00m[37m [39;49;00mPrintWriter(bw))[37m [39;49;00m{[37m[39;49;00m
[37m			[39;49;00mout.[36mprintln[39;49;00m([33m"[39;49;00m[33mInput area: [39;49;00m[33m"[39;49;00m[37m [39;49;00m+[37m [39;49;00ma);[37m[39;49;00m
[37m		

In [259]:
result.CWEofVulnerability


'CWE-532'

In [260]:
result.justificationOfVulenerability

'Logging sensitive information such as input data can lead to information exposure, especially if logs are accessible to unauthorized users.'

In [261]:
from catboost import CatBoostClassifier
import pandas as pd
import numpy as np
from features import calculate_features_for_files, build_dataset

In [262]:
# Load model and features
model = CatBoostClassifier()
model.load_model("stylometry_classifier.cbm")

<catboost.core.CatBoostClassifier at 0x178e46090>

In [263]:
samples = pd.read_csv("test_samples.csv", index_col='user_id')

In [264]:
model.predict(samples)

array([[59],
       [54],
       [13],
       [75],
       [67],
       [32],
       [52],
       [68]])

In [265]:
vulnerable_snippets = [
    (-1, original_snippet, -1),
    (-2, result.modifiedCode, -2),
]

validation_samples = calculate_features_for_files(vulnerable_snippets)

In [266]:
X_new = build_dataset([sample[1] for sample in validation_samples])

In [267]:
X_new

Unnamed: 0,ASTNodeBigramsTF_BinaryOperation_BinaryOperation,ASTNodeBigramsTF_BinaryOperation_Literal,ASTNodeBigramsTF_BinaryOperation_MemberReference,ASTNodeBigramsTF_BinaryOperation_MethodInvocation,ASTNodeBigramsTF_BlockStatement_LocalVariableDeclaration,ASTNodeBigramsTF_BlockStatement_StatementExpression,ASTNodeBigramsTF_CatchClause_CatchClauseParameter,ASTNodeBigramsTF_CatchClause_StatementExpression,ASTNodeBigramsTF_ClassCreator_ClassCreator,ASTNodeBigramsTF_ClassCreator_Literal,...,ln(num_private/length),ln(num_public/length),ln(num_static/length),ln(num_try/length),ln(num_void/length),newLineBeforeOpenBrace,stdDevLineLength,stdDevNumParams,tabsLeadLines,whiteSpaceRatio
0,0.142857,0.134454,0.12605,0.05042,0.033613,0.042017,,,0.016807,,...,,-6.023448,-6.716595,,-6.716595,0.0,29.326182,0.0,1.0,0.274691
1,0.112583,0.112583,0.10596,0.039735,0.02649,0.039735,0.006623,0.006623,0.013245,0.013245,...,-7.040536,-6.347389,-6.347389,-7.040536,-6.347389,0.0,26.62697,0.0,1.0,0.241304


In [268]:
missing_cols = set(samples.columns) - set(X_new.columns)
for col in missing_cols:
    X_new[col] = np.NaN

# Ensure the order of columns matches the training set
X_new = X_new[samples.columns]

# Step 4: Handle any potential NaN or infinite values
X_new = X_new.replace([np.inf, -np.inf], np.nan)

  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[col] = np.NaN
  X_new[co

In [269]:
model.predict(X_new)

array([[67],
       [67]])