Skip to content

Commit

Permalink
Merge pull request #524 from Nizernizer/fix/custom-model
Browse files Browse the repository at this point in the history
fix: custom model
  • Loading branch information
lostsnow committed May 26, 2023
2 parents 78be5f5 + 67a115d commit e7e1009
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 193 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static void solveDubboRequest(Object handler, Object channel, Object requ
put("requestURL", u.getScheme() + "://" + u.getAuthority() + u.getPath());
put("requestURI", u.getPath());
put("queryString", "");
put("method", "DUBOO");
put("method", "DUBBO");
put("protocol", "DUBBO");
put("scheme", u.getScheme());
put("contextPath", "");
Expand All @@ -43,15 +43,14 @@ public static void solveDubboRequest(Object handler, Object channel, Object requ
}



public static void collectDubboRequestSource(Object handler, Object invocation, String methodName,
Object[] arguments, Map<String, ?> headers,
String hookClass, String hookMethod, String hookSign,
AtomicInteger invokeIdSequencer) {
if (arguments == null || arguments.length == 0) {
return;
}
Map <String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
Map<String, Object> requestMeta = EngineManager.REQUEST_CONTEXT.get();
if (requestMeta == null) {
return;
}
Expand All @@ -70,7 +69,7 @@ public static void collectDubboRequestSource(Object handler, Object invocation,
tgt.add(new TaintPosition("P1"));

SourceNode sourceNode = new SourceNode(src, tgt, null);
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0);
TaintPoolUtils.trackObject(event, sourceNode, arguments, 0, true);

Map<String, String> sHeaders = new HashMap<String, String>();
if (headers != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,10 @@ private static boolean trackTarget(MethodEvent event, SourceNode sourceNode) {
return false;
}

TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0);
// @TODO: hook json serializer for custom model
handlerCustomModel(event, sourceNode);
TaintPoolUtils.trackObject(event, sourceNode, event.returnInstance, 0, false);
return true;
}

/**
* todo: 处理过程和结果需要细化
*
* @param event MethodEvent
*/
public static void handlerCustomModel(MethodEvent event, SourceNode sourceNode) {
if (!"getSession".equals(event.getMethodName())) {
Set<Object> modelValues = TaintPoolUtils.parseCustomModel(event.returnInstance);
for (Object modelValue : modelValues) {
TaintPoolUtils.trackObject(event, sourceNode, modelValue, 0);
}
}
}

private static boolean allowCall(MethodEvent event) {
boolean allowed = true;
if (METHOD_OF_GETATTRIBUTE.equals(event.getMethodName())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class DubboService {
public static void solveSyncInvoke(MethodEvent event, Object invocation, String url, Map<String, String> headers,
AtomicInteger invokeIdSequencer) {
try {
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0);
TaintPoolUtils.trackObject(event, null, event.parameterInstances, 0, false);
boolean hasTaint = false;
int sourceLen = 0;
if (!event.getSourceHashes().isEmpty()) {
Expand All @@ -26,7 +26,7 @@ public static void solveSyncInvoke(MethodEvent event, Object invocation, String

if (headers != null && headers.size() > 0) {
hasTaint = false;
TaintPoolUtils.trackObject(event, null, headers, 0);
TaintPoolUtils.trackObject(event, null, headers, 0, false);
if (event.getSourceHashes().size() > sourceLen) {
hasTaint = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public static void solveSyncInvoke(MethodEvent event, AtomicInteger invokeIdSequ

// get args
Object args = event.parameterInstances[0];
TaintPoolUtils.trackObject(event, null, args, 0);
TaintPoolUtils.trackObject(event, null, args, 0, true);

boolean hasTaint = false;
if (!event.getSourceHashes().isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
package io.dongtai.iast.core.utils;

import io.dongtai.log.DongTaiLog;

import java.lang.reflect.AccessibleObject;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.*;

/**
Expand All @@ -11,15 +16,15 @@ public class ReflectUtils {

public static Field getFieldFromClass(Class<?> cls, String fieldName) throws NoSuchFieldException {
Field field = cls.getDeclaredField(fieldName);
field.setAccessible(true);
setAccessible(field);
return field;
}

public static Field getDeclaredFieldFromClassByName(Class<?> cls, String fieldName) {
Field[] declaredFields = cls.getDeclaredFields();
for (Field field : declaredFields) {
if (fieldName.equals(field.getName())) {
field.setAccessible(true);
setAccessible(field);
return field;
}
}
Expand Down Expand Up @@ -55,8 +60,18 @@ public static Method getPublicMethodFromClass(Class<?> cls, String method) throw

public static Method getPublicMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) throws NoSuchMethodException {
Method method = cls.getMethod(methodName, parameterTypes);
method.setAccessible(true);
return method;
return getSecurityPublicMethod(method);
}

public static Method getSecurityPublicMethod(Method method) throws NoSuchMethodException {
if (hasNotSecurityManager()) {
setAccessible(method);
return method;
}
return AccessController.doPrivileged((PrivilegedAction<Method>) () -> {
setAccessible(method);
return method;
});
}

public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName, Class<?>[] parameterTypes) {
Expand All @@ -66,8 +81,11 @@ public static Method getDeclaredMethodFromClass(Class<?> cls, String methodName,
}
for (Method method : methods) {
if (methodName.equals(method.getName()) && Arrays.equals(parameterTypes, method.getParameterTypes())) {
method.setAccessible(true);
return method;
try {
return getSecurityPublicMethod(method);
} catch (NoSuchMethodException e) {
e.printStackTrace();
}
}
}
return null;
Expand Down Expand Up @@ -137,13 +155,47 @@ public static List<Class<?>> getAllInterfaces(Class<?> cls) {
private static void getAllInterfaces(Class<?> cls, List<Class<?>> interfaceList) {
while (cls != null) {
Class<?>[] interfaces = cls.getInterfaces();
for (int i = 0; i < interfaces.length; i++) {
if (!interfaceList.contains(interfaces[i])) {
interfaceList.add(interfaces[i]);
getAllInterfaces(interfaces[i], interfaceList);
for (Class<?> anInterface : interfaces) {
if (!interfaceList.contains(anInterface)) {
interfaceList.add(anInterface);
getAllInterfaces(anInterface, interfaceList);
}
}
cls = cls.getSuperclass();
}
}

public static Field[] getDeclaredFieldsSecurity(Class<?> cls) {
Objects.requireNonNull(cls);
if (hasNotSecurityManager()) {
return getDeclaredFields(cls);
}
return (Field[]) AccessController.doPrivileged((PrivilegedAction<Field[]>) () -> {
return getDeclaredFields(cls);
});
}

private static Field[] getDeclaredFields(Class<?> cls) {
Field[] declaredFields = cls.getDeclaredFields();
for (Field field : declaredFields) {
setAccessible(field);
}
return declaredFields;
}

private static boolean hasNotSecurityManager() {
return System.getSecurityManager() == null;
}

private static void setAccessible(AccessibleObject accessibleObject) {
try{
if (!accessibleObject.isAccessible()) {
accessibleObject.setAccessible(true);
}
} catch (Throwable e){
DongTaiLog.debug("setAccessible failed: {}, {}",
e.getMessage(), e.getCause() != null ? e.getCause().getMessage() : "");
}

}
}
Loading

0 comments on commit e7e1009

Please sign in to comment.