Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: custom model #524

Merged
merged 9 commits into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static void solveDubboRequest(Object handler, Object channel, Object requ
put("requestURL", u.getScheme() + "://" + u.getAuthority() + u.getPath());
put("requestURI", u.getPath());
put("queryString", "");
put("method", "DUBOO");
put("method", "DUBBO");
put("protocol", "DUBBO");
put("scheme", u.getScheme());
put("contextPath", "");
Expand All @@ -43,15 +43,14 @@ public static void solveDubboRequest(Object handler, Object channel, Object requ
}



public static void collectDubboRequestSource(Object handler, Object invocation, String methodName,
Object[] arguments, Map<String, ?> headers,
String hookClass, String hookMethod, String hookSign,
AtomicInteger invokeIdSequencer) {
if (arguments == null || arguments.length == 0) {
return;
}
Map <String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
Map<String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
if (requestMeta == null) {
return;
}
Expand All @@ -70,7 +69,7 @@ public static void collectDubboRequestSource(Object handler, Object invocation,
tgt.add(new TaintPosition("P1"));

SourceNode sourceNode = new SourceNode(src, tgt, null);
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0);
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0, true);

Map<String, String> sHeaders = new HashMap<String, String>();
if (headers != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,10 @@ private static boolean trackTarget(MethodEvent event, SourceNode sourceNode) {
return false;
}

TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0);
// @TODO: hook json serializer for custom model
handlerCustomModel(event, sourceNode);
TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0, false);
return true;
}

/**
* todo: 处理过程和结果需要细化
*
* @param event MethodEvent
*/
public static void handlerCustomModel(MethodEvent event, SourceNode sourceNode) {
if (!"getSession".equals(event.getMethodName())) {
Set<Object> modelValues = TaintPoolUtils.parseCustomModel(event.returnInstance);
for (Object modelValue : modelValues) {
TaintPoolUtils.trackObject(event, sourceNode, modelValue, 0);
}
}
}

private static boolean allowCall(MethodEvent event) {
boolean allowed = true;
if (METHOD_OF_GETATTRIBUTE.equals(event.getMethodName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class DubboService {
public static void solveSyncInvoke(MethodEvent event, Object invocation, String url, Map<String, String> headers,
AtomicInteger invokeIdSequencer) {
try {
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0);
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0, false);
boolean hasTaint = false;
int sourceLen = 0;
if (!event.getSourceHashes().isEmpty()) {
Expand All @@ -26,7 +26,7 @@ public static void solveSyncInvoke(MethodEvent event, Object invocation, String

if (headers != null && headers.size() > 0) {
hasTaint = false;
TaintPoolUtils.trackObject(event, null, headers, 0);
TaintPoolUtils.trackObject(event, null, headers, 0, false);
if (event.getSourceHashes().size() > sourceLen) {
hasTaint = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static void solveSyncInvoke(MethodEvent event, AtomicInteger invokeIdSequ

// get args
Object args = event.parameterInstances[0];
TaintPoolUtils.trackObject(event, null, args, 0);
TaintPoolUtils.trackObject(event, null, args, 0, true);

boolean hasTaint = false;
if (!event.getSourceHashes().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package io.dongtai.iast.core.utils;

import io.dongtai.log.DongTaiLog;

import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.*;

/**
Expand All @@ -11,15 +16,15 @@ public class ReflectUtils {

public static Field getFieldFromClass(Class<?> cls, String fieldName) throws NoSuchFieldException {
Field field = cls.getDeclaredField(fieldName);
field.setAccessible(true);
setAccessible(field);
return field;
}

public static Field getDeclaredFieldFromClassByName(Class<?> cls, String fieldName) {
Field[] declaredFields = cls.getDeclaredFields();
for (Field field : declaredFields) {
if (fieldName.equals(field.getName())) {
field.setAccessible(true);
setAccessible(field);
return field;
}
}
Expand Down Expand Up @@ -55,8 +60,18 @@ public static Method getPublicMethodFromClass(Class<?> cls, String method) throw

public static Method getPublicMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) throws NoSuchMethodException {
Method method = cls.getMethod(methodName, parameterTypes);
method.setAccessible(true);
return method;
return getSecurityPublicMethod(method);
}

public static Method getSecurityPublicMethod(Method method) throws NoSuchMethodException {
if (hasNotSecurityManager()) {
setAccessible(method);
return method;
}
return AccessController.doPrivileged((PrivilegedAction<Method>) () -> {
setAccessible(method);
return method;
});
}

public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) {
Expand All @@ -66,8 +81,11 @@ public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName,
}
for (Method method : methods) {
if (methodName.equals(method.getName()) && Arrays.equals(parameterTypes, method.getParameterTypes())) {
method.setAccessible(true);
return method;
try {
return getSecurityPublicMethod(method);
} catch (NoSuchMethodException e) {
e.printStackTrace();
}
}
}
return null;
Expand Down Expand Up @@ -137,13 +155,47 @@ public static List<Class<?>> getAllInterfaces(Class<?> cls) {
private static void getAllInterfaces(Class<?> cls, List<Class<?>> interfaceList) {
while (cls != null) {
Class<?>[] interfaces = cls.getInterfaces();
for (int i = 0; i < interfaces.length; i++) {
if (!interfaceList.contains(interfaces[i])) {
interfaceList.add(interfaces[i]);
getAllInterfaces(interfaces[i], interfaceList);
for (Class<?> anInterface : interfaces) {
if (!interfaceList.contains(anInterface)) {
interfaceList.add(anInterface);
getAllInterfaces(anInterface, interfaceList);
}
}
cls = cls.getSuperclass();
}
}

public static Field[] getDeclaredFieldsSecurity(Class<?> cls) {
Objects.requireNonNull(cls);
if (hasNotSecurityManager()) {
return getDeclaredFields(cls);
}
return (Field[]) AccessController.doPrivileged((PrivilegedAction<Field[]>) () -> {
return getDeclaredFields(cls);
});
}

private static Field[] getDeclaredFields(Class<?> cls) {
Field[] declaredFields = cls.getDeclaredFields();
for (Field field : declaredFields) {
setAccessible(field);
}
return declaredFields;
}

private static boolean hasNotSecurityManager() {
return System.getSecurityManager() == null;
}

private static void setAccessible(AccessibleObject accessibleObject) {
try{
if (!accessibleObject.isAccessible()) {
accessibleObject.setAccessible(true);
}
} catch (Throwable e){
DongTaiLog.debug("setAccessible failed: {}, {}",
e.getMessage(), e.getCause() != null ? e.getCause().getMessage() : "");
}

}
}
Loading
Loading