-
Notifications
You must be signed in to change notification settings - Fork 9
/
MockInBeanTracker.java
149 lines (128 loc) · 5.69 KB
/
MockInBeanTracker.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package com.teketik.test.mockinbean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cglib.core.DefaultNamingPolicy;
import org.springframework.cglib.core.NamingPolicy;
import org.springframework.cglib.core.Predicate;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.MethodInterceptor;
import org.springframework.cglib.proxy.MethodProxy;
import java.lang.reflect.Method;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;
/**
* <p>Tracker in application context gluing beans, proxies and mocks together.
* <p>Note that this does not inject proxies in beans, it merely keeps references.
* @author Antoine Meyer
*/
class MockInBeanTracker {
static final class MockTracker {
private final Map<Object, Map<Thread, Object>> beanToMockThreadLocal = new IdentityHashMap<>();
synchronized void track(Object bean, Object mockOrSpy) {
MapUtils
.getOrPut(beanToMockThreadLocal, bean, () -> new HashMap<>())
.put(Thread.currentThread(), mockOrSpy);
}
/**
* @param bean
* @return true if there are no more mocks being tracked for {@code bean}.
*/
synchronized boolean untrack(Object bean) {
return Optional.ofNullable(beanToMockThreadLocal.get(bean))
.map(threadLocal -> {
threadLocal.remove(Thread.currentThread());
return threadLocal.isEmpty();
})
.orElse(false);
}
/**
* @param bean
* @return the mock of this {@code bean} related to this thread if more than one mock exist,
* or that mock if there is only one, or nothing.
*/
synchronized Optional<Object> getTracked(Object bean) {
return Optional.ofNullable(beanToMockThreadLocal.get(bean))
.flatMap(threadToMock -> {
if (threadToMock.size() == 1) {
/*
* If only one mock exists, we return it even if running on a different thread to give
* more flexibility to sequential tests.
* Parallel tests on the other hand need to run the mock on the same thread than the test
* as there is no other way to locate it.
*/
return Optional.of(threadToMock.values().iterator().next());
} else {
return Optional.ofNullable(threadToMock.get(Thread.currentThread()));
}
});
}
}
static final class ProxyTracker {
private final Map<Object, Object> proxyToBean = new IdentityHashMap<>();
private final Map<Object, Object> beanToProxy = new IdentityHashMap<>();
synchronized Object getByBean(Object bean) {
return beanToProxy.get(bean);
}
synchronized Object getByBeanOrMake(Object bean, Supplier<Object> proxyMaker) {
return Optional.ofNullable(beanToProxy.get(bean))
.orElseGet(() -> {
final Object proxy = proxyMaker.get();
proxyToBean.put(proxy, bean);
beanToProxy.put(bean, proxy);
return proxy;
});
}
synchronized Object getBeanByProxy(Object proxy) {
return proxyToBean.get(proxy);
}
}
private static final NamingPolicy ENHANCER_NAMING_POLICY = new DefaultNamingPolicy() {
@Override
public String getClassName(String prefix, String source, Object key, Predicate names) {
return super.getClassName(prefix, source + "MockInBean", key, names);
}
};
static boolean isProxy(Object o) {
return o.getClass().toString().contains("$$EnhancerMockInBeanByCGLIB$$");
}
private final Log logger = LogFactory.getLog(getClass());
final MockTracker mockTracker = new MockTracker();
final ProxyTracker proxyTracker = new ProxyTracker();
public Object setupProxyIfNotExisting(Object beanOrProxy) {
if (isProxy(beanOrProxy)) {
return beanOrProxy;
} else {
return proxyTracker.getByBeanOrMake(beanOrProxy, () -> {
logger.debug("Creating proxy of bean " + beanOrProxy);
return makeProxy(beanOrProxy);
});
}
}
private Object makeProxy(final Object originalBean) {
final Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(originalBean.getClass());
enhancer.setCallback(new MethodInterceptor() {
@Override
public Object intercept(Object object, Method method, Object[] parameters, MethodProxy methodProxy)
throws Throwable {
final Object target = resolveInvocationTarget(originalBean);
return method.invoke(target, parameters);
}
private Object resolveInvocationTarget(final Object originalBean) {
final Optional<Object> trackedMock = mockTracker.getTracked(originalBean);
if (trackedMock.isPresent()) {
logger.debug("Resolved mock from thread local for class " + originalBean);
return trackedMock.get();
} else {
logger.debug("No mock in thread local, using original " + originalBean);
return originalBean;
}
}
});
enhancer.setNamingPolicy(ENHANCER_NAMING_POLICY);
return enhancer.create();
}
}