@@ -12,14 +12,11 @@ import org.utbot.framework.codegen.domain.ProjectType
1212import org.utbot.framework.codegen.domain.StaticsMocking
1313import org.utbot.framework.codegen.domain.TestFramework
1414import org.utbot.framework.codegen.generator.CodeGenerator
15+ import org.utbot.framework.codegen.generator.SpringCodeGenerator
1516import org.utbot.framework.codegen.services.language.CgLanguageAssistant
1617import org.utbot.framework.codegen.tree.ututils.UtilClassKind
1718import org.utbot.framework.codegen.tree.ututils.UtilClassKind.Companion.UT_UTILS_INSTANCE_NAME
18- import org.utbot.framework.plugin.api.CodegenLanguage
19- import org.utbot.framework.plugin.api.ExecutableId
20- import org.utbot.framework.plugin.api.MockFramework
21- import org.utbot.framework.plugin.api.MockStrategyApi
22- import org.utbot.framework.plugin.api.UtMethodTestSet
19+ import org.utbot.framework.plugin.api.*
2320import org.utbot.framework.plugin.api.util.UtContext
2421import org.utbot.framework.plugin.api.util.description
2522import org.utbot.framework.plugin.api.util.id
@@ -263,21 +260,46 @@ class TestCodeGeneratorPipeline(private val testInfrastructureConfiguration: Tes
263260
264261 withUtContext(UtContext (classUnderTest.java.classLoader)) {
265262 val codeGenerator = with (testInfrastructureConfiguration) {
266- CodeGenerator (
267- classUnderTest.id,
268- projectType = ProjectType .PureJvm ,
269- generateUtilClassFile = generateUtilClassFile,
270- paramNames = params,
271- testFramework = testFramework,
272- staticsMocking = staticsMocking,
273- forceStaticMocking = forceStaticMocking,
274- generateWarningsForStaticMocking = false ,
275- codegenLanguage = codegenLanguage,
276- cgLanguageAssistant = CgLanguageAssistant .getByCodegenLanguage(codegenLanguage),
277- parameterizedTestSource = parametrizedTestSource,
278- runtimeExceptionTestsBehaviour = runtimeExceptionTestsBehaviour,
279- enableTestsTimeout = enableTestsTimeout
280- )
263+ when (projectType) {
264+ ProjectType .Spring -> SpringCodeGenerator (
265+ classUnderTest.id,
266+ projectType = ProjectType .Spring ,
267+ generateUtilClassFile = generateUtilClassFile,
268+ paramNames = params,
269+ testFramework = testFramework,
270+ staticsMocking = staticsMocking,
271+ forceStaticMocking = forceStaticMocking,
272+ generateWarningsForStaticMocking = false ,
273+ codegenLanguage = codegenLanguage,
274+ cgLanguageAssistant = CgLanguageAssistant .getByCodegenLanguage(codegenLanguage),
275+ parameterizedTestSource = parametrizedTestSource,
276+ runtimeExceptionTestsBehaviour = runtimeExceptionTestsBehaviour,
277+ enableTestsTimeout = enableTestsTimeout,
278+ codeGenerationContext = SpringApplicationContext (
279+ mockInstalled = true ,
280+ staticsMockingIsConfigured = true ,
281+ shouldUseImplementors = false ,
282+ springTestType = SpringTestType .UNIT_TEST ,
283+ springSettings = SpringSettings .AbsentSpringSettings (),
284+ )
285+ )
286+ ProjectType .PureJvm -> CodeGenerator (
287+ classUnderTest.id,
288+ projectType = ProjectType .PureJvm ,
289+ generateUtilClassFile = generateUtilClassFile,
290+ paramNames = params,
291+ testFramework = testFramework,
292+ staticsMocking = staticsMocking,
293+ forceStaticMocking = forceStaticMocking,
294+ generateWarningsForStaticMocking = false ,
295+ codegenLanguage = codegenLanguage,
296+ cgLanguageAssistant = CgLanguageAssistant .getByCodegenLanguage(codegenLanguage),
297+ parameterizedTestSource = parametrizedTestSource,
298+ runtimeExceptionTestsBehaviour = runtimeExceptionTestsBehaviour,
299+ enableTestsTimeout = enableTestsTimeout
300+ )
301+ else -> error(" Unsupported project type $projectType in code generator instantiation" )
302+ }
281303 }
282304 val testClassCustomName = " ${classUnderTest.java.simpleName} GeneratedTest"
283305
@@ -327,7 +349,10 @@ class TestCodeGeneratorPipeline(private val testInfrastructureConfiguration: Tes
327349 testFramework = TestFramework .defaultItem,
328350 codegenLanguage = configuration.language,
329351 mockFramework = MockFramework .defaultItem,
330- mockStrategy = MockStrategyApi .defaultItem,
352+ mockStrategy = when (configuration.projectType) {
353+ ProjectType .Spring -> MockStrategyApi .springDefaultItem
354+ else -> MockStrategyApi .defaultItem
355+ },
331356 staticsMocking = StaticsMocking .defaultItem,
332357 parametrizedTestSource = configuration.parametrizedTestSource,
333358 forceStaticMocking = ForceStaticMocking .defaultItem,
0 commit comments