Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optional relation type #37

Merged
merged 3 commits into from
Nov 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/example-prj/src/__generated__/fabbrica/index.js

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

64 changes: 36 additions & 28 deletions packages/prisma-fabbrica/src/templates/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ import { template } from "talt";

type StripCreate<T extends string> = T extends `create${infer S}` ? Uncapitalize<S> : T;

function byName<T extends { readonly name: string }>(name: string | { readonly name: string }) {
return (x: T) => x.name === (typeof name === "string" ? name : name.name);
}

function camelize(pascal: string) {
return pascal[0].toLowerCase() + pascal.slice(1);
}
Expand Down Expand Up @@ -56,6 +60,12 @@ function filterRequiredInputObjectTypeField(inputType: DMMF.InputType) {
return filterRequiredFields(inputType).filter(isInputObjectTypeField);
}

function filterBelongsToField(model: DMMF.Model, inputType: DMMF.InputType) {
return inputType.fields
.filter(isInputObjectTypeField)
.filter(field => model.fields.find(byName(field))?.isList === false);
}

function filterEnumFields(inputType: DMMF.InputType) {
return inputType.fields.filter(
field =>
Expand All @@ -65,7 +75,7 @@ function filterEnumFields(inputType: DMMF.InputType) {

function extractFirstEnumValue(enums: DMMF.SchemaEnum[], field: DMMF.SchemaArg) {
const typeName = field.inputTypes[0].type;
const found = enums.find(e => e.name === field.inputTypes[0].type);
const found = enums.find(byName(typeName));
if (!found) {
throw new Error(`Not found enum ${typeName}`);
}
Expand All @@ -88,7 +98,7 @@ export const importStatement = (specifier: string, prismaClientModuleSpecifier:
`();

export const scalarFieldType = (
modelName: string,
model: DMMF.Model,
fieldName: string,
inputType: DMMF.SchemaArgInputType,
): ts.TypeNode => {
Expand Down Expand Up @@ -116,14 +126,14 @@ export const scalarFieldType = (
// return template.typeNode`Prisma.Json`();
return ast.keywordTypeNode(ts.SyntaxKind.AnyKeyword);
default:
throw new Error(`Unknown scalar type "${inputType.type}" for ${modelName}.${fieldName} .`);
throw new Error(`Unknown scalar type "${inputType.type}" for ${model.name}.${fieldName} .`);
}
};

export const argInputType = (modelName: string, fieldName: string, inputType: DMMF.SchemaArgInputType): ts.TypeNode => {
export const argInputType = (model: DMMF.Model, fieldName: string, inputType: DMMF.SchemaArgInputType): ts.TypeNode => {
const fieldType = () => {
if (inputType.location === "scalar") {
return scalarFieldType(modelName, fieldName, inputType);
return scalarFieldType(model, fieldName, inputType);
} else if (inputType.location === "enumTypes") {
return ast.typeReferenceNode(ast.identifier(inputType.type as string));
} else if (inputType.location === "outputObjectTypes" || inputType.location === "inputObjectTypes") {
Expand All @@ -140,7 +150,7 @@ export const argInputType = (modelName: string, fieldName: string, inputType: DM
: fieldType();
};

export const modelScalarOrEnumFields = (modelName: string, inputType: DMMF.InputType) =>
export const modelScalarOrEnumFields = (model: DMMF.Model, inputType: DMMF.InputType) =>
template.statement<ts.TypeAliasDeclaration>`
type MODEL_SCALAR_OR_ENUM_FIELDS = ${() =>
ast.typeLiteralNode(
Expand All @@ -149,18 +159,16 @@ export const modelScalarOrEnumFields = (modelName: string, inputType: DMMF.Input
undefined,
field.name,
undefined,
ast.unionTypeNode(
field.inputTypes.map(childInputType => argInputType(modelName, field.name, childInputType)),
),
ast.unionTypeNode(field.inputTypes.map(childInputType => argInputType(model, field.name, childInputType))),
),
),
)}
`({
MODEL_SCALAR_OR_ENUM_FIELDS: ast.identifier(`${modelName}ScalarOrEnumFields`),
MODEL_SCALAR_OR_ENUM_FIELDS: ast.identifier(`${model.name}ScalarOrEnumFields`),
});

export const modelBelongsToRelationFactory = (fieldType: DMMF.SchemaArg, model: DMMF.Model) => {
const targetModel = model.fields.find(f => f.name === fieldType.name)!;
const targetModel = model.fields.find(byName(fieldType))!;
return template.statement<ts.TypeAliasDeclaration>`
type ${() => ast.identifier(`${model.name}${fieldType.name}Factory`)} = {
_factoryFor: ${() => ast.literalTypeNode(ast.stringLiteral(targetModel.type))};
Expand All @@ -170,7 +178,7 @@ export const modelBelongsToRelationFactory = (fieldType: DMMF.SchemaArg, model:
`();
};

export const modelFactoryDefineInput = (modelName: string, inputType: DMMF.InputType) =>
export const modelFactoryDefineInput = (model: DMMF.Model, inputType: DMMF.InputType) =>
template.statement<ts.TypeAliasDeclaration>`
type MODEL_FACTORY_DEFINE_INPUT = ${() =>
ast.typeLiteralNode(
Expand All @@ -180,16 +188,17 @@ export const modelFactoryDefineInput = (modelName: string, inputType: DMMF.Input
field.name,
!field.isRequired || isScalarOrEnumField(field) ? ast.token(ts.SyntaxKind.QuestionToken) : undefined,
ast.unionTypeNode([
...(field.isRequired && isInputObjectTypeField(field)
? [ast.typeReferenceNode(ast.identifier(`${modelName}${field.name}Factory`))]
...((field.isRequired || model.fields.find(byName(field))!.isList === false) &&
isInputObjectTypeField(field)
? [ast.typeReferenceNode(ast.identifier(`${model.name}${field.name}Factory`))]
: []),
...field.inputTypes.map(childInputType => argInputType(modelName, field.name, childInputType)),
...field.inputTypes.map(childInputType => argInputType(model, field.name, childInputType)),
]),
),
),
)};
`({
MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${modelName}FactoryDefineInput`),
MODEL_FACTORY_DEFINE_INPUT: ast.identifier(`${model.name}FactoryDefineInput`),
});

export const modelFactoryDefineOptions = (modelName: string, isOpionalDefaultData: boolean) =>
Expand All @@ -212,12 +221,13 @@ export const modelFactoryDefineOptions = (modelName: string, isOpionalDefaultDat
});

export const isModelAssociationFactory = (fieldType: DMMF.SchemaArg, model: DMMF.Model) => {
const targetModel = model.fields.find(f => f.name === fieldType.name)!;
const targetModel = model.fields.find(byName(fieldType))!;
return template.statement<ts.FunctionDeclaration>`
function ${() => ast.identifier(`is${model.name}${fieldType.name}Factory`)}(
x: MODEL_BELONGS_TO_RELATION_FACTORY | ${() => argInputType(model.name, fieldType.name, fieldType.inputTypes[0])}
x: MODEL_BELONGS_TO_RELATION_FACTORY | ${() =>
argInputType(model, fieldType.name, fieldType.inputTypes[0])} | undefined
): x is MODEL_BELONGS_TO_RELATION_FACTORY {
return (x as any)._factoryFor === ${() => ast.stringLiteral(targetModel.type)};
return (x as any)?._factoryFor === ${() => ast.stringLiteral(targetModel.type)};
}
`({
MODEL_BELONGS_TO_RELATION_FACTORY: ast.typeReferenceNode(`${model.name}${fieldType.name}Factory`),
Expand All @@ -237,10 +247,10 @@ export const autoGenerateModelScalarsOrEnumsFieldArgs = (
MODEL_NAME: ast.stringLiteral(model.name),
FIELD_NAME: ast.stringLiteral(field.name),
IS_ID:
model.fields.find(f => f.name === field.name)!.isId || model.primaryKey?.fields.includes(field.name)
model.fields.find(byName(field))!.isId || model.primaryKey?.fields.includes(field.name)
? ast.true()
: ast.false(),
IS_UNIQUE: model.fields.find(f => f.name === field.name)!.isUnique ? ast.true() : ast.false(),
IS_UNIQUE: model.fields.find(byName(field))!.isUnique ? ast.true() : ast.false(),
})
: ast.stringLiteral(extractFirstEnumValue(enums, field));

Expand Down Expand Up @@ -279,7 +289,7 @@ export const defineModelFactoryInernal = (model: DMMF.Model, inputType: DMMF.Inp
const defaultData= await resolveValue(defaultDataResolver ?? {});
const defaultAssociations = ${() =>
ast.objectLiteralExpression(
filterRequiredInputObjectTypeField(inputType).map(field =>
filterBelongsToField(model, inputType).map(field =>
ast.propertyAssignment(
field.name,
template.expression`
Expand Down Expand Up @@ -373,15 +383,13 @@ export function getSourceFile({
...document.datamodel.models
.map(model => ({ model, createInputType: findPrsimaCreateInputTypeFromModelName(document, model.name) }))
.flatMap(({ model, createInputType }) => [
modelScalarOrEnumFields(model.name, createInputType),
...filterRequiredInputObjectTypeField(createInputType).map(fieldType =>
modelScalarOrEnumFields(model, createInputType),
...filterBelongsToField(model, createInputType).map(fieldType =>
modelBelongsToRelationFactory(fieldType, model),
),
modelFactoryDefineInput(model.name, createInputType),
modelFactoryDefineInput(model, createInputType),
modelFactoryDefineOptions(model.name, filterRequiredInputObjectTypeField(createInputType).length === 0),
...filterRequiredInputObjectTypeField(createInputType).map(fieldType =>
isModelAssociationFactory(fieldType, model),
),
...filterBelongsToField(model, createInputType).map(fieldType => isModelAssociationFactory(fieldType, model)),
autoGenerateModelScalarsOrEnums(model, createInputType, document.schema.enumTypes.model ?? []),
defineModelFactoryInernal(model, createInputType),
defineModelFactory(model.name, createInputType),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ describe(modelScalarOrEnumFields, () => {
});
const inputType = findPrsimaCreateInputTypeFromModelName(dmmf, "TestModel");
const source = template.statement(expected)();
expect(printNode(modelScalarOrEnumFields("TestModel", inputType))).toBe(printNode(source).trim());
expect(printNode(modelScalarOrEnumFields(dmmf.datamodel.models[0], inputType))).toBe(printNode(source).trim());
});

it("does not generate for nullable field", async () => {
Expand All @@ -70,6 +70,6 @@ describe(modelScalarOrEnumFields, () => {
id: number;
}
`();
expect(printNode(modelScalarOrEnumFields("TestModel", inputType))).toBe(printNode(expected).trim());
expect(printNode(modelScalarOrEnumFields(dmmf.datamodel.models[0], inputType))).toBe(printNode(expected).trim());
});
});

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@ export const ReviewFactory = defineReviewFactory({
reviewer: UserFactory,
},
});
export const PostFactoryAlt = definePostFactory({
defaultData: {
author: UserFactory,
},
});

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.