diff --git a/Dockerfile b/Dockerfile index 7befb46498a..d556008b795 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,8 +20,8 @@ FROM ${NODE_IMAGE} AS frontend-builder WORKDIR /app/frontend -# Install pnpm -RUN corepack enable && corepack prepare pnpm@latest --activate +# Install pnpm (pinned to v9 to match CI and keep builds reproducible) +RUN corepack enable && corepack prepare pnpm@9 --activate # Install dependencies first (better caching) COPY frontend/package.json frontend/pnpm-lock.yaml ./ diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 9e7e837e953..74799d81f39 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.126 +0.1.127 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index a550118139e..5a190c3377b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -81,7 +81,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } totpCache := repository.NewTotpCache(redisClient) totpService := service.NewTotpService(userRepository, secretEncryptor, totpCache, settingService, emailService, emailQueueService) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService) + userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) + userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) + userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService, redeemService, totpService, userAttributeService) userHandler := handler.NewUserHandler(userService, authService, emailService, emailCache, affiliateService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) @@ -198,7 +201,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { registry := payment.ProvideRegistry() defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository, affiliateService) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService, userAttributeService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -211,9 +214,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService) - userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) - userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) - userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) errorPassthroughRepository := repository.NewErrorPassthroughRepository(client) errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 525ff0927c1..f98761ea295 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -1120,6 +1120,7 @@ var ( {Name: "used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "notes", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "expires_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "validity_days", Type: field.TypeInt, Default: 30}, {Name: "group_id", Type: field.TypeInt64, Nullable: true}, {Name: "used_by", Type: field.TypeInt64, Nullable: true}, @@ -1132,13 +1133,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "redeem_codes_groups_redeem_codes", - Columns: []*schema.Column{RedeemCodesColumns[9]}, + Columns: []*schema.Column{RedeemCodesColumns[10]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "redeem_codes_users_redeem_codes", - Columns: []*schema.Column{RedeemCodesColumns[10]}, + Columns: []*schema.Column{RedeemCodesColumns[11]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.SetNull, }, @@ -1152,12 +1153,17 @@ var ( { Name: "redeemcode_used_by", Unique: false, - Columns: []*schema.Column{RedeemCodesColumns[10]}, + Columns: []*schema.Column{RedeemCodesColumns[11]}, }, { Name: "redeemcode_group_id", Unique: false, - Columns: []*schema.Column{RedeemCodesColumns[9]}, + Columns: []*schema.Column{RedeemCodesColumns[10]}, + }, + { + Name: "redeemcode_expires_at", + Unique: false, + Columns: []*schema.Column{RedeemCodesColumns[8]}, }, }, } @@ -1318,6 +1324,10 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "image_input_size", Type: field.TypeString, Nullable: true, Size: 32}, + {Name: "image_output_size", Type: field.TypeString, Nullable: true, Size: 32}, + {Name: "image_size_source", Type: field.TypeString, Nullable: true, Size: 16}, + {Name: "image_size_breakdown", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, @@ -1334,31 +1344,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[38]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[35]}, + Columns: []*schema.Column{UsageLogsColumns[39]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[36]}, + Columns: []*schema.Column{UsageLogsColumns[40]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[37]}, + Columns: []*schema.Column{UsageLogsColumns[41]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -1367,32 +1377,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[36]}, + Columns: []*schema.Column{UsageLogsColumns[40]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33]}, + Columns: []*schema.Column{UsageLogsColumns[37]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[34]}, + Columns: []*schema.Column{UsageLogsColumns[38]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[35]}, + Columns: []*schema.Column{UsageLogsColumns[39]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[37]}, + Columns: []*schema.Column{UsageLogsColumns[41]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[36]}, }, { Name: "usagelog_model", @@ -1412,17 +1422,17 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[36], UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[40], UsageLogsColumns[36]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[33], UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[37], UsageLogsColumns[36]}, }, { Name: "usagelog_group_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[35], UsageLogsColumns[32]}, + Columns: []*schema.Column{UsageLogsColumns[39], UsageLogsColumns[36]}, }, }, } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 13f6193d89a..45c56314e84 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -28602,6 +28602,7 @@ type RedeemCodeMutation struct { used_at *time.Time notes *string created_at *time.Time + expires_at *time.Time validity_days *int addvalidity_days *int clearedFields map[string]struct{} @@ -29059,6 +29060,55 @@ func (m *RedeemCodeMutation) ResetCreatedAt() { m.created_at = nil } +// SetExpiresAt sets the "expires_at" field. +func (m *RedeemCodeMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *RedeemCodeMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the RedeemCode entity. +// If the RedeemCode object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *RedeemCodeMutation) OldExpiresAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (m *RedeemCodeMutation) ClearExpiresAt() { + m.expires_at = nil + m.clearedFields[redeemcode.FieldExpiresAt] = struct{}{} +} + +// ExpiresAtCleared returns if the "expires_at" field was cleared in this mutation. +func (m *RedeemCodeMutation) ExpiresAtCleared() bool { + _, ok := m.clearedFields[redeemcode.FieldExpiresAt] + return ok +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *RedeemCodeMutation) ResetExpiresAt() { + m.expires_at = nil + delete(m.clearedFields, redeemcode.FieldExpiresAt) +} + // SetGroupID sets the "group_id" field. func (m *RedeemCodeMutation) SetGroupID(i int64) { m.group = &i @@ -29265,7 +29315,7 @@ func (m *RedeemCodeMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *RedeemCodeMutation) Fields() []string { - fields := make([]string, 0, 10) + fields := make([]string, 0, 11) if m.code != nil { fields = append(fields, redeemcode.FieldCode) } @@ -29290,6 +29340,9 @@ func (m *RedeemCodeMutation) Fields() []string { if m.created_at != nil { fields = append(fields, redeemcode.FieldCreatedAt) } + if m.expires_at != nil { + fields = append(fields, redeemcode.FieldExpiresAt) + } if m.group != nil { fields = append(fields, redeemcode.FieldGroupID) } @@ -29320,6 +29373,8 @@ func (m *RedeemCodeMutation) Field(name string) (ent.Value, bool) { return m.Notes() case redeemcode.FieldCreatedAt: return m.CreatedAt() + case redeemcode.FieldExpiresAt: + return m.ExpiresAt() case redeemcode.FieldGroupID: return m.GroupID() case redeemcode.FieldValidityDays: @@ -29349,6 +29404,8 @@ func (m *RedeemCodeMutation) OldField(ctx context.Context, name string) (ent.Val return m.OldNotes(ctx) case redeemcode.FieldCreatedAt: return m.OldCreatedAt(ctx) + case redeemcode.FieldExpiresAt: + return m.OldExpiresAt(ctx) case redeemcode.FieldGroupID: return m.OldGroupID(ctx) case redeemcode.FieldValidityDays: @@ -29418,6 +29475,13 @@ func (m *RedeemCodeMutation) SetField(name string, value ent.Value) error { } m.SetCreatedAt(v) return nil + case redeemcode.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil case redeemcode.FieldGroupID: v, ok := value.(int64) if !ok { @@ -29498,6 +29562,9 @@ func (m *RedeemCodeMutation) ClearedFields() []string { if m.FieldCleared(redeemcode.FieldNotes) { fields = append(fields, redeemcode.FieldNotes) } + if m.FieldCleared(redeemcode.FieldExpiresAt) { + fields = append(fields, redeemcode.FieldExpiresAt) + } if m.FieldCleared(redeemcode.FieldGroupID) { fields = append(fields, redeemcode.FieldGroupID) } @@ -29524,6 +29591,9 @@ func (m *RedeemCodeMutation) ClearField(name string) error { case redeemcode.FieldNotes: m.ClearNotes() return nil + case redeemcode.FieldExpiresAt: + m.ClearExpiresAt() + return nil case redeemcode.FieldGroupID: m.ClearGroupID() return nil @@ -29559,6 +29629,9 @@ func (m *RedeemCodeMutation) ResetField(name string) error { case redeemcode.FieldCreatedAt: m.ResetCreatedAt() return nil + case redeemcode.FieldExpiresAt: + m.ResetExpiresAt() + return nil case redeemcode.FieldGroupID: m.ResetGroupID() return nil @@ -34260,6 +34333,10 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + image_input_size *string + image_output_size *string + image_size_source *string + image_size_breakdown *map[string]int cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} @@ -36202,6 +36279,202 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetImageInputSize sets the "image_input_size" field. +func (m *UsageLogMutation) SetImageInputSize(s string) { + m.image_input_size = &s +} + +// ImageInputSize returns the value of the "image_input_size" field in the mutation. +func (m *UsageLogMutation) ImageInputSize() (r string, exists bool) { + v := m.image_input_size + if v == nil { + return + } + return *v, true +} + +// OldImageInputSize returns the old "image_input_size" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldImageInputSize(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageInputSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageInputSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageInputSize: %w", err) + } + return oldValue.ImageInputSize, nil +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (m *UsageLogMutation) ClearImageInputSize() { + m.image_input_size = nil + m.clearedFields[usagelog.FieldImageInputSize] = struct{}{} +} + +// ImageInputSizeCleared returns if the "image_input_size" field was cleared in this mutation. +func (m *UsageLogMutation) ImageInputSizeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldImageInputSize] + return ok +} + +// ResetImageInputSize resets all changes to the "image_input_size" field. +func (m *UsageLogMutation) ResetImageInputSize() { + m.image_input_size = nil + delete(m.clearedFields, usagelog.FieldImageInputSize) +} + +// SetImageOutputSize sets the "image_output_size" field. +func (m *UsageLogMutation) SetImageOutputSize(s string) { + m.image_output_size = &s +} + +// ImageOutputSize returns the value of the "image_output_size" field in the mutation. +func (m *UsageLogMutation) ImageOutputSize() (r string, exists bool) { + v := m.image_output_size + if v == nil { + return + } + return *v, true +} + +// OldImageOutputSize returns the old "image_output_size" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldImageOutputSize(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageOutputSize is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageOutputSize requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageOutputSize: %w", err) + } + return oldValue.ImageOutputSize, nil +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (m *UsageLogMutation) ClearImageOutputSize() { + m.image_output_size = nil + m.clearedFields[usagelog.FieldImageOutputSize] = struct{}{} +} + +// ImageOutputSizeCleared returns if the "image_output_size" field was cleared in this mutation. +func (m *UsageLogMutation) ImageOutputSizeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldImageOutputSize] + return ok +} + +// ResetImageOutputSize resets all changes to the "image_output_size" field. +func (m *UsageLogMutation) ResetImageOutputSize() { + m.image_output_size = nil + delete(m.clearedFields, usagelog.FieldImageOutputSize) +} + +// SetImageSizeSource sets the "image_size_source" field. +func (m *UsageLogMutation) SetImageSizeSource(s string) { + m.image_size_source = &s +} + +// ImageSizeSource returns the value of the "image_size_source" field in the mutation. +func (m *UsageLogMutation) ImageSizeSource() (r string, exists bool) { + v := m.image_size_source + if v == nil { + return + } + return *v, true +} + +// OldImageSizeSource returns the old "image_size_source" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldImageSizeSource(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageSizeSource is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageSizeSource requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageSizeSource: %w", err) + } + return oldValue.ImageSizeSource, nil +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (m *UsageLogMutation) ClearImageSizeSource() { + m.image_size_source = nil + m.clearedFields[usagelog.FieldImageSizeSource] = struct{}{} +} + +// ImageSizeSourceCleared returns if the "image_size_source" field was cleared in this mutation. +func (m *UsageLogMutation) ImageSizeSourceCleared() bool { + _, ok := m.clearedFields[usagelog.FieldImageSizeSource] + return ok +} + +// ResetImageSizeSource resets all changes to the "image_size_source" field. +func (m *UsageLogMutation) ResetImageSizeSource() { + m.image_size_source = nil + delete(m.clearedFields, usagelog.FieldImageSizeSource) +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (m *UsageLogMutation) SetImageSizeBreakdown(value map[string]int) { + m.image_size_breakdown = &value +} + +// ImageSizeBreakdown returns the value of the "image_size_breakdown" field in the mutation. +func (m *UsageLogMutation) ImageSizeBreakdown() (r map[string]int, exists bool) { + v := m.image_size_breakdown + if v == nil { + return + } + return *v, true +} + +// OldImageSizeBreakdown returns the old "image_size_breakdown" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldImageSizeBreakdown(ctx context.Context) (v map[string]int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldImageSizeBreakdown is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldImageSizeBreakdown requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldImageSizeBreakdown: %w", err) + } + return oldValue.ImageSizeBreakdown, nil +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (m *UsageLogMutation) ClearImageSizeBreakdown() { + m.image_size_breakdown = nil + m.clearedFields[usagelog.FieldImageSizeBreakdown] = struct{}{} +} + +// ImageSizeBreakdownCleared returns if the "image_size_breakdown" field was cleared in this mutation. +func (m *UsageLogMutation) ImageSizeBreakdownCleared() bool { + _, ok := m.clearedFields[usagelog.FieldImageSizeBreakdown] + return ok +} + +// ResetImageSizeBreakdown resets all changes to the "image_size_breakdown" field. +func (m *UsageLogMutation) ResetImageSizeBreakdown() { + m.image_size_breakdown = nil + delete(m.clearedFields, usagelog.FieldImageSizeBreakdown) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { m.cache_ttl_overridden = &b @@ -36443,7 +36716,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 37) + fields := make([]string, 0, 41) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -36549,6 +36822,18 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.image_input_size != nil { + fields = append(fields, usagelog.FieldImageInputSize) + } + if m.image_output_size != nil { + fields = append(fields, usagelog.FieldImageOutputSize) + } + if m.image_size_source != nil { + fields = append(fields, usagelog.FieldImageSizeSource) + } + if m.image_size_breakdown != nil { + fields = append(fields, usagelog.FieldImageSizeBreakdown) + } if m.cache_ttl_overridden != nil { fields = append(fields, usagelog.FieldCacheTTLOverridden) } @@ -36633,6 +36918,14 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldImageInputSize: + return m.ImageInputSize() + case usagelog.FieldImageOutputSize: + return m.ImageOutputSize() + case usagelog.FieldImageSizeSource: + return m.ImageSizeSource() + case usagelog.FieldImageSizeBreakdown: + return m.ImageSizeBreakdown() case usagelog.FieldCacheTTLOverridden: return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: @@ -36716,6 +37009,14 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldImageInputSize: + return m.OldImageInputSize(ctx) + case usagelog.FieldImageOutputSize: + return m.OldImageOutputSize(ctx) + case usagelog.FieldImageSizeSource: + return m.OldImageSizeSource(ctx) + case usagelog.FieldImageSizeBreakdown: + return m.OldImageSizeBreakdown(ctx) case usagelog.FieldCacheTTLOverridden: return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: @@ -36974,6 +37275,34 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldImageInputSize: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageInputSize(v) + return nil + case usagelog.FieldImageOutputSize: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageOutputSize(v) + return nil + case usagelog.FieldImageSizeSource: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageSizeSource(v) + return nil + case usagelog.FieldImageSizeBreakdown: + v, ok := value.(map[string]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetImageSizeBreakdown(v) + return nil case usagelog.FieldCacheTTLOverridden: v, ok := value.(bool) if !ok { @@ -37291,6 +37620,18 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldImageInputSize) { + fields = append(fields, usagelog.FieldImageInputSize) + } + if m.FieldCleared(usagelog.FieldImageOutputSize) { + fields = append(fields, usagelog.FieldImageOutputSize) + } + if m.FieldCleared(usagelog.FieldImageSizeSource) { + fields = append(fields, usagelog.FieldImageSizeSource) + } + if m.FieldCleared(usagelog.FieldImageSizeBreakdown) { + fields = append(fields, usagelog.FieldImageSizeBreakdown) + } return fields } @@ -37347,6 +37688,18 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldImageInputSize: + m.ClearImageInputSize() + return nil + case usagelog.FieldImageOutputSize: + m.ClearImageOutputSize() + return nil + case usagelog.FieldImageSizeSource: + m.ClearImageSizeSource() + return nil + case usagelog.FieldImageSizeBreakdown: + m.ClearImageSizeBreakdown() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -37460,6 +37813,18 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldImageInputSize: + m.ResetImageInputSize() + return nil + case usagelog.FieldImageOutputSize: + m.ResetImageOutputSize() + return nil + case usagelog.FieldImageSizeSource: + m.ResetImageSizeSource() + return nil + case usagelog.FieldImageSizeBreakdown: + m.ResetImageSizeBreakdown() + return nil case usagelog.FieldCacheTTLOverridden: m.ResetCacheTTLOverridden() return nil diff --git a/backend/ent/redeemcode.go b/backend/ent/redeemcode.go index 24cd423164a..34b55f6be00 100644 --- a/backend/ent/redeemcode.go +++ b/backend/ent/redeemcode.go @@ -35,6 +35,8 @@ type RedeemCode struct { Notes *string `json:"notes,omitempty"` // CreatedAt holds the value of the "created_at" field. CreatedAt time.Time `json:"created_at,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt *time.Time `json:"expires_at,omitempty"` // GroupID holds the value of the "group_id" field. GroupID *int64 `json:"group_id,omitempty"` // ValidityDays holds the value of the "validity_days" field. @@ -89,7 +91,7 @@ func (*RedeemCode) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case redeemcode.FieldCode, redeemcode.FieldType, redeemcode.FieldStatus, redeemcode.FieldNotes: values[i] = new(sql.NullString) - case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt: + case redeemcode.FieldUsedAt, redeemcode.FieldCreatedAt, redeemcode.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -163,6 +165,13 @@ func (_m *RedeemCode) assignValues(columns []string, values []any) error { } else if value.Valid { _m.CreatedAt = value.Time } + case redeemcode.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = new(time.Time) + *_m.ExpiresAt = value.Time + } case redeemcode.FieldGroupID: if value, ok := values[i].(*sql.NullInt64); !ok { return fmt.Errorf("unexpected type %T for field group_id", values[i]) @@ -252,6 +261,11 @@ func (_m *RedeemCode) String() string { builder.WriteString("created_at=") builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) builder.WriteString(", ") + if v := _m.ExpiresAt; v != nil { + builder.WriteString("expires_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") if v := _m.GroupID; v != nil { builder.WriteString("group_id=") builder.WriteString(fmt.Sprintf("%v", *v)) diff --git a/backend/ent/redeemcode/redeemcode.go b/backend/ent/redeemcode/redeemcode.go index b010476c76c..c7b30c15d19 100644 --- a/backend/ent/redeemcode/redeemcode.go +++ b/backend/ent/redeemcode/redeemcode.go @@ -30,6 +30,8 @@ const ( FieldNotes = "notes" // FieldCreatedAt holds the string denoting the created_at field in the database. FieldCreatedAt = "created_at" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" // FieldGroupID holds the string denoting the group_id field in the database. FieldGroupID = "group_id" // FieldValidityDays holds the string denoting the validity_days field in the database. @@ -67,6 +69,7 @@ var Columns = []string{ FieldUsedAt, FieldNotes, FieldCreatedAt, + FieldExpiresAt, FieldGroupID, FieldValidityDays, } @@ -148,6 +151,11 @@ func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() } +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} + // ByGroupID orders the results by the group_id field. func ByGroupID(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldGroupID, opts...).ToFunc() diff --git a/backend/ent/redeemcode/where.go b/backend/ent/redeemcode/where.go index 1fdedba572b..8325b9fc140 100644 --- a/backend/ent/redeemcode/where.go +++ b/backend/ent/redeemcode/where.go @@ -95,6 +95,11 @@ func CreatedAt(v time.Time) predicate.RedeemCode { return predicate.RedeemCode(sql.FieldEQ(FieldCreatedAt, v)) } +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v)) +} + // GroupID applies equality check predicate on the "group_id" field. It's identical to GroupIDEQ. func GroupID(v int64) predicate.RedeemCode { return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v)) @@ -535,6 +540,56 @@ func CreatedAtLTE(v time.Time) predicate.RedeemCode { return predicate.RedeemCode(sql.FieldLTE(FieldCreatedAt, v)) } +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldLTE(FieldExpiresAt, v)) +} + +// ExpiresAtIsNil applies the IsNil predicate on the "expires_at" field. +func ExpiresAtIsNil() predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldIsNull(FieldExpiresAt)) +} + +// ExpiresAtNotNil applies the NotNil predicate on the "expires_at" field. +func ExpiresAtNotNil() predicate.RedeemCode { + return predicate.RedeemCode(sql.FieldNotNull(FieldExpiresAt)) +} + // GroupIDEQ applies the EQ predicate on the "group_id" field. func GroupIDEQ(v int64) predicate.RedeemCode { return predicate.RedeemCode(sql.FieldEQ(FieldGroupID, v)) diff --git a/backend/ent/redeemcode_create.go b/backend/ent/redeemcode_create.go index efdcee40b25..1bba027ba82 100644 --- a/backend/ent/redeemcode_create.go +++ b/backend/ent/redeemcode_create.go @@ -128,6 +128,20 @@ func (_c *RedeemCodeCreate) SetNillableCreatedAt(v *time.Time) *RedeemCodeCreate return _c } +// SetExpiresAt sets the "expires_at" field. +func (_c *RedeemCodeCreate) SetExpiresAt(v time.Time) *RedeemCodeCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_c *RedeemCodeCreate) SetNillableExpiresAt(v *time.Time) *RedeemCodeCreate { + if v != nil { + _c.SetExpiresAt(*v) + } + return _c +} + // SetGroupID sets the "group_id" field. func (_c *RedeemCodeCreate) SetGroupID(v int64) *RedeemCodeCreate { _c.mutation.SetGroupID(v) @@ -327,6 +341,10 @@ func (_c *RedeemCodeCreate) createSpec() (*RedeemCode, *sqlgraph.CreateSpec) { _spec.SetField(redeemcode.FieldCreatedAt, field.TypeTime, value) _node.CreatedAt = value } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = &value + } if value, ok := _c.mutation.ValidityDays(); ok { _spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value) _node.ValidityDays = value @@ -525,6 +543,24 @@ func (u *RedeemCodeUpsert) ClearNotes() *RedeemCodeUpsert { return u } +// SetExpiresAt sets the "expires_at" field. +func (u *RedeemCodeUpsert) SetExpiresAt(v time.Time) *RedeemCodeUpsert { + u.Set(redeemcode.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *RedeemCodeUpsert) UpdateExpiresAt() *RedeemCodeUpsert { + u.SetExcluded(redeemcode.FieldExpiresAt) + return u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *RedeemCodeUpsert) ClearExpiresAt() *RedeemCodeUpsert { + u.SetNull(redeemcode.FieldExpiresAt) + return u +} + // SetGroupID sets the "group_id" field. func (u *RedeemCodeUpsert) SetGroupID(v int64) *RedeemCodeUpsert { u.Set(redeemcode.FieldGroupID, v) @@ -732,6 +768,27 @@ func (u *RedeemCodeUpsertOne) ClearNotes() *RedeemCodeUpsertOne { }) } +// SetExpiresAt sets the "expires_at" field. +func (u *RedeemCodeUpsertOne) SetExpiresAt(v time.Time) *RedeemCodeUpsertOne { + return u.Update(func(s *RedeemCodeUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *RedeemCodeUpsertOne) UpdateExpiresAt() *RedeemCodeUpsertOne { + return u.Update(func(s *RedeemCodeUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *RedeemCodeUpsertOne) ClearExpiresAt() *RedeemCodeUpsertOne { + return u.Update(func(s *RedeemCodeUpsert) { + s.ClearExpiresAt() + }) +} + // SetGroupID sets the "group_id" field. func (u *RedeemCodeUpsertOne) SetGroupID(v int64) *RedeemCodeUpsertOne { return u.Update(func(s *RedeemCodeUpsert) { @@ -1111,6 +1168,27 @@ func (u *RedeemCodeUpsertBulk) ClearNotes() *RedeemCodeUpsertBulk { }) } +// SetExpiresAt sets the "expires_at" field. +func (u *RedeemCodeUpsertBulk) SetExpiresAt(v time.Time) *RedeemCodeUpsertBulk { + return u.Update(func(s *RedeemCodeUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *RedeemCodeUpsertBulk) UpdateExpiresAt() *RedeemCodeUpsertBulk { + return u.Update(func(s *RedeemCodeUpsert) { + s.UpdateExpiresAt() + }) +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (u *RedeemCodeUpsertBulk) ClearExpiresAt() *RedeemCodeUpsertBulk { + return u.Update(func(s *RedeemCodeUpsert) { + s.ClearExpiresAt() + }) +} + // SetGroupID sets the "group_id" field. func (u *RedeemCodeUpsertBulk) SetGroupID(v int64) *RedeemCodeUpsertBulk { return u.Update(func(s *RedeemCodeUpsert) { diff --git a/backend/ent/redeemcode_update.go b/backend/ent/redeemcode_update.go index 0f05e06dc23..1e0ec1e681f 100644 --- a/backend/ent/redeemcode_update.go +++ b/backend/ent/redeemcode_update.go @@ -153,6 +153,26 @@ func (_u *RedeemCodeUpdate) ClearNotes() *RedeemCodeUpdate { return _u } +// SetExpiresAt sets the "expires_at" field. +func (_u *RedeemCodeUpdate) SetExpiresAt(v time.Time) *RedeemCodeUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *RedeemCodeUpdate) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *RedeemCodeUpdate) ClearExpiresAt() *RedeemCodeUpdate { + _u.mutation.ClearExpiresAt() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *RedeemCodeUpdate) SetGroupID(v int64) *RedeemCodeUpdate { _u.mutation.SetGroupID(v) @@ -321,6 +341,12 @@ func (_u *RedeemCodeUpdate) sqlSave(ctx context.Context) (_node int, err error) if _u.mutation.NotesCleared() { _spec.ClearField(redeemcode.FieldNotes, field.TypeString) } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime) + } if value, ok := _u.mutation.ValidityDays(); ok { _spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value) } @@ -528,6 +554,26 @@ func (_u *RedeemCodeUpdateOne) ClearNotes() *RedeemCodeUpdateOne { return _u } +// SetExpiresAt sets the "expires_at" field. +func (_u *RedeemCodeUpdateOne) SetExpiresAt(v time.Time) *RedeemCodeUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *RedeemCodeUpdateOne) SetNillableExpiresAt(v *time.Time) *RedeemCodeUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// ClearExpiresAt clears the value of the "expires_at" field. +func (_u *RedeemCodeUpdateOne) ClearExpiresAt() *RedeemCodeUpdateOne { + _u.mutation.ClearExpiresAt() + return _u +} + // SetGroupID sets the "group_id" field. func (_u *RedeemCodeUpdateOne) SetGroupID(v int64) *RedeemCodeUpdateOne { _u.mutation.SetGroupID(v) @@ -726,6 +772,12 @@ func (_u *RedeemCodeUpdateOne) sqlSave(ctx context.Context) (_node *RedeemCode, if _u.mutation.NotesCleared() { _spec.ClearField(redeemcode.FieldNotes, field.TypeString) } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(redeemcode.FieldExpiresAt, field.TypeTime, value) + } + if _u.mutation.ExpiresAtCleared() { + _spec.ClearField(redeemcode.FieldExpiresAt, field.TypeTime) + } if value, ok := _u.mutation.ValidityDays(); ok { _spec.SetField(redeemcode.FieldValidityDays, field.TypeInt, value) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index a282d9ba39d..b1899173cac 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -1386,7 +1386,7 @@ func init() { // redeemcode.DefaultCreatedAt holds the default value on creation for the created_at field. redeemcode.DefaultCreatedAt = redeemcodeDescCreatedAt.Default.(func() time.Time) // redeemcodeDescValidityDays is the schema descriptor for validity_days field. - redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() + redeemcodeDescValidityDays := redeemcodeFields[10].Descriptor() // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) securitysecretMixin := schema.SecuritySecret{}.Mixin() @@ -1722,12 +1722,24 @@ func init() { usagelogDescImageSize := usagelogFields[34].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescImageInputSize is the schema descriptor for image_input_size field. + usagelogDescImageInputSize := usagelogFields[35].Descriptor() + // usagelog.ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save. + usagelog.ImageInputSizeValidator = usagelogDescImageInputSize.Validators[0].(func(string) error) + // usagelogDescImageOutputSize is the schema descriptor for image_output_size field. + usagelogDescImageOutputSize := usagelogFields[36].Descriptor() + // usagelog.ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save. + usagelog.ImageOutputSizeValidator = usagelogDescImageOutputSize.Validators[0].(func(string) error) + // usagelogDescImageSizeSource is the schema descriptor for image_size_source field. + usagelogDescImageSizeSource := usagelogFields[37].Descriptor() + // usagelog.ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save. + usagelog.ImageSizeSourceValidator = usagelogDescImageSizeSource.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[35].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[39].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[36].Descriptor() + usagelogDescCreatedAt := usagelogFields[40].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/auth_identity.go b/backend/ent/schema/auth_identity.go index 5f864080086..3deeadcd96c 100644 --- a/backend/ent/schema/auth_identity.go +++ b/backend/ent/schema/auth_identity.go @@ -15,12 +15,13 @@ import ( ) var authProviderTypes = map[string]struct{}{ - "email": {}, - "github": {}, - "google": {}, - "linuxdo": {}, - "oidc": {}, - "wechat": {}, + "email": {}, + "github": {}, + "google": {}, + "linuxdo": {}, + "oidc": {}, + "wechat": {}, + "dingtalk": {}, } func validateAuthProviderType(value string) error { diff --git a/backend/ent/schema/auth_identity_schema_test.go b/backend/ent/schema/auth_identity_schema_test.go index d3e2405069a..af272790d1d 100644 --- a/backend/ent/schema/auth_identity_schema_test.go +++ b/backend/ent/schema/auth_identity_schema_test.go @@ -83,7 +83,7 @@ func TestAuthIdentityFoundationSchemas(t *testing.T) { require.Equal(t, 1, signupSource.Validators) validator := requireStringFieldValidator(t, User{}.Fields(), "signup_source") - for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google"} { + for _, value := range []string{"email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk"} { require.NoError(t, validator(value)) } require.Error(t, validator("unknown")) diff --git a/backend/ent/schema/redeem_code.go b/backend/ent/schema/redeem_code.go index 6fb8614847a..fdaf0808304 100644 --- a/backend/ent/schema/redeem_code.go +++ b/backend/ent/schema/redeem_code.go @@ -63,6 +63,10 @@ func (RedeemCode) Fields() []ent.Field { Immutable(). Default(time.Now). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + field.Time("expires_at"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), field.Int64("group_id"). Optional(). Nillable(), @@ -90,5 +94,6 @@ func (RedeemCode) Indexes() []ent.Index { index.Fields("status"), index.Fields("used_by"), index.Fields("group_id"), + index.Fields("expires_at"), } } diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index bd3ebfcc3ce..db9e5178922 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -134,6 +134,21 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + field.String("image_input_size"). + MaxLen(32). + Optional(). + Nillable(), + field.String("image_output_size"). + MaxLen(32). + Optional(). + Nillable(), + field.String("image_size_source"). + MaxLen(16). + Optional(). + Nillable(), + field.JSON("image_size_breakdown", map[string]int{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) field.Bool("cache_ttl_overridden"). Default(false), diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index 08bab83a9c9..c6e0427330e 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -77,10 +77,10 @@ func (User) Fields() []ent.Field { field.String("signup_source"). Validate(func(value string) error { switch value { - case "email", "linuxdo", "wechat", "oidc", "github", "google": + case "email", "linuxdo", "wechat", "oidc", "github", "google", "dingtalk": return nil default: - return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google") + return fmt.Errorf("must be one of email, linuxdo, wechat, oidc, github, google, dingtalk") } }). Default("email"), diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index a8e0cc6ce8d..283fe828a97 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -3,6 +3,7 @@ package ent import ( + "encoding/json" "fmt" "strings" "time" @@ -92,6 +93,14 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // ImageInputSize holds the value of the "image_input_size" field. + ImageInputSize *string `json:"image_input_size,omitempty"` + // ImageOutputSize holds the value of the "image_output_size" field. + ImageOutputSize *string `json:"image_output_size,omitempty"` + // ImageSizeSource holds the value of the "image_size_source" field. + ImageSizeSource *string `json:"image_size_source,omitempty"` + // ImageSizeBreakdown holds the value of the "image_size_breakdown" field. + ImageSizeBreakdown map[string]int `json:"image_size_breakdown,omitempty"` // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. @@ -179,13 +188,15 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { + case usagelog.FieldImageSizeBreakdown: + values[i] = new([]byte) case usagelog.FieldStream, usagelog.FieldCacheTTLOverridden: values[i] = new(sql.NullBool) case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier: values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldChannelID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldRequestedModel, usagelog.FieldUpstreamModel, usagelog.FieldModelMappingChain, usagelog.FieldBillingTier, usagelog.FieldBillingMode, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldImageInputSize, usagelog.FieldImageOutputSize, usagelog.FieldImageSizeSource: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -434,6 +445,35 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldImageInputSize: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field image_input_size", values[i]) + } else if value.Valid { + _m.ImageInputSize = new(string) + *_m.ImageInputSize = value.String + } + case usagelog.FieldImageOutputSize: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field image_output_size", values[i]) + } else if value.Valid { + _m.ImageOutputSize = new(string) + *_m.ImageOutputSize = value.String + } + case usagelog.FieldImageSizeSource: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field image_size_source", values[i]) + } else if value.Valid { + _m.ImageSizeSource = new(string) + *_m.ImageSizeSource = value.String + } + case usagelog.FieldImageSizeBreakdown: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field image_size_breakdown", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ImageSizeBreakdown); err != nil { + return fmt.Errorf("unmarshal field image_size_breakdown: %w", err) + } + } case usagelog.FieldCacheTTLOverridden: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) @@ -640,6 +680,24 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.ImageInputSize; v != nil { + builder.WriteString("image_input_size=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ImageOutputSize; v != nil { + builder.WriteString("image_output_size=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ImageSizeSource; v != nil { + builder.WriteString("image_size_source=") + builder.WriteString(*v) + } + builder.WriteString(", ") + builder.WriteString("image_size_breakdown=") + builder.WriteString(fmt.Sprintf("%v", _m.ImageSizeBreakdown)) + builder.WriteString(", ") builder.WriteString("cache_ttl_overridden=") builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index a7438e604fb..297e0b41ad5 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -84,6 +84,14 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldImageInputSize holds the string denoting the image_input_size field in the database. + FieldImageInputSize = "image_input_size" + // FieldImageOutputSize holds the string denoting the image_output_size field in the database. + FieldImageOutputSize = "image_output_size" + // FieldImageSizeSource holds the string denoting the image_size_source field in the database. + FieldImageSizeSource = "image_size_source" + // FieldImageSizeBreakdown holds the string denoting the image_size_breakdown field in the database. + FieldImageSizeBreakdown = "image_size_breakdown" // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. @@ -175,6 +183,10 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldImageInputSize, + FieldImageOutputSize, + FieldImageSizeSource, + FieldImageSizeBreakdown, FieldCacheTTLOverridden, FieldCreatedAt, } @@ -242,6 +254,12 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // ImageInputSizeValidator is a validator for the "image_input_size" field. It is called by the builders before save. + ImageInputSizeValidator func(string) error + // ImageOutputSizeValidator is a validator for the "image_output_size" field. It is called by the builders before save. + ImageOutputSizeValidator func(string) error + // ImageSizeSourceValidator is a validator for the "image_size_source" field. It is called by the builders before save. + ImageSizeSourceValidator func(string) error // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. @@ -431,6 +449,21 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByImageInputSize orders the results by the image_input_size field. +func ByImageInputSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageInputSize, opts...).ToFunc() +} + +// ByImageOutputSize orders the results by the image_output_size field. +func ByImageOutputSize(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageOutputSize, opts...).ToFunc() +} + +// ByImageSizeSource orders the results by the image_size_source field. +func ByImageSizeSource(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldImageSizeSource, opts...).ToFunc() +} + // ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index b8439a03978..2987f179303 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -230,6 +230,21 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// ImageInputSize applies equality check predicate on the "image_input_size" field. It's identical to ImageInputSizeEQ. +func ImageInputSize(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v)) +} + +// ImageOutputSize applies equality check predicate on the "image_output_size" field. It's identical to ImageOutputSizeEQ. +func ImageOutputSize(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v)) +} + +// ImageSizeSource applies equality check predicate on the "image_size_source" field. It's identical to ImageSizeSourceEQ. +func ImageSizeSource(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v)) +} + // CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. func CacheTTLOverridden(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) @@ -1900,6 +1915,241 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// ImageInputSizeEQ applies the EQ predicate on the "image_input_size" field. +func ImageInputSizeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageInputSize, v)) +} + +// ImageInputSizeNEQ applies the NEQ predicate on the "image_input_size" field. +func ImageInputSizeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldImageInputSize, v)) +} + +// ImageInputSizeIn applies the In predicate on the "image_input_size" field. +func ImageInputSizeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldImageInputSize, vs...)) +} + +// ImageInputSizeNotIn applies the NotIn predicate on the "image_input_size" field. +func ImageInputSizeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldImageInputSize, vs...)) +} + +// ImageInputSizeGT applies the GT predicate on the "image_input_size" field. +func ImageInputSizeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldImageInputSize, v)) +} + +// ImageInputSizeGTE applies the GTE predicate on the "image_input_size" field. +func ImageInputSizeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldImageInputSize, v)) +} + +// ImageInputSizeLT applies the LT predicate on the "image_input_size" field. +func ImageInputSizeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldImageInputSize, v)) +} + +// ImageInputSizeLTE applies the LTE predicate on the "image_input_size" field. +func ImageInputSizeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldImageInputSize, v)) +} + +// ImageInputSizeContains applies the Contains predicate on the "image_input_size" field. +func ImageInputSizeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldImageInputSize, v)) +} + +// ImageInputSizeHasPrefix applies the HasPrefix predicate on the "image_input_size" field. +func ImageInputSizeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldImageInputSize, v)) +} + +// ImageInputSizeHasSuffix applies the HasSuffix predicate on the "image_input_size" field. +func ImageInputSizeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldImageInputSize, v)) +} + +// ImageInputSizeIsNil applies the IsNil predicate on the "image_input_size" field. +func ImageInputSizeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldImageInputSize)) +} + +// ImageInputSizeNotNil applies the NotNil predicate on the "image_input_size" field. +func ImageInputSizeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldImageInputSize)) +} + +// ImageInputSizeEqualFold applies the EqualFold predicate on the "image_input_size" field. +func ImageInputSizeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldImageInputSize, v)) +} + +// ImageInputSizeContainsFold applies the ContainsFold predicate on the "image_input_size" field. +func ImageInputSizeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldImageInputSize, v)) +} + +// ImageOutputSizeEQ applies the EQ predicate on the "image_output_size" field. +func ImageOutputSizeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageOutputSize, v)) +} + +// ImageOutputSizeNEQ applies the NEQ predicate on the "image_output_size" field. +func ImageOutputSizeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldImageOutputSize, v)) +} + +// ImageOutputSizeIn applies the In predicate on the "image_output_size" field. +func ImageOutputSizeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldImageOutputSize, vs...)) +} + +// ImageOutputSizeNotIn applies the NotIn predicate on the "image_output_size" field. +func ImageOutputSizeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldImageOutputSize, vs...)) +} + +// ImageOutputSizeGT applies the GT predicate on the "image_output_size" field. +func ImageOutputSizeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldImageOutputSize, v)) +} + +// ImageOutputSizeGTE applies the GTE predicate on the "image_output_size" field. +func ImageOutputSizeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldImageOutputSize, v)) +} + +// ImageOutputSizeLT applies the LT predicate on the "image_output_size" field. +func ImageOutputSizeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldImageOutputSize, v)) +} + +// ImageOutputSizeLTE applies the LTE predicate on the "image_output_size" field. +func ImageOutputSizeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldImageOutputSize, v)) +} + +// ImageOutputSizeContains applies the Contains predicate on the "image_output_size" field. +func ImageOutputSizeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldImageOutputSize, v)) +} + +// ImageOutputSizeHasPrefix applies the HasPrefix predicate on the "image_output_size" field. +func ImageOutputSizeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldImageOutputSize, v)) +} + +// ImageOutputSizeHasSuffix applies the HasSuffix predicate on the "image_output_size" field. +func ImageOutputSizeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldImageOutputSize, v)) +} + +// ImageOutputSizeIsNil applies the IsNil predicate on the "image_output_size" field. +func ImageOutputSizeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldImageOutputSize)) +} + +// ImageOutputSizeNotNil applies the NotNil predicate on the "image_output_size" field. +func ImageOutputSizeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldImageOutputSize)) +} + +// ImageOutputSizeEqualFold applies the EqualFold predicate on the "image_output_size" field. +func ImageOutputSizeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldImageOutputSize, v)) +} + +// ImageOutputSizeContainsFold applies the ContainsFold predicate on the "image_output_size" field. +func ImageOutputSizeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldImageOutputSize, v)) +} + +// ImageSizeSourceEQ applies the EQ predicate on the "image_size_source" field. +func ImageSizeSourceEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldImageSizeSource, v)) +} + +// ImageSizeSourceNEQ applies the NEQ predicate on the "image_size_source" field. +func ImageSizeSourceNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldImageSizeSource, v)) +} + +// ImageSizeSourceIn applies the In predicate on the "image_size_source" field. +func ImageSizeSourceIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldImageSizeSource, vs...)) +} + +// ImageSizeSourceNotIn applies the NotIn predicate on the "image_size_source" field. +func ImageSizeSourceNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldImageSizeSource, vs...)) +} + +// ImageSizeSourceGT applies the GT predicate on the "image_size_source" field. +func ImageSizeSourceGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldImageSizeSource, v)) +} + +// ImageSizeSourceGTE applies the GTE predicate on the "image_size_source" field. +func ImageSizeSourceGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldImageSizeSource, v)) +} + +// ImageSizeSourceLT applies the LT predicate on the "image_size_source" field. +func ImageSizeSourceLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldImageSizeSource, v)) +} + +// ImageSizeSourceLTE applies the LTE predicate on the "image_size_source" field. +func ImageSizeSourceLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldImageSizeSource, v)) +} + +// ImageSizeSourceContains applies the Contains predicate on the "image_size_source" field. +func ImageSizeSourceContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldImageSizeSource, v)) +} + +// ImageSizeSourceHasPrefix applies the HasPrefix predicate on the "image_size_source" field. +func ImageSizeSourceHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldImageSizeSource, v)) +} + +// ImageSizeSourceHasSuffix applies the HasSuffix predicate on the "image_size_source" field. +func ImageSizeSourceHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldImageSizeSource, v)) +} + +// ImageSizeSourceIsNil applies the IsNil predicate on the "image_size_source" field. +func ImageSizeSourceIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeSource)) +} + +// ImageSizeSourceNotNil applies the NotNil predicate on the "image_size_source" field. +func ImageSizeSourceNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeSource)) +} + +// ImageSizeSourceEqualFold applies the EqualFold predicate on the "image_size_source" field. +func ImageSizeSourceEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldImageSizeSource, v)) +} + +// ImageSizeSourceContainsFold applies the ContainsFold predicate on the "image_size_source" field. +func ImageSizeSourceContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldImageSizeSource, v)) +} + +// ImageSizeBreakdownIsNil applies the IsNil predicate on the "image_size_breakdown" field. +func ImageSizeBreakdownIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldImageSizeBreakdown)) +} + +// ImageSizeBreakdownNotNil applies the NotNil predicate on the "image_size_breakdown" field. +func ImageSizeBreakdownNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldImageSizeBreakdown)) +} + // CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index fded364e0e6..17e800f9ca3 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -477,6 +477,54 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetImageInputSize sets the "image_input_size" field. +func (_c *UsageLogCreate) SetImageInputSize(v string) *UsageLogCreate { + _c.mutation.SetImageInputSize(v) + return _c +} + +// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableImageInputSize(v *string) *UsageLogCreate { + if v != nil { + _c.SetImageInputSize(*v) + } + return _c +} + +// SetImageOutputSize sets the "image_output_size" field. +func (_c *UsageLogCreate) SetImageOutputSize(v string) *UsageLogCreate { + _c.mutation.SetImageOutputSize(v) + return _c +} + +// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableImageOutputSize(v *string) *UsageLogCreate { + if v != nil { + _c.SetImageOutputSize(*v) + } + return _c +} + +// SetImageSizeSource sets the "image_size_source" field. +func (_c *UsageLogCreate) SetImageSizeSource(v string) *UsageLogCreate { + _c.mutation.SetImageSizeSource(v) + return _c +} + +// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableImageSizeSource(v *string) *UsageLogCreate { + if v != nil { + _c.SetImageSizeSource(*v) + } + return _c +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (_c *UsageLogCreate) SetImageSizeBreakdown(v map[string]int) *UsageLogCreate { + _c.mutation.SetImageSizeBreakdown(v) + return _c +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { _c.mutation.SetCacheTTLOverridden(v) @@ -754,6 +802,21 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.ImageInputSize(); ok { + if err := usagelog.ImageInputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)} + } + } + if v, ok := _c.mutation.ImageOutputSize(); ok { + if err := usagelog.ImageOutputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)} + } + } + if v, ok := _c.mutation.ImageSizeSource(); ok { + if err := usagelog.ImageSizeSourceValidator(v); err != nil { + return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)} + } + } if _, ok := _c.mutation.CacheTTLOverridden(); !ok { return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} } @@ -916,6 +979,22 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.ImageInputSize(); ok { + _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value) + _node.ImageInputSize = &value + } + if value, ok := _c.mutation.ImageOutputSize(); ok { + _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value) + _node.ImageOutputSize = &value + } + if value, ok := _c.mutation.ImageSizeSource(); ok { + _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value) + _node.ImageSizeSource = &value + } + if value, ok := _c.mutation.ImageSizeBreakdown(); ok { + _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value) + _node.ImageSizeBreakdown = value + } if value, ok := _c.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _node.CacheTTLOverridden = value @@ -1679,6 +1758,78 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetImageInputSize sets the "image_input_size" field. +func (u *UsageLogUpsert) SetImageInputSize(v string) *UsageLogUpsert { + u.Set(usagelog.FieldImageInputSize, v) + return u +} + +// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateImageInputSize() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldImageInputSize) + return u +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (u *UsageLogUpsert) ClearImageInputSize() *UsageLogUpsert { + u.SetNull(usagelog.FieldImageInputSize) + return u +} + +// SetImageOutputSize sets the "image_output_size" field. +func (u *UsageLogUpsert) SetImageOutputSize(v string) *UsageLogUpsert { + u.Set(usagelog.FieldImageOutputSize, v) + return u +} + +// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateImageOutputSize() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldImageOutputSize) + return u +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (u *UsageLogUpsert) ClearImageOutputSize() *UsageLogUpsert { + u.SetNull(usagelog.FieldImageOutputSize) + return u +} + +// SetImageSizeSource sets the "image_size_source" field. +func (u *UsageLogUpsert) SetImageSizeSource(v string) *UsageLogUpsert { + u.Set(usagelog.FieldImageSizeSource, v) + return u +} + +// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateImageSizeSource() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldImageSizeSource) + return u +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (u *UsageLogUpsert) ClearImageSizeSource() *UsageLogUpsert { + u.SetNull(usagelog.FieldImageSizeSource) + return u +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (u *UsageLogUpsert) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsert { + u.Set(usagelog.FieldImageSizeBreakdown, v) + return u +} + +// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateImageSizeBreakdown() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldImageSizeBreakdown) + return u +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (u *UsageLogUpsert) ClearImageSizeBreakdown() *UsageLogUpsert { + u.SetNull(usagelog.FieldImageSizeBreakdown) + return u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { u.Set(usagelog.FieldCacheTTLOverridden, v) @@ -2457,6 +2608,90 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetImageInputSize sets the "image_input_size" field. +func (u *UsageLogUpsertOne) SetImageInputSize(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageInputSize(v) + }) +} + +// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateImageInputSize() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageInputSize() + }) +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (u *UsageLogUpsertOne) ClearImageInputSize() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageInputSize() + }) +} + +// SetImageOutputSize sets the "image_output_size" field. +func (u *UsageLogUpsertOne) SetImageOutputSize(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageOutputSize(v) + }) +} + +// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateImageOutputSize() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageOutputSize() + }) +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (u *UsageLogUpsertOne) ClearImageOutputSize() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageOutputSize() + }) +} + +// SetImageSizeSource sets the "image_size_source" field. +func (u *UsageLogUpsertOne) SetImageSizeSource(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageSizeSource(v) + }) +} + +// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateImageSizeSource() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageSizeSource() + }) +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (u *UsageLogUpsertOne) ClearImageSizeSource() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageSizeSource() + }) +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (u *UsageLogUpsertOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageSizeBreakdown(v) + }) +} + +// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateImageSizeBreakdown() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageSizeBreakdown() + }) +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (u *UsageLogUpsertOne) ClearImageSizeBreakdown() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageSizeBreakdown() + }) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -3403,6 +3638,90 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetImageInputSize sets the "image_input_size" field. +func (u *UsageLogUpsertBulk) SetImageInputSize(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageInputSize(v) + }) +} + +// UpdateImageInputSize sets the "image_input_size" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateImageInputSize() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageInputSize() + }) +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (u *UsageLogUpsertBulk) ClearImageInputSize() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageInputSize() + }) +} + +// SetImageOutputSize sets the "image_output_size" field. +func (u *UsageLogUpsertBulk) SetImageOutputSize(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageOutputSize(v) + }) +} + +// UpdateImageOutputSize sets the "image_output_size" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateImageOutputSize() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageOutputSize() + }) +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (u *UsageLogUpsertBulk) ClearImageOutputSize() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageOutputSize() + }) +} + +// SetImageSizeSource sets the "image_size_source" field. +func (u *UsageLogUpsertBulk) SetImageSizeSource(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageSizeSource(v) + }) +} + +// UpdateImageSizeSource sets the "image_size_source" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateImageSizeSource() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageSizeSource() + }) +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (u *UsageLogUpsertBulk) ClearImageSizeSource() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageSizeSource() + }) +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (u *UsageLogUpsertBulk) SetImageSizeBreakdown(v map[string]int) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetImageSizeBreakdown(v) + }) +} + +// UpdateImageSizeBreakdown sets the "image_size_breakdown" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateImageSizeBreakdown() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateImageSizeBreakdown() + }) +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (u *UsageLogUpsertBulk) ClearImageSizeBreakdown() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearImageSizeBreakdown() + }) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index bb5ac86c78a..e8fa003c63e 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -739,6 +739,78 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetImageInputSize sets the "image_input_size" field. +func (_u *UsageLogUpdate) SetImageInputSize(v string) *UsageLogUpdate { + _u.mutation.SetImageInputSize(v) + return _u +} + +// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableImageInputSize(v *string) *UsageLogUpdate { + if v != nil { + _u.SetImageInputSize(*v) + } + return _u +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (_u *UsageLogUpdate) ClearImageInputSize() *UsageLogUpdate { + _u.mutation.ClearImageInputSize() + return _u +} + +// SetImageOutputSize sets the "image_output_size" field. +func (_u *UsageLogUpdate) SetImageOutputSize(v string) *UsageLogUpdate { + _u.mutation.SetImageOutputSize(v) + return _u +} + +// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableImageOutputSize(v *string) *UsageLogUpdate { + if v != nil { + _u.SetImageOutputSize(*v) + } + return _u +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (_u *UsageLogUpdate) ClearImageOutputSize() *UsageLogUpdate { + _u.mutation.ClearImageOutputSize() + return _u +} + +// SetImageSizeSource sets the "image_size_source" field. +func (_u *UsageLogUpdate) SetImageSizeSource(v string) *UsageLogUpdate { + _u.mutation.SetImageSizeSource(v) + return _u +} + +// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableImageSizeSource(v *string) *UsageLogUpdate { + if v != nil { + _u.SetImageSizeSource(*v) + } + return _u +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (_u *UsageLogUpdate) ClearImageSizeSource() *UsageLogUpdate { + _u.mutation.ClearImageSizeSource() + return _u +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (_u *UsageLogUpdate) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdate { + _u.mutation.SetImageSizeBreakdown(v) + return _u +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (_u *UsageLogUpdate) ClearImageSizeBreakdown() *UsageLogUpdate { + _u.mutation.ClearImageSizeBreakdown() + return _u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { _u.mutation.SetCacheTTLOverridden(v) @@ -892,6 +964,21 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.ImageInputSize(); ok { + if err := usagelog.ImageInputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)} + } + } + if v, ok := _u.mutation.ImageOutputSize(); ok { + if err := usagelog.ImageOutputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)} + } + } + if v, ok := _u.mutation.ImageSizeSource(); ok { + if err := usagelog.ImageSizeSourceValidator(v); err != nil { + return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1099,6 +1186,30 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.ImageInputSize(); ok { + _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value) + } + if _u.mutation.ImageInputSizeCleared() { + _spec.ClearField(usagelog.FieldImageInputSize, field.TypeString) + } + if value, ok := _u.mutation.ImageOutputSize(); ok { + _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value) + } + if _u.mutation.ImageOutputSizeCleared() { + _spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString) + } + if value, ok := _u.mutation.ImageSizeSource(); ok { + _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value) + } + if _u.mutation.ImageSizeSourceCleared() { + _spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString) + } + if value, ok := _u.mutation.ImageSizeBreakdown(); ok { + _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value) + } + if _u.mutation.ImageSizeBreakdownCleared() { + _spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON) + } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } @@ -1974,6 +2085,78 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetImageInputSize sets the "image_input_size" field. +func (_u *UsageLogUpdateOne) SetImageInputSize(v string) *UsageLogUpdateOne { + _u.mutation.SetImageInputSize(v) + return _u +} + +// SetNillableImageInputSize sets the "image_input_size" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableImageInputSize(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetImageInputSize(*v) + } + return _u +} + +// ClearImageInputSize clears the value of the "image_input_size" field. +func (_u *UsageLogUpdateOne) ClearImageInputSize() *UsageLogUpdateOne { + _u.mutation.ClearImageInputSize() + return _u +} + +// SetImageOutputSize sets the "image_output_size" field. +func (_u *UsageLogUpdateOne) SetImageOutputSize(v string) *UsageLogUpdateOne { + _u.mutation.SetImageOutputSize(v) + return _u +} + +// SetNillableImageOutputSize sets the "image_output_size" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableImageOutputSize(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetImageOutputSize(*v) + } + return _u +} + +// ClearImageOutputSize clears the value of the "image_output_size" field. +func (_u *UsageLogUpdateOne) ClearImageOutputSize() *UsageLogUpdateOne { + _u.mutation.ClearImageOutputSize() + return _u +} + +// SetImageSizeSource sets the "image_size_source" field. +func (_u *UsageLogUpdateOne) SetImageSizeSource(v string) *UsageLogUpdateOne { + _u.mutation.SetImageSizeSource(v) + return _u +} + +// SetNillableImageSizeSource sets the "image_size_source" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableImageSizeSource(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetImageSizeSource(*v) + } + return _u +} + +// ClearImageSizeSource clears the value of the "image_size_source" field. +func (_u *UsageLogUpdateOne) ClearImageSizeSource() *UsageLogUpdateOne { + _u.mutation.ClearImageSizeSource() + return _u +} + +// SetImageSizeBreakdown sets the "image_size_breakdown" field. +func (_u *UsageLogUpdateOne) SetImageSizeBreakdown(v map[string]int) *UsageLogUpdateOne { + _u.mutation.SetImageSizeBreakdown(v) + return _u +} + +// ClearImageSizeBreakdown clears the value of the "image_size_breakdown" field. +func (_u *UsageLogUpdateOne) ClearImageSizeBreakdown() *UsageLogUpdateOne { + _u.mutation.ClearImageSizeBreakdown() + return _u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { _u.mutation.SetCacheTTLOverridden(v) @@ -2140,6 +2323,21 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.ImageInputSize(); ok { + if err := usagelog.ImageInputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_input_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_input_size": %w`, err)} + } + } + if v, ok := _u.mutation.ImageOutputSize(); ok { + if err := usagelog.ImageOutputSizeValidator(v); err != nil { + return &ValidationError{Name: "image_output_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_output_size": %w`, err)} + } + } + if v, ok := _u.mutation.ImageSizeSource(); ok { + if err := usagelog.ImageSizeSourceValidator(v); err != nil { + return &ValidationError{Name: "image_size_source", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size_source": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -2364,6 +2562,30 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.ImageInputSize(); ok { + _spec.SetField(usagelog.FieldImageInputSize, field.TypeString, value) + } + if _u.mutation.ImageInputSizeCleared() { + _spec.ClearField(usagelog.FieldImageInputSize, field.TypeString) + } + if value, ok := _u.mutation.ImageOutputSize(); ok { + _spec.SetField(usagelog.FieldImageOutputSize, field.TypeString, value) + } + if _u.mutation.ImageOutputSizeCleared() { + _spec.ClearField(usagelog.FieldImageOutputSize, field.TypeString) + } + if value, ok := _u.mutation.ImageSizeSource(); ok { + _spec.SetField(usagelog.FieldImageSizeSource, field.TypeString, value) + } + if _u.mutation.ImageSizeSourceCleared() { + _spec.ClearField(usagelog.FieldImageSizeSource, field.TypeString) + } + if value, ok := _u.mutation.ImageSizeBreakdown(); ok { + _spec.SetField(usagelog.FieldImageSizeBreakdown, field.TypeJSON, value) + } + if _u.mutation.ImageSizeBreakdownCleared() { + _spec.ClearField(usagelog.FieldImageSizeBreakdown, field.TypeJSON) + } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } diff --git a/backend/go.sum b/backend/go.sum index e16a9fc08c1..db410b49ce2 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -216,6 +216,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -249,6 +251,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -278,6 +282,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -310,6 +316,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d42828d8137..f08e0deadc9 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -72,6 +72,7 @@ type Config struct { LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` WeChat WeChatConnectConfig `mapstructure:"wechat_connect"` OIDC OIDCConnectConfig `mapstructure:"oidc_connect"` + DingTalk DingTalkConnectConfig `mapstructure:"dingtalk_connect"` GitHubOAuth EmailOAuthProviderConfig `mapstructure:"github_oauth"` GoogleOAuth EmailOAuthProviderConfig `mapstructure:"google_oauth"` Default DefaultConfig `mapstructure:"default"` @@ -242,6 +243,47 @@ type OIDCConnectConfig struct { UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` } +type DingTalkConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` + + // 平台底座 + 业务行为 + DingTalkAppKind string `mapstructure:"dingtalk_app_kind"` // 仅 "internal_app"(V4 fail-closed) + AppType string `mapstructure:"app_type"` // "public" (default) | "internal" + + // Corp 限定(none | internal_only) + CorpRestrictionPolicy string `mapstructure:"corp_restriction_policy"` + InternalCorpID string `mapstructure:"internal_corp_id"` + BypassRegistration bool `mapstructure:"bypass_registration"` + SyncCorpEmail bool `mapstructure:"sync_corp_email"` + SyncDisplayName bool `mapstructure:"sync_display_name"` + SyncDept bool `mapstructure:"sync_dept"` + SyncCorpEmailAttrKey string `mapstructure:"sync_corp_email_attr_key"` + SyncDisplayNameAttrKey string `mapstructure:"sync_display_name_attr_key"` + SyncDeptAttrKey string `mapstructure:"sync_dept_attr_key"` + SyncCorpEmailAttrName string `mapstructure:"sync_corp_email_attr_name"` + SyncDisplayNameAttrName string `mapstructure:"sync_display_name_attr_name"` + SyncDeptAttrName string `mapstructure:"sync_dept_attr_name"` + + // 邮箱 + Username + RequireEmail bool `mapstructure:"require_email"` + UsernameOverwritePolicy string `mapstructure:"username_overwrite_policy"` + + // Attribute(私有版扩展点;开源版仅声明) + UsernameAttributeKey string `mapstructure:"username_attribute_key"` + EnableAttributeMatching bool `mapstructure:"enable_attribute_matching"` + EnableAttributeSync bool `mapstructure:"enable_attribute_sync"` + AttributeSyncFields []string `mapstructure:"attribute_sync_fields"` + AttributeSyncOverwritePolicy string `mapstructure:"attribute_sync_overwrite_policy"` +} + type EmailOAuthProviderConfig struct { Enabled bool `mapstructure:"enabled"` ClientID string `mapstructure:"client_id"` @@ -1536,6 +1578,19 @@ func setDefaults() { viper.SetDefault("oidc_connect.userinfo_id_path", "") viper.SetDefault("oidc_connect.userinfo_username_path", "") + // DingTalk Connect OAuth 登录 + viper.SetDefault("dingtalk_connect.enabled", false) + viper.SetDefault("dingtalk_connect.authorize_url", "https://login.dingtalk.com/oauth2/auth") + viper.SetDefault("dingtalk_connect.token_url", "https://api.dingtalk.com/v1.0/oauth2/userAccessToken") + viper.SetDefault("dingtalk_connect.userinfo_url", "https://api.dingtalk.com/v1.0/contact/users/me") + viper.SetDefault("dingtalk_connect.scopes", "openid") + viper.SetDefault("dingtalk_connect.frontend_redirect_url", "/auth/dingtalk/callback") + viper.SetDefault("dingtalk_connect.dingtalk_app_kind", "internal_app") + viper.SetDefault("dingtalk_connect.app_type", "public") + viper.SetDefault("dingtalk_connect.corp_restriction_policy", "none") + viper.SetDefault("dingtalk_connect.require_email", true) + viper.SetDefault("dingtalk_connect.username_overwrite_policy", "if_empty") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -2608,6 +2663,9 @@ func (c *Config) Validate() error { if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 { return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds") } + if err := ValidateDingTalkConfig(c.DingTalk); err != nil { + return fmt.Errorf("dingtalk_connect: %w", err) + } return nil } diff --git a/backend/internal/config/validate_dingtalk.go b/backend/internal/config/validate_dingtalk.go new file mode 100644 index 00000000000..15734eb5ad6 --- /dev/null +++ b/backend/internal/config/validate_dingtalk.go @@ -0,0 +1,30 @@ +// Package config 包含钉钉连接配置的校验逻辑。 +// +// internal_only 模式安全模型(方案 A): +// 不再要求 admin 填写 InternalCorpID 做二次 corpID 比对。 +// 安全边界由钉钉"企业内部应用"类型本身保证——只有应用所属企业的员工才能完成 OAuth, +// 因此 ValidateDingTalkConfig 只要求 app_type=internal(V1),不再要求 InternalCorpID 非空(原 V3 已删除)。 +// InternalCorpID 字段保留,admin 可选填;若填写,checkDingTalkCorpAllowed 不会使用它做约束。 +package config + +import "errors" + +var ( + ErrDingTalkV1AppTypeMismatch = errors.New("dingtalk: internal_only requires app_type=internal") + ErrDingTalkV4InvalidAppKind = errors.New("dingtalk: dingtalk_app_kind must be internal_app") +) + +func ValidateDingTalkConfig(cfg DingTalkConnectConfig) error { + if !cfg.Enabled { + return nil + } + if cfg.DingTalkAppKind != "internal_app" { + return ErrDingTalkV4InvalidAppKind + } + if cfg.CorpRestrictionPolicy == "internal_only" { + if cfg.AppType != "internal" { + return ErrDingTalkV1AppTypeMismatch + } + } + return nil +} diff --git a/backend/internal/config/validate_dingtalk_test.go b/backend/internal/config/validate_dingtalk_test.go new file mode 100644 index 00000000000..f121b97d729 --- /dev/null +++ b/backend/internal/config/validate_dingtalk_test.go @@ -0,0 +1,53 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateDingTalkConfig_Disabled_Skip(t *testing.T) { + require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{Enabled: false})) +} + +func TestValidateDingTalkConfig_V4_DingTalkAppKind(t *testing.T) { + err := ValidateDingTalkConfig(DingTalkConnectConfig{ + Enabled: true, + DingTalkAppKind: "third_party_enterprise_app", + CorpRestrictionPolicy: "none", + }) + require.ErrorIs(t, err, ErrDingTalkV4InvalidAppKind) +} + +func TestValidateDingTalkConfig_V1_InternalOnlyRequiresInternalAppType(t *testing.T) { + err := ValidateDingTalkConfig(DingTalkConnectConfig{ + Enabled: true, + DingTalkAppKind: "internal_app", + AppType: "public", + CorpRestrictionPolicy: "internal_only", + InternalCorpID: "dingABC", + }) + require.ErrorIs(t, err, ErrDingTalkV1AppTypeMismatch) +} + +// TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A: +// internal_only 策略下,InternalCorpID="" 应通过校验(企业隔离由钉钉 AppType=internal 保证)。 +func TestValidateDingTalkConfig_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) { + err := ValidateDingTalkConfig(DingTalkConnectConfig{ + Enabled: true, + DingTalkAppKind: "internal_app", + AppType: "internal", + CorpRestrictionPolicy: "internal_only", + InternalCorpID: "", + }) + require.NoError(t, err) +} + +func TestValidateDingTalkConfig_HappyPath_None(t *testing.T) { + require.NoError(t, ValidateDingTalkConfig(DingTalkConnectConfig{ + Enabled: true, + DingTalkAppKind: "internal_app", + AppType: "public", + CorpRestrictionPolicy: "none", + })) +} diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 00da48212aa..50beadf68e6 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -43,6 +43,9 @@ type DataProxy struct { Status string `json:"status"` } +// DataAccount 是管理员显式备份导出使用的账号结构,故意不走 dto.Account 的脱敏路径, +// Credentials 原文返回。这是"管理员备份"这一显式行为的一部分;如未来需要导出脱敏版本, +// 应新增独立结构而非修改这里。 type DataAccount struct { Name string `json:"name"` Notes *string `json:"notes,omitempty"` diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index ffab74d6a7a..282ceede0a1 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1643,7 +1643,7 @@ func (h *OAuthHandler) SetupTokenCookieAuth(c *gin.Context) { } // GetUsage handles getting account usage information -// GET /api/v1/admin/accounts/:id/usage?source=passive|active +// GET /api/v1/admin/accounts/:id/usage?source=passive|active&force=true func (h *AccountHandler) GetUsage(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -1652,12 +1652,13 @@ func (h *AccountHandler) GetUsage(c *gin.Context) { } source := c.DefaultQuery("source", "active") + force := c.Query("force") == "true" var usage *service.UsageInfo if source == "passive" { usage, err = h.accountUsageService.GetPassiveUsage(c.Request.Context(), accountID) } else { - usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID) + usage, err = h.accountUsageService.GetUsage(c.Request.Context(), accountID, force) } if err != nil { response.ErrorFrom(c, err) @@ -1994,6 +1995,48 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { response.Success(c, models) } +// SyncUpstreamModels handles syncing live supported models from an account's upstream. +// POST /api/v1/admin/accounts/:id/models/sync-upstream +func (h *AccountHandler) SyncUpstreamModels(c *gin.Context) { + accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid account ID") + return + } + + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.NotFound(c, "Account not found") + return + } + + if h.accountTestService == nil { + response.InternalError(c, "Account test service is not configured") + return + } + + models, err := h.accountTestService.FetchUpstreamSupportedModels(c.Request.Context(), account) + if err != nil { + var syncErr *service.UpstreamModelSyncError + if errors.As(err, &syncErr) { + switch syncErr.Kind { + case service.UpstreamModelSyncErrorConfiguration, service.UpstreamModelSyncErrorUnsupported: + response.BadRequest(c, syncErr.SafeMessage()) + default: + slog.Warn("sync_upstream_models_failed", "account_id", accountID, "kind", syncErr.Kind) + response.Error(c, http.StatusBadGateway, syncErr.SafeMessage()) + } + return + } + + slog.Warn("sync_upstream_models_failed", "account_id", accountID) + response.Error(c, http.StatusBadGateway, "Failed to sync upstream models from upstream") + return + } + + response.Success(c, gin.H{"models": models}) +} + // SetPrivacy handles setting privacy for a single OpenAI/Antigravity OAuth account // POST /api/v1/admin/accounts/:id/set-privacy func (h *AccountHandler) SetPrivacy(c *gin.Context) { diff --git a/backend/internal/handler/admin/account_handler_available_models_test.go b/backend/internal/handler/admin/account_handler_available_models_test.go index c5f1e2d884c..0efbd6d434d 100644 --- a/backend/internal/handler/admin/account_handler_available_models_test.go +++ b/backend/internal/handler/admin/account_handler_available_models_test.go @@ -3,10 +3,14 @@ package admin import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" + "strings" "testing" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -33,6 +37,39 @@ func setupAvailableModelsRouter(adminSvc service.AdminService) *gin.Engine { return router } +type syncUpstreamHTTPUpstream struct { + resp *http.Response + err error +} + +func (u *syncUpstreamHTTPUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *syncUpstreamHTTPUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +func setupSyncUpstreamModelsRouter(adminSvc service.AdminService, upstream service.HTTPUpstream) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountTestSvc := service.NewAccountTestService( + nil, + nil, + nil, + nil, + upstream, + &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + nil, + ) + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, accountTestSvc, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/:id/models/sync-upstream", handler.SyncUpstreamModels) + return router +} + func TestAccountHandlerGetAvailableModels_OpenAIOAuthUsesExplicitModelMapping(t *testing.T) { svc := &availableModelsAdminService{ stubAdminService: newStubAdminService(), @@ -103,3 +140,58 @@ func TestAccountHandlerGetAvailableModels_OpenAIOAuthPassthroughFallsBackToDefau require.NotEmpty(t, resp.Data) require.NotEqual(t, "gpt-5", resp.Data[0].ID) } + +func TestAccountHandlerSyncUpstreamModels_ConfigErrorReturnsBadRequest(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 44, + Name: "openai-apikey-missing-key", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Credentials: map[string]any{ + "base_url": "https://openai.example.com/v1", + }, + }, + } + router := setupSyncUpstreamModelsRouter(svc, &syncUpstreamHTTPUpstream{}) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/44/models/sync-upstream", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "No OpenAI API key is available") +} + +func TestAccountHandlerSyncUpstreamModels_UpstreamErrorDoesNotExposeBody(t *testing.T) { + svc := &availableModelsAdminService{ + stubAdminService: newStubAdminService(), + account: service.Account{ + ID: 45, + Name: "openai-apikey-upstream-error", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Credentials: map[string]any{ + "api_key": "openai-key", + "base_url": "https://openai.example.com/v1", + }, + }, + } + upstream := &syncUpstreamHTTPUpstream{resp: &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)), + }} + router := setupSyncUpstreamModelsRouter(svc, upstream) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/45/models/sync-upstream", nil) + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, rec.Body.String(), "Upstream model list request failed with HTTP 502") + require.NotContains(t, rec.Body.String(), "SECRET_TOKEN") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 460f63578ef..e9fbb630fad 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -546,9 +546,14 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } + // cacheKey 必须包含当日日期,否则跨午夜后 30s 内会复用昨天的 "today_*" 结果。 keyRaw, _ := json.Marshal(struct { + V int `json:"v"` + Day string `json:"day"` UserIDs []int64 `json:"user_ids"` }{ + V: 2, // bump 当响应结构变化(如加入 by_platform 时) + Day: timezone.Today().Format("2006-01-02"), UserIDs: userIDs, }) cacheKey := string(keyRaw) diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 24365f3da41..7b4300b1e21 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -8,6 +8,7 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -33,23 +34,51 @@ func NewRedeemHandler(adminService service.AdminService, redeemService *service. // GenerateRedeemCodesRequest represents generate redeem codes request type GenerateRedeemCodesRequest struct { - Count int `json:"count" binding:"required,min=1,max=100"` - Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` - Value float64 `json:"value"` - GroupID *int64 `json:"group_id"` // 订阅类型必填 - ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减 + Count int `json:"count" binding:"required,min=1,max=100"` + Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` + Value float64 `json:"value"` + GroupID *int64 `json:"group_id"` // 订阅类型必填 + ValidityDays int `json:"validity_days"` // 订阅类型使用,正数增加/负数退款扣减 + ExpiresAt *time.Time `json:"expires_at"` + ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"` } // CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. // Type 为 omitempty 而非 required 是为了向后兼容旧版调用方(不传 type 时默认 balance)。 type CreateAndRedeemCodeRequest struct { - Code string `json:"code" binding:"required,min=3,max=128"` - Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容) - Value float64 `json:"value" binding:"required"` - UserID int64 `json:"user_id" binding:"required,gt=0"` - GroupID *int64 `json:"group_id"` // subscription 类型必填 - ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减 - Notes string `json:"notes"` + Code string `json:"code" binding:"required,min=3,max=128"` + Type string `json:"type" binding:"omitempty,oneof=balance concurrency subscription invitation"` // 不传时默认 balance(向后兼容) + Value float64 `json:"value" binding:"required"` + UserID int64 `json:"user_id" binding:"required,gt=0"` + GroupID *int64 `json:"group_id"` // subscription 类型必填 + ValidityDays int `json:"validity_days"` // subscription 类型:正数增加,负数退款扣减 + Notes string `json:"notes"` + ExpiresAt *time.Time `json:"expires_at"` + ExpiresInDays *int `json:"expires_in_days" binding:"omitempty,min=1,max=3650"` +} + +func resolveRedeemCodeExpiresAt(expiresAt *time.Time, expiresInDays *int) (*time.Time, error) { + if expiresAt != nil && expiresInDays != nil { + return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRY_CONFLICT", "expires_at and expires_in_days cannot both be set") + } + + now := time.Now().UTC() + if expiresInDays != nil { + if *expiresInDays <= 0 { + return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_IN_DAYS_INVALID", "expires_in_days must be greater than zero") + } + expires := now.AddDate(0, 0, *expiresInDays) + return &expires, nil + } + if expiresAt == nil { + return nil, nil + } + + expires := expiresAt.UTC() + if !expires.After(now) { + return nil, infraerrors.BadRequest("REDEEM_CODE_EXPIRES_AT_INVALID", "expires_at must be in the future") + } + return &expires, nil } // List handles listing all redeem codes with pagination @@ -107,6 +136,12 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } + expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays) + if err != nil { + response.ErrorFrom(c, err) + return + } + executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{ Count: req.Count, @@ -114,6 +149,7 @@ func (h *RedeemHandler) Generate(c *gin.Context) { Value: req.Value, GroupID: req.GroupID, ValidityDays: req.ValidityDays, + ExpiresAt: expiresAt, }) if execErr != nil { return nil, execErr @@ -158,6 +194,12 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { } } + expiresAt, err := resolveRedeemCodeExpiresAt(req.ExpiresAt, req.ExpiresInDays) + if err != nil { + response.ErrorFrom(c, err) + return + } + executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { existing, err := h.redeemService.GetByCode(ctx, req.Code) if err == nil { @@ -175,6 +217,7 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { Notes: req.Notes, GroupID: req.GroupID, ValidityDays: req.ValidityDays, + ExpiresAt: expiresAt, }) if createErr != nil { // Unique code race: if code now exists, use idempotent semantics by used_by. @@ -199,6 +242,9 @@ func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, exis } // If previous run created the code but crashed before redeem, redeem it now. + if existing.IsExpired() { + return nil, service.ErrRedeemCodeExpired + } if existing.CanUse() { redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code) if err == nil { @@ -321,7 +367,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { writer := csv.NewWriter(&buf) // Write header - if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "created_at"}); err != nil { + if err := writer.Write([]string{"id", "code", "type", "value", "status", "used_by", "used_by_email", "used_at", "expires_at", "created_at"}); err != nil { response.InternalError(c, "Failed to export redeem codes: "+err.Error()) return } @@ -340,6 +386,10 @@ func (h *RedeemHandler) Export(c *gin.Context) { if code.UsedAt != nil { usedAt = code.UsedAt.Format("2006-01-02 15:04:05") } + expiresAt := "" + if code.ExpiresAt != nil { + expiresAt = code.ExpiresAt.Format("2006-01-02 15:04:05") + } if err := writer.Write([]string{ fmt.Sprintf("%d", code.ID), code.Code, @@ -349,6 +399,7 @@ func (h *RedeemHandler) Export(c *gin.Context) { usedBy, usedByEmail, usedAt, + expiresAt, code.CreatedAt.Format("2006-01-02 15:04:05"), }); err != nil { response.InternalError(c, "Failed to export redeem codes: "+err.Error()) diff --git a/backend/internal/handler/admin/redeem_handler_test.go b/backend/internal/handler/admin/redeem_handler_test.go index f1f7778f903..d6972460098 100644 --- a/backend/internal/handler/admin/redeem_handler_test.go +++ b/backend/internal/handler/admin/redeem_handler_test.go @@ -6,6 +6,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -139,3 +140,33 @@ func TestCreateAndRedeem_BalanceIgnoresSubscriptionFields(t *testing.T) { assert.NotEqual(t, http.StatusBadRequest, code, "balance type should not require group_id or validity_days") } + +func TestResolveRedeemCodeExpiresAt_FromDays(t *testing.T) { + days := 3 + expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days) + require.NoError(t, err) + require.NotNil(t, expiresAt) + require.WithinDuration(t, time.Now().UTC().AddDate(0, 0, days), *expiresAt, 2*time.Second) +} + +func TestResolveRedeemCodeExpiresAt_RejectsPastAbsoluteTime(t *testing.T) { + past := time.Now().UTC().Add(-time.Minute) + expiresAt, err := resolveRedeemCodeExpiresAt(&past, nil) + require.Error(t, err) + require.Nil(t, expiresAt) +} + +func TestResolveRedeemCodeExpiresAt_RejectsNonPositiveDays(t *testing.T) { + days := 0 + expiresAt, err := resolveRedeemCodeExpiresAt(nil, &days) + require.Error(t, err) + require.Nil(t, expiresAt) +} + +func TestResolveRedeemCodeExpiresAt_RejectsConflictingInputs(t *testing.T) { + future := time.Now().UTC().Add(time.Hour) + days := 3 + expiresAt, err := resolveRedeemCodeExpiresAt(&future, &days) + require.Error(t, err) + require.Nil(t, expiresAt) +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 0ea664d82c9..eaaae4713b1 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,9 +1,11 @@ package admin import ( + "context" "crypto/rand" "encoding/hex" "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -60,10 +62,11 @@ type SettingHandler struct { opsService *service.OpsService paymentConfigService *service.PaymentConfigService paymentService *service.PaymentService + userAttributeService *service.UserAttributeService } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, paymentConfigService *service.PaymentConfigService, paymentService *service.PaymentService, userAttributeService *service.UserAttributeService) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, @@ -71,6 +74,7 @@ func NewSettingHandler(settingService *service.SettingService, emailService *ser opsService: opsService, paymentConfigService: paymentConfigService, paymentService: paymentService, + userAttributeService: userAttributeService, } } @@ -135,6 +139,22 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { LinuxDoConnectClientID: settings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + DingTalkConnectEnabled: settings.DingTalkConnectEnabled, + DingTalkConnectClientID: settings.DingTalkConnectClientID, + DingTalkConnectClientSecretConfigured: settings.DingTalkConnectClientSecretConfigured, + DingTalkConnectRedirectURL: settings.DingTalkConnectRedirectURL, + DingTalkConnectCorpRestrictionPolicy: settings.DingTalkConnectCorpRestrictionPolicy, + DingTalkConnectInternalCorpID: settings.DingTalkConnectInternalCorpID, + DingTalkConnectBypassRegistration: settings.DingTalkConnectBypassRegistration, + DingTalkConnectSyncCorpEmail: settings.DingTalkConnectSyncCorpEmail, + DingTalkConnectSyncDisplayName: settings.DingTalkConnectSyncDisplayName, + DingTalkConnectSyncDept: settings.DingTalkConnectSyncDept, + DingTalkConnectSyncCorpEmailAttrKey: settings.DingTalkConnectSyncCorpEmailAttrKey, + DingTalkConnectSyncDisplayNameAttrKey: settings.DingTalkConnectSyncDisplayNameAttrKey, + DingTalkConnectSyncDeptAttrKey: settings.DingTalkConnectSyncDeptAttrKey, + DingTalkConnectSyncCorpEmailAttrName: settings.DingTalkConnectSyncCorpEmailAttrName, + DingTalkConnectSyncDisplayNameAttrName: settings.DingTalkConnectSyncDisplayNameAttrName, + DingTalkConnectSyncDeptAttrName: settings.DingTalkConnectSyncDeptAttrName, WeChatConnectEnabled: settings.WeChatConnectEnabled, WeChatConnectAppID: settings.WeChatConnectAppID, WeChatConnectAppSecretConfigured: settings.WeChatConnectAppSecretConfigured, @@ -376,6 +396,24 @@ type UpdateSettingsRequest struct { LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // DingTalk Connect OAuth 登录 + DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"` + DingTalkConnectClientID string `json:"dingtalk_connect_client_id"` + DingTalkConnectClientSecret string `json:"dingtalk_connect_client_secret"` + DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"` + DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"` + DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"` + DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"` + DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"` + DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"` + DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"` + DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"` + DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"` + DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"` + DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"` + DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"` + DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"` + // WeChat Connect OAuth 登录 WeChatConnectEnabled bool `json:"wechat_connect_enabled"` WeChatConnectAppID string `json:"wechat_connect_app_id"` @@ -446,45 +484,50 @@ type UpdateSettingsRequest struct { CustomEndpoints *[]dto.CustomEndpoint `json:"custom_endpoints"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` - AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"` - AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"` - AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"` - DefaultUserRPMLimit int `json:"default_user_rpm_limit"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` - AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` - AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` - AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` - AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` - AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` - AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` - AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` - AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` - AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` - AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` - AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` - AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` - AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` - AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` - AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` - AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` - AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` - AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` - AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` - AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` - AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"` - AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"` - AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"` - AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"` - AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"` - AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"` - AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"` - AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"` - AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"` - AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"` - ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` + AffiliateRebateRate *float64 `json:"affiliate_rebate_rate"` + AffiliateRebateFreezeHours *int `json:"affiliate_rebate_freeze_hours"` + AffiliateRebateDurationDays *int `json:"affiliate_rebate_duration_days"` + AffiliateRebatePerInviteeCap *float64 `json:"affiliate_rebate_per_invitee_cap"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` + DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` + AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` + AuthSourceDefaultEmailSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_email_subscriptions"` + AuthSourceDefaultEmailGrantOnSignup *bool `json:"auth_source_default_email_grant_on_signup"` + AuthSourceDefaultEmailGrantOnFirstBind *bool `json:"auth_source_default_email_grant_on_first_bind"` + AuthSourceDefaultLinuxDoBalance *float64 `json:"auth_source_default_linuxdo_balance"` + AuthSourceDefaultLinuxDoConcurrency *int `json:"auth_source_default_linuxdo_concurrency"` + AuthSourceDefaultLinuxDoSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_linuxdo_subscriptions"` + AuthSourceDefaultLinuxDoGrantOnSignup *bool `json:"auth_source_default_linuxdo_grant_on_signup"` + AuthSourceDefaultLinuxDoGrantOnFirstBind *bool `json:"auth_source_default_linuxdo_grant_on_first_bind"` + AuthSourceDefaultOIDCBalance *float64 `json:"auth_source_default_oidc_balance"` + AuthSourceDefaultOIDCConcurrency *int `json:"auth_source_default_oidc_concurrency"` + AuthSourceDefaultOIDCSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_oidc_subscriptions"` + AuthSourceDefaultOIDCGrantOnSignup *bool `json:"auth_source_default_oidc_grant_on_signup"` + AuthSourceDefaultOIDCGrantOnFirstBind *bool `json:"auth_source_default_oidc_grant_on_first_bind"` + AuthSourceDefaultWeChatBalance *float64 `json:"auth_source_default_wechat_balance"` + AuthSourceDefaultWeChatConcurrency *int `json:"auth_source_default_wechat_concurrency"` + AuthSourceDefaultWeChatSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_wechat_subscriptions"` + AuthSourceDefaultWeChatGrantOnSignup *bool `json:"auth_source_default_wechat_grant_on_signup"` + AuthSourceDefaultWeChatGrantOnFirstBind *bool `json:"auth_source_default_wechat_grant_on_first_bind"` + AuthSourceDefaultGitHubBalance *float64 `json:"auth_source_default_github_balance"` + AuthSourceDefaultGitHubConcurrency *int `json:"auth_source_default_github_concurrency"` + AuthSourceDefaultGitHubSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_github_subscriptions"` + AuthSourceDefaultGitHubGrantOnSignup *bool `json:"auth_source_default_github_grant_on_signup"` + AuthSourceDefaultGitHubGrantOnFirstBind *bool `json:"auth_source_default_github_grant_on_first_bind"` + AuthSourceDefaultGoogleBalance *float64 `json:"auth_source_default_google_balance"` + AuthSourceDefaultGoogleConcurrency *int `json:"auth_source_default_google_concurrency"` + AuthSourceDefaultGoogleSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_google_subscriptions"` + AuthSourceDefaultGoogleGrantOnSignup *bool `json:"auth_source_default_google_grant_on_signup"` + AuthSourceDefaultGoogleGrantOnFirstBind *bool `json:"auth_source_default_google_grant_on_first_bind"` + AuthSourceDefaultDingTalkBalance *float64 `json:"auth_source_default_dingtalk_balance"` + AuthSourceDefaultDingTalkConcurrency *int `json:"auth_source_default_dingtalk_concurrency"` + AuthSourceDefaultDingTalkSubscriptions *[]dto.DefaultSubscriptionSetting `json:"auth_source_default_dingtalk_subscriptions"` + AuthSourceDefaultDingTalkGrantOnSignup *bool `json:"auth_source_default_dingtalk_grant_on_signup"` + AuthSourceDefaultDingTalkGrantOnFirstBind *bool `json:"auth_source_default_dingtalk_grant_on_first_bind"` + ForceEmailOnThirdPartySignup *bool `json:"force_email_on_third_party_signup"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -661,6 +704,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { req.AuthSourceDefaultLinuxDoSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultLinuxDoSubscriptions) req.AuthSourceDefaultOIDCSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultOIDCSubscriptions) req.AuthSourceDefaultWeChatSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultWeChatSubscriptions) + req.AuthSourceDefaultDingTalkSubscriptions = normalizeOptionalDefaultSubscriptions(req.AuthSourceDefaultDingTalkSubscriptions) // SMTP 配置保护:如果请求中 smtp_host 为空但数据库中已有配置,则保留已有 SMTP 配置 // 防止前端加载设置失败时空表单覆盖已保存的 SMTP 配置 @@ -777,6 +821,100 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // DingTalk Connect 参数验证 + // 防御性:任何写入路径上把已废弃的 corp_restriction_policy=whitelist 入参 coerce 为 none, + // 避免任何直连 admin API 的客户端把死值写回 DB(前端 UI 已无此选项)。 + req.DingTalkConnectCorpRestrictionPolicy = service.CoerceDingTalkCorpPolicyForWrite(req.DingTalkConnectCorpRestrictionPolicy) + + if req.DingTalkConnectEnabled { + req.DingTalkConnectClientID = strings.TrimSpace(req.DingTalkConnectClientID) + req.DingTalkConnectClientSecret = strings.TrimSpace(req.DingTalkConnectClientSecret) + req.DingTalkConnectRedirectURL = strings.TrimSpace(req.DingTalkConnectRedirectURL) + req.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(req.DingTalkConnectCorpRestrictionPolicy) + req.DingTalkConnectInternalCorpID = strings.TrimSpace(req.DingTalkConnectInternalCorpID) + + if req.DingTalkConnectClientID == "" { + response.BadRequest(c, "DingTalk Client ID is required when enabled") + return + } + if req.DingTalkConnectRedirectURL == "" { + response.BadRequest(c, "DingTalk Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.DingTalkConnectRedirectURL); err != nil { + response.BadRequest(c, "DingTalk Redirect URL must be an absolute http(s) URL") + return + } + + // 如果未提供 client_secret,则保留现有值(如有)。 + if req.DingTalkConnectClientSecret == "" { + if previousSettings.DingTalkConnectClientSecret == "" { + response.BadRequest(c, "DingTalk Client Secret is required when enabled") + return + } + req.DingTalkConnectClientSecret = previousSettings.DingTalkConnectClientSecret + } + + // Corp 策略校验(V1/V4 fail-closed) + dingTalkCfg := config.DingTalkConnectConfig{ + Enabled: true, + DingTalkAppKind: "internal_app", // 硬编码:settings 层仅支持 internal_app + AppType: "internal", // 对于 internal_only 策略的默认值 + CorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy, + InternalCorpID: req.DingTalkConnectInternalCorpID, + } + // 若未填 corp_restriction_policy,保留已有配置 + if dingTalkCfg.CorpRestrictionPolicy == "" { + dingTalkCfg.CorpRestrictionPolicy = previousSettings.DingTalkConnectCorpRestrictionPolicy + } + // 对于 internal_only 策略,app_type 必须为 internal(V1 校验) + if dingTalkCfg.CorpRestrictionPolicy == "internal_only" { + dingTalkCfg.AppType = "internal" + } else { + dingTalkCfg.AppType = "public" + } + if err := config.ValidateDingTalkConfig(dingTalkCfg); err != nil { + response.ErrorWithDetails(c, http.StatusBadRequest, err.Error(), mapDingTalkValidateError(err), nil) + return + } + + // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制为 false, + // 防止 admin 在切换 policy 时把 bypass 残留在 DB 中(前端 UI 也已隐藏该开关)。 + if dingTalkCfg.CorpRestrictionPolicy != "internal_only" { + req.DingTalkConnectBypassRegistration = false + // 身份同步三开关同理:仅 internal_only 模式下有意义,其它策略强制 false。 + req.DingTalkConnectSyncCorpEmail = false + req.DingTalkConnectSyncDisplayName = false + req.DingTalkConnectSyncDept = false + } + // 身份同步目标 attr key:trimSpace + 空值 fallback 到默认值 + req.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrKey) + if req.DingTalkConnectSyncCorpEmailAttrKey == "" { + req.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email" + } + req.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrKey) + if req.DingTalkConnectSyncDisplayNameAttrKey == "" { + req.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name" + } + req.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrKey) + if req.DingTalkConnectSyncDeptAttrKey == "" { + req.DingTalkConnectSyncDeptAttrKey = "dingtalk_department" + } + // 身份同步目标 attr 显示名称:trim + 空值 fallback 到默认中文名 + req.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(req.DingTalkConnectSyncCorpEmailAttrName) + if req.DingTalkConnectSyncCorpEmailAttrName == "" { + req.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱" + } + req.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(req.DingTalkConnectSyncDisplayNameAttrName) + if req.DingTalkConnectSyncDisplayNameAttrName == "" { + req.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名" + } + req.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(req.DingTalkConnectSyncDeptAttrName) + if req.DingTalkConnectSyncDeptAttrName == "" { + req.DingTalkConnectSyncDeptAttrName = "钉钉部门" + } + } + if req.WeChatConnectEnabled { req.WeChatConnectAppID = strings.TrimSpace(req.WeChatConnectAppID) req.WeChatConnectAppSecret = strings.TrimSpace(req.WeChatConnectAppSecret) @@ -1272,113 +1410,129 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, - PromoCodeEnabled: req.PromoCodeEnabled, - PasswordResetEnabled: req.PasswordResetEnabled, - FrontendURL: req.FrontendURL, - InvitationCodeEnabled: req.InvitationCodeEnabled, - TotpEnabled: req.TotpEnabled, - LoginAgreementEnabled: req.LoginAgreementEnabled, - LoginAgreementMode: loginAgreementMode, - LoginAgreementUpdatedAt: loginAgreementUpdatedAt, - LoginAgreementDocuments: loginAgreementDocuments, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, - LinuxDoConnectClientID: req.LinuxDoConnectClientID, - LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, - LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, - WeChatConnectEnabled: req.WeChatConnectEnabled, - WeChatConnectAppID: req.WeChatConnectAppID, - WeChatConnectAppSecret: req.WeChatConnectAppSecret, - WeChatConnectOpenAppID: req.WeChatConnectOpenAppID, - WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret, - WeChatConnectMPAppID: req.WeChatConnectMPAppID, - WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret, - WeChatConnectMobileAppID: req.WeChatConnectMobileAppID, - WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret, - WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled, - WeChatConnectMPEnabled: req.WeChatConnectMPEnabled, - WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled, - WeChatConnectMode: req.WeChatConnectMode, - WeChatConnectScopes: req.WeChatConnectScopes, - WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, - WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, - OIDCConnectEnabled: req.OIDCConnectEnabled, - OIDCConnectProviderName: req.OIDCConnectProviderName, - OIDCConnectClientID: req.OIDCConnectClientID, - OIDCConnectClientSecret: req.OIDCConnectClientSecret, - OIDCConnectIssuerURL: req.OIDCConnectIssuerURL, - OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL, - OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL, - OIDCConnectTokenURL: req.OIDCConnectTokenURL, - OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL, - OIDCConnectJWKSURL: req.OIDCConnectJWKSURL, - OIDCConnectScopes: req.OIDCConnectScopes, - OIDCConnectRedirectURL: req.OIDCConnectRedirectURL, - OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL, - OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod, - OIDCConnectUsePKCE: oidcUsePKCE, - OIDCConnectValidateIDToken: oidcValidateIDToken, - OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs, - OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds, - OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified, - OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath, - OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath, - OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath, - GitHubOAuthEnabled: req.GitHubOAuthEnabled, - GitHubOAuthClientID: req.GitHubOAuthClientID, - GitHubOAuthClientSecret: req.GitHubOAuthClientSecret, - GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL, - GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL, - GoogleOAuthEnabled: req.GoogleOAuthEnabled, - GoogleOAuthClientID: req.GoogleOAuthClientID, - GoogleOAuthClientSecret: req.GoogleOAuthClientSecret, - GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL, - GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - HomeContent: req.HomeContent, - HideCcsImportButton: req.HideCcsImportButton, - PurchaseSubscriptionEnabled: purchaseEnabled, - PurchaseSubscriptionURL: purchaseURL, - TableDefaultPageSize: req.TableDefaultPageSize, - TablePageSizeOptions: req.TablePageSizeOptions, - CustomMenuItems: customMenuJSON, - CustomEndpoints: customEndpointsJSON, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - AffiliateRebateRate: affiliateRebateRate, - AffiliateRebateFreezeHours: affiliateRebateFreezeHours, - AffiliateRebateDurationDays: affiliateRebateDurationDays, - AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap, - DefaultUserRPMLimit: req.DefaultUserRPMLimit, - DefaultSubscriptions: defaultSubscriptions, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, - MinClaudeCodeVersion: req.MinClaudeCodeVersion, - MaxClaudeCodeVersion: req.MaxClaudeCodeVersion, - AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, - BackendModeEnabled: req.BackendModeEnabled, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + RegistrationEmailSuffixWhitelist: req.RegistrationEmailSuffixWhitelist, + PromoCodeEnabled: req.PromoCodeEnabled, + PasswordResetEnabled: req.PasswordResetEnabled, + FrontendURL: req.FrontendURL, + InvitationCodeEnabled: req.InvitationCodeEnabled, + TotpEnabled: req.TotpEnabled, + LoginAgreementEnabled: req.LoginAgreementEnabled, + LoginAgreementMode: loginAgreementMode, + LoginAgreementUpdatedAt: loginAgreementUpdatedAt, + LoginAgreementDocuments: loginAgreementDocuments, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + DingTalkConnectEnabled: req.DingTalkConnectEnabled, + DingTalkConnectClientID: req.DingTalkConnectClientID, + DingTalkConnectClientSecret: req.DingTalkConnectClientSecret, + DingTalkConnectRedirectURL: req.DingTalkConnectRedirectURL, + DingTalkConnectCorpRestrictionPolicy: req.DingTalkConnectCorpRestrictionPolicy, + DingTalkConnectInternalCorpID: req.DingTalkConnectInternalCorpID, + DingTalkConnectBypassRegistration: req.DingTalkConnectBypassRegistration, + DingTalkConnectSyncCorpEmail: req.DingTalkConnectSyncCorpEmail, + DingTalkConnectSyncDisplayName: req.DingTalkConnectSyncDisplayName, + DingTalkConnectSyncDept: req.DingTalkConnectSyncDept, + DingTalkConnectSyncCorpEmailAttrKey: req.DingTalkConnectSyncCorpEmailAttrKey, + DingTalkConnectSyncDisplayNameAttrKey: req.DingTalkConnectSyncDisplayNameAttrKey, + DingTalkConnectSyncDeptAttrKey: req.DingTalkConnectSyncDeptAttrKey, + DingTalkConnectSyncCorpEmailAttrName: req.DingTalkConnectSyncCorpEmailAttrName, + DingTalkConnectSyncDisplayNameAttrName: req.DingTalkConnectSyncDisplayNameAttrName, + DingTalkConnectSyncDeptAttrName: req.DingTalkConnectSyncDeptAttrName, + WeChatConnectEnabled: req.WeChatConnectEnabled, + WeChatConnectAppID: req.WeChatConnectAppID, + WeChatConnectAppSecret: req.WeChatConnectAppSecret, + WeChatConnectOpenAppID: req.WeChatConnectOpenAppID, + WeChatConnectOpenAppSecret: req.WeChatConnectOpenAppSecret, + WeChatConnectMPAppID: req.WeChatConnectMPAppID, + WeChatConnectMPAppSecret: req.WeChatConnectMPAppSecret, + WeChatConnectMobileAppID: req.WeChatConnectMobileAppID, + WeChatConnectMobileAppSecret: req.WeChatConnectMobileAppSecret, + WeChatConnectOpenEnabled: req.WeChatConnectOpenEnabled, + WeChatConnectMPEnabled: req.WeChatConnectMPEnabled, + WeChatConnectMobileEnabled: req.WeChatConnectMobileEnabled, + WeChatConnectMode: req.WeChatConnectMode, + WeChatConnectScopes: req.WeChatConnectScopes, + WeChatConnectRedirectURL: req.WeChatConnectRedirectURL, + WeChatConnectFrontendRedirectURL: req.WeChatConnectFrontendRedirectURL, + OIDCConnectEnabled: req.OIDCConnectEnabled, + OIDCConnectProviderName: req.OIDCConnectProviderName, + OIDCConnectClientID: req.OIDCConnectClientID, + OIDCConnectClientSecret: req.OIDCConnectClientSecret, + OIDCConnectIssuerURL: req.OIDCConnectIssuerURL, + OIDCConnectDiscoveryURL: req.OIDCConnectDiscoveryURL, + OIDCConnectAuthorizeURL: req.OIDCConnectAuthorizeURL, + OIDCConnectTokenURL: req.OIDCConnectTokenURL, + OIDCConnectUserInfoURL: req.OIDCConnectUserInfoURL, + OIDCConnectJWKSURL: req.OIDCConnectJWKSURL, + OIDCConnectScopes: req.OIDCConnectScopes, + OIDCConnectRedirectURL: req.OIDCConnectRedirectURL, + OIDCConnectFrontendRedirectURL: req.OIDCConnectFrontendRedirectURL, + OIDCConnectTokenAuthMethod: req.OIDCConnectTokenAuthMethod, + OIDCConnectUsePKCE: oidcUsePKCE, + OIDCConnectValidateIDToken: oidcValidateIDToken, + OIDCConnectAllowedSigningAlgs: req.OIDCConnectAllowedSigningAlgs, + OIDCConnectClockSkewSeconds: req.OIDCConnectClockSkewSeconds, + OIDCConnectRequireEmailVerified: req.OIDCConnectRequireEmailVerified, + OIDCConnectUserInfoEmailPath: req.OIDCConnectUserInfoEmailPath, + OIDCConnectUserInfoIDPath: req.OIDCConnectUserInfoIDPath, + OIDCConnectUserInfoUsernamePath: req.OIDCConnectUserInfoUsernamePath, + GitHubOAuthEnabled: req.GitHubOAuthEnabled, + GitHubOAuthClientID: req.GitHubOAuthClientID, + GitHubOAuthClientSecret: req.GitHubOAuthClientSecret, + GitHubOAuthRedirectURL: req.GitHubOAuthRedirectURL, + GitHubOAuthFrontendRedirectURL: req.GitHubOAuthFrontendRedirectURL, + GoogleOAuthEnabled: req.GoogleOAuthEnabled, + GoogleOAuthClientID: req.GoogleOAuthClientID, + GoogleOAuthClientSecret: req.GoogleOAuthClientSecret, + GoogleOAuthRedirectURL: req.GoogleOAuthRedirectURL, + GoogleOAuthFrontendRedirectURL: req.GoogleOAuthFrontendRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, + PurchaseSubscriptionEnabled: purchaseEnabled, + PurchaseSubscriptionURL: purchaseURL, + TableDefaultPageSize: req.TableDefaultPageSize, + TablePageSizeOptions: req.TablePageSizeOptions, + CustomMenuItems: customMenuJSON, + CustomEndpoints: customEndpointsJSON, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + AffiliateRebateRate: affiliateRebateRate, + AffiliateRebateFreezeHours: affiliateRebateFreezeHours, + AffiliateRebateDurationDays: affiliateRebateDurationDays, + AffiliateRebatePerInviteeCap: affiliateRebatePerInviteeCap, + DefaultUserRPMLimit: req.DefaultUserRPMLimit, + DefaultSubscriptions: defaultSubscriptions, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, + MinClaudeCodeVersion: req.MinClaudeCodeVersion, + MaxClaudeCodeVersion: req.MaxClaudeCodeVersion, + AllowUngroupedKeyScheduling: req.AllowUngroupedKeyScheduling, + BackendModeEnabled: req.BackendModeEnabled, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -1574,6 +1728,13 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnSignup, previousAuthSourceDefaults.Google.GrantOnSignup), GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultGoogleGrantOnFirstBind, previousAuthSourceDefaults.Google.GrantOnFirstBind), }, + DingTalk: service.ProviderDefaultGrantSettings{ + Balance: float64ValueOrDefault(req.AuthSourceDefaultDingTalkBalance, previousAuthSourceDefaults.DingTalk.Balance), + Concurrency: intValueOrDefault(req.AuthSourceDefaultDingTalkConcurrency, previousAuthSourceDefaults.DingTalk.Concurrency), + Subscriptions: defaultSubscriptionsValueOrDefault(req.AuthSourceDefaultDingTalkSubscriptions, previousAuthSourceDefaults.DingTalk.Subscriptions), + GrantOnSignup: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnSignup, previousAuthSourceDefaults.DingTalk.GrantOnSignup), + GrantOnFirstBind: boolValueOrDefault(req.AuthSourceDefaultDingTalkGrantOnFirstBind, previousAuthSourceDefaults.DingTalk.GrantOnFirstBind), + }, ForceEmailOnThirdPartySignup: boolValueOrDefault(req.ForceEmailOnThirdPartySignup, previousAuthSourceDefaults.ForceEmailOnThirdPartySignup), } if err := h.settingService.UpdateSettingsWithAuthSourceDefaults(c.Request.Context(), settings, authSourceDefaults); err != nil { @@ -1632,6 +1793,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } + h.ensureDingTalkSyncAttributes(c.Request.Context(), updatedSettings) updatedAuthSourceDefaults, err := h.settingService.GetAuthSourceDefaultSettings(c.Request.Context()) if err != nil { response.ErrorFrom(c, err) @@ -1682,6 +1844,22 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + DingTalkConnectEnabled: updatedSettings.DingTalkConnectEnabled, + DingTalkConnectClientID: updatedSettings.DingTalkConnectClientID, + DingTalkConnectClientSecretConfigured: updatedSettings.DingTalkConnectClientSecretConfigured, + DingTalkConnectRedirectURL: updatedSettings.DingTalkConnectRedirectURL, + DingTalkConnectCorpRestrictionPolicy: updatedSettings.DingTalkConnectCorpRestrictionPolicy, + DingTalkConnectInternalCorpID: updatedSettings.DingTalkConnectInternalCorpID, + DingTalkConnectBypassRegistration: updatedSettings.DingTalkConnectBypassRegistration, + DingTalkConnectSyncCorpEmail: updatedSettings.DingTalkConnectSyncCorpEmail, + DingTalkConnectSyncDisplayName: updatedSettings.DingTalkConnectSyncDisplayName, + DingTalkConnectSyncDept: updatedSettings.DingTalkConnectSyncDept, + DingTalkConnectSyncCorpEmailAttrKey: updatedSettings.DingTalkConnectSyncCorpEmailAttrKey, + DingTalkConnectSyncDisplayNameAttrKey: updatedSettings.DingTalkConnectSyncDisplayNameAttrKey, + DingTalkConnectSyncDeptAttrKey: updatedSettings.DingTalkConnectSyncDeptAttrKey, + DingTalkConnectSyncCorpEmailAttrName: updatedSettings.DingTalkConnectSyncCorpEmailAttrName, + DingTalkConnectSyncDisplayNameAttrName: updatedSettings.DingTalkConnectSyncDisplayNameAttrName, + DingTalkConnectSyncDeptAttrName: updatedSettings.DingTalkConnectSyncDeptAttrName, WeChatConnectEnabled: updatedSettings.WeChatConnectEnabled, WeChatConnectAppID: updatedSettings.WeChatConnectAppID, WeChatConnectAppSecretConfigured: updatedSettings.WeChatConnectAppSecretConfigured, @@ -1822,6 +2000,18 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } // hasPaymentFields returns true if any payment-related field was explicitly provided. +// mapDingTalkValidateError maps ValidateDingTalkConfig errors to machine-readable reason codes. +func mapDingTalkValidateError(err error) string { + switch { + case errors.Is(err, config.ErrDingTalkV1AppTypeMismatch): + return "dingtalk_apptype_mismatch" + case errors.Is(err, config.ErrDingTalkV4InvalidAppKind): + return "dingtalk_app_kind_invalid" + default: + return "dingtalk_corp_config_invalid" + } +} + func hasPaymentFields(req UpdateSettingsRequest) bool { return req.PaymentEnabled != nil || req.PaymentMinAmount != nil || req.PaymentMaxAmount != nil || req.PaymentDailyLimit != nil || @@ -1935,6 +2125,45 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { changed = append(changed, "linuxdo_connect_redirect_url") } + if before.DingTalkConnectEnabled != after.DingTalkConnectEnabled { + changed = append(changed, "dingtalk_connect_enabled") + } + if before.DingTalkConnectClientID != after.DingTalkConnectClientID { + changed = append(changed, "dingtalk_connect_client_id") + } + if req.DingTalkConnectClientSecret != "" { + changed = append(changed, "dingtalk_connect_client_secret") + } + if before.DingTalkConnectRedirectURL != after.DingTalkConnectRedirectURL { + changed = append(changed, "dingtalk_connect_redirect_url") + } + if before.DingTalkConnectCorpRestrictionPolicy != after.DingTalkConnectCorpRestrictionPolicy { + changed = append(changed, "dingtalk_connect_corp_restriction_policy") + } + if before.DingTalkConnectInternalCorpID != after.DingTalkConnectInternalCorpID { + changed = append(changed, "dingtalk_connect_internal_corp_id") + } + if before.DingTalkConnectBypassRegistration != after.DingTalkConnectBypassRegistration { + changed = append(changed, "dingtalk_connect_bypass_registration") + } + if before.DingTalkConnectSyncCorpEmail != after.DingTalkConnectSyncCorpEmail { + changed = append(changed, "dingtalk_connect_sync_corp_email") + } + if before.DingTalkConnectSyncDisplayName != after.DingTalkConnectSyncDisplayName { + changed = append(changed, "dingtalk_connect_sync_display_name") + } + if before.DingTalkConnectSyncDept != after.DingTalkConnectSyncDept { + changed = append(changed, "dingtalk_connect_sync_dept") + } + if before.DingTalkConnectSyncCorpEmailAttrKey != after.DingTalkConnectSyncCorpEmailAttrKey { + changed = append(changed, "dingtalk_connect_sync_corp_email_attr_key") + } + if before.DingTalkConnectSyncDisplayNameAttrKey != after.DingTalkConnectSyncDisplayNameAttrKey { + changed = append(changed, "dingtalk_connect_sync_display_name_attr_key") + } + if before.DingTalkConnectSyncDeptAttrKey != after.DingTalkConnectSyncDeptAttrKey { + changed = append(changed, "dingtalk_connect_sync_dept_attr_key") + } if before.WeChatConnectEnabled != after.WeChatConnectEnabled { changed = append(changed, "wechat_connect_enabled") } @@ -2246,6 +2475,7 @@ func appendAuthSourceDefaultChanges(changed []string, before *service.AuthSource {name: "wechat", before: before.WeChat, after: after.WeChat}, {name: "github", before: before.GitHub, after: after.GitHub}, {name: "google", before: before.Google, after: after.Google}, + {name: "dingtalk", before: before.DingTalk, after: after.DingTalk}, } for _, field := range fields { if field.before.Balance != field.after.Balance { @@ -2350,6 +2580,11 @@ func systemSettingsResponseData(settings dto.SystemSettings, authSourceDefaults data["auth_source_default_linuxdo_subscriptions"] = authSourceDefaults.LinuxDo.Subscriptions data["auth_source_default_linuxdo_grant_on_signup"] = authSourceDefaults.LinuxDo.GrantOnSignup data["auth_source_default_linuxdo_grant_on_first_bind"] = authSourceDefaults.LinuxDo.GrantOnFirstBind + data["auth_source_default_dingtalk_balance"] = authSourceDefaults.DingTalk.Balance + data["auth_source_default_dingtalk_concurrency"] = authSourceDefaults.DingTalk.Concurrency + data["auth_source_default_dingtalk_subscriptions"] = authSourceDefaults.DingTalk.Subscriptions + data["auth_source_default_dingtalk_grant_on_signup"] = authSourceDefaults.DingTalk.GrantOnSignup + data["auth_source_default_dingtalk_grant_on_first_bind"] = authSourceDefaults.DingTalk.GrantOnFirstBind data["auth_source_default_oidc_balance"] = authSourceDefaults.OIDC.Balance data["auth_source_default_oidc_concurrency"] = authSourceDefaults.OIDC.Concurrency data["auth_source_default_oidc_subscriptions"] = authSourceDefaults.OIDC.Subscriptions @@ -3044,3 +3279,56 @@ func (h *SettingHandler) TestWebSearchEmulation(c *gin.Context) { } response.Success(c, result) } + +// ensureDingTalkSyncAttributes 在保存 settings 后,按 admin 配置的 (attr key, attr name) +// 兜底 upsert 对应 user attribute definition:不存在则创建;存在但 name 不同则更新 name +// (type/options/required 不变)。仅 internal_only + 对应 sync 开关开启时执行。 +// 失败仅记录日志,不阻塞 settings 保存。 +func (h *SettingHandler) ensureDingTalkSyncAttributes(ctx context.Context, settings *service.SystemSettings) { + if h.userAttributeService == nil || settings == nil { + return + } + if settings.DingTalkConnectCorpRestrictionPolicy != "internal_only" { + return + } + if settings.DingTalkConnectSyncDisplayName { + h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDisplayNameAttrKey, settings.DingTalkConnectSyncDisplayNameAttrName, "钉钉 internal_only 登录时同步的钉钉姓名", service.AttributeTypeText) + } + if settings.DingTalkConnectSyncCorpEmail { + h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncCorpEmailAttrKey, settings.DingTalkConnectSyncCorpEmailAttrName, "钉钉 internal_only 登录时同步的企业邮箱", service.AttributeTypeEmail) + } + if settings.DingTalkConnectSyncDept { + h.ensureUserAttributeDefinition(ctx, settings.DingTalkConnectSyncDeptAttrKey, settings.DingTalkConnectSyncDeptAttrName, "钉钉 internal_only 登录时同步的完整部门路径(如:公司/研发部)", service.AttributeTypeText) + } +} + +func (h *SettingHandler) ensureUserAttributeDefinition(ctx context.Context, key, name, description string, attrType service.UserAttributeType) { + key = strings.TrimSpace(key) + if key == "" { + return + } + existing, err := h.userAttributeService.GetDefinitionByKey(ctx, key) + if err == nil && existing != nil { + if strings.TrimSpace(name) != "" && existing.Name != name { + if _, err := h.userAttributeService.UpdateDefinition(ctx, existing.ID, service.UpdateAttributeDefinitionInput{ + Name: &name, + }); err != nil { + slog.Warn("dingtalk: update user attribute definition name failed", "key", key, "err", err.Error()) + return + } + slog.Info("dingtalk: updated user attribute definition name", "key", key, "name", name) + } + return + } + if _, err := h.userAttributeService.CreateDefinition(ctx, service.CreateAttributeDefinitionInput{ + Key: key, + Name: name, + Description: description, + Type: attrType, + Enabled: true, + }); err != nil { + slog.Warn("dingtalk: ensure user attribute definition failed", "key", key, "err", err.Error()) + return + } + slog.Info("dingtalk: created user attribute definition", "key", key, "name", name, "type", attrType) +} diff --git a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go index 085fd2ca788..f953f76760f 100644 --- a/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go +++ b/backend/internal/handler/admin/setting_handler_auth_source_defaults_test.go @@ -137,7 +137,7 @@ func TestSettingHandler_GetSettings_InjectsAuthSourceDefaults(t *testing.T) { }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) rec := httptest.NewRecorder() c, _ := gin.CreateTestContext(rec) @@ -174,7 +174,7 @@ func TestSettingHandler_UpdateSettings_PreservesOmittedAuthSourceDefaults(t *tes }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "registration_enabled": true, @@ -214,7 +214,7 @@ func TestSettingHandler_UpdateSettings_PersistsPaymentVisibleMethodsAndAdvancedS }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "promo_code_enabled": true, @@ -264,7 +264,7 @@ func TestSettingHandler_UpdateSettings_PreservesLegacyBlankPaymentVisibleMethodS }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "promo_code_enabled": false, @@ -309,7 +309,7 @@ func TestSettingHandler_UpdateSettings_PersistsExplicitFalseOIDCCompatibilityFla }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "promo_code_enabled": true, @@ -388,7 +388,7 @@ func TestSettingHandler_UpdateSettings_DoesNotSolidifyImplicitOIDCSecurityDefaul ClockSkewSeconds: 120, }, }) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "promo_code_enabled": true, @@ -417,7 +417,7 @@ func TestSettingHandler_UpdateSettings_RejectsInvalidPaymentVisibleMethodSource( }, } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "promo_code_enabled": true, @@ -450,7 +450,7 @@ func TestSettingHandler_UpdateSettings_DoesNotPersistPartialSystemSettingsWhenAu err: errors.New("write auth source defaults failed"), } svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) - handler := NewSettingHandler(svc, nil, nil, nil, nil, nil) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) body := map[string]any{ "registration_enabled": true, diff --git a/backend/internal/handler/admin/setting_handler_dingtalk_test.go b/backend/internal/handler/admin/setting_handler_dingtalk_test.go new file mode 100644 index 00000000000..a3d944ccf0e --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_dingtalk_test.go @@ -0,0 +1,319 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// dingtalkSettingsRepoStub 复用 settingHandlerRepoStub(已在 setting_handler_auth_source_defaults_test.go 定义) + +func newDingTalkSettingsHandler() (*SettingHandler, *settingHandlerRepoStub) { + repo := &settingHandlerRepoStub{values: map[string]string{}} + svc := service.NewSettingService(repo, &config.Config{Default: config.DefaultConfig{UserConcurrency: 5}}) + handler := NewSettingHandler(svc, nil, nil, nil, nil, nil, nil) + return handler, repo +} + +// baseValidDingTalkBody 返回一个可以通过所有校验的最小合法 body。 +func baseValidDingTalkBody() map[string]any { + return map[string]any{ + "dingtalk_connect_enabled": true, + "dingtalk_connect_client_id": "test-client-id", + "dingtalk_connect_client_secret": "test-client-secret", + "dingtalk_connect_redirect_url": "https://example.com/auth/dingtalk/callback", + "dingtalk_connect_corp_restriction_policy": "none", + } +} + +// TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID 验证方案 A: +// internal_only + internal_corp_id="" 应通过校验(→ 200),不再是 400。 +func TestSettingsPUT_DingTalk_V3_InternalOnlyAllowsEmptyCorpID(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + body["dingtalk_connect_internal_corp_id"] = "" // 空值现在合法 + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) +} + +// TestSettingsPUT_DingTalk_HappyPath_None 验证 none policy → 200 +func TestSettingsPUT_DingTalk_HappyPath_None(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "none" + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["dingtalk_connect_enabled"]) +} + +// TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID 验证 internal_only + corp_id → 200 +func TestSettingsPUT_DingTalk_HappyPath_InternalOnly_WithCorpID(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + body["dingtalk_connect_internal_corp_id"] = "ding-corp-123" + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) +} + +// TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip 验证 bypass_registration 字段 save+load。 +// 必须用 policy=internal_only:bypass 仅在该 policy 下生效,其它 policy 写入层会 coerce 为 false。 +func TestSettingsPUT_DingTalk_BypassRegistration_RoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + body["dingtalk_connect_bypass_registration"] = true + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["dingtalk_connect_bypass_registration"]) +} + +// TestSettingsPUT_DingTalk_Disabled_SkipsValidation 验证 disabled 时跳过 corp 校验 → 200。 +// 用 enabled=true 时必然触发"Client ID is required when enabled"的空 client_id 作为 +// 哨兵——只要 enabled=false 仍能 200 就证明跳过了。 +func TestSettingsPUT_DingTalk_Disabled_SkipsValidation(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := map[string]any{ + "dingtalk_connect_enabled": false, + "dingtalk_connect_client_id": "", // 这种空值在 enabled=true 时会被 400 拒绝 + "dingtalk_connect_corp_restriction_policy": "internal_only", + } + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) +} + +// TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip 验证三个 sync 开关在 internal_only 下可正常 save+load。 +func TestSettingsPUT_DingTalk_SyncFlags_InternalOnly_RoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + body["dingtalk_connect_sync_corp_email"] = true + body["dingtalk_connect_sync_display_name"] = true + body["dingtalk_connect_sync_dept"] = true + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, true, data["dingtalk_connect_sync_corp_email"], "sync_corp_email should be true for internal_only") + require.Equal(t, true, data["dingtalk_connect_sync_display_name"], "sync_display_name should be true for internal_only") + require.Equal(t, true, data["dingtalk_connect_sync_dept"], "sync_dept should be true for internal_only") +} + +// TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse 验证 policy=none 时三个 sync 开关被 coerce 为 false。 +func TestSettingsPUT_DingTalk_SyncFlags_PolicyNone_CoercedToFalse(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, _ := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "none" + body["dingtalk_connect_sync_corp_email"] = true + body["dingtalk_connect_sync_display_name"] = true + body["dingtalk_connect_sync_dept"] = true + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, false, data["dingtalk_connect_sync_corp_email"], "sync_corp_email must be coerced to false when policy=none") + require.Equal(t, false, data["dingtalk_connect_sync_display_name"], "sync_display_name must be coerced to false when policy=none") + require.Equal(t, false, data["dingtalk_connect_sync_dept"], "sync_dept must be coerced to false when policy=none") +} + +// TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone 验证升级兼容: +// admin 直接把 corp_restriction_policy=whitelist 提交(前端 UI 已无此选项,但 API 仍可命中) +// 不应导致 400 失败,应该被静默 coerce 为 none 后通过校验。 +func TestSettingsPUT_DingTalk_StaleWhitelist_CoercedToNone(t *testing.T) { + gin.SetMode(gin.TestMode) + handler, repo := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "whitelist" + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "none", repo.values[service.SettingKeyDingTalkConnectCorpRestrictionPolicy], + "stale whitelist 应在写入路径被 coerce 为 none") +} + +// TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip 验证 3 个 attr key 字段 save+load + 空值 fallback 到默认值。 +func TestSettingsPUT_DingTalk_SyncAttrKey_RoundTrip(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("custom_attr_keys_saved", func(t *testing.T) { + handler, repo := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + body["dingtalk_connect_sync_corp_email"] = true + body["dingtalk_connect_sync_display_name"] = true + body["dingtalk_connect_sync_dept"] = true + body["dingtalk_connect_sync_corp_email_attr_key"] = "my_email_attr" + body["dingtalk_connect_sync_display_name_attr_key"] = "my_name_attr" + body["dingtalk_connect_sync_dept_attr_key"] = "my_dept_attr" + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + + // 验证写入 DB 的 key + require.Equal(t, "my_email_attr", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey]) + require.Equal(t, "my_name_attr", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey]) + require.Equal(t, "my_dept_attr", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey]) + + // 验证响应中的 attr key + var resp response.Response + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + data, ok := resp.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, "my_email_attr", data["dingtalk_connect_sync_corp_email_attr_key"]) + require.Equal(t, "my_name_attr", data["dingtalk_connect_sync_display_name_attr_key"]) + require.Equal(t, "my_dept_attr", data["dingtalk_connect_sync_dept_attr_key"]) + }) + + t.Run("empty_attr_keys_fallback_to_defaults", func(t *testing.T) { + handler, repo := newDingTalkSettingsHandler() + + body := baseValidDingTalkBody() + body["dingtalk_connect_corp_restriction_policy"] = "internal_only" + // 不传 attr key → 写入层 fallback 到默认值 + body["dingtalk_connect_sync_corp_email_attr_key"] = "" + body["dingtalk_connect_sync_display_name_attr_key"] = "" + body["dingtalk_connect_sync_dept_attr_key"] = "" + + rawBody, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPut, "/api/v1/admin/settings", bytes.NewReader(rawBody)) + c.Request.Header.Set("Content-Type", "application/json") + + handler.UpdateSettings(c) + + require.Equal(t, http.StatusOK, rec.Code) + + // 空值应 fallback 到默认值并持久化 + require.Equal(t, "dingtalk_email", repo.values[service.SettingKeyDingTalkConnectSyncCorpEmailAttrKey]) + require.Equal(t, "dingtalk_name", repo.values[service.SettingKeyDingTalkConnectSyncDisplayNameAttrKey]) + require.Equal(t, "dingtalk_department", repo.values[service.SettingKeyDingTalkConnectSyncDeptAttrKey]) + }) +} diff --git a/backend/internal/handler/auth_dingtalk_client.go b/backend/internal/handler/auth_dingtalk_client.go new file mode 100644 index 00000000000..2db07d05ddf --- /dev/null +++ b/backend/internal/handler/auth_dingtalk_client.go @@ -0,0 +1,398 @@ +package handler + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// dingTalkClientConfig 是 DingTalkClient 需要的最小配置子集 +type dingTalkClientConfig struct { + ClientID string + ClientSecret string + TokenURL string + UserInfoURL string +} + +type DingTalkClient struct { + cfg dingTalkClientConfig + appToken string + appTokenExp time.Time // 钉钉 7200s,留 200s 余量 → 7000s + mu sync.Mutex + httpClient *http.Client + // TODO(multi-instance): Redis 集中缓存 appToken +} + +type DingTalkUserTokenResp struct { + AccessToken string `json:"accessToken"` + RefreshToken string `json:"refreshToken"` + ExpireIn int64 `json:"expireIn"` + CorpID string `json:"corpId"` +} + +func (c *DingTalkClient) ExchangeCodeForUserToken(ctx context.Context, code string) (*DingTalkUserTokenResp, error) { + body := map[string]string{ + "clientId": c.cfg.ClientID, + "clientSecret": c.cfg.ClientSecret, + "code": code, + "grantType": "authorization_code", + } + payload, _ := json.Marshal(body) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.TokenURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + var out DingTalkUserTokenResp + if err := json.Unmarshal(raw, &out); err != nil { + return nil, err + } + if strings.TrimSpace(out.AccessToken) == "" { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + return &out, nil +} + +type DingTalkAPIError struct { + Code string + Message string + HTTP int +} + +func (e *DingTalkAPIError) Error() string { + return fmt.Sprintf("dingtalk api error code=%s msg=%s http=%d", e.Code, e.Message, e.HTTP) +} + +func parseDingTalkErr(raw []byte, status int) error { + var v struct { + Code string `json:"code"` + Message string `json:"message"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + _ = json.Unmarshal(raw, &v) + code := v.Code + if code == "" && v.ErrCode != 0 { + code = fmt.Sprintf("%d", v.ErrCode) + } + msg := v.Message + if msg == "" { + msg = v.ErrMsg + } + return &DingTalkAPIError{Code: code, Message: msg, HTTP: status} +} + +// GetUnionIdByUserToken 调用 /v1.0/contact/users/me 返回 unionId 与用户自设昵称 nick。 +// nick 来自钉钉新版 OIDC 接口(用户在 App 个人资料填的昵称),与旧版 user/get.nickname 不同源。 +func (c *DingTalkClient) GetUnionIdByUserToken(ctx context.Context, userToken string) (unionID string, nick string, err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.cfg.UserInfoURL, nil) + if err != nil { + return "", "", err + } + req.Header.Set("x-acs-dingtalk-access-token", userToken) + resp, err := c.httpClient.Do(req) + if err != nil { + return "", "", err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", "", parseDingTalkErr(raw, resp.StatusCode) + } + var v struct { + UnionID string `json:"unionId"` + Nick string `json:"nick"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return "", "", err + } + if strings.TrimSpace(v.UnionID) == "" { + return "", "", parseDingTalkErr(raw, resp.StatusCode) + } + return v.UnionID, v.Nick, nil +} + +type DingTalkStaffInfo struct { + UserID string + Name string // 企业内真实姓名(钉钉企业管理后台配置) + Nickname string // 钉钉个人昵称(用户自己设置) + Email string + DeptIDs []int64 + // CorpID 不来自 staff 接口,来自 userToken;不在此 struct +} + +// dingTalkOAPIBase 推导钉钉旧版 OAPI base URL(host: api.dingtalk.com → oapi.dingtalk.com)。 +// getbyunionid 与 topapi/v2/user/get 仅在旧版 OAPI 提供,不在 v1.0 OpenAPI。 +func (c *DingTalkClient) dingTalkOAPIBase() string { + u, err := url.Parse(c.cfg.UserInfoURL) + if err != nil || u.Scheme == "" || u.Host == "" { + return "https://oapi.dingtalk.com" + } + host := u.Host + if strings.HasPrefix(host, "api.") { + host = "oapi." + strings.TrimPrefix(host, "api.") + } + return u.Scheme + "://" + host +} + +func (c *DingTalkClient) GetAppToken(ctx context.Context) (string, error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.appToken != "" && time.Now().Before(c.appTokenExp) { + return c.appToken, nil + } + body := map[string]string{"appKey": c.cfg.ClientID, "appSecret": c.cfg.ClientSecret} + payload, _ := json.Marshal(body) + // 钉钉新版 v1.0 企业内部应用 access_token: POST /v1.0/oauth2/accessToken + // 此 token 也可作为旧版 OAPI 的 access_token 使用(钉钉文档已说明) + appTokenURL := strings.Replace(c.cfg.TokenURL, "/oauth2/userAccessToken", "/oauth2/accessToken", 1) + if !strings.Contains(appTokenURL, "accessToken") && !strings.Contains(appTokenURL, "gettoken") { + appTokenURL = c.cfg.TokenURL // fallback for test stub + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, appTokenURL, bytes.NewReader(payload)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", parseDingTalkErr(raw, resp.StatusCode) + } + var v struct { + AccessToken string `json:"accessToken"` + ExpireIn int64 `json:"expireIn"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return "", err + } + if v.AccessToken == "" { + return "", parseDingTalkErr(raw, resp.StatusCode) + } + c.appToken = v.AccessToken + ttl := v.ExpireIn + if ttl > 200 { + ttl -= 200 + } + c.appTokenExp = time.Now().Add(time.Duration(ttl) * time.Second) + return c.appToken, nil +} + +func (c *DingTalkClient) GetUserIdByUnionId(ctx context.Context, unionID string) (string, error) { + appToken, err := c.GetAppToken(ctx) + if err != nil { + return "", err + } + body := map[string]string{"unionid": unionID} + payload, _ := json.Marshal(body) + // 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/user/getbyunionid?access_token=XXX + // access_token 通过 query string 传递(不是 header) + var targetURL string + if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") { + targetURL = c.dingTalkOAPIBase() + "/topapi/user/getbyunionid?access_token=" + url.QueryEscape(appToken) + } else { + targetURL = c.cfg.UserInfoURL // fallback for test stub + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", parseDingTalkErr(raw, resp.StatusCode) + } + var v struct { + Result struct { + UserID string `json:"userid"` + } `json:"result"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return "", err + } + if v.ErrCode != 0 { + return "", parseDingTalkErr(raw, resp.StatusCode) + } + if strings.TrimSpace(v.Result.UserID) == "" { + return "", parseDingTalkErr(raw, resp.StatusCode) + } + return v.Result.UserID, nil +} + +// DingTalkDeptInfo 部门信息(topapi/v2/department/get 返回子集) +type DingTalkDeptInfo struct { + DeptID int64 + Name string + ParentID int64 +} + +// GetDeptInfo 查询单个部门信息(用于递归拼部门路径)。 +// 调用钉钉旧版 OAPI: POST /topapi/v2/department/get?access_token=XXX +func (c *DingTalkClient) GetDeptInfo(ctx context.Context, deptID int64) (*DingTalkDeptInfo, error) { + appToken, err := c.GetAppToken(ctx) + if err != nil { + return nil, err + } + body := map[string]any{"dept_id": deptID, "language": "zh_CN"} + payload, _ := json.Marshal(body) + var targetURL string + if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") { + targetURL = c.dingTalkOAPIBase() + "/topapi/v2/department/get?access_token=" + url.QueryEscape(appToken) + } else { + targetURL = c.cfg.UserInfoURL // test stub fallback + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + var v struct { + Result struct { + DeptID int64 `json:"dept_id"` + Name string `json:"name"` + ParentID int64 `json:"parent_id"` + } `json:"result"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return nil, err + } + if v.ErrCode != 0 { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + return &DingTalkDeptInfo{ + DeptID: v.Result.DeptID, + Name: v.Result.Name, + ParentID: v.Result.ParentID, + }, nil +} + +func (c *DingTalkClient) GetStaffInfoByUserId(ctx context.Context, userID string) (*DingTalkStaffInfo, error) { + appToken, err := c.GetAppToken(ctx) + if err != nil { + return nil, err + } + body := map[string]string{"userid": userID} + payload, _ := json.Marshal(body) + // 钉钉旧版 OAPI: POST https://oapi.dingtalk.com/topapi/v2/user/get?access_token=XXX + var targetURL string + if strings.Contains(c.cfg.UserInfoURL, "/contact/users/me") { + targetURL = c.dingTalkOAPIBase() + "/topapi/v2/user/get?access_token=" + url.QueryEscape(appToken) + } else { + targetURL = c.cfg.UserInfoURL // fallback for test stub + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(payload)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + raw, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + var v struct { + Result struct { + UserID string `json:"userid"` + Name string `json:"name"` + Nickname string `json:"nickname"` + Email string `json:"email"` + OrgEmail string `json:"org_email"` + Extension string `json:"extension"` + DeptID []int64 `json:"dept_id_list"` + } `json:"result"` + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(raw, &v); err != nil { + return nil, err + } + if v.ErrCode != 0 { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + if strings.TrimSpace(v.Result.UserID) == "" { + return nil, parseDingTalkErr(raw, resp.StatusCode) + } + // 邮箱三级 fallback:org_email > email > extension["企业邮箱"](钉钉自定义扩展字段,JSON string) + email := strings.TrimSpace(v.Result.OrgEmail) + emailSource := "org_email" + if email == "" { + email = strings.TrimSpace(v.Result.Email) + emailSource = "email" + } + extensionParsed := false + if email == "" && strings.TrimSpace(v.Result.Extension) != "" { + var ext map[string]string + if err := json.Unmarshal([]byte(v.Result.Extension), &ext); err == nil { + extensionParsed = true + if v, ok := ext["企业邮箱"]; ok { + email = strings.TrimSpace(v) + emailSource = "extension.企业邮箱" + } + } + } + if email == "" { + emailSource = "none" + } + slog.Info("dingtalk staff fetched", + "userid", v.Result.UserID, + "name_present", v.Result.Name != "", + "nickname_present", v.Result.Nickname != "", + "name_eq_nickname", v.Result.Name != "" && v.Result.Name == v.Result.Nickname, + "email_present", v.Result.Email != "", + "org_email_present", v.Result.OrgEmail != "", + "extension_present", v.Result.Extension != "", + "extension_parsed", extensionParsed, + "email_source", emailSource, + "dept_count", len(v.Result.DeptID), + ) + return &DingTalkStaffInfo{ + UserID: v.Result.UserID, + Name: v.Result.Name, + Nickname: v.Result.Nickname, + Email: email, + DeptIDs: v.Result.DeptID, + }, nil +} diff --git a/backend/internal/handler/auth_dingtalk_client_test.go b/backend/internal/handler/auth_dingtalk_client_test.go new file mode 100644 index 00000000000..aa2e2fdddda --- /dev/null +++ b/backend/internal/handler/auth_dingtalk_client_test.go @@ -0,0 +1,143 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestDingTalkClient_ExchangeCodeForUserToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "POST", r.Method) + require.Equal(t, "/v1.0/oauth2/userAccessToken", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"USER_TOKEN_X","expireIn":7200,"refreshToken":"R","corpId":"dingABC"}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{ + ClientID: "k", ClientSecret: "s", + TokenURL: server.URL + "/v1.0/oauth2/userAccessToken", + }, + httpClient: server.Client(), + } + resp, err := cli.ExchangeCodeForUserToken(context.Background(), "AUTH_CODE") + require.NoError(t, err) + require.Equal(t, "USER_TOKEN_X", resp.AccessToken) + require.Equal(t, "dingABC", resp.CorpID) +} + +func TestDingTalkClient_GetUnionIdByUserToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "USER_TOKEN_X", r.Header.Get("x-acs-dingtalk-access-token")) + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"nick":"张三","unionId":"UID_AAA","openId":"OPEN","avatarUrl":"http://x"}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/v1.0/contact/users/me"}, + httpClient: server.Client(), + } + unionID, nick, err := cli.GetUnionIdByUserToken(context.Background(), "USER_TOKEN_X") + require.NoError(t, err) + require.Equal(t, "UID_AAA", unionID) + require.Equal(t, "张三", nick) +} + +func TestDingTalkClient_GetAppToken_Cached(t *testing.T) { + callCount := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + _, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{ClientID: "k", ClientSecret: "s", TokenURL: server.URL + "/gettoken"}, + httpClient: server.Client(), + } + t1, err := cli.GetAppToken(context.Background()) + require.NoError(t, err) + t2, err := cli.GetAppToken(context.Background()) + require.NoError(t, err) + require.Equal(t, t1, t2) + require.Equal(t, 1, callCount, "second call should hit cache") +} + +func TestDingTalkClient_GetUserIdByUnionId_60011(t *testing.T) { + appTokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte(`{"accessToken":"APP_TKN","expireIn":7200}`)) + })) + defer appTokenServer.Close() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":60011,"errmsg":"not in directory"}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{TokenURL: appTokenServer.URL + "/gettoken"}, + httpClient: server.Client(), + } + cli.appToken = "APP_TKN" + cli.appTokenExp = time.Now().Add(time.Hour) + cli.cfg.UserInfoURL = server.URL + "/v1.0/contact/users/byUnionId" + + _, err := cli.GetUserIdByUnionId(context.Background(), "UID_AAA") + require.Error(t, err) + apiErr, ok := err.(*DingTalkAPIError) + require.True(t, ok) + require.Equal(t, "60011", apiErr.Code) +} + +// TestDingTalkClient_GetDeptInfo_Success 验证 GetDeptInfo 正常情况返回部门信息。 +func TestDingTalkClient_GetDeptInfo_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":0,"errmsg":"ok","result":{"dept_id":42,"name":"AI数据","parent_id":1}}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{ + UserInfoURL: server.URL + "/stub", // 不含 /contact/users/me,走 test stub 路径 + }, + httpClient: server.Client(), + } + cli.appToken = "APP_TKN" + cli.appTokenExp = time.Now().Add(time.Hour) + + info, err := cli.GetDeptInfo(context.Background(), 42) + require.NoError(t, err) + require.Equal(t, int64(42), info.DeptID) + require.Equal(t, "AI数据", info.Name) + require.Equal(t, int64(1), info.ParentID) +} + +// TestDingTalkClient_GetDeptInfo_ErrCode60003 验证 errcode=60003(部门不存在)时返回错误。 +func TestDingTalkClient_GetDeptInfo_ErrCode60003(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"dept not found"}`)) + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"}, + httpClient: server.Client(), + } + cli.appToken = "APP_TKN" + cli.appTokenExp = time.Now().Add(time.Hour) + + _, err := cli.GetDeptInfo(context.Background(), 999) + require.Error(t, err) + apiErr, ok := err.(*DingTalkAPIError) + require.True(t, ok) + require.Equal(t, "60003", apiErr.Code) +} diff --git a/backend/internal/handler/auth_dingtalk_oauth.go b/backend/internal/handler/auth_dingtalk_oauth.go new file mode 100644 index 00000000000..a5b27dc61cc --- /dev/null +++ b/backend/internal/handler/auth_dingtalk_oauth.go @@ -0,0 +1,1066 @@ +package handler + +import ( + "context" + "errors" + "fmt" + "log/slog" + "net/http" + "net/url" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + dbuser "github.com/Wei-Shaw/sub2api/ent/user" + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +// dingTalkUpstreamRedirect 在 4 步链上游调用失败时记录详细错误日志并跳错误页。 +// 把钉钉 errcode/errmsg 写进 backend log + URL fragment,避免被泛 "internal error" 吞掉。 +func dingTalkUpstreamRedirect(c *gin.Context, frontendCallback, step string, err error) { + var apiErr *DingTalkAPIError + dtCode := "" + dtMsg := "" + dtHTTP := 0 + if errors.As(err, &apiErr) { + dtCode = apiErr.Code + dtMsg = apiErr.Message + dtHTTP = apiErr.HTTP + } + slog.Error("dingtalk upstream call failed", + "step", step, + "dingtalk_code", dtCode, + "dingtalk_msg", dtMsg, + "http_status", dtHTTP, + "error", err.Error(), + ) + msg := dtMsg + if strings.TrimSpace(msg) == "" { + msg = infraerrors.Message(err) + } + if strings.TrimSpace(dtCode) != "" { + msg = "dingtalk[" + dtCode + "] " + msg + } + redirectOAuthError(c, frontendCallback, mapDingTalkErrorCode(err), msg, "") +} + +// ─── 常量 ────────────────────────────────────────────────────────────────── + +const ( + dingTalkOAuthCookiePath = "/api/v1/auth/oauth/dingtalk" + dingTalkOAuthStateCookieName = "dingtalk_oauth_state" + dingTalkOAuthRedirectCookie = "dingtalk_oauth_redirect" + dingTalkOAuthIntentCookieName = "dingtalk_oauth_intent" + dingTalkOAuthBindUserCookieName = "dingtalk_oauth_bind_user" + dingTalkOAuthCookieMaxAgeSec = 600 // 10 分钟 + dingTalkOAuthDefaultRedirectTo = "/dashboard" + dingTalkOAuthDefaultFrontendCB = "/auth/dingtalk/callback" + + dingTalkLevelThreeEnabled = true +) + +// ─── Config helper ───────────────────────────────────────────────────────── + +// getDingTalkOAuthConfig 返回 DingTalk OAuth 最终生效配置。 +// 优先从 settingSvc(settings 表)读取,回退到 h.cfg.DingTalk。 +func (h *AuthHandler) getDingTalkOAuthConfig(ctx context.Context) (config.DingTalkConnectConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetDingTalkConnectOAuthConfig(ctx) + } + if h == nil || h.cfg == nil { + return config.DingTalkConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !h.cfg.DingTalk.Enabled { + return config.DingTalkConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "dingtalk oauth login is disabled") + } + return h.cfg.DingTalk, nil +} + +// ─── Cookie helpers(使用 dingtalk path)───────────────────────────────── + +func setDingTalkCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: dingTalkOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearDingTalkCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: dingTalkOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +// ─── DingTalkOAuthStart ──────────────────────────────────────────────────── + +// DingTalkOAuthStart 启动 DingTalk Connect OAuth 登录流程。 +// GET /api/v1/auth/oauth/dingtalk/start?redirect=/dashboard&intent=login +func (h *AuthHandler) DingTalkOAuthStart(c *gin.Context) { + cfg, err := h.getDingTalkOAuthConfig(c.Request.Context()) + if err != nil { + frontendCB := dingTalkOAuthDefaultFrontendCB + redirectOAuthError(c, frontendCB, "dingtalk_not_enabled", "", "") + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = dingTalkOAuthDefaultRedirectTo + } + + browserSessionKey, err := generateOAuthPendingBrowserSession() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BROWSER_SESSION_GEN_FAILED", "failed to generate oauth browser session").WithCause(err)) + return + } + + secureCookie := isRequestHTTPS(c) + setDingTalkCookie(c, dingTalkOAuthStateCookieName, encodeCookieValue(state), dingTalkOAuthCookieMaxAgeSec, secureCookie) + setDingTalkCookie(c, dingTalkOAuthRedirectCookie, encodeCookieValue(redirectTo), dingTalkOAuthCookieMaxAgeSec, secureCookie) + + intent := normalizeOAuthIntent(c.Query("intent")) + setDingTalkCookie(c, dingTalkOAuthIntentCookieName, encodeCookieValue(intent), dingTalkOAuthCookieMaxAgeSec, secureCookie) + + setOAuthPendingBrowserCookie(c, browserSessionKey, secureCookie) + clearOAuthPendingSessionCookie(c, secureCookie) + + if intent == oauthIntentBindCurrentUser { + bindCookieValue, err := h.buildOAuthBindUserCookieFromContext(c) + if err != nil { + response.ErrorFrom(c, err) + return + } + setDingTalkCookie(c, dingTalkOAuthBindUserCookieName, encodeCookieValue(bindCookieValue), dingTalkOAuthCookieMaxAgeSec, secureCookie) + } else { + clearDingTalkCookie(c, dingTalkOAuthBindUserCookieName, secureCookie) + } + + authURL, err := buildDingTalkAuthorizeURL(cfg, state) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build dingtalk authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// ─── buildDingTalkAuthorizeURL ───────────────────────────────────────────── + +// ─── findDingTalkCompatEmailUser ─────────────────────────────────────────── + +// findDingTalkCompatEmailUser 通过真实邮箱查找可与 DingTalk 账号兼容绑定的现有用户。 +func (h *AuthHandler) findDingTalkCompatEmailUser(ctx context.Context, email string) (*dbent.User, error) { + if !dingTalkLevelThreeEnabled { + return nil, nil + } + + client := h.entClient() + if client == nil { + return nil, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready") + } + + email = strings.TrimSpace(strings.ToLower(email)) + if email == "" || + strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + return nil, nil + } + + userEntities, err := client.User.Query(). + Where(userNormalizedEmailPredicate(email)). + Order(dbent.Asc(dbuser.FieldID)). + All(ctx) + if err != nil { + return nil, infraerrors.InternalServer("COMPAT_EMAIL_LOOKUP_FAILED", "failed to look up compat email user").WithCause(err) + } + switch len(userEntities) { + case 0: + return nil, nil + case 1: + return userEntities[0], nil + default: + return nil, infraerrors.Conflict("USER_EMAIL_CONFLICT", "normalized email matched multiple users") + } +} + +// ─── createDingTalkOAuthChoicePendingSession ─────────────────────────────── + +// createDingTalkOAuthChoicePendingSession 创建 DingTalk OAuth 三方注册/绑定的 choice pending session。 +// signupBlocked=true 时关闭"创建新账户"出口;若同时没有 compat email 匹配的已有账户, +// 直接把 step 切到 bind_login_required,避免前端展示一个没有实际可点选项的 choice 界面。 +func (h *AuthHandler) createDingTalkOAuthChoicePendingSession( + c *gin.Context, + identity service.PendingAuthIdentityKey, + suggestedEmail string, + resolvedEmail string, + redirectTo string, + browserSessionKey string, + upstreamClaims map[string]any, + compatEmail string, + compatEmailUser *dbent.User, + forceEmailOnSignup bool, + signupBlocked bool, +) error { + suggestionEmail := strings.TrimSpace(suggestedEmail) + canonicalEmail := strings.TrimSpace(resolvedEmail) + if suggestionEmail == "" { + suggestionEmail = canonicalEmail + } + + completionResponse := map[string]any{ + "step": oauthPendingChoiceStep, + "adoption_required": true, + "redirect": strings.TrimSpace(redirectTo), + "email": suggestionEmail, + "resolved_email": canonicalEmail, + "existing_account_email": "", + "existing_account_bindable": false, + "create_account_allowed": !signupBlocked, + "force_email_on_signup": forceEmailOnSignup, + "choice_reason": "third_party_signup", + } + if strings.TrimSpace(compatEmail) != "" { + completionResponse["compat_email"] = strings.TrimSpace(compatEmail) + } + resolvedChoiceEmail := suggestionEmail + if compatEmailUser != nil { + completionResponse["email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_email"] = strings.TrimSpace(compatEmailUser.Email) + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "compat_email_match" + resolvedChoiceEmail = strings.TrimSpace(compatEmailUser.Email) + } + if forceEmailOnSignup && compatEmailUser == nil { + completionResponse["choice_reason"] = "force_email_on_signup" + } + // 注册被拦:无论是否匹配到 compat email user,都跳过 choice,直接进 bind_login。 + // "开放注册" 关闭 且 "钉钉企业模式豁免" 也关闭时,唯一合法出口是绑定已有账户, + // 不应该让用户看到"创建新账户"按钮;compat user 命中只是让 bind_login 的邮箱字段预填得更准。 + if signupBlocked { + completionResponse["step"] = "bind_login_required" + completionResponse["existing_account_bindable"] = true + completionResponse["choice_reason"] = "signup_blocked_redirect_to_bind" + } + + var targetUserID *int64 + if compatEmailUser != nil && compatEmailUser.ID > 0 { + targetUserID = &compatEmailUser.ID + } + + return h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, + Identity: identity, + TargetUserID: targetUserID, + ResolvedEmail: resolvedChoiceEmail, + RedirectTo: redirectTo, + BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }) +} + +// ─── DingTalkOAuthCallback ───────────────────────────────────────────────── + +// DingTalkOAuthCallback 处理钉钉授权回调。 +// GET /api/v1/auth/oauth/dingtalk/callback?code=...&state=... +func (h *AuthHandler) DingTalkOAuthCallback(c *gin.Context) { + cfg, cfgErr := h.getDingTalkOAuthConfig(c.Request.Context()) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = dingTalkOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearDingTalkCookie(c, dingTalkOAuthStateCookieName, secureCookie) + clearDingTalkCookie(c, dingTalkOAuthRedirectCookie, secureCookie) + clearDingTalkCookie(c, dingTalkOAuthIntentCookieName, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, dingTalkOAuthStateCookieName) + if err != nil || state != expectedState { + redirectOAuthError(c, frontendCallback, "csrf", "state mismatch", "") + return + } + redirectTo, _ := readCookieDecoded(c, dingTalkOAuthRedirectCookie) + intent, _ := readCookieDecoded(c, dingTalkOAuthIntentCookieName) + intent = normalizeOAuthIntent(intent) + browserSessionKey, _ := readOAuthPendingBrowserCookie(c) + if strings.TrimSpace(browserSessionKey) == "" { + redirectOAuthError(c, frontendCallback, "missing_browser_session", "missing browser session cookie", "") + return + } + forceEmailOnSignup := h.isForceEmailOnThirdPartySignup(c.Request.Context()) + + // ─── 4 步链(Step 1 + Step 2 必须;Step 3/4 按需 + 跨组织降级)─── + client := h.dingTalkClient(cfg) + userToken, err := client.ExchangeCodeForUserToken(c.Request.Context(), code) + if err != nil { + dingTalkUpstreamRedirect(c, frontendCallback, "exchange_code", err) + return + } + + // D: corp 校验提前到 Step 1 之后、Step 2 之前,减少不必要的上游调用 + corpID := strings.TrimSpace(userToken.CorpID) + if !checkDingTalkCorpAllowed(cfg, corpID) { + // 不在 URL 中透传 corpID,避免内部企业标识泄露给前端 + redirectOAuthError(c, frontendCallback, "corp_rejected", "", "") + return + } + + // Step 2: 必须 — UnionID 是全局唯一,作为 subject + 合成邮箱种子;nick 是用户在 App 自设的昵称 + unionID, oauthNick, err := client.GetUnionIdByUserToken(c.Request.Context(), userToken.AccessToken) + if err != nil { + dingTalkUpstreamRedirect(c, frontendCallback, "get_union_id", err) + return + } + + identityKey := service.PendingAuthIdentityKey{ProviderType: "dingtalk", ProviderKey: "dingtalk", ProviderSubject: unionID} + + // Step 3/4 调用策略由 policy 决定,与 require_email 解耦。 + // policy=internal_only → 必须成功(hard fail),因为 AppType=internal 已保证用户在应用企业。 + // policy=none / "" → 尝试,失败降级(公网场景跨组织用户属正常预期)。 + // require_email 只影响 Step 3/4 结果后的邮箱处理路径,不影响是否调用。 + var staff *DingTalkStaffInfo + switch cfg.CorpRestrictionPolicy { + case "internal_only": + // AppType=internal 已保证用户在应用企业,Step 3/4 必须成功。 + // 失败 = 钉钉 OAPI 故障或应用配置错误,应 hard fail。 + upstreamUserID, errStep3 := client.GetUserIdByUnionId(c.Request.Context(), unionID) + if errStep3 != nil { + dingTalkUpstreamRedirect(c, frontendCallback, "get_user_id", errStep3) + return + } + staffInfo, errStep4 := client.GetStaffInfoByUserId(c.Request.Context(), upstreamUserID) + if errStep4 != nil { + dingTalkUpstreamRedirect(c, frontendCallback, "get_staff_info", errStep4) + return + } + staff = staffInfo + + default: // "none" or "" + // 公网登录,跨组织用户 Step 3/4 可能失败(设计预期),尝试调用,失败降级。 + // 即使 require_email=false 也尝试拿 name(用于 upstreamClaims.username),失败就空着。 + upstreamUserID, errStep3 := client.GetUserIdByUnionId(c.Request.Context(), unionID) + if errStep3 != nil { + slog.Debug("dingtalk step3 fallback (none/cross-org)", + "corp_id", corpID, "union_id", unionID, "err", errStep3.Error()) + staff = &DingTalkStaffInfo{} + break + } + staffInfo, errStep4 := client.GetStaffInfoByUserId(c.Request.Context(), upstreamUserID) + if errStep4 != nil { + slog.Debug("dingtalk step4 fallback (none/cross-org)", + "corp_id", corpID, "union_id", unionID, "err", errStep4.Error()) + staff = &DingTalkStaffInfo{} + break + } + staff = staffInfo + } + + // nick 来自 OIDC /contact/users/me,优先作为钉钉昵称(user/get.nickname 多数为空)。 + if staff != nil && strings.TrimSpace(oauthNick) != "" { + staff.Nickname = strings.TrimSpace(oauthNick) + } + + upstreamClaims := buildDingTalkUpstreamClaims(staff, unionID, corpID) + + // ─── S1 主动绑定分支(PR-3 才走到这里)─── + if intent == oauthIntentBindCurrentUser { + targetUserID, err := h.readOAuthBindUserIDFromCookie(c, dingTalkOAuthBindUserCookieName) + if err != nil { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid bind user cookie", "") + return + } + // policy=none 跨组织用户绑定时 staff.Email="",用合成邮箱占位(用于 audit log,不用于注册) + bindResolvedEmail := staff.Email + if bindResolvedEmail == "" { + bindResolvedEmail = buildDingTalkSyntheticEmail(unionID) + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentBindCurrentUser, Identity: identityKey, + TargetUserID: &targetUserID, ResolvedEmail: bindResolvedEmail, + RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{"redirect": redirectTo}, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + clearDingTalkCookie(c, dingTalkOAuthBindUserCookieName, secureCookie) + redirectToFrontendCallback(c, frontendCallback) + return + } + + // ─── Level 1:auth_identities hit ─── + if existing, _ := h.findOAuthIdentityUser(c.Request.Context(), identityKey); existing != nil { + // 身份同步:已登录用户,直接同步(user_id 已知)。 + // 异步执行避免上游钉钉接口(GetStaffInfoByUserId / 部门递归)阻塞登录跳转。 + runDingTalkSyncAsync(c.Request.Context(), func(ctx context.Context) { + h.syncDingTalkIdentity(ctx, cfg, client, existing.ID, staff, false) + }) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: &existing.ID, + ResolvedEmail: existing.Email, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{"redirect": redirectTo}, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + signupBlocked := h.isDingTalkSignupBlocked(c.Request.Context(), cfg) + + // ─── 非命中:require_email=false 走 synthetic email 直接登录 ─── + if !cfg.RequireEmail { + if signupBlocked { + // 注册被拦 + 无邮箱可输:唯一出路是绑定已有账户 + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil, + ResolvedEmail: "", RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: dingTalkBindLoginCompletionResponse(redirectTo), + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + syntheticEmail := buildDingTalkSyntheticEmail(unionID) + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil, + ResolvedEmail: syntheticEmail, RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: map[string]any{"redirect": redirectTo, "synthetic_email": syntheticEmail}, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + // ─── require_email=true 且 staff.Email 空 → 补邮箱(默认)或直接 bind_login(注册被拦时) ─── + if staff.Email == "" { + completionResponse := map[string]any{ + "step": "email_completion", + "requires_email_completion": true, + "redirect": redirectTo, + } + if signupBlocked { + // 注册被全局关闭且未豁免:跳过补邮箱页,直接进 bind_login 让用户输入已有账户 + completionResponse = dingTalkBindLoginCompletionResponse(redirectTo) + } + if err := h.createOAuthPendingSession(c, oauthPendingSessionPayload{ + Intent: oauthIntentLogin, Identity: identityKey, TargetUserID: nil, + ResolvedEmail: "", RedirectTo: redirectTo, BrowserSessionKey: browserSessionKey, + UpstreamIdentityClaims: upstreamClaims, + CompletionResponse: completionResponse, + }); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + redirectToFrontendCallback(c, frontendCallback) + return + } + + // ─── L3/L4 有邮箱:统一 choice pending session ─── + var compatEmailUser *dbent.User + if dingTalkLevelThreeEnabled && staff.Email != "" { + compatEmailUser, _ = h.findDingTalkCompatEmailUser(c.Request.Context(), staff.Email) + } + if err := h.createDingTalkOAuthChoicePendingSession( + c, identityKey, staff.Email, staff.Email, + redirectTo, browserSessionKey, upstreamClaims, + staff.Email, compatEmailUser, forceEmailOnSignup, + signupBlocked, + ); err != nil { + redirectOAuthError(c, frontendCallback, "session_error", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + redirectToFrontendCallback(c, frontendCallback) +} + +func buildDingTalkSyntheticEmail(userID string) string { + return "dingtalk-" + strings.ToLower(strings.TrimSpace(userID)) + service.DingTalkConnectSyntheticEmailDomain +} + +// isDingTalkSignupBlocked 当注册总开关关闭且未开启钉钉企业模式豁免 +// (policy=internal_only + dingtalk_connect_bypass_registration=true)时返回 true。 +// 镜像 service.AuthService.canBypassRegistrationDisabledForOAuth 用于 OAuth callback +// 早期路由决策:注册被拦 → 跳过补邮箱页直接进 bind_login,避免用户填完表单才报错。 +func (h *AuthHandler) isDingTalkSignupBlocked(ctx context.Context, cfg config.DingTalkConnectConfig) bool { + if h.settingSvc == nil { + return false + } + if h.settingSvc.IsRegistrationEnabled(ctx) { + return false + } + if cfg.BypassRegistration && cfg.CorpRestrictionPolicy == "internal_only" { + return false + } + return true +} + +func dingTalkBindLoginCompletionResponse(redirectTo string) map[string]any { + return map[string]any{ + "step": "bind_login_required", + "existing_account_bindable": true, + "create_account_allowed": false, + "redirect": redirectTo, + } +} + +func buildDingTalkUpstreamClaims(staff *DingTalkStaffInfo, unionID, corpID string) map[string]any { + primaryDeptID := int64(0) + if len(staff.DeptIDs) > 0 { + primaryDeptID = staff.DeptIDs[0] + } + return map[string]any{ + "email": staff.Email, + "username": staff.Name, + "nickname": staff.Nickname, + "subject": unionID, // 与 identityKey.ProviderSubject 保持一致(全局唯一 unionID) + "corp_user_id": staff.UserID, // 企业 userid(跨组织时为空),保留作独立字段用于 audit + "union_id": unionID, + "corp_id": corpID, + "primary_dept_id": primaryDeptID, // 首个部门 ID,用于 internal_only 同步路径 + } +} + +func checkDingTalkCorpAllowed(cfg config.DingTalkConnectConfig, corpID string) bool { + switch cfg.CorpRestrictionPolicy { + case "internal_only": + // 方案 A:完全跳过 corpID 字段校验,由 step 3 `GetUserIdByUnionId` 做真实判定。 + // 原因:钉钉 /v1.0/oauth2/userAccessToken 在部分授权场景(扫码登录、非企业工作台入口) + // 不会返回 corpId 字段。而 step 3 用本企业 appToken 查 unionId→userId 映射, + // 跨企业用户会被钉钉拒绝(错误码 60011/60121),mapDingTalkErrorCode 已将其映射回 "corp_rejected"。 + // AppType=internal 已由 ValidateDingTalkConfig 强制保证应用属性。 + return true + case "none", "": + return true + default: + return false + } +} + +// decideDingTalkStep34Strategy 根据 policy 和 Step 3/4 运行时错误决定处理方式。 +// 返回 (proceed bool, fatal bool): +// - proceed=true:继续处理(step 成功或降级) +// - fatal=true:应 hard fail(upstream_error) +// +// 此 helper 从主链中提取,便于 unit test 独立验证策略决策逻辑。 +func decideDingTalkStep34Strategy(policy string, stepErr error) (shouldFallback bool, isFatal bool) { + if stepErr == nil { + return false, false // 成功,不需要降级 + } + switch policy { + case "internal_only": + return false, true // hard fail:同企业 Step 3/4 必须成功 + case "none", "": + return true, false // 降级:公网场景跨组织用户失败属正常预期 + default: + return false, true // 未知 policy,视为 hard fail + } +} + +// mapDingTalkErrorCode 把 DingTalkAPIError 映射到 redirectOAuthError 用的字符串 code +func mapDingTalkErrorCode(err error) string { + var apiErr *DingTalkAPIError + if !errors.As(err, &apiErr) { + return "upstream_error" + } + switch apiErr.Code { + case "60011", "60121": + return "corp_rejected" + case "40014", "50015", "88": + return "upstream_error" + default: + return "upstream_error" + } +} + +// dingTalkClient 构造或返回缓存的 client 实例(h-level 单例)。 +// 若 cfg 关键字段(ClientID/ClientSecret/TokenURL/UserInfoURL)与已缓存实例不一致, +// 则丢弃旧实例(含 appToken 缓存)并重建,避免管理员改配置后旧凭据持续生效。 +func (h *AuthHandler) dingTalkClient(cfg config.DingTalkConnectConfig) *DingTalkClient { + h.dingTalkClientMu.Lock() + defer h.dingTalkClientMu.Unlock() + newCfg := dingTalkClientConfig{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + TokenURL: cfg.TokenURL, + UserInfoURL: cfg.UserInfoURL, + } + if h.dingTalkClientInstance == nil || h.dingTalkClientInstance.cfg != newCfg { + h.dingTalkClientInstance = &DingTalkClient{ + cfg: newCfg, + // 与 wechat OAuth client 对齐,避免上游网络抖动时请求悬挂。 + httpClient: &http.Client{Timeout: 10 * time.Second}, + } + } + return h.dingTalkClientInstance +} + +// ─── buildDingTalkAuthorizeURL ───────────────────────────────────────────── + +// buildDingTalkAuthorizeURL 根据配置和 state 构建钉钉 OAuth 授权 URL。 +func buildDingTalkAuthorizeURL(cfg config.DingTalkConnectConfig, state string) (string, error) { + base := strings.TrimSpace(cfg.AuthorizeURL) + if base == "" { + return "", infraerrors.InternalServer("DINGTALK_AUTHORIZE_URL_EMPTY", "dingtalk authorize_url not configured") + } + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + return "", infraerrors.InternalServer("DINGTALK_REDIRECT_URL_EMPTY", "dingtalk redirect_url not configured") + } + + u, err := url.Parse(base) + if err != nil { + return "", infraerrors.InternalServer("DINGTALK_AUTHORIZE_URL_PARSE_FAILED", "failed to parse dingtalk authorize_url").WithCause(err) + } + + scopes := strings.TrimSpace(cfg.Scopes) + if scopes == "" { + scopes = "openid" + } + + q := u.Query() + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + q.Set("response_type", "code") + q.Set("scope", scopes) + q.Set("state", state) + q.Set("prompt", "consent") + u.RawQuery = q.Encode() + + return u.String(), nil +} + +// ─── Complete Registration ───────────────────────────────────────────────── + +type completeDingTalkOAuthRequest struct { + InvitationCode string `json:"invitation_code" binding:"required"` + AffCode string `json:"aff_code,omitempty"` + AdoptDisplayName *bool `json:"adopt_display_name,omitempty"` + AdoptAvatar *bool `json:"adopt_avatar,omitempty"` +} + +// CompleteDingTalkOAuthRegistration completes a pending OAuth registration by validating +// the invitation code and creating the user account. +// POST /api/v1/auth/oauth/dingtalk/complete-registration +func (h *AuthHandler) CompleteDingTalkOAuthRegistration(c *gin.Context) { + var req completeDingTalkOAuthRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "INVALID_REQUEST", "message": err.Error()}) + return + } + + secureCookie := isRequestHTTPS(c) + sessionToken, err := readOAuthPendingSessionCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthSessionNotFound) + return + } + browserSessionKey, err := readOAuthPendingBrowserCookie(c) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, service.ErrPendingAuthBrowserMismatch) + return + } + pendingSvc, err := h.pendingIdentityService() + if err != nil { + response.ErrorFrom(c, err) + return + } + session, err := pendingSvc.GetBrowserSession(c.Request.Context(), sessionToken, browserSessionKey) + if err != nil { + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + response.ErrorFrom(c, err) + return + } + if err := ensurePendingOAuthCompleteRegistrationSession(session); err != nil { + response.ErrorFrom(c, err) + return + } + if updatedSession, handled, err := h.legacyCompleteRegistrationSessionStatus(c, session); err != nil { + response.ErrorFrom(c, err) + return + } else if handled { + c.JSON(http.StatusOK, buildPendingOAuthSessionStatusPayload(updatedSession)) + return + } else { + session = updatedSession + } + if err := h.ensureBackendModeAllowsNewUserLogin(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + email := strings.TrimSpace(session.ResolvedEmail) + username := pendingSessionStringValue(session.UpstreamIdentityClaims, "username") + // E: username 空时退到 email local part(跨组织用户没拿到 staff.Name 也能注册) + if username == "" { + if at := strings.Index(email, "@"); at > 0 { + username = email[:at] + } else { + username = email + } + } + if email == "" || username == "" { + response.ErrorFrom(c, infraerrors.BadRequest("PENDING_AUTH_SESSION_INVALID", "pending auth registration context is invalid")) + return + } + + client := h.entClient() + if client == nil { + response.ErrorFrom(c, infraerrors.ServiceUnavailable("PENDING_AUTH_NOT_READY", "pending auth service is not ready")) + return + } + if err := ensurePendingOAuthRegistrationIdentityAvailable(c.Request.Context(), client, session); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + decision, err := h.ensurePendingOAuthAdoptionDecision(c, session.ID, oauthAdoptionDecisionRequest{ + AdoptDisplayName: req.AdoptDisplayName, + AdoptAvatar: req.AdoptAvatar, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "dingtalk") + if err != nil { + response.ErrorFrom(c, err) + return + } + if err := applyPendingOAuthAdoptionAndConsumeSession(c.Request.Context(), client, h.authService, h.userService, session, decision, user.ID); err != nil { + respondPendingOAuthBindingApplyError(c, err) + return + } + // 新用户注册完成后执行身份同步(user_id 现在已知)。 + // 异步执行避免阻塞 token 响应。 + if completionCfg, cfgErr := h.getDingTalkOAuthConfig(c.Request.Context()); cfgErr == nil { + dtClient := h.dingTalkClient(completionCfg) + claims := session.UpstreamIdentityClaims + runDingTalkSyncAsync(c.Request.Context(), func(ctx context.Context) { + h.syncDingTalkIdentityFromClaims(ctx, completionCfg, dtClient, user.ID, claims, true) + }) + } + h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + clearOAuthPendingSessionCookie(c, secureCookie) + clearOAuthPendingBrowserCookie(c, secureCookie) + + c.JSON(http.StatusOK, gin.H{ + "access_token": tokenPair.AccessToken, + "refresh_token": tokenPair.RefreshToken, + "expires_in": tokenPair.ExpiresIn, + "token_type": "Bearer", + }) +} + +// CreateDingTalkOAuthAccount creates a new user account from a pending DingTalk OAuth session. +// POST /api/v1/auth/oauth/dingtalk/create-account +func (h *AuthHandler) CreateDingTalkOAuthAccount(c *gin.Context) { + h.createPendingOAuthAccount(c, "dingtalk") +} + +// BindDingTalkOAuthLogin 处理已有账户绑定钉钉 OAuth 登录。 +// POST /api/v1/auth/oauth/dingtalk/bind-login +func (h *AuthHandler) BindDingTalkOAuthLogin(c *gin.Context) { + h.bindPendingOAuthLogin(c, "dingtalk") +} + +// ─── DingTalk 身份同步 ───────────────────────────────────────────────────── + +// runDingTalkSyncAsync 在后台 goroutine 执行钉钉身份同步,避免阻塞登录响应。 +// 与请求 ctx 解耦(handler 返回后会被取消),但保留其 values(trace/request id)。 +// 固定 30s 超时上限,防止 goroutine 因上游卡顿无限挂起。 +func runDingTalkSyncAsync(parent context.Context, fn func(ctx context.Context)) { + base := context.WithoutCancel(parent) + go func() { + defer func() { + if r := recover(); r != nil { + slog.Error("dingtalk sync: panic recovered", "panic", r) + } + }() + ctx, cancel := context.WithTimeout(base, 30*time.Second) + defer cancel() + fn(ctx) + }() +} + +// syncDingTalkIdentity 在 internal_only 模式下,按三个 sync 开关把钉钉身份信息 +// 同步到用户属性表(以及 users.username)。 +// 任何错误仅记日志,不中断登录流程(最终一致性)。 +func (h *AuthHandler) syncDingTalkIdentity(ctx context.Context, cfg config.DingTalkConnectConfig, client *DingTalkClient, userID int64, staff *DingTalkStaffInfo, syncUsername bool) { + slog.Info("dingtalk sync: entry", + "user_id", userID, + "policy", cfg.CorpRestrictionPolicy, + "sync_corp_email", cfg.SyncCorpEmail, + "sync_display_name", cfg.SyncDisplayName, + "sync_dept", cfg.SyncDept, + "sync_username", syncUsername, + "attr_key_email", cfg.SyncCorpEmailAttrKey, + "attr_key_name", cfg.SyncDisplayNameAttrKey, + "attr_key_dept", cfg.SyncDeptAttrKey, + "staff_nil", staff == nil, + ) + if cfg.CorpRestrictionPolicy != "internal_only" || staff == nil { + slog.Info("dingtalk sync: skip, not internal_only or staff nil") + return + } + slog.Info("dingtalk sync: staff snapshot", + "name", staff.Name, "email", staff.Email, "dept_ids", staff.DeptIDs, + ) + if !cfg.SyncCorpEmail && !cfg.SyncDisplayName && !cfg.SyncDept { + slog.Info("dingtalk sync: skip, all flags disabled") + return + } + if h.userAttributeService == nil { + slog.Warn("dingtalk sync: userAttributeService not available, skipping") + return + } + + // 仅首次注册时覆盖 users.username(避免每次登录覆盖用户后续手动改过的名字)。 + // dingtalk_name 属性下面单独每次写入企业 name,不受此条件影响。 + if syncUsername && cfg.SyncDisplayName { + username := strings.TrimSpace(staff.Nickname) + source := "nickname" + if username == "" { + username = strings.TrimSpace(staff.Name) + source = "name(fallback)" + } + if username != "" && h.userService != nil { + if _, err := h.userService.UpdateProfile(ctx, userID, service.UpdateProfileRequest{Username: &username}); err != nil { + slog.Warn("dingtalk sync: failed to update username", "user_id", userID, "err", err) + } else { + slog.Info("dingtalk sync: username updated (register)", "user_id", userID, "username", username, "source", source) + } + } + } + + // 属性同步(目标 attr key 从 cfg 读取,默认值由 GetDingTalkConnectOAuthConfig 保证非空) + type syncField struct { + key string + value string + } + var fields []syncField + + if cfg.SyncDisplayName && strings.TrimSpace(staff.Name) != "" { + fields = append(fields, syncField{cfg.SyncDisplayNameAttrKey, strings.TrimSpace(staff.Name)}) + } + if cfg.SyncCorpEmail && strings.TrimSpace(staff.Email) != "" { + fields = append(fields, syncField{cfg.SyncCorpEmailAttrKey, strings.TrimSpace(staff.Email)}) + } + if cfg.SyncDept && len(staff.DeptIDs) > 0 { + // 跳过根部门 ID=1,找第一个真实子部门;都是根则保留 1(最终写入空字符串覆盖旧值)。 + primaryDeptID := int64(0) + for _, id := range staff.DeptIDs { + if id > 1 { + primaryDeptID = id + break + } + } + if primaryDeptID == 0 { + primaryDeptID = staff.DeptIDs[0] + } + slog.Info("dingtalk sync: pick primary dept", "user_id", userID, "all_dept_ids", staff.DeptIDs, "primary", primaryDeptID) + path, err := h.resolveDingTalkDeptPath(ctx, client, primaryDeptID) + if err != nil { + slog.Warn("dingtalk sync: failed to resolve dept path", "user_id", userID, "dept_id", primaryDeptID, "err", err) + } else { + // path="" 表示公司直属(仅在根部门下),仍写入空串覆盖旧值。 + fields = append(fields, syncField{cfg.SyncDeptAttrKey, path}) + } + } + + if len(fields) == 0 { + return + } + + // 逐 key 查 definition 并 upsert + for _, f := range fields { + if err := h.setUserAttributeByKey(ctx, userID, f.key, f.value); err != nil { + slog.Warn("dingtalk sync: failed to set attribute", "user_id", userID, "key", f.key, "err", err) + } + } +} + +// syncDingTalkIdentityFromClaims 从 upstreamClaims 恢复 DingTalkStaffInfo 并调用 syncDingTalkIdentity。 +// 用于 pending session 完成阶段(complete-registration / create-account / bind-login)。 +// syncUsername=true 表示首次注册场景,需要把 nickname 写入 users.username。 +func (h *AuthHandler) syncDingTalkIdentityFromClaims(ctx context.Context, cfg config.DingTalkConnectConfig, client *DingTalkClient, userID int64, claims map[string]any, syncUsername bool) { + staff := dingTalkStaffFromClaims(claims) + h.syncDingTalkIdentity(ctx, cfg, client, userID, staff, syncUsername) +} + +// maybeSyncDingTalkAfterRegistration 在通用 OAuth 注册路径完成后调用。 +// 同步 4 个字段:users.username(首次) + dingtalk_name/email/department(每次)。 +func (h *AuthHandler) maybeSyncDingTalkAfterRegistration(ctx context.Context, session *dbent.PendingAuthSession, userID int64) { + h.dispatchDingTalkPendingSync(ctx, session, userID, true) +} + +// maybeSyncDingTalkAfterLogin 在通用 OAuth 登录/绑定路径完成后调用。 +// 仅刷新 3 个属性(dingtalk_name/email/department),不动 users.username。 +func (h *AuthHandler) maybeSyncDingTalkAfterLogin(ctx context.Context, session *dbent.PendingAuthSession, userID int64) { + h.dispatchDingTalkPendingSync(ctx, session, userID, false) +} + +func (h *AuthHandler) dispatchDingTalkPendingSync(ctx context.Context, session *dbent.PendingAuthSession, userID int64, syncUsername bool) { + if session == nil || userID <= 0 { + return + } + if !strings.EqualFold(strings.TrimSpace(session.ProviderType), "dingtalk") { + return + } + cfg, err := h.getDingTalkOAuthConfig(ctx) + if err != nil { + slog.Debug("dingtalk sync: skip post-login sync, config unavailable", "user_id", userID, "err", err.Error()) + return + } + client := h.dingTalkClient(cfg) + claims := session.UpstreamIdentityClaims + // 异步执行避免阻塞 token 响应。 + runDingTalkSyncAsync(ctx, func(asyncCtx context.Context) { + h.syncDingTalkIdentityFromClaims(asyncCtx, cfg, client, userID, claims, syncUsername) + }) +} + +// dingTalkStaffFromClaims 从 upstreamClaims 重建最小 DingTalkStaffInfo。 +func dingTalkStaffFromClaims(claims map[string]any) *DingTalkStaffInfo { + if claims == nil { + return &DingTalkStaffInfo{} + } + staff := &DingTalkStaffInfo{} + if v, ok := claims["username"].(string); ok { + staff.Name = v + } + if v, ok := claims["nickname"].(string); ok { + staff.Nickname = v + } + if v, ok := claims["email"].(string); ok { + staff.Email = v + } + if v, ok := claims["corp_user_id"].(string); ok { + staff.UserID = v + } + // primary_dept_id 存为 int64 或 float64(JSON round-trip) + switch v := claims["primary_dept_id"].(type) { + case int64: + if v > 0 { + staff.DeptIDs = []int64{v} + } + case float64: + if id := int64(v); id > 0 { + staff.DeptIDs = []int64{id} + } + } + return staff +} + +// setUserAttributeByKey 按 attribute key 查找 definition,再 upsert 用户属性值。 +// definition 不存在时记 warn 日志跳过(admin 在 settings 保存时已按需 upsert +// 对应 def;缺失意味着 admin 改了 attr key 但未保存 settings,或 def 被手工删除)。 +func (h *AuthHandler) setUserAttributeByKey(ctx context.Context, userID int64, key, value string) error { + def, err := h.userAttributeService.GetDefinitionByKey(ctx, key) + if err != nil { + slog.Warn("dingtalk sync: attribute definition not found, skipping", "key", key, "err", err.Error()) + return nil + } + if err := h.userAttributeService.UpdateUserAttributes(ctx, userID, []service.UpdateUserAttributeInput{ + {AttributeID: def.ID, Value: value}, + }); err != nil { + return err + } + slog.Info("dingtalk sync: attribute upserted", "user_id", userID, "key", key, "attr_id", def.ID) + return nil +} + +// resolveDingTalkDeptPath 从叶部门递归向上拼 "公司/部门/子部门" 路径字符串。 +// 遇 dept_id=1(根)或 parent_id=0 停止。加 visited set 防循环,最多 50 层。 +func (h *AuthHandler) resolveDingTalkDeptPath(ctx context.Context, client *DingTalkClient, deptID int64) (string, error) { + slog.Info("dingtalk sync: resolve dept path start", "dept_id", deptID) + const maxDepth = 50 + visited := make(map[int64]bool, maxDepth) + var parts []string + + current := deptID + for i := 0; i < maxDepth; i++ { + if current < 1 || visited[current] { + break + } + visited[current] = true + + info, err := client.GetDeptInfo(ctx, current) + if err != nil { + return "", fmt.Errorf("get dept info %d: %w", current, err) + } + if strings.TrimSpace(info.Name) != "" { + parts = append([]string{strings.TrimSpace(info.Name)}, parts...) + } + // 钉钉根部门 dept_id=1,ParentID 通常为 0;遇到 0 / self 终止避免循环。 + if info.ParentID < 1 || info.ParentID == current { + break + } + current = info.ParentID + } + + // 去除根组织名(parts[0] 始终是企业全称),仅保留部门层级。 + // 例:["公司","A","B"] → "A/B";["公司"] → ""(公司直属)。 + if len(parts) > 0 { + parts = parts[1:] + } + + return strings.Join(parts, "/"), nil +} diff --git a/backend/internal/handler/auth_dingtalk_oauth_test.go b/backend/internal/handler/auth_dingtalk_oauth_test.go new file mode 100644 index 00000000000..1f60e6b6525 --- /dev/null +++ b/backend/internal/handler/auth_dingtalk_oauth_test.go @@ -0,0 +1,391 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestDingTalkOAuthStart_Disabled は sentinel テスト。 +// TODO(task-1.10): newTestAuthHandlerWithDingTalk helper が追加されたら t.Skip を外す。 +func TestDingTalkOAuthStart_Disabled(t *testing.T) { + t.Skip("helper newTestAuthHandlerWithDingTalk added in Task 1.10; sentinel only") +} + +// TestBuildDingTalkSyntheticEmail_UsesUnionID 验证合成邮箱种子使用 unionID。 +func TestBuildDingTalkSyntheticEmail_UsesUnionID(t *testing.T) { + unionID := "union_AbCdEf123" + email := buildDingTalkSyntheticEmail(unionID) + + want := "dingtalk-union_abcdef123@dingtalk-connect.invalid" + require.Equal(t, want, email) + + // 确保结果都是小写(邮箱大小写不敏感,统一小写) + require.True(t, strings.ToLower(email) == email, "synthetic email should be all lowercase") + + // 确保前缀正确 + require.True(t, strings.HasPrefix(email, "dingtalk-"), "should have dingtalk- prefix") + + // 确保后缀是合成邮箱域名 + require.True(t, strings.HasSuffix(email, "@dingtalk-connect.invalid"), "should have reserved domain suffix") +} + +// TestBuildDingTalkSyntheticEmail_TrimsSpace 验证 unionID 空白被修剪。 +func TestBuildDingTalkSyntheticEmail_TrimsSpace(t *testing.T) { + email := buildDingTalkSyntheticEmail(" UID_XYZ ") + require.Equal(t, "dingtalk-uid_xyz@dingtalk-connect.invalid", email) +} + +// TestBuildDingTalkUpstreamClaims_EmptyStaff 验证 staff 为空 struct(跨组织降级路径)时: +// - subject 等于 unionID(与 identityKey.ProviderSubject 一致) +// - corp_user_id 为空字符串(跨组织时拿不到企业 userid) +// - email/username 为空字符串 +// B/C: Step 3/4 失败降级时 staff = &DingTalkStaffInfo{},claims 不应有 nil。 +func TestBuildDingTalkUpstreamClaims_EmptyStaff(t *testing.T) { + staff := &DingTalkStaffInfo{} + claims := buildDingTalkUpstreamClaims(staff, "UNION_AAA", "CORP_X") + + require.Equal(t, "", claims["email"]) + require.Equal(t, "", claims["username"]) + // 重构后 subject = unionID(与 identityKey.ProviderSubject 保持一致) + require.Equal(t, "UNION_AAA", claims["subject"]) + require.Equal(t, "", claims["corp_user_id"]) // 企业 userid 跨组织时为空 + require.Equal(t, "UNION_AAA", claims["union_id"]) + require.Equal(t, "CORP_X", claims["corp_id"]) +} + +// TestCheckDingTalkCorpAllowed_CrossOrgPolicy 验证 policy=none 时允许任意 corp。 +// D: corp 校验提前后逻辑不变。 +func TestCheckDingTalkCorpAllowed_CrossOrgPolicy(t *testing.T) { + cfg := config.DingTalkConnectConfig{CorpRestrictionPolicy: "none"} + + assert.True(t, checkDingTalkCorpAllowed(cfg, "dingABC"), "policy=none should allow any corp") + assert.True(t, checkDingTalkCorpAllowed(cfg, ""), "policy=none should allow empty corp") + assert.True(t, checkDingTalkCorpAllowed(cfg, "foreign_corp"), "policy=none should allow foreign corp") +} + +// TestCheckDingTalkCorpAllowed_InternalOnly 验证 policy=internal_only 时的 corp 校验语义(方案 A 修订)。 +// 钉钉 userAccessToken 在部分授权场景(扫码登录、非企业工作台入口)不返回 corpId 字段, +// 因此 checkDingTalkCorpAllowed 完全不校验 corpID,由 step 3 GetUserIdByUnionId 做真实判定 +// (跨企业用户会被钉钉错误码 60011/60121 拒绝,mapDingTalkErrorCode 映射回 corp_rejected)。 +func TestCheckDingTalkCorpAllowed_InternalOnly(t *testing.T) { + cfgWithCorpID := config.DingTalkConnectConfig{ + CorpRestrictionPolicy: "internal_only", + InternalCorpID: "dingInternal", + } + assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "dingInternal"), "internal_only: matching corpID allowed") + assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, "foreign_corp"), "internal_only: corpID 字段不再用于决策,step 3 兜底") + assert.True(t, checkDingTalkCorpAllowed(cfgWithCorpID, ""), "internal_only: 空 corpID 也通过(钉钉部分授权场景不返回 corpId)") + + cfgNoCorpID := config.DingTalkConnectConfig{ + CorpRestrictionPolicy: "internal_only", + InternalCorpID: "", + } + assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, "dingAnyNonEmpty"), "internal_only + no InternalCorpID: 非空 corpID 通过") + assert.True(t, checkDingTalkCorpAllowed(cfgNoCorpID, ""), "internal_only + no InternalCorpID: 空 corpID 也通过") +} + +// TestDecideDingTalkStep34Strategy_PolicyNone 验证 policy=none 时 +// Step 3/4 失败应降级(shouldFallback=true, isFatal=false)。 +func TestDecideDingTalkStep34Strategy_PolicyNone(t *testing.T) { + step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403} + + shouldFallback, isFatal := decideDingTalkStep34Strategy("none", step3Err) + + require.True(t, shouldFallback, "policy=none: step3 failure should trigger fallback") + require.False(t, isFatal, "policy=none: step3 failure should NOT be fatal") +} + +// TestDecideDingTalkStep34Strategy_PolicyNoneEmpty 验证 policy="" 时行为与 "none" 相同。 +func TestDecideDingTalkStep34Strategy_PolicyNoneEmpty(t *testing.T) { + stepErr := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403} + + shouldFallback, isFatal := decideDingTalkStep34Strategy("", stepErr) + + require.True(t, shouldFallback, "policy='': step failure should trigger fallback") + require.False(t, isFatal, "policy='': step failure should NOT be fatal") +} + +// TestDecideDingTalkStep34Strategy_PolicyInternalOnly 验证 policy=internal_only 时 +// Step 3/4 失败应 hard fail(isFatal=true)。 +func TestDecideDingTalkStep34Strategy_PolicyInternalOnly(t *testing.T) { + step3Err := &DingTalkAPIError{Code: "60011", Message: "not in directory", HTTP: 403} + + shouldFallback, isFatal := decideDingTalkStep34Strategy("internal_only", step3Err) + + require.False(t, shouldFallback, "policy=internal_only: should NOT fallback on step3 error") + require.True(t, isFatal, "policy=internal_only: step3 failure should be fatal") +} + +// TestDecideDingTalkStep34Strategy_NoError 验证 stepErr=nil 时两个返回值均为 false。 +func TestDecideDingTalkStep34Strategy_NoError(t *testing.T) { + for _, policy := range []string{"none", "internal_only", ""} { + shouldFallback, isFatal := decideDingTalkStep34Strategy(policy, nil) + require.False(t, shouldFallback, "no error should not trigger fallback (policy=%q)", policy) + require.False(t, isFatal, "no error should not be fatal (policy=%q)", policy) + } +} + +// TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart 验证 username 为空时 +// 退到 email local part(@ 之前的部分)。 +// E: CompleteDingTalkOAuthRegistration username fallback。 +func TestCompleteDingTalkRegistration_UsernameFromEmailLocalPart(t *testing.T) { + tests := []struct { + name string + email string + username string + wantUser string + wantValid bool + }{ + { + name: "username empty, normal email → local part", + email: "dingtalk-uid123@dingtalk-connect.invalid", + username: "", + wantUser: "dingtalk-uid123", + wantValid: true, + }, + { + name: "username already set → keep original", + email: "user@example.com", + username: "张三", + wantUser: "张三", + wantValid: true, + }, + { + name: "username empty, no @ in email → use whole email", + email: "noemail", + username: "", + wantUser: "noemail", + wantValid: true, + }, + { + name: "both empty → invalid", + email: "", + username: "", + wantUser: "", + wantValid: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + username := tc.username + email := tc.email + + // 模拟 CompleteDingTalkOAuthRegistration 中的 fallback 逻辑 + if username == "" { + if at := strings.Index(email, "@"); at > 0 { + username = email[:at] + } else { + username = email + } + } + + isValid := email != "" && username != "" + require.Equal(t, tc.wantUser, username, fmt.Sprintf("username for email=%q", tc.email)) + require.Equal(t, tc.wantValid, isValid, "validity check") + }) + } +} + +// TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID 验证重构后 subject = unionID +// 而非 staff.UserID,与 identityKey.ProviderSubject 保持一致。 +// §4.2: buildDingTalkUpstreamClaims subject 字段修正。 +func TestBuildDingTalkUpstreamClaims_SubjectEqualsUnionID(t *testing.T) { + staff := &DingTalkStaffInfo{UserID: "user123", Name: "张三", Email: "zhangsan@corp.com"} + claims := buildDingTalkUpstreamClaims(staff, "union456", "dingcorp789") + + // 重构后 subject = unionID(全局唯一,与 identityKey.ProviderSubject 一致) + require.Equal(t, "union456", claims["subject"], "subject should equal unionID after refactor") + // 企业 userid 保留为独立字段,供 audit/debug 使用 + require.Equal(t, "user123", claims["corp_user_id"], "corp_user_id should be staff.UserID") + // union_id 字段与 subject 相同(冗余保留,便于读取) + require.Equal(t, "union456", claims["union_id"]) + require.Equal(t, "dingcorp789", claims["corp_id"]) + require.Equal(t, "张三", claims["username"]) + require.Equal(t, "zhangsan@corp.com", claims["email"]) +} + +// TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID 验证跨组织降级时 +// corp_user_id 为空字符串(跨组织拿不到企业 userid),subject 仍为 unionID。 +func TestBuildDingTalkUpstreamClaims_CrossOrgEmptyCorpUserID(t *testing.T) { + // 跨组织降级路径:staff = &DingTalkStaffInfo{}(所有字段为零值) + staff := &DingTalkStaffInfo{} + claims := buildDingTalkUpstreamClaims(staff, "union_cross_org", "foreign_corp") + + require.Equal(t, "union_cross_org", claims["subject"], "subject should still be unionID for cross-org users") + require.Equal(t, "", claims["corp_user_id"], "corp_user_id should be empty for cross-org fallback") + require.Equal(t, "", claims["email"]) + require.Equal(t, "", claims["username"]) +} + +// TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims 验证首个 dept_id 被存入 claims。 +func TestBuildDingTalkUpstreamClaims_PrimaryDeptIDInClaims(t *testing.T) { + staff := &DingTalkStaffInfo{UserID: "u1", Name: "张三", Email: "a@b.com", DeptIDs: []int64{42, 99}} + claims := buildDingTalkUpstreamClaims(staff, "uid1", "corpX") + + // 只取首个 dept_id + require.Equal(t, int64(42), claims["primary_dept_id"], "primary_dept_id should be the first dept_id") +} + +// TestBuildDingTalkUpstreamClaims_NoDeptIDs 验证无部门时 primary_dept_id=0。 +func TestBuildDingTalkUpstreamClaims_NoDeptIDs(t *testing.T) { + staff := &DingTalkStaffInfo{UserID: "u2", Name: "李四"} + claims := buildDingTalkUpstreamClaims(staff, "uid2", "corpY") + + require.Equal(t, int64(0), claims["primary_dept_id"], "primary_dept_id should be 0 when no depts") +} + +// TestDingTalkStaffFromClaims_RoundTrip 验证 dingTalkStaffFromClaims 能从 claims 恢复 staff 信息。 +func TestDingTalkStaffFromClaims_RoundTrip(t *testing.T) { + staff := &DingTalkStaffInfo{UserID: "u3", Name: "王五", Email: "ww@corp.com", DeptIDs: []int64{55}} + claims := buildDingTalkUpstreamClaims(staff, "uid3", "corpZ") + + recovered := dingTalkStaffFromClaims(claims) + require.Equal(t, "王五", recovered.Name) + require.Equal(t, "ww@corp.com", recovered.Email) + require.Equal(t, "u3", recovered.UserID) + require.Equal(t, []int64{55}, recovered.DeptIDs) +} + +// TestResolveDingTalkDeptPath_SingleLevel 验证单层部门(parent_id=1)返回部门名。 +func TestResolveDingTalkDeptPath_SingleLevel(t *testing.T) { + handler := &AuthHandler{} + callCount := 0 + responses := map[string]string{ + "42": `{"errcode":0,"result":{"dept_id":42,"name":"研发部","parent_id":1}}`, + "1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`, + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + var req struct { + DeptID int64 `json:"dept_id"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + w.Header().Set("Content-Type", "application/json") + if resp, ok := responses[fmt.Sprintf("%d", req.DeptID)]; ok { + _, _ = w.Write([]byte(resp)) + } else { + _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`)) + } + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"}, + httpClient: server.Client(), + } + cli.appToken = "tok" + cli.appTokenExp = time.Now().Add(time.Hour) + + path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42) + require.NoError(t, err) + require.Equal(t, "研发部", path) + require.Equal(t, 2, callCount) +} + +// TestSyncDingTalkIdentity_UsesCfgAttrKeys 验证 syncDingTalkIdentity 使用 cfg 中配置的 attr key +// 而不是硬编码值。通过 userAttributeService=nil 使同步路径走 warn 跳过,但在此之前先验证 +// syncField 构建逻辑(即 attr key 从 cfg 读取)。 +// 间接验证:通过构造定制 cfg,确认不同 attr key 可以正确传入(编译时保证类型正确,运行时不 panic)。 +func TestSyncDingTalkIdentity_UsesCfgAttrKeys_NoopWithNilService(t *testing.T) { + handler := &AuthHandler{ + userAttributeService: nil, // nil → 触发 warn 跳过,但不 panic + } + + cfg := config.DingTalkConnectConfig{ + CorpRestrictionPolicy: "internal_only", + SyncCorpEmail: true, + SyncDisplayName: true, + SyncDept: true, + // 自定义 attr key(非默认值) + SyncCorpEmailAttrKey: "custom_email_key", + SyncDisplayNameAttrKey: "custom_name_key", + SyncDeptAttrKey: "custom_dept_key", + } + + staff := &DingTalkStaffInfo{ + Name: "张三", + Email: "zhangsan@example.com", + } + + // 调用不应 panic(userAttributeService 为 nil 时走 warn 跳过路径) + require.NotPanics(t, func() { + handler.syncDingTalkIdentity(context.Background(), cfg, nil, 42, staff, false) + }) +} + +// TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService 验证 cfg 默认 attr key 为空时 +// 使用 fallback 默认值(dingtalk_email / dingtalk_name / dingtalk_department)。 +// 此测试主要验证调用路径不 panic;实际 key 赋值默认值的逻辑在 GetDingTalkConnectOAuthConfig 层。 +func TestSyncDingTalkIdentity_DefaultAttrKeys_NoopWithNilService(t *testing.T) { + handler := &AuthHandler{ + userAttributeService: nil, + } + + cfg := config.DingTalkConnectConfig{ + CorpRestrictionPolicy: "internal_only", + SyncCorpEmail: true, + SyncDisplayName: true, + SyncDept: false, + // 不设置 attr key(等同于 GetDingTalkConnectOAuthConfig 未设置时 fallback 后的默认值已在调用前填充) + SyncCorpEmailAttrKey: "dingtalk_email", + SyncDisplayNameAttrKey: "dingtalk_name", + SyncDeptAttrKey: "dingtalk_department", + } + + staff := &DingTalkStaffInfo{ + Name: "李四", + Email: "lisi@corp.com", + } + + require.NotPanics(t, func() { + handler.syncDingTalkIdentity(context.Background(), cfg, nil, 99, staff, false) + }) +} + +// TestResolveDingTalkDeptPath_MultiLevel 验证多层部门路径拼接。 +func TestResolveDingTalkDeptPath_MultiLevel(t *testing.T) { + handler := &AuthHandler{} + // 模拟:42(AI研发) → parent=10(研发部) → parent=1(根) + responses := map[string]string{ + "42": `{"errcode":0,"result":{"dept_id":42,"name":"AI研发","parent_id":10}}`, + "10": `{"errcode":0,"result":{"dept_id":10,"name":"研发部","parent_id":1}}`, + "1": `{"errcode":0,"result":{"dept_id":1,"name":"公司","parent_id":0}}`, + } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 解析请求 body 拿到 dept_id + var req struct { + DeptID int64 `json:"dept_id"` + } + _ = json.NewDecoder(r.Body).Decode(&req) + key := fmt.Sprintf("%d", req.DeptID) + w.Header().Set("Content-Type", "application/json") + if resp, ok := responses[key]; ok { + _, _ = w.Write([]byte(resp)) + } else { + _, _ = w.Write([]byte(`{"errcode":60003,"errmsg":"not found"}`)) + } + })) + defer server.Close() + + cli := &DingTalkClient{ + cfg: dingTalkClientConfig{UserInfoURL: server.URL + "/stub"}, + httpClient: server.Client(), + } + cli.appToken = "tok" + cli.appTokenExp = time.Now().Add(time.Hour) + + path, err := handler.resolveDingTalkDeptPath(context.Background(), cli, 42) + require.NoError(t, err) + require.Equal(t, "研发部/AI研发", path) +} diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 1f9a66ff371..a9af910d746 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -4,6 +4,7 @@ import ( "context" "log/slog" "strings" + "sync" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -18,25 +19,30 @@ import ( // AuthHandler handles authentication-related requests type AuthHandler struct { - cfg *config.Config - authService *service.AuthService - userService *service.UserService - settingSvc *service.SettingService - promoService *service.PromoService - redeemService *service.RedeemService - totpService *service.TotpService + cfg *config.Config + authService *service.AuthService + userService *service.UserService + settingSvc *service.SettingService + promoService *service.PromoService + redeemService *service.RedeemService + totpService *service.TotpService + userAttributeService *service.UserAttributeService + + dingTalkClientInstance *DingTalkClient + dingTalkClientMu sync.Mutex } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, redeemService *service.RedeemService, totpService *service.TotpService, userAttributeService *service.UserAttributeService) *AuthHandler { return &AuthHandler{ - cfg: cfg, - authService: authService, - userService: userService, - settingSvc: settingService, - promoService: promoService, - redeemService: redeemService, - totpService: totpService, + cfg: cfg, + authService: authService, + userService: userService, + settingSvc: settingService, + promoService: promoService, + redeemService: redeemService, + totpService: totpService, + userAttributeService: userAttributeService, } } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index 7df4abfd451..f0ea5fdeda7 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -350,7 +350,8 @@ func (h *AuthHandler) findLinuxDoCompatEmailUser(ctx context.Context, email stri if email == "" || strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || - strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) { return nil, nil } @@ -519,7 +520,7 @@ func (h *AuthHandler) CompleteLinuxDoOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "linuxdo") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_oauth_pending_flow.go b/backend/internal/handler/auth_oauth_pending_flow.go index 490afd0f5d7..1014a3e8f73 100644 --- a/backend/internal/handler/auth_oauth_pending_flow.go +++ b/backend/internal/handler/auth_oauth_pending_flow.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net/http" "net/url" "strings" @@ -195,6 +196,14 @@ func (h *AuthHandler) createOAuthPendingSession(c *gin.Context, payload oauthPen }, }) if err != nil { + slog.Error("pending auth session create failed", + "intent", strings.TrimSpace(payload.Intent), + "provider_type", strings.TrimSpace(payload.Identity.ProviderType), + "provider_key", strings.TrimSpace(payload.Identity.ProviderKey), + "provider_subject_len", len(strings.TrimSpace(payload.Identity.ProviderSubject)), + "resolved_email_len", len(strings.TrimSpace(payload.ResolvedEmail)), + "has_target_user", payload.TargetUserID != nil, + "error", err.Error()) return infraerrors.InternalServer("PENDING_AUTH_SESSION_CREATE_FAILED", "failed to create pending auth session").WithCause(err) } @@ -266,6 +275,22 @@ func pendingSessionWantsInvitation(payload map[string]any) bool { return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "error")), "invitation_required") } +// pendingSessionRequiresEmailCompletion 判断 callback 写入的 completion payload 是否处于"补邮箱"状态。 +// 钉钉跨组织/staff 邮箱缺失时进入此状态:前端跳到补邮箱页,exchange 不应走 adoption apply。 +func pendingSessionRequiresEmailCompletion(payload map[string]any) bool { + if v, ok := payload["requires_email_completion"].(bool); ok && v { + return true + } + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "email_completion") +} + +// pendingSessionRequiresBindLogin 判断 callback 写入的 completion payload 是否处于"必须绑定已有账户"状态。 +// 钉钉 signupBlocked=true(注册关 + 钉钉企业豁免关)时进入此状态:前端渲染 bind_login 表单, +// exchange 不应消费 session,否则后续 /pending/bind-login 找不到 session。 +func pendingSessionRequiresBindLogin(payload map[string]any) bool { + return strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(payload, "step")), "bind_login_required") +} + func pendingOAuthCompletionCanIssueTokenPair(session *dbent.PendingAuthSession, payload map[string]any) bool { if session == nil { return false @@ -1467,8 +1492,10 @@ func normalizePendingOAuthCompletionResponse(payload map[string]any) map[string] delete(normalized, key) } step := strings.ToLower(strings.TrimSpace(pendingSessionStringValue(normalized, "step"))) + // 把多种 choice 别名归一为 oauthPendingChoiceStep;bind_login_required 是独立终态 + // (前端渲染 needsBindLogin 而非 needsChooser),故不能并入归一化列表。 switch step { - case "choice", "choose_account_action", "choose_account", "choose", "email_required", "bind_login_required": + case "choice", "choose_account_action", "choose_account", "choose", "email_required": normalized["step"] = oauthPendingChoiceStep } if strings.EqualFold(strings.TrimSpace(pendingSessionStringValue(normalized, "step")), oauthPendingChoiceStep) { @@ -1594,6 +1621,8 @@ func (h *AuthHandler) bindPendingOAuthLogin(c *gin.Context, provider string) { } h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + // bindPendingOAuthLogin = 绑定已有账户登录,不动 users.username(用户已有自己的名字) + h.maybeSyncDingTalkAfterLogin(c.Request.Context(), session, user.ID) tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") if err != nil { response.InternalError(c, "Failed to generate token pair") @@ -1792,6 +1821,8 @@ func (h *AuthHandler) createPendingOAuthAccount(c *gin.Context, provider string) } h.authService.RecordSuccessfulLogin(c.Request.Context(), user.ID) + // createPendingOAuthAccount = 注册新账户,需要把钉钉昵称同步到 users.username 作为初始值 + h.maybeSyncDingTalkAfterRegistration(c.Request.Context(), session, user.ID) clearCookies() writeOAuthTokenPairResponse(c, tokenPair) } @@ -1893,6 +1924,14 @@ func (h *AuthHandler) ExchangePendingOAuthCompletion(c *gin.Context) { response.Success(c, payload) return } + if pendingSessionRequiresEmailCompletion(payload) { + response.Success(c, payload) + return + } + if pendingSessionRequiresBindLogin(payload) { + response.Success(c, payload) + return + } if !adoptionDecision.hasDecision() { adoptionRequired, _ := payload["adoption_required"].(bool) if adoptionRequired { diff --git a/backend/internal/handler/auth_oidc_oauth.go b/backend/internal/handler/auth_oidc_oauth.go index 4264002db37..c7c517c8965 100644 --- a/backend/internal/handler/auth_oidc_oauth.go +++ b/backend/internal/handler/auth_oidc_oauth.go @@ -502,7 +502,8 @@ func (h *AuthHandler) findOIDCCompatEmailUser(ctx context.Context, email string) if email == "" || strings.HasSuffix(email, service.LinuxDoConnectSyntheticEmailDomain) || strings.HasSuffix(email, service.OIDCConnectSyntheticEmailDomain) || - strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) { + strings.HasSuffix(email, service.WeChatConnectSyntheticEmailDomain) || + strings.HasSuffix(email, service.DingTalkConnectSyntheticEmailDomain) { return nil, nil } @@ -666,7 +667,7 @@ func (h *AuthHandler) CompleteOIDCOAuthRegistration(c *gin.Context) { response.ErrorFrom(c, err) return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "oidc") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/auth_wechat_oauth.go b/backend/internal/handler/auth_wechat_oauth.go index 34e70ed09da..2199c5bd281 100644 --- a/backend/internal/handler/auth_wechat_oauth.go +++ b/backend/internal/handler/auth_wechat_oauth.go @@ -548,7 +548,7 @@ func (h *AuthHandler) CompleteWeChatOAuthRegistration(c *gin.Context) { return } - tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode) + tokenPair, user, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username, req.InvitationCode, req.AffCode, "wechat") if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/dto/account_mapper_redact_test.go b/backend/internal/handler/dto/account_mapper_redact_test.go new file mode 100644 index 00000000000..bd584e11237 --- /dev/null +++ b/backend/internal/handler/dto/account_mapper_redact_test.go @@ -0,0 +1,67 @@ +package dto + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestAccountFromServiceShallow_RedactsSensitiveCredentials(t *testing.T) { + src := &service.Account{ + ID: 42, + Name: "demo", + Platform: "anthropic", + Type: "oauth", + Credentials: map[string]any{ + "access_token": "at-secret", + "refresh_token": "rt-secret", + "id_token": "id-secret", + "api_key": "sk-secret", + "base_url": "https://api.example.com", + "model_mapping": map[string]any{"foo": "bar"}, + }, + } + + got := AccountFromServiceShallow(src) + require.NotNil(t, got) + + // 敏感键不在 Credentials 里 + require.NotContains(t, got.Credentials, "access_token") + require.NotContains(t, got.Credentials, "refresh_token") + require.NotContains(t, got.Credentials, "id_token") + require.NotContains(t, got.Credentials, "api_key") + // 非敏感键保留 + require.Equal(t, "https://api.example.com", got.Credentials["base_url"]) + require.Equal(t, map[string]any{"foo": "bar"}, got.Credentials["model_mapping"]) + + // 状态 map 标记敏感键存在 + require.True(t, got.CredentialsStatus["has_access_token"]) + require.True(t, got.CredentialsStatus["has_refresh_token"]) + require.True(t, got.CredentialsStatus["has_id_token"]) + require.True(t, got.CredentialsStatus["has_api_key"]) + + // JSON 序列化校验:响应体里不会出现敏感子串 + raw, err := json.Marshal(got) + require.NoError(t, err) + require.NotContains(t, string(raw), "rt-secret") + require.NotContains(t, string(raw), "at-secret") + require.NotContains(t, string(raw), "sk-secret") + require.NotContains(t, string(raw), "id-secret") + // 状态标识应序列化进 JSON + require.Contains(t, string(raw), "credentials_status") + require.Contains(t, string(raw), "has_refresh_token") + + // 原始 service.Account 不应被改动 + require.Equal(t, "rt-secret", src.Credentials["refresh_token"]) +} + +func TestAccountFromServiceShallow_NilCredentialsOmitsStatus(t *testing.T) { + src := &service.Account{ID: 1, Name: "n", Platform: "anthropic", Type: "oauth"} + got := AccountFromServiceShallow(src) + require.NotNil(t, got) + require.Nil(t, got.Credentials) + require.Nil(t, got.CredentialsStatus) +} diff --git a/backend/internal/handler/dto/credentials_redact.go b/backend/internal/handler/dto/credentials_redact.go new file mode 100644 index 00000000000..e65a8007060 --- /dev/null +++ b/backend/internal/handler/dto/credentials_redact.go @@ -0,0 +1,44 @@ +// Package dto provides data transfer objects for HTTP handlers. +package dto + +import "github.com/Wei-Shaw/sub2api/internal/service" + +// RedactCredentials 复制一份 in,剥离 service.SensitiveCredentialKeys 列出的所有敏感子键, +// 并产出一个 has_ 状态 map 表示哪些敏感键存在且非零值。 +// +// 输入 nil 时返回 nil, nil(避免响应里出现空对象)。 +// 不修改入参;调用方拿到的 out 可安全序列化进 JSON 返回前端。 +func RedactCredentials(in map[string]any) (out map[string]any, status map[string]bool) { + if in == nil { + return nil, nil + } + out = make(map[string]any, len(in)) + for k, v := range in { + if service.IsSensitiveCredentialKey(k) { + if isCredentialValuePresent(v) { + if status == nil { + status = make(map[string]bool, 4) + } + status["has_"+k] = true + } + continue + } + out[k] = v + } + return out, status +} + +// isCredentialValuePresent 判断值是否"存在且非零"。空字符串、nil、false 均视为未配置; +// 其余非零类型(数字、对象、字符串等)视为已配置。 +func isCredentialValuePresent(v any) bool { + switch x := v.(type) { + case nil: + return false + case string: + return x != "" + case bool: + return x + default: + return true + } +} diff --git a/backend/internal/handler/dto/credentials_redact_test.go b/backend/internal/handler/dto/credentials_redact_test.go new file mode 100644 index 00000000000..431078fafdd --- /dev/null +++ b/backend/internal/handler/dto/credentials_redact_test.go @@ -0,0 +1,97 @@ +package dto + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRedactCredentials_NilInput(t *testing.T) { + out, status := RedactCredentials(nil) + require.Nil(t, out) + require.Nil(t, status) +} + +func TestRedactCredentials_StripsSensitiveKeysAndReportsStatus(t *testing.T) { + in := map[string]any{ + "refresh_token": "rt-secret", + "access_token": "at-secret", + "api_key": "sk-secret", + "aws_secret_access_key": "aws-secret", + "service_account_json": map[string]any{"private_key": "..."}, + "private_key": "raw-key", + // 非敏感 + "base_url": "https://api.example.com", + "model_mapping": map[string]any{"foo": "bar"}, + "project_id": "proj-1", + "expires_at": int64(123456), + } + + out, status := RedactCredentials(in) + + require.NotContains(t, out, "refresh_token") + require.NotContains(t, out, "access_token") + require.NotContains(t, out, "api_key") + require.NotContains(t, out, "aws_secret_access_key") + require.NotContains(t, out, "service_account_json") + require.NotContains(t, out, "private_key") + + require.Equal(t, "https://api.example.com", out["base_url"]) + require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"]) + require.Equal(t, "proj-1", out["project_id"]) + require.Equal(t, int64(123456), out["expires_at"]) + + require.True(t, status["has_refresh_token"]) + require.True(t, status["has_access_token"]) + require.True(t, status["has_api_key"]) + require.True(t, status["has_aws_secret_access_key"]) + require.True(t, status["has_service_account_json"]) + require.True(t, status["has_private_key"]) + + // 状态 map 不应携带非敏感键的 has_* + require.NotContains(t, status, "has_base_url") + require.NotContains(t, status, "has_project_id") +} + +func TestRedactCredentials_EmptyValuesNotMarkedPresent(t *testing.T) { + in := map[string]any{ + "refresh_token": "", + "access_token": nil, + "api_key": false, + "id_token": "actual-id", + } + out, status := RedactCredentials(in) + require.Empty(t, out, "敏感键即使为空也不应出现在 redacted output") + require.False(t, status["has_refresh_token"]) + require.False(t, status["has_access_token"]) + require.False(t, status["has_api_key"]) + require.True(t, status["has_id_token"]) +} + +func TestRedactCredentials_DoesNotMutateInput(t *testing.T) { + in := map[string]any{ + "refresh_token": "secret", + "base_url": "x", + } + _, _ = RedactCredentials(in) + require.Equal(t, "secret", in["refresh_token"], "原始 map 不应被修改") + require.Equal(t, "x", in["base_url"]) +} + +func TestRedactCredentials_AllKnownSensitiveKeys(t *testing.T) { + keys := []string{ + "access_token", "refresh_token", "id_token", + "api_key", "session_key", "cookie", + "aws_secret_access_key", "aws_session_token", + "service_account_json", "service_account", "private_key", + } + in := make(map[string]any, len(keys)) + for _, k := range keys { + in[k] = "filled" + } + out, status := RedactCredentials(in) + require.Empty(t, out) + for _, k := range keys { + require.True(t, status["has_"+k], "key %s 应在 status 中标记为已配置", k) + } +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 2559b112cb9..2c71be9d492 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -198,13 +198,15 @@ func AccountFromServiceShallow(a *service.Account) *Account { if a == nil { return nil } + redactedCreds, credsStatus := RedactCredentials(a.Credentials) out := &Account{ ID: a.ID, Name: a.Name, Notes: a.Notes, Platform: a.Platform, Type: a.Type, - Credentials: a.Credentials, + Credentials: redactedCreds, + CredentialsStatus: credsStatus, Extra: a.Extra, ProxyID: a.ProxyID, Concurrency: a.Concurrency, @@ -531,11 +533,15 @@ func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode { UsedBy: rc.UsedBy, UsedAt: rc.UsedAt, CreatedAt: rc.CreatedAt, + ExpiresAt: rc.ExpiresAt, GroupID: rc.GroupID, ValidityDays: rc.ValidityDays, User: UserFromServiceShallow(rc.User), Group: GroupFromServiceShallow(rc.Group), } + if rc.IsExpired() { + out.Status = service.StatusExpired + } // For admin_balance/admin_concurrency types, include notes so users can see // why they were charged or credited by admin @@ -600,6 +606,10 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + ImageInputSize: l.ImageInputSize, + ImageOutputSize: l.ImageOutputSize, + ImageSizeSource: l.ImageSizeSource, + ImageSizeBreakdown: l.ImageSizeBreakdown, MediaType: l.MediaType, UserAgent: l.UserAgent, CacheTTLOverridden: l.CacheTTLOverridden, diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go index c2635e339a5..eca838b9aff 100644 --- a/backend/internal/handler/dto/mappers_usage_test.go +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -148,6 +148,65 @@ func TestUsageLogFromService_FallsBackToLegacyModelWhenRequestedModelMissing(t * require.Equal(t, "claude-3", adminDTO.Model) } +func TestUsageLogFromService_IncludesImageBillingMetadataForUserAndAdmin(t *testing.T) { + t.Parallel() + + imageSize := "4K" + inputSize := "1024x1024" + outputSize := "3840x2160" + source := "output" + log := &service.UsageLog{ + RequestID: "req_image_metadata", + Model: "gpt-image-2", + ImageCount: 2, + ImageSize: &imageSize, + ImageInputSize: &inputSize, + ImageOutputSize: &outputSize, + ImageSizeSource: &source, + ImageSizeBreakdown: map[string]int{"4K": 2}, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + for _, got := range []*UsageLog{userDTO, &adminDTO.UsageLog} { + require.Equal(t, 2, got.ImageCount) + require.NotNil(t, got.ImageSize) + require.Equal(t, imageSize, *got.ImageSize) + require.NotNil(t, got.ImageInputSize) + require.Equal(t, inputSize, *got.ImageInputSize) + require.NotNil(t, got.ImageOutputSize) + require.Equal(t, outputSize, *got.ImageOutputSize) + require.NotNil(t, got.ImageSizeSource) + require.Equal(t, source, *got.ImageSizeSource) + require.Equal(t, map[string]int{"4K": 2}, got.ImageSizeBreakdown) + } +} + +func TestUsageLogFromService_PreservesHistoricalMissingImageSize(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_legacy_image_missing_size", + Model: "gpt-image-2", + ImageCount: 1, + ImageSize: nil, + } + + dto := UsageLogFromService(log) + require.Equal(t, 1, dto.ImageCount) + require.Nil(t, dto.ImageSize) + require.Nil(t, dto.ImageInputSize) + require.Nil(t, dto.ImageOutputSize) + require.Nil(t, dto.ImageSizeSource) + require.Nil(t, dto.ImageSizeBreakdown) + + body, err := json.Marshal(dto) + require.NoError(t, err) + require.Contains(t, string(body), `"image_size":null`) + require.NotContains(t, string(body), `"image_size":"2K"`) +} + func f64Ptr(value float64) *float64 { return &value } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 551cf0dc995..fb09faf76a3 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -56,6 +56,23 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + DingTalkConnectEnabled bool `json:"dingtalk_connect_enabled"` + DingTalkConnectClientID string `json:"dingtalk_connect_client_id"` + DingTalkConnectClientSecretConfigured bool `json:"dingtalk_connect_client_secret_configured"` + DingTalkConnectRedirectURL string `json:"dingtalk_connect_redirect_url"` + DingTalkConnectCorpRestrictionPolicy string `json:"dingtalk_connect_corp_restriction_policy"` + DingTalkConnectInternalCorpID string `json:"dingtalk_connect_internal_corp_id"` + DingTalkConnectBypassRegistration bool `json:"dingtalk_connect_bypass_registration"` + DingTalkConnectSyncCorpEmail bool `json:"dingtalk_connect_sync_corp_email"` + DingTalkConnectSyncDisplayName bool `json:"dingtalk_connect_sync_display_name"` + DingTalkConnectSyncDept bool `json:"dingtalk_connect_sync_dept"` + DingTalkConnectSyncCorpEmailAttrKey string `json:"dingtalk_connect_sync_corp_email_attr_key"` + DingTalkConnectSyncDisplayNameAttrKey string `json:"dingtalk_connect_sync_display_name_attr_key"` + DingTalkConnectSyncDeptAttrKey string `json:"dingtalk_connect_sync_dept_attr_key"` + DingTalkConnectSyncCorpEmailAttrName string `json:"dingtalk_connect_sync_corp_email_attr_name"` + DingTalkConnectSyncDisplayNameAttrName string `json:"dingtalk_connect_sync_display_name_attr_name"` + DingTalkConnectSyncDeptAttrName string `json:"dingtalk_connect_sync_dept_attr_name"` + WeChatConnectEnabled bool `json:"wechat_connect_enabled"` WeChatConnectAppID string `json:"wechat_connect_app_id"` WeChatConnectAppSecretConfigured bool `json:"wechat_connect_app_secret_configured"` @@ -260,6 +277,7 @@ type PublicSettings struct { TablePageSizeOptions []int `json:"table_page_size_options"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` CustomEndpoints []CustomEndpoint `json:"custom_endpoints"` + DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index e15a916eec4..cc360f78fda 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -149,25 +149,28 @@ type AdminGroup struct { } type Account struct { - ID int64 `json:"id"` - Name string `json:"name"` - Notes *string `json:"notes"` - Platform string `json:"platform"` - Type string `json:"type"` - Credentials map[string]any `json:"credentials"` - Extra map[string]any `json:"extra"` - ProxyID *int64 `json:"proxy_id"` - Concurrency int `json:"concurrency"` - LoadFactor *int `json:"load_factor,omitempty"` - Priority int `json:"priority"` - RateMultiplier float64 `json:"rate_multiplier"` - Status string `json:"status"` - ErrorMessage string `json:"error_message"` - LastUsedAt *time.Time `json:"last_used_at"` - ExpiresAt *int64 `json:"expires_at"` - AutoPauseOnExpired bool `json:"auto_pause_on_expired"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Notes *string `json:"notes"` + Platform string `json:"platform"` + Type string `json:"type"` + // Credentials 经 RedactCredentials 处理后只含非敏感子键;敏感 token / api_key / 私钥 + // 的存在性通过 CredentialsStatus(has_)暴露,原始值不返回前端。 + Credentials map[string]any `json:"credentials"` + CredentialsStatus map[string]bool `json:"credentials_status,omitempty"` + Extra map[string]any `json:"extra"` + ProxyID *int64 `json:"proxy_id"` + Concurrency int `json:"concurrency"` + LoadFactor *int `json:"load_factor,omitempty"` + Priority int `json:"priority"` + RateMultiplier float64 `json:"rate_multiplier"` + Status string `json:"status"` + ErrorMessage string `json:"error_message"` + LastUsedAt *time.Time `json:"last_used_at"` + ExpiresAt *int64 `json:"expires_at"` + AutoPauseOnExpired bool `json:"auto_pause_on_expired"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` Schedulable bool `json:"schedulable"` @@ -335,6 +338,7 @@ type RedeemCode struct { UsedBy *int64 `json:"used_by"` UsedAt *time.Time `json:"used_at"` CreatedAt time.Time `json:"created_at"` + ExpiresAt *time.Time `json:"expires_at,omitempty"` GroupID *int64 `json:"group_id"` ValidityDays int `json:"validity_days"` @@ -400,9 +404,13 @@ type UsageLog struct { FirstTokenMs *int `json:"first_token_ms"` // 图片生成字段 - ImageCount int `json:"image_count"` - ImageSize *string `json:"image_size"` - MediaType *string `json:"media_type"` + ImageCount int `json:"image_count"` + ImageSize *string `json:"image_size"` + ImageInputSize *string `json:"image_input_size"` + ImageOutputSize *string `json:"image_output_size"` + ImageSizeSource *string `json:"image_size_source"` + ImageSizeBreakdown map[string]int `json:"image_size_breakdown"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 65836a7e452..0c88ebb4faf 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -325,6 +326,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) reqLog.Warn("gateway.select_account_no_available", zap.String("model", reqModel), zap.Int64p("group_id", apiKey.GroupID), @@ -374,6 +376,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) reqLog.Warn("gateway.select_account_no_slot_no_wait_plan", zap.Int64("account_id", account.ID), zap.String("model", reqModel), @@ -566,6 +569,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, fs.FailedAccountIDs, parsedReq.MetadataUserID, subject.UserID) if err != nil { if len(fs.FailedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) reqLog.Warn("gateway.select_account_no_available", zap.String("model", reqModel), zap.Int64p("group_id", currentAPIKey.GroupID), @@ -626,6 +630,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) reqLog.Warn("gateway.select_account_no_slot_no_wait_plan", zap.Int64("account_id", account.ID), zap.String("model", reqModel), @@ -946,8 +951,8 @@ func (h *GatewayHandler) Models(c *gin.Context) { platform = forcedPlatform } - // Get available models from account configurations (without platform filter) - availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") + // Get available models from account configurations for the selected group platform. + availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, platform) if len(availableModels) > 0 { // Build model list from whitelist @@ -968,7 +973,7 @@ func (h *GatewayHandler) Models(c *gin.Context) { } // Fallback to default models - if platform == "openai" { + if platform == service.PlatformOpenAI { c.JSON(http.StatusOK, gin.H{ "object": "list", "data": openai.DefaultModels, @@ -976,6 +981,14 @@ func (h *GatewayHandler) Models(c *gin.Context) { return } + if platform == service.PlatformGemini { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": geminicli.DefaultModels, + }) + return + } + c.JSON(http.StatusOK, gin.H{ "object": "list", "data": claude.DefaultModels, @@ -1312,6 +1325,11 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody + if service.IsOpenAISilentRefusalErrorBody(responseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted) + return + } // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { @@ -1542,6 +1560,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err)) + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index c6b73190367..7d2c2b1d45f 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -161,14 +161,26 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { APIKeyID: apiKey.ID, } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) + groupPlatform := "" + if apiKey.Group != nil { + groupPlatform = apiKey.Group.Platform + } + selectionSessionHash := sessionHash + if groupPlatform == service.PlatformGemini && selectionSessionHash != "" { + selectionSessionHash = "gemini:" + selectionSessionHash + } // 3. Account selection + failover loop fs := NewFailoverState(h.maxAccountSwitches, false) + if groupPlatform == service.PlatformGemini { + fs = NewFailoverState(h.maxAccountSwitchesGemini, false) + } for { - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, selectionSessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } @@ -194,6 +206,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") return } @@ -213,13 +226,33 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { } accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + if groupPlatform == service.PlatformGemini && account.Platform != service.PlatformGemini { + if accountReleaseFunc != nil { + accountReleaseFunc() + } + fs.FailedAccountIDs[account.ID] = struct{}{} + continue + } + // 5. Forward request writerSizeBeforeForward := c.Writer.Size() forwardBody := body if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } - result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + var result *service.ForwardResult + if account.Platform == service.PlatformGemini { + if h.geminiCompatService == nil { + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", "Gemini compatibility service is not configured") + if accountReleaseFunc != nil { + accountReleaseFunc() + } + return + } + result, err = h.geminiCompatService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody) + } else { + result, err = h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq) + } if accountReleaseFunc != nil { accountReleaseFunc() @@ -302,5 +335,10 @@ func (h *GatewayHandler) handleCCFailoverExhausted(c *gin.Context, lastErr *serv if lastErr != nil && lastErr.StatusCode > 0 { statusCode = lastErr.StatusCode } + if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.chatCompletionsErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage()) + return + } h.chatCompletionsErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index a97f572d2a8..03246f8be2d 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -174,6 +174,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, fs.FailedAccountIDs, "", int64(0)) if err != nil { if len(fs.FailedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) return } @@ -199,6 +200,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") return } @@ -308,5 +310,10 @@ func (h *GatewayHandler) handleResponsesFailoverExhausted(c *gin.Context, lastEr if lastErr != nil && lastErr.StatusCode > 0 { statusCode = lastErr.StatusCode } + if lastErr != nil && service.IsOpenAISilentRefusalErrorBody(lastErr.ResponseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.responsesErrorResponse(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage()) + return + } h.responsesErrorResponse(c, statusCode, "server_error", "All available accounts exhausted") } diff --git a/backend/internal/handler/gateway_models_test.go b/backend/internal/handler/gateway_models_test.go new file mode 100644 index 00000000000..af52ae23a5a --- /dev/null +++ b/backend/internal/handler/gateway_models_test.go @@ -0,0 +1,136 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type gatewayModelsAccountRepoStub struct { + service.AccountRepository + + byGroup map[int64][]service.Account +} + +type gatewayModelsResponseForTest struct { + Object string `json:"object"` + Data []gatewayModelItemForTest `json:"data"` +} + +type gatewayModelItemForTest struct { + ID string `json:"id"` +} + +func (s *gatewayModelsAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]service.Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func newGatewayModelsHandlerForTest(repo service.AccountRepository) *GatewayHandler { + return &GatewayHandler{ + gatewayService: service.NewGatewayService( + repo, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ), + } +} + +func TestGatewayModels_GeminiGroupFallsBackToGeminiModels(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(20) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + {ID: 1, Platform: service.PlatformGemini}, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, "list", got.Object) + require.Contains(t, modelIDsForTest(got.Data), "gemini-2.5-flash") + require.NotContains(t, modelIDsForTest(got.Data), "claude-sonnet-4-6") +} + +func TestGatewayModels_GeminiGroupFiltersMappedModelsByPlatform(t *testing.T) { + gin.SetMode(gin.TestMode) + + groupID := int64(21) + h := newGatewayModelsHandlerForTest( + &gatewayModelsAccountRepoStub{ + byGroup: map[int64][]service.Account{ + groupID: { + { + ID: 1, + Platform: service.PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-sonnet-4-6": "claude-sonnet-4-6", + }, + }, + }, + { + ID: 2, + Platform: service.PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-flash": "gemini-2.5-flash", + }, + }, + }, + }, + }, + }, + ) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/v1/models", nil) + c.Set(string(middleware2.ContextKeyAPIKey), &service.APIKey{ + Group: &service.Group{ID: groupID, Platform: service.PlatformGemini}, + }) + + h.Models(c) + + require.Equal(t, http.StatusOK, rec.Code) + + var got gatewayModelsResponseForTest + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, []string{"gemini-2.5-flash"}, modelIDsForTest(got.Data)) +} + +func modelIDsForTest(models []gatewayModelItemForTest) []string { + ids := make([]string, 0, len(models)) + for _, model := range models { + ids = append(ids, model.ID) + } + return ids +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 90ebe9ecc69..3395eeec621 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -61,6 +61,7 @@ func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { c.JSON(http.StatusOK, gemini.FallbackModelsList()) return } + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } @@ -113,6 +114,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } @@ -372,6 +374,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, fs.FailedAccountIDs, "", int64(0)) // Gemini 不使用会话限制 if err != nil { if len(fs.FailedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } @@ -419,6 +422,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountReleaseFunc := selection.ReleaseFunc if !selection.Acquired { if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts") return } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index de384710284..f78c63a2d93 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -143,6 +143,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } else { @@ -155,6 +156,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { } } if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } @@ -176,6 +178,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + writerSizeBeforeForward := c.Writer.Size() result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, "") forwardDurationMs := time.Since(forwardStart).Milliseconds() @@ -201,6 +204,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { } else { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) // Pool mode: retry on the same account if failoverErr.RetryableOnSameAccount { @@ -292,7 +299,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { // resolveRawCCUpstreamEndpoint returns the actual upstream endpoint for // OpenAI Chat Completions requests. For APIKey accounts whose upstream -// has been probed to not support the Responses API, the request is +// is forced or probed to not support the Responses API, the request is // forwarded directly to /v1/chat/completions — not through the default // CC→Responses conversion path. func resolveRawCCUpstreamEndpoint(c *gin.Context, account *service.Account) string { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 6b07b7ba70b..d9e81d4d83b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -282,6 +282,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) if errors.Is(err, service.ErrNoAvailableCompactAccounts) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "compact_not_supported", "No available OpenAI accounts support /responses/compact", streamStarted) return @@ -297,6 +298,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } @@ -330,6 +332,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if channelMapping.Mapped { forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel) } + writerSizeBeforeForward := c.Writer.Size() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, forwardBody) forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { @@ -354,6 +357,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } else { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + if c.Writer.Size() != writerSizeBeforeForward { + h.handleFailoverExhausted(c, failoverErr, true) + return + } h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) // 池模式:同账号重试 if failoverErr.RetryableOnSameAccount { @@ -677,6 +684,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { ) if len(failedAccountIDs) == 0 { if err != nil { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } @@ -690,6 +698,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { } } if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } @@ -992,6 +1001,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( reqLog *zap.Logger, ) (func(), bool) { if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) return nil, false } @@ -1002,6 +1012,7 @@ func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true } if selection.WaitPlan == nil { + markOpsRoutingCapacityLimited(c) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) return nil, false } @@ -1598,6 +1609,11 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody + if service.IsOpenAISilentRefusalErrorBody(responseBody) { + service.SetOpsUpstreamError(c, statusCode, service.OpenAISilentRefusalClientMessage(), "") + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", service.OpenAISilentRefusalClientMessage(), streamStarted) + return + } // 先检查透传规则 if h.errorPassthroughService != nil && len(responseBody) > 0 { diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 08a6b6e85cf..be19a03573d 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -157,6 +157,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { zap.Int("excluded_account_count", len(failedAccountIDs)), ) if len(failedAccountIDs) == 0 { + markOpsRoutingCapacityLimitedIfNoAvailable(c, err) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) return } @@ -168,6 +169,7 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { return } if selection == nil || selection.Account == nil { + markOpsRoutingCapacityLimited(c) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available compatible accounts", streamStarted) return } diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 935549121bb..c8803aabacd 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "log" "runtime" "runtime/debug" @@ -22,10 +23,11 @@ import ( ) const ( - opsModelKey = "ops_model" - opsStreamKey = "ops_stream" - opsRequestBodyKey = "ops_request_body" - opsAccountIDKey = "ops_account_id" + opsModelKey = "ops_model" + opsStreamKey = "ops_stream" + opsRequestBodyKey = "ops_request_body" + opsAccountIDKey = "ops_account_id" + opsRoutingCapacityLimitedKey = "ops_routing_capacity_limited" opsUpstreamModelKey = "ops_upstream_model" opsRequestTypeKey = "ops_request_type" @@ -45,6 +47,8 @@ const ( opsCodeSubscriptionNotFound = "SUBSCRIPTION_NOT_FOUND" opsCodeSubscriptionInvalid = "SUBSCRIPTION_INVALID" opsCodeUserInactive = "USER_INACTIVE" + opsCodeInvalidAPIKey = "INVALID_API_KEY" + opsCodeAPIKeyRequired = "API_KEY_REQUIRED" ) const ( @@ -393,6 +397,42 @@ func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) } } +func markOpsRoutingCapacityLimited(c *gin.Context) { + if c == nil { + return + } + c.Set(opsRoutingCapacityLimitedKey, true) +} + +func markOpsRoutingCapacityLimitedIfNoAvailable(c *gin.Context, err error) { + if !isOpsNoAvailableAccountError(err) { + return + } + markOpsRoutingCapacityLimited(c) +} + +func isOpsRoutingCapacityLimited(c *gin.Context) bool { + if c == nil { + return false + } + v, ok := c.Get(opsRoutingCapacityLimitedKey) + if !ok { + return false + } + marked, _ := v.(bool) + return marked +} + +func isOpsNoAvailableAccountError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, service.ErrNoAvailableAccounts) || errors.Is(err, service.ErrNoAvailableCompactAccounts) { + return true + } + return isOpsNoAvailableAccountMessage(err.Error()) +} + type opsCaptureWriter struct { gin.ResponseWriter limit int @@ -775,11 +815,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code) - phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code) - isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message) - - errorOwner := classifyOpsErrorOwner(phase, parsed.Message) - errorSource := classifyOpsErrorSource(phase, parsed.Message) + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, normalizedType, parsed.Message, parsed.Code, status) entry := &service.OpsInsertErrorLogInput{ RequestID: requestID, @@ -1114,6 +1150,9 @@ func classifyOpsPhase(errType, message, code string) string { msg := strings.ToLower(message) // Standardized phases: request|auth|routing|upstream|network|internal // Map billing/concurrency/response => request; scheduling => routing. + if isOpsClientAuthError(code, msg) { + return "auth" + } switch strings.TrimSpace(code) { case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid: return "request" @@ -1134,7 +1173,7 @@ func classifyOpsPhase(errType, message, code string) string { case "upstream_error", "overloaded_error": return "upstream" case "api_error": - if strings.Contains(msg, opsErrNoAvailableAccounts) { + if isOpsNoAvailableAccountMessage(msg) { return "routing" } return "internal" @@ -1178,7 +1217,31 @@ func classifyOpsIsRetryable(errType string, statusCode int) bool { } } -func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string) bool { +func classifyOpsErrorLog(c *gin.Context, errType, message, code string, status int) (phase string, isBusinessLimited bool, errorOwner string, errorSource string) { + phase = classifyOpsPhase(errType, message, code) + routingCapacityLimited := isOpsRoutingCapacityLimited(c) + clientBusinessLimited := service.HasOpsClientBusinessLimited(c) + upstreamError := hasOpsUpstreamErrorContext(c) + if upstreamError && !routingCapacityLimited { + phase = "upstream" + } + if clientBusinessLimited && !upstreamError && !routingCapacityLimited { + phase = "auth" + } + if routingCapacityLimited { + phase = "routing" + } + localClientAuthError := !upstreamError && phase == "auth" && isOpsClientAuthError(code, strings.ToLower(message)) + isBusinessLimited = routingCapacityLimited || clientBusinessLimited || classifyOpsIsBusinessLimited(errType, phase, code, status, message, localClientAuthError) + errorOwner = classifyOpsErrorOwner(phase, message) + errorSource = classifyOpsErrorSource(phase, message) + return phase, isBusinessLimited, errorOwner, errorSource +} + +func classifyOpsIsBusinessLimited(errType, phase, code string, status int, message string, localClientAuthError ...bool) bool { + if len(localClientAuthError) > 0 && localClientAuthError[0] { + return true + } switch strings.TrimSpace(code) { case opsCodeInsufficientBalance, opsCodeUsageLimitExceeded, opsCodeSubscriptionNotFound, opsCodeSubscriptionInvalid, opsCodeUserInactive: return true @@ -1195,6 +1258,47 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa return false } +func isOpsClientAuthError(code string, msg string) bool { + switch strings.TrimSpace(code) { + case opsCodeInvalidAPIKey, opsCodeAPIKeyRequired: + return true + } + return strings.Contains(msg, "invalid api key") || strings.Contains(msg, "api key is required") +} + +func hasOpsUpstreamErrorContext(c *gin.Context) bool { + if c == nil { + return false + } + if v, ok := c.Get(service.OpsUpstreamStatusCodeKey); ok { + switch code := v.(type) { + case int: + if code > 0 { + return true + } + case int64: + if code > 0 { + return true + } + } + } + if v, ok := c.Get(service.OpsUpstreamErrorsKey); ok { + if events, ok := v.([]*service.OpsUpstreamErrorEvent); ok && len(events) > 0 { + return true + } + } + return false +} + +func isOpsNoAvailableAccountMessage(message string) bool { + msg := strings.ToLower(message) + return strings.Contains(msg, opsErrNoAvailableAccounts) || + strings.Contains(msg, "no available account") || + strings.Contains(msg, "no available gemini accounts") || + strings.Contains(msg, "no available openai accounts") || + strings.Contains(msg, "no available compatible accounts") +} + func classifyOpsErrorOwner(phase string, message string) string { // Standardized owners: client|provider|platform switch phase { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 6ae45110c3d..2a141fdfbce 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -275,6 +275,218 @@ func TestNormalizeOpsErrorType(t *testing.T) { } } +func TestClassifyOpsNoAvailableAccountsExcludedFromSLA(t *testing.T) { + const message = "No available accounts" + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + markOpsRoutingCapacityLimited(c) + + errType := normalizeOpsErrorType("api_error", "") + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable) + + require.Equal(t, "api_error", errType) + require.Equal(t, "routing", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "platform", errorOwner) + require.Equal(t, "gateway", errorSource) +} + +func TestClassifyOpsRoutingCapacityMarkerExcludesMaskedSelectionFailureFromSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + markOpsRoutingCapacityLimited(c) + + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( + c, + "api_error", + "Service temporarily unavailable", + "", + http.StatusServiceUnavailable, + ) + + require.Equal(t, "routing", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "platform", errorOwner) + require.Equal(t, "gateway", errorSource) +} + +func TestClassifyOpsAuthClientErrorsExcludedFromSLA(t *testing.T) { + tests := []struct { + name string + errType string + message string + code string + status int + }{ + { + name: "standard invalid API key", + errType: "api_error", + message: "Invalid API key", + code: "INVALID_API_KEY", + status: http.StatusUnauthorized, + }, + { + name: "standard missing API key", + errType: "api_error", + message: "API key is required in Authorization header (Bearer scheme), x-api-key header, or x-goog-api-key header", + code: "API_KEY_REQUIRED", + status: http.StatusUnauthorized, + }, + { + name: "google invalid API key", + errType: "api_error", + message: "Invalid API key", + code: "401", + status: http.StatusUnauthorized, + }, + { + name: "google missing API key", + errType: "api_error", + message: "API key is required", + code: "401", + status: http.StatusUnauthorized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + errType := normalizeOpsErrorType(tt.errType, tt.code) + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, tt.message, tt.code, tt.status) + + require.Equal(t, "api_error", errType) + require.Equal(t, "auth", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "client", errorOwner) + require.Equal(t, "client_request", errorSource) + }) + } +} + +func TestClassifyOpsIPRestrictionAccessDeniedExcludedFromSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction) + + errType := normalizeOpsErrorType("api_error", "ACCESS_DENIED") + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Access denied", "ACCESS_DENIED", http.StatusForbidden) + + require.Equal(t, "api_error", errType) + require.Equal(t, "auth", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "client", errorOwner) + require.Equal(t, "client_request", errorSource) +} + +func TestClassifyOpsOtherErrorsStillCountForSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + errType := normalizeOpsErrorType("api_error", "INTERNAL_ERROR") + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, "Failed to validate API key", "INTERNAL_ERROR", http.StatusInternalServerError) + + require.Equal(t, "api_error", errType) + require.Equal(t, "internal", phase) + require.False(t, isBusinessLimited) + require.Equal(t, "platform", errorOwner) + require.Equal(t, "gateway", errorSource) +} + +func TestClassifyOpsUnsupportedModelExcludedFromSLA(t *testing.T) { + tests := []string{ + "No available accounts: no available accounts supporting model: made-up-model", + "No available accounts: no available OpenAI accounts supporting model: made-up-model", + "No available Gemini accounts: no available Gemini accounts supporting model: made-up-model", + "No available accounts: no available accounts supporting model: made-up-model (channel pricing restriction)", + } + + for _, message := range tests { + t.Run(message, func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + markOpsRoutingCapacityLimited(c) + + errType := normalizeOpsErrorType("api_error", "") + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog(c, errType, message, "", http.StatusServiceUnavailable) + + require.Equal(t, "api_error", errType) + require.Equal(t, "routing", phase) + require.True(t, isBusinessLimited) + require.Equal(t, "platform", errorOwner) + require.Equal(t, "gateway", errorSource) + }) + } +} + +func TestClassifyOpsUnmarkedNoAvailableTextStillCountsForSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( + c, + "api_error", + "No available accounts", + "", + http.StatusServiceUnavailable, + ) + + require.Equal(t, "routing", phase) + require.False(t, isBusinessLimited) + require.Equal(t, "platform", errorOwner) + require.Equal(t, "gateway", errorSource) +} + +func TestClassifyOpsUpstreamAuthTextStillCountsForSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + service.SetOpsUpstreamError(c, http.StatusUnauthorized, "Invalid API key", "") + + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( + c, + "api_error", + "Invalid API key", + "401", + http.StatusUnauthorized, + ) + + require.Equal(t, "upstream", phase) + require.False(t, isBusinessLimited) + require.Equal(t, "provider", errorOwner) + require.Equal(t, "upstream_http", errorSource) +} + +func TestClassifyOpsUpstreamNoAvailableTextStillCountsForSLA(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + service.SetOpsUpstreamError(c, http.StatusServiceUnavailable, "No available accounts", "") + + phase, isBusinessLimited, errorOwner, errorSource := classifyOpsErrorLog( + c, + "api_error", + "No available accounts", + "", + http.StatusServiceUnavailable, + ) + + require.Equal(t, "upstream", phase) + require.False(t, isBusinessLimited) + require.Equal(t, "provider", errorOwner) + require.Equal(t, "upstream_http", errorSource) +} + func TestSetOpsEndpointContext_SetsContextKeys(t *testing.T) { gin.SetMode(gin.TestMode) rec := httptest.NewRecorder() diff --git a/backend/internal/handler/page_handler_test.go b/backend/internal/handler/page_handler_test.go index 0a9f0d96157..a6813cdfa07 100644 --- a/backend/internal/handler/page_handler_test.go +++ b/backend/internal/handler/page_handler_test.go @@ -58,7 +58,7 @@ func TestResolvePageImagePath(t *testing.T) { if !ok { t.Fatal("expected direct image path to be accepted") } - want := filepath.Join(base, "logo.png") + want := mustEvalSymlinks(t, filepath.Join(base, "logo.png")) if got != want { t.Fatalf("path = %q, want %q", got, want) } @@ -67,7 +67,7 @@ func TestResolvePageImagePath(t *testing.T) { if !ok { t.Fatal("expected nested image path to be accepted") } - want = filepath.Join(base, "images", "logo.png") + want = mustEvalSymlinks(t, filepath.Join(base, "images", "logo.png")) if got != want { t.Fatalf("path = %q, want %q", got, want) } @@ -100,3 +100,13 @@ func TestResolvePageImagePathRejectsSymlinkEscape(t *testing.T) { t.Fatalf("expected symlink escape to be rejected, got %q", got) } } + +func mustEvalSymlinks(t *testing.T, path string) string { + t.Helper() + + realPath, err := filepath.EvalSymlinks(path) + if err != nil { + t.Fatalf("eval symlinks for %q: %v", path, err) + } + return realPath +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 6c389e3da1e..c4ba43e4f5e 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -61,6 +61,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { TablePageSizeOptions: settings.TablePageSizeOptions, CustomMenuItems: dto.ParseUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), + DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, WeChatOAuthEnabled: settings.WeChatOAuthEnabled, WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 3f6ed8c2bcc..f1dbf4e14ea 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -67,6 +67,7 @@ type userProfileResponse struct { LinuxDoBound bool `json:"linuxdo_bound"` OIDCBound bool `json:"oidc_bound"` WeChatBound bool `json:"wechat_bound"` + DingTalkBound bool `json:"dingtalk_bound"` } type userProfileSourceContext struct { @@ -528,15 +529,17 @@ func userProfileResponseFromService(user *service.User, identities service.UserI LinuxDoBound: identities.LinuxDo.Bound, OIDCBound: identities.OIDC.Bound, WeChatBound: identities.WeChat.Bound, + DingTalkBound: identities.DingTalk.Bound, } } func userProfileBindingMap(identities service.UserIdentitySummarySet) map[string]service.UserIdentitySummary { return map[string]service.UserIdentitySummary{ - "email": identities.Email, - "linuxdo": identities.LinuxDo, - "oidc": identities.OIDC, - "wechat": identities.WeChat, + "email": identities.Email, + "linuxdo": identities.LinuxDo, + "oidc": identities.OIDC, + "wechat": identities.WeChat, + "dingtalk": identities.DingTalk, } } @@ -585,7 +588,7 @@ func inferUserProfileSources(user *service.User, identities service.UserIdentity func thirdPartyIdentityProviders(identities service.UserIdentitySummarySet) []service.UserIdentitySummary { out := make([]service.UserIdentitySummary, 0, 3) - for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat} { + for _, summary := range []service.UserIdentitySummary{identities.LinuxDo, identities.OIDC, identities.WeChat, identities.DingTalk} { if summary.Bound { out = append(out, summary) } diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index 1234b56819e..c4c6e634e5a 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -105,10 +105,16 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string { // CreatePayment creates an Alipay payment using the following routing: // - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay. -// - Desktop: prefer alipay.trade.precreate to get a scan payload directly. -// - Desktop fallback: if precreate is unavailable for the merchant, fall back -// to alipay.trade.page.pay and expose both pay_url and qr_code so the -// frontend can render a QR while still allowing direct page open. +// - Desktop, default: prefer alipay.trade.precreate (FACE_TO_FACE_PAYMENT) to +// get a scannable QR payload. If precreate is unavailable for the merchant, +// fall back to alipay.trade.page.pay and expose pay_url only — the frontend +// opens the Alipay checkout in a new tab. +// - Desktop, paymentMode == "redirect": skip precreate and go straight to +// alipay.trade.page.pay so the frontend always opens the Alipay checkout +// in a new tab. Use this when the merchant has not enabled FACE_TO_FACE_PAYMENT. +// +// Note: alipay.trade.page.pay returns a checkout page URL, not a scannable +// payment QR. Never expose it via the QRCode field. func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { @@ -150,6 +156,13 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment } func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + // Explicit redirect mode: merchant opted into "always open the Alipay + // checkout page in a new tab" via the provider instance's payment_mode. + // Skip precreate to avoid a wasted API call. + if strings.EqualFold(strings.TrimSpace(a.config["paymentMode"]), "redirect") { + return a.createPagePayTrade(client, req, notifyURL, returnURL) + } + resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL) if precreateErr == nil { return resp, nil @@ -204,10 +217,12 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay if err != nil { return nil, fmt.Errorf("alipay TradePagePay: %w", err) } + // Only PayURL is exposed: alipay.trade.page.pay returns a checkout page URL + // that must be opened in a browser, not a scannable payment QR. Setting it + // as QRCode would let the frontend render an unscannable image. return &payment.CreatePaymentResponse{ TradeNo: req.OrderID, PayURL: payURL.String(), - QRCode: payURL.String(), }, nil } diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index fdc8eec1ac5..9f8aec53173 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -189,8 +189,63 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { if resp.PayURL == "" { t.Fatal("expected pay_url for desktop page pay") } - if resp.QRCode != resp.PayURL { - t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL) + // page.pay returns a checkout page URL, not a scannable QR payload — + // it must never be exposed via QRCode (the frontend would render an + // unscannable image from it). + if resp.QRCode != "" { + t.Fatalf("qr_code = %q, want empty for page pay", resp.QRCode) + } +} + +// When the provider instance is configured with paymentMode == "redirect", +// the desktop flow must skip precreate and go straight to page.pay. +func TestCreateTradeRedirectModeSkipsPrecreate(t *testing.T) { + origPreCreate := alipayTradePreCreate + origPagePay := alipayTradePagePay + t.Cleanup(func() { + alipayTradePreCreate = origPreCreate + alipayTradePagePay = origPagePay + }) + + preCreateCalls := 0 + pagePayCalls := 0 + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + preCreateCalls++ + return &alipay.TradePreCreateRsp{ + Error: alipay.Error{Code: alipay.CodeSuccess}, + QRCode: "https://qr.alipay.example.com/precreate-token", + }, nil + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ + if param.ProductCode != alipayProductCodePagePay { + t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePagePay) + } + return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") + } + + provider := &Alipay{ + config: map[string]string{"paymentMode": "redirect"}, + } + resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_103", + Amount: "12.00", + Subject: "Balance recharge", + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if preCreateCalls != 0 { + t.Fatalf("precreate calls = %d, want 0 (redirect mode must skip precreate)", preCreateCalls) + } + if pagePayCalls != 1 { + t.Fatalf("page pay calls = %d, want 1", pagePayCalls) + } + if resp.PayURL == "" { + t.Fatal("expected pay_url for redirect mode") + } + if resp.QRCode != "" { + t.Fatalf("qr_code = %q, want empty for redirect mode", resp.QRCode) } } diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 16aff9f848f..e318d1cdaf3 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -254,6 +254,8 @@ const ( proxyTLSHandshakeTimeout = 5 * time.Second // clientTimeout 整体请求超时(含连接、发送、等待响应、读取 body) clientTimeout = 10 * time.Second + // fetchAvailableModelsBodyLimit limits model-list responses to avoid unbounded memory use. + fetchAvailableModelsBodyLimit int64 = 8 << 20 ) func NewClient(proxyURL string) (*Client, error) { @@ -655,6 +657,10 @@ type FetchAvailableModelsResponse struct { // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON // 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { + if c == nil || c.httpClient == nil { + return nil, nil, errors.New("antigravity client is not configured") + } + reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) if err != nil { @@ -664,6 +670,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI // 固定顺序:prod -> daily availableURLs := BaseURLs + fetchClient := c.fetchAvailableModelsHTTPClient() var lastErr error for urlIdx, baseURL := range availableURLs { apiURL := baseURL + "/v1internal:fetchAvailableModels" @@ -676,7 +683,7 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI req.Header.Set("Content-Type", "application/json") req.Header.Set("User-Agent", GetUserAgentForContext(ctx)) - resp, err := c.httpClient.Do(req) + resp, err := fetchClient.Do(req) if err != nil { lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { @@ -686,11 +693,14 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, lastErr } - respBodyBytes, err := io.ReadAll(resp.Body) + respBodyBytes, err := io.ReadAll(io.LimitReader(resp.Body, fetchAvailableModelsBodyLimit+1)) _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 if err != nil { return nil, nil, fmt.Errorf("读取响应失败: %w", err) } + if int64(len(respBodyBytes)) > fetchAvailableModelsBodyLimit { + return nil, nil, fmt.Errorf("响应超过 %d 字节", fetchAvailableModelsBodyLimit) + } // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { @@ -726,6 +736,42 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, lastErr } +func (c *Client) fetchAvailableModelsHTTPClient() *http.Client { + fetchClient := *c.httpClient + fetchClient.CheckRedirect = checkFetchAvailableModelsRedirect + return &fetchClient +} + +func checkFetchAvailableModelsRedirect(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + if req == nil || req.URL == nil { + return errors.New("redirect url is nil") + } + if !isAllowedFetchAvailableModelsRedirectHost(req.URL.Hostname()) { + return fmt.Errorf("redirect to unsupported host: %s", req.URL.Hostname()) + } + return nil +} + +func isAllowedFetchAvailableModelsRedirectHost(host string) bool { + host = strings.ToLower(strings.TrimSpace(host)) + if host == "" { + return false + } + for _, baseURL := range BaseURLs { + parsed, err := url.Parse(baseURL) + if err != nil { + continue + } + if strings.EqualFold(host, parsed.Hostname()) { + return true + } + } + return false +} + // ── Privacy API ────────────────────────────────────────────────────── // privacyBaseURL 隐私设置 API 仅使用 daily 端点(与 Antigravity 客户端行为一致) diff --git a/backend/internal/pkg/apicompat/anthropic_responses_test.go b/backend/internal/pkg/apicompat/anthropic_responses_test.go index aa36ef0b9a9..3a0a3bc6f12 100644 --- a/backend/internal/pkg/apicompat/anthropic_responses_test.go +++ b/backend/internal/pkg/apicompat/anthropic_responses_test.go @@ -744,6 +744,10 @@ func TestStreamingReasoning(t *testing.T) { assert.Equal(t, "content_block_start", events[0].Type) assert.Equal(t, "thinking", events[0].ContentBlock.Type) + sse, err := ResponsesAnthropicEventToSSE(events[0]) + require.NoError(t, err) + assert.Contains(t, sse, `"content_block":{"thinking":"","type":"thinking"}`) + // reasoning text delta events = ResponsesEventToAnthropicEvents(&ResponsesStreamEvent{ Type: "response.reasoning_summary_text.delta", diff --git a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go index bf5c23d5652..25f5c475fec 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_responses_test.go +++ b/backend/internal/pkg/apicompat/chatcompletions_responses_test.go @@ -225,6 +225,41 @@ func TestChatCompletionsToResponses_WhitespaceOnlyBase64ImageURLSkipped(t *testi assert.Equal(t, "Describe this", parts[0].Text) } +func TestChatCompletionsToResponses_EmptyContentNeverNull(t *testing.T) { + // Regression for #2515: the upstream Responses API rejects an input item + // whose content field is JSON null. Any chat-completions message that + // yields no usable content parts must serialize content as a string. + cases := []struct { + name string + content json.RawMessage + }{ + {"null content", json.RawMessage(`null`)}, + {"empty array content", json.RawMessage(`[]`)}, + {"only empty text part", json.RawMessage(`[{"type":"text","text":""}]`)}, + {"only empty base64 image part", json.RawMessage(`[{"type":"image_url","image_url":{"url":"data:image/png;base64,"}}]`)}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-5.5", + Messages: []ChatMessage{ + {Role: "user", Content: tc.content}, + }, + } + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + assert.NotContains(t, string(resp.Input), `"content":null`, + "converted input must not contain a null content field") + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 1) + assert.Equal(t, `""`, string(items[0].Content), + "content must be an empty string, not null") + }) + } +} + func TestChatCompletionsToResponses_SystemArrayContent(t *testing.T) { req := &ChatCompletionsRequest{ Model: "gpt-4o", @@ -379,6 +414,34 @@ func TestChatCompletionsToResponses_AssistantThinkingTagPreserved(t *testing.T) assert.Contains(t, parts[0].Text, "final answer") } +func TestChatCompletionsToResponses_AssistantReasoningContentPreserved(t *testing.T) { + req := &ChatCompletionsRequest{ + Model: "gpt-4o", + Messages: []ChatMessage{ + {Role: "user", Content: json.RawMessage(`"Hi"`)}, + { + Role: "assistant", + ReasoningContent: "internal plan", + Content: json.RawMessage(`"final answer"`), + }, + }, + } + + resp, err := ChatCompletionsToResponses(req) + require.NoError(t, err) + + var items []ResponsesInputItem + require.NoError(t, json.Unmarshal(resp.Input, &items)) + require.Len(t, items, 2) + + var parts []ResponsesContentPart + require.NoError(t, json.Unmarshal(items[1].Content, &parts)) + require.Len(t, parts, 1) + assert.Equal(t, "output_text", parts[0].Type) + assert.Contains(t, parts[0].Text, "internal plan") + assert.Contains(t, parts[0].Text, "final answer") +} + // --------------------------------------------------------------------------- // ResponsesToChatCompletions tests // --------------------------------------------------------------------------- diff --git a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go index 64ef5781585..fe2c150b20f 100644 --- a/backend/internal/pkg/apicompat/chatcompletions_to_responses.go +++ b/backend/internal/pkg/apicompat/chatcompletions_to_responses.go @@ -150,6 +150,11 @@ func chatUserToResponses(m ChatMessage) ([]ResponsesInputItem, error) { // empty/nil and there are tool_calls, only function_call items are emitted. func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { var items []ResponsesInputItem + content := "" + + if m.ReasoningContent != "" { + content = "" + m.ReasoningContent + "" + } // Emit assistant message with output_text if content is non-empty. if len(m.Content) > 0 { @@ -158,13 +163,20 @@ func chatAssistantToResponses(m ChatMessage) ([]ResponsesInputItem, error) { return nil, err } if s != "" { - parts := []ResponsesContentPart{{Type: "output_text", Text: s}} - partsJSON, err := json.Marshal(parts) - if err != nil { - return nil, err + if content != "" { + content += "\n" } - items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) + content += s + } + } + + if content != "" { + parts := []ResponsesContentPart{{Type: "output_text", Text: content}} + partsJSON, err := json.Marshal(parts) + if err != nil { + return nil, err } + items = append(items, ResponsesInputItem{Role: "assistant", Content: partsJSON}) } // Emit one function_call item per tool_call. @@ -325,7 +337,14 @@ func marshalChatInputContent(content chatMessageContent) (json.RawMessage, error if content.Text != nil { return json.Marshal(*content.Text) } - return json.Marshal(convertChatContentPartsToResponses(content.Parts)) + parts := convertChatContentPartsToResponses(content.Parts) + if len(parts) == 0 { + // A nil slice marshals to JSON null, which the upstream Responses API + // rejects ("expected an array of objects or string, but got null"). + // Fall back to an empty string when no usable parts remain. + return json.Marshal("") + } + return json.Marshal(parts) } func convertChatContentPartsToResponses(parts []ChatContentPart) []ResponsesContentPart { diff --git a/backend/internal/pkg/apicompat/types.go b/backend/internal/pkg/apicompat/types.go index f9cd5a1c7f9..7c46ccaf310 100644 --- a/backend/internal/pkg/apicompat/types.go +++ b/backend/internal/pkg/apicompat/types.go @@ -75,6 +75,28 @@ type AnthropicContentBlock struct { IsError bool `json:"is_error,omitempty"` } +func (b AnthropicContentBlock) MarshalJSON() ([]byte, error) { + type anthropicContentBlock AnthropicContentBlock + base := struct { + anthropicContentBlock + }{anthropicContentBlock: anthropicContentBlock(b)} + + switch b.Type { + case "text": + return json.Marshal(struct { + Text string `json:"text"` + anthropicContentBlock + }{Text: b.Text, anthropicContentBlock: anthropicContentBlock(b)}) + case "thinking": + return json.Marshal(struct { + Thinking string `json:"thinking"` + anthropicContentBlock + }{Thinking: b.Thinking, anthropicContentBlock: anthropicContentBlock(b)}) + default: + return json.Marshal(base) + } +} + // AnthropicImageSource describes the source data for an image content block. type AnthropicImageSource struct { Type string `json:"type"` // "base64" @@ -306,6 +328,37 @@ type ResponsesUsage struct { OutputTokensDetails *ResponsesOutputTokensDetails `json:"output_tokens_details,omitempty"` } +func (u *ResponsesUsage) UnmarshalJSON(data []byte) error { + type responsesUsageAlias ResponsesUsage + var aux struct { + responsesUsageAlias + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + PromptTokensDetails *ResponsesInputTokensDetails `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *ResponsesOutputTokensDetails `json:"completion_tokens_details,omitempty"` + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *u = ResponsesUsage(aux.responsesUsageAlias) + if u.InputTokens == 0 && aux.PromptTokens != 0 { + u.InputTokens = aux.PromptTokens + } + if u.OutputTokens == 0 && aux.CompletionTokens != 0 { + u.OutputTokens = aux.CompletionTokens + } + if u.InputTokensDetails == nil && aux.PromptTokensDetails != nil { + u.InputTokensDetails = aux.PromptTokensDetails + } + if u.OutputTokensDetails == nil && aux.CompletionTokensDetails != nil { + u.OutputTokensDetails = aux.CompletionTokensDetails + } + if u.TotalTokens == 0 && (u.InputTokens != 0 || u.OutputTokens != 0) { + u.TotalTokens = u.InputTokens + u.OutputTokens + } + return nil +} + // ResponsesInputTokensDetails breaks down input token usage. type ResponsesInputTokensDetails struct { CachedTokens int `json:"cached_tokens,omitempty"` diff --git a/backend/internal/pkg/openai_compat/upstream_capability.go b/backend/internal/pkg/openai_compat/upstream_capability.go index ff05afe55b8..154a01fb819 100644 --- a/backend/internal/pkg/openai_compat/upstream_capability.go +++ b/backend/internal/pkg/openai_compat/upstream_capability.go @@ -17,7 +17,7 @@ // pensieve/short-term/maxims/preserve-existing-runtime-behavior-when-replacing-logic-in-stateful-systems) package openai_compat -// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的支持状态。 +// AccountResponsesSupport 描述账号上游对 OpenAI Responses API 的有效支持状态。 // // 仅用于 platform=openai + type=apikey 的账号;其他账号类型不应调用本包判定。 type AccountResponsesSupport int @@ -35,11 +35,43 @@ const ( ResponsesSupportNo ) -// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储探测结果的键名。 +// ResponsesSupportMode 描述账号级 Responses API 路由覆盖模式。 +type ResponsesSupportMode string + +const ( + // ResponsesSupportModeAuto 表示跟随自动探测结果。 + ResponsesSupportModeAuto ResponsesSupportMode = "auto" + + // ResponsesSupportModeForceResponses 强制使用 /v1/responses。 + ResponsesSupportModeForceResponses ResponsesSupportMode = "force_responses" + + // ResponsesSupportModeForceChatCompletions 强制使用 /v1/chat/completions。 + ResponsesSupportModeForceChatCompletions ResponsesSupportMode = "force_chat_completions" +) + +// ExtraKeyResponsesMode 是 accounts.extra JSON 中存储手动覆盖模式的键名。 +// 值类型为 string:auto=跟随探测,force_responses=强制 Responses, +// force_chat_completions=强制 Chat Completions。 +const ExtraKeyResponsesMode = "openai_responses_mode" + +// ExtraKeyResponsesSupported 是 accounts.extra JSON 中存储自动探测结果的键名。 // 值类型为 bool:true=支持、false=不支持、键缺失=未探测。 const ExtraKeyResponsesSupported = "openai_responses_supported" -// ResolveResponsesSupport 从账号的 extra map 中读取探测标记。 +// NormalizeResponsesSupportMode 归一化账号级 Responses API 路由覆盖模式。 +// 缺失或非法值按 auto 处理,以保持存量行为。 +func NormalizeResponsesSupportMode(mode string) ResponsesSupportMode { + switch ResponsesSupportMode(mode) { + case ResponsesSupportModeForceResponses: + return ResponsesSupportModeForceResponses + case ResponsesSupportModeForceChatCompletions: + return ResponsesSupportModeForceChatCompletions + default: + return ResponsesSupportModeAuto + } +} + +// ResolveResponsesSupport 从账号的 extra map 中读取手动覆盖模式与探测标记。 // // 标记缺失或类型不匹配时返回 ResponsesSupportUnknown——调用方应按 // "未探测=保留旧行为=走 Responses" 处理(参见 ShouldUseResponsesAPI)。 @@ -47,6 +79,14 @@ func ResolveResponsesSupport(extra map[string]any) AccountResponsesSupport { if extra == nil { return ResponsesSupportUnknown } + if mode, ok := extra[ExtraKeyResponsesMode].(string); ok { + switch NormalizeResponsesSupportMode(mode) { + case ResponsesSupportModeForceResponses: + return ResponsesSupportYes + case ResponsesSupportModeForceChatCompletions: + return ResponsesSupportNo + } + } v, ok := extra[ExtraKeyResponsesSupported] if !ok { return ResponsesSupportUnknown diff --git a/backend/internal/pkg/openai_compat/upstream_capability_test.go b/backend/internal/pkg/openai_compat/upstream_capability_test.go index d650daa4365..008579a757a 100644 --- a/backend/internal/pkg/openai_compat/upstream_capability_test.go +++ b/backend/internal/pkg/openai_compat/upstream_capability_test.go @@ -16,6 +16,12 @@ func TestResolveResponsesSupport(t *testing.T) { {"value wrong type string", map[string]any{ExtraKeyResponsesSupported: "true"}, ResponsesSupportUnknown}, {"value wrong type number", map[string]any{ExtraKeyResponsesSupported: 1}, ResponsesSupportUnknown}, {"value nil", map[string]any{ExtraKeyResponsesSupported: nil}, ResponsesSupportUnknown}, + {"force responses", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses)}, ResponsesSupportYes}, + {"force chat completions", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions)}, ResponsesSupportNo}, + {"auto follows probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeAuto), ExtraKeyResponsesSupported: false}, ResponsesSupportNo}, + {"invalid mode follows probe", map[string]any{ExtraKeyResponsesMode: "bogus", ExtraKeyResponsesSupported: true}, ResponsesSupportYes}, + {"force responses overrides probe false", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, ResponsesSupportYes}, + {"force chat completions overrides probe true", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, ResponsesSupportNo}, } for _, tc := range tests { @@ -42,6 +48,10 @@ func TestShouldUseResponsesAPI(t *testing.T) { // 已探测:标记决定 {"explicitly supported", map[string]any{ExtraKeyResponsesSupported: true}, true}, {"explicitly unsupported", map[string]any{ExtraKeyResponsesSupported: false}, false}, + + // 手动覆盖:覆盖自动探测结果 + {"force responses overrides unsupported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceResponses), ExtraKeyResponsesSupported: false}, true}, + {"force chat completions overrides supported probe", map[string]any{ExtraKeyResponsesMode: string(ResponsesSupportModeForceChatCompletions), ExtraKeyResponsesSupported: true}, false}, } for _, tc := range tests { @@ -53,3 +63,26 @@ func TestShouldUseResponsesAPI(t *testing.T) { }) } } + +func TestNormalizeResponsesSupportMode(t *testing.T) { + tests := []struct { + name string + mode string + want ResponsesSupportMode + }{ + {"empty", "", ResponsesSupportModeAuto}, + {"auto", "auto", ResponsesSupportModeAuto}, + {"force responses", "force_responses", ResponsesSupportModeForceResponses}, + {"force chat completions", "force_chat_completions", ResponsesSupportModeForceChatCompletions}, + {"invalid", "enabled", ResponsesSupportModeAuto}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := NormalizeResponsesSupportMode(tc.mode) + if got != tc.want { + t.Errorf("NormalizeResponsesSupportMode(%q) = %q, want %q", tc.mode, got, tc.want) + } + }) + } +} diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index fe5f98d66e0..39283d22a28 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -230,6 +230,20 @@ type UserDashboardStats struct { // 性能指标 Rpm int64 `json:"rpm"` // 近5分钟平均每分钟请求数 Tpm int64 `json:"tpm"` // 近5分钟平均每分钟Token数 + + // 按"有效平台"维度拆分(与 ops 路径口径一致:group.platform 优先,否则 account.platform) + ByPlatform []PlatformDashboardStats `json:"by_platform,omitempty"` +} + +// PlatformDashboardStats 单个平台的用量明细。 +type PlatformDashboardStats struct { + Platform string `json:"platform"` + TotalRequests int64 `json:"total_requests"` + TotalTokens int64 `json:"total_tokens"` + TotalActualCost float64 `json:"total_actual_cost"` + TodayRequests int64 `json:"today_requests"` + TodayTokens int64 `json:"today_tokens"` + TodayActualCost float64 `json:"today_actual_cost"` } // UsageLogFilters represents filters for usage log queries @@ -265,13 +279,22 @@ type UsageStats struct { EndpointPaths []EndpointStat `json:"endpoint_paths,omitempty"` } -// BatchUserUsageStats represents usage stats for a single user -type BatchUserUsageStats struct { - UserID int64 `json:"user_id"` +// PlatformUsage 表示某用户/某 API key 在单个"有效平台"维度的用量明细。 +// Platform 取值与 ops 路径口径一致:优先 groups.platform,否则 accounts.platform。 +type PlatformUsage struct { + Platform string `json:"platform"` TodayActualCost float64 `json:"today_actual_cost"` TotalActualCost float64 `json:"total_actual_cost"` } +// BatchUserUsageStats represents usage stats for a single user +type BatchUserUsageStats struct { + UserID int64 `json:"user_id"` + TodayActualCost float64 `json:"today_actual_cost"` + TotalActualCost float64 `json:"total_actual_cost"` + ByPlatform []PlatformUsage `json:"by_platform,omitempty"` +} + // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats struct { APIKeyID int64 `json:"api_key_id"` diff --git a/backend/internal/repository/account_repo_compact_extra_test.go b/backend/internal/repository/account_repo_compact_extra_test.go index 604f392e84e..e2ce6602741 100644 --- a/backend/internal/repository/account_repo_compact_extra_test.go +++ b/backend/internal/repository/account_repo_compact_extra_test.go @@ -12,3 +12,14 @@ func TestShouldEnqueueSchedulerOutboxForExtraUpdates_CompactCapabilityKeysAreRel t.Fatalf("expected compact capability updates to enqueue scheduler outbox") } } + +func TestShouldEnqueueSchedulerOutboxForExtraUpdates_OpenAIResponsesCapabilityKeysAreRelevant(t *testing.T) { + updates := map[string]any{ + "openai_responses_mode": "force_chat_completions", + "openai_responses_supported": false, + } + + if !shouldEnqueueSchedulerOutboxForExtraUpdates(updates) { + t.Fatalf("expected responses capability updates to enqueue scheduler outbox") + } +} diff --git a/backend/internal/repository/announcement_repo.go b/backend/internal/repository/announcement_repo.go index afe1fb25c90..f19c24f1f97 100644 --- a/backend/internal/repository/announcement_repo.go +++ b/backend/internal/repository/announcement_repo.go @@ -204,7 +204,8 @@ func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)), announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)), ). - Order(dbent.Desc(announcement.FieldID)) + Order(dbent.Desc(announcement.FieldID)). + Limit(200) items, err := q.All(ctx) if err != nil { diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 112575f49bd..9b6377bc45e 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -283,47 +283,90 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination } func (r *groupRepository) listWithAccountCountSort(ctx context.Context, q *dbent.GroupQuery, params pagination.PaginationParams, total int) ([]service.Group, *pagination.PaginationResult, error) { - groups, err := q. + // 第一步:只查 ID + sort_order(轻量,不做分页 — 需要全量排序 account_count)。 + rows, err := q.Clone(). + Select(group.FieldID, group.FieldSortOrder). Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, nil, err } - groupIDs := make([]int64, 0, len(groups)) - outGroups := make([]service.Group, 0, len(groups)) - for i := range groups { - g := groupEntityToService(groups[i]) - outGroups = append(outGroups, *g) - groupIDs = append(groupIDs, g.ID) + type sortEntry struct { + id int64 + sortOrder int + accountCount int64 + } + entries := make([]sortEntry, 0, len(rows)) + groupIDs := make([]int64, len(rows)) + for i, r := range rows { + groupIDs[i] = r.ID + entries = append(entries, sortEntry{id: r.ID, sortOrder: r.SortOrder}) } + // 第二步:批量加载 account counts(一次 SQL)。 counts, err := r.loadAccountCounts(ctx, groupIDs) if err != nil { return nil, nil, err } - for i := range outGroups { - c := counts[outGroups[i].ID] - outGroups[i].AccountCount = c.Total - outGroups[i].ActiveAccountCount = c.Active - outGroups[i].RateLimitedAccountCount = c.RateLimited + for i := range entries { + c := counts[entries[i].id] + if c.Total > 0 { + entries[i].accountCount = c.Total + } } + // 第三步:Go 侧排序(数据量 = Group 总数,通常 < 200,安全)。 sortOrder := params.NormalizedSortOrder(pagination.SortOrderDesc) - sort.SliceStable(outGroups, func(i, j int) bool { - if outGroups[i].AccountCount == outGroups[j].AccountCount { - if outGroups[i].SortOrder == outGroups[j].SortOrder { - return outGroups[i].ID < outGroups[j].ID - } - return outGroups[i].SortOrder < outGroups[j].SortOrder + tieCmp := func(a, b sortEntry) bool { + if a.sortOrder == b.sortOrder { + return a.id < b.id + } + return a.sortOrder < b.sortOrder + } + sort.SliceStable(entries, func(i, j int) bool { + if entries[i].accountCount == entries[j].accountCount { + return tieCmp(entries[i], entries[j]) } if sortOrder == pagination.SortOrderAsc { - return outGroups[i].AccountCount < outGroups[j].AccountCount + return entries[i].accountCount < entries[j].accountCount } - return outGroups[i].AccountCount > outGroups[j].AccountCount + return entries[i].accountCount > entries[j].accountCount }) - return paginateSlice(outGroups, params), paginationResultFromTotal(int64(total), params), nil + // 第四步:分页,只加载当前页需要的完整 Group。 + page := paginateSlice(entries, params) + if len(page) == 0 { + return nil, paginationResultFromTotal(int64(total), params), nil + } + + pageIDs := make([]int64, len(page)) + pageIdx := make(map[int64]int, len(page)) + for i, e := range page { + pageIDs[i] = e.id + pageIdx[e.id] = i + } + + groups, err := r.client.Group.Query(). + Where(group.IDIn(pageIDs...)). + All(ctx) + if err != nil { + return nil, nil, err + } + + outGroups := make([]service.Group, len(page)) + for i := range groups { + g := groupEntityToService(groups[i]) + c := counts[g.ID] + g.AccountCount = c.Total + g.ActiveAccountCount = c.Active + g.RateLimitedAccountCount = c.RateLimited + if idx, ok := pageIdx[g.ID]; ok { + outGroups[idx] = *g + } + } + + return outGroups, paginationResultFromTotal(int64(total), params), nil } func groupListOrder(params pagination.PaginationParams) []func(*entsql.Selector) { diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index eeee5c23f40..7ef82f0cb30 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -44,6 +44,33 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) + requireColumn(t, tx, "usage_logs", "image_input_size", "character varying", 32, true) + requireColumn(t, tx, "usage_logs", "image_output_size", "character varying", 32, true) + requireColumn(t, tx, "usage_logs", "image_size_source", "character varying", 16, true) + requireColumn(t, tx, "usage_logs", "image_size_breakdown", "jsonb", 0, true) + requireConstraintDefinitionContains( + t, + tx, + "usage_logs", + "usage_logs_image_size_source_check", + "image_size_source", + "'output'", + "'input'", + "'default'", + "'legacy'", + ) + requireConstraintDefinitionContains( + t, + tx, + "usage_logs", + "usage_logs_image_billing_size_check", + "image_count", + "image_size IS NOT NULL", + "'1K'", + "'2K'", + "'4K'", + "'mixed'", + ) // usage_billing_dedup: billing idempotency narrow table var usageBillingDedupRegclass sql.NullString diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 07975970ef8..47c38d3e07d 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -30,6 +30,7 @@ func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemC SetStatus(code.Status). SetNotes(code.Notes). SetValidityDays(code.ValidityDays). + SetNillableExpiresAt(code.ExpiresAt). SetNillableUsedBy(code.UsedBy). SetNillableUsedAt(code.UsedAt). SetNillableGroupID(code.GroupID). @@ -56,6 +57,7 @@ func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service. SetStatus(c.Status). SetNotes(c.Notes). SetValidityDays(c.ValidityDays). + SetNillableExpiresAt(c.ExpiresAt). SetNillableUsedBy(c.UsedBy). SetNillableUsedAt(c.UsedAt). SetNillableGroupID(c.GroupID) @@ -107,7 +109,28 @@ func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagin q = q.Where(redeemcode.TypeEQ(codeType)) } if status != "" { - q = q.Where(redeemcode.StatusEQ(status)) + now := time.Now() + switch status { + case service.StatusExpired: + q = q.Where(redeemcode.Or( + redeemcode.StatusEQ(service.StatusExpired), + redeemcode.And( + redeemcode.StatusEQ(service.StatusUnused), + redeemcode.ExpiresAtNotNil(), + redeemcode.ExpiresAtLTE(now), + ), + )) + case service.StatusUnused: + q = q.Where( + redeemcode.StatusEQ(service.StatusUnused), + redeemcode.Or( + redeemcode.ExpiresAtIsNil(), + redeemcode.ExpiresAtGT(now), + ), + ) + default: + q = q.Where(redeemcode.StatusEQ(status)) + } } if search != "" { q = q.Where( @@ -158,6 +181,8 @@ func redeemCodeListOrder(params pagination.PaginationParams) []func(*entsql.Sele field = redeemcode.FieldUsedAt case "created_at": field = redeemcode.FieldCreatedAt + case "expires_at": + field = redeemcode.FieldExpiresAt case "code": field = redeemcode.FieldCode default: @@ -194,6 +219,11 @@ func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemC } else { up.ClearGroupID() } + if code.ExpiresAt != nil { + up.SetExpiresAt(*code.ExpiresAt) + } else { + up.ClearExpiresAt() + } updated, err := up.Save(ctx) if err != nil { @@ -307,6 +337,7 @@ func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode { UsedAt: m.UsedAt, Notes: derefString(m.Notes), CreatedAt: m.CreatedAt, + ExpiresAt: m.ExpiresAt, GroupID: m.GroupID, ValidityDays: m.ValidityDays, } diff --git a/backend/internal/repository/redeem_code_repo_integration_test.go b/backend/internal/repository/redeem_code_repo_integration_test.go index 39674b52c0a..24e5910ef9b 100644 --- a/backend/internal/repository/redeem_code_repo_integration_test.go +++ b/backend/internal/repository/redeem_code_repo_integration_test.go @@ -51,11 +51,13 @@ func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group { // --- Create / CreateBatch / GetByID / GetByCode --- func (s *RedeemCodeRepoSuite) TestCreate() { + expiresAt := time.Now().UTC().Add(2 * time.Hour) code := &service.RedeemCode{ - Code: "TEST-CREATE", - Type: service.RedeemTypeBalance, - Value: 100, - Status: service.StatusUnused, + Code: "TEST-CREATE", + Type: service.RedeemTypeBalance, + Value: 100, + Status: service.StatusUnused, + ExpiresAt: &expiresAt, } err := s.repo.Create(s.ctx, code) @@ -65,6 +67,8 @@ func (s *RedeemCodeRepoSuite) TestCreate() { got, err := s.repo.GetByID(s.ctx, code.ID) s.Require().NoError(err, "GetByID") s.Require().Equal("TEST-CREATE", got.Code) + s.Require().NotNil(got.ExpiresAt) + s.Require().WithinDuration(expiresAt, *got.ExpiresAt, time.Second) } func (s *RedeemCodeRepoSuite) TestCreateBatch() { @@ -166,6 +170,23 @@ func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() { s.Require().Equal(service.StatusUsed, codes[0].Status) } +func (s *RedeemCodeRepoSuite) TestListWithFilters_StatusExpiredByExpiresAt() { + past := time.Now().UTC().Add(-time.Hour) + future := time.Now().UTC().Add(time.Hour) + s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-EXPIRED-BY-TIME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &past})) + s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED-FUTURE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused, ExpiresAt: &future})) + + expired, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusExpired, "") + s.Require().NoError(err) + s.Require().Len(expired, 1) + s.Require().Equal("STAT-EXPIRED-BY-TIME", expired[0].Code) + + unused, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUnused, "") + s.Require().NoError(err) + s.Require().Len(unused, 1) + s.Require().Equal("STAT-UNUSED-FUTURE", unused[0].Code) +} + func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() { s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused})) s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused})) diff --git a/backend/internal/repository/scheduler_cache.go b/backend/internal/repository/scheduler_cache.go index 590ddaa36e2..ab01a8638ac 100644 --- a/backend/internal/repository/scheduler_cache.go +++ b/backend/internal/repository/scheduler_cache.go @@ -546,6 +546,8 @@ func filterSchedulerExtra(extra map[string]any) map[string]any { "responses_websockets_v2_enabled", "openai_ws_enabled", "openai_ws_force_http", + "openai_responses_mode", + "openai_responses_supported", } filtered := make(map[string]any) for _, key := range keys { diff --git a/backend/internal/repository/scheduler_cache_unit_test.go b/backend/internal/repository/scheduler_cache_unit_test.go index 33f3b581b51..86de87c7099 100644 --- a/backend/internal/repository/scheduler_cache_unit_test.go +++ b/backend/internal/repository/scheduler_cache_unit_test.go @@ -18,6 +18,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) { "openai_oauth_responses_websockets_v2_enabled": true, "openai_oauth_responses_websockets_v2_mode": service.OpenAIWSIngressModePassthrough, "openai_ws_force_http": true, + "openai_responses_mode": "force_chat_completions", + "openai_responses_supported": false, "mixed_scheduling": true, "unused_large_field": "drop-me", }, @@ -28,6 +30,8 @@ func TestBuildSchedulerMetadataAccount_KeepsOpenAIWSFlags(t *testing.T) { require.Equal(t, true, got.Extra["openai_oauth_responses_websockets_v2_enabled"]) require.Equal(t, service.OpenAIWSIngressModePassthrough, got.Extra["openai_oauth_responses_websockets_v2_mode"]) require.Equal(t, true, got.Extra["openai_ws_force_http"]) + require.Equal(t, "force_chat_completions", got.Extra["openai_responses_mode"]) + require.Equal(t, false, got.Extra["openai_responses_supported"]) require.Equal(t, true, got.Extra["mixed_scheduling"]) require.Nil(t, got.Extra["unused_large_field"]) } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index f2fb87da33e..f11910a08ef 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, image_input_size, image_output_size, image_size_source, image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -73,6 +73,10 @@ var usageLogInsertArgTypes = [...]string{ "text", // ip_address "integer", // image_count "text", // image_size + "text", // image_input_size + "text", // image_output_size + "text", // image_size_source + "jsonb", // image_size_breakdown "text", // service_tier "text", // reasoning_effort "text", // inbound_endpoint @@ -92,6 +96,22 @@ const rawUsageLogModelColumn = "model" // Historical rows may contain upstream/billing model values, while newer rows store requested_model. // Requested/upstream/mapping analytics must use resolveModelDimensionExpression instead. +// usageLogSuccessFilterUL 用于把"失败请求 usage log"(tokens=0、cost=0、不计费的占位记录) +// 从统计性聚合中排除,避免污染 Dashboard / 用量拆分等指标。 +// +// schema 中没有 success bool 列;新增列要做迁移,风险大;这里用 actual_cost > 0 作为代理: +// 任何成功落账的请求都会产生 actual_cost(包括 token 计费、纯图片 token 计费、按次/按图计费), +// 反之 failed-request usage log 的 actual_cost 为 0。 +// 早期版本用 4 项 token 和 > 0 判定会把"按次/按图计费"与"image_output_tokens 独立计费"的纯图片 +// 请求误判为失败,导致这部分请求从用量统计里消失,故改用 actual_cost。 +// 配合 `FROM usage_logs ul` JOIN 查询使用。 +const usageLogSuccessFilterUL = "ul.actual_cost > 0" + +// usageLogEffectivePlatformExpr 用于按"有效平台"维度聚合 usage_logs: +// 优先取请求实际走的分组 platform,若分组未设置 platform 再 fallback 到 account.platform。 +// 配套要求查询里 LEFT JOIN groups g ON g.id = ul.group_id 与 LEFT JOIN accounts a ON a.id = ul.account_id。 +const usageLogEffectivePlatformExpr = "COALESCE(NULLIF(g.platform,''), a.platform)" + // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ "hour": "YYYY-MM-DD HH24:00", @@ -120,6 +140,24 @@ func appendRawUsageLogModelWhereCondition(conditions []string, args []any, model return conditions, args } +func appendUsageLogBillingModeWhereCondition(conditions []string, args []any, billingMode string) ([]string, []any) { + mode := strings.TrimSpace(billingMode) + if mode == "" { + return conditions, args + } + placeholder := fmt.Sprintf("$%d", len(args)+1) + switch service.BillingMode(mode) { + case service.BillingModeImage: + conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR COALESCE(image_count, 0) > 0)", placeholder)) + case service.BillingModeToken: + conditions = append(conditions, fmt.Sprintf("(billing_mode = %s OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))", placeholder)) + default: + conditions = append(conditions, fmt.Sprintf("billing_mode = %s", placeholder)) + } + args = append(args, mode) + return conditions, args +} + // appendRawUsageLogModelQueryFilter keeps direct model filters on the raw model column for backward // compatibility with historical rows. Requested/upstream analytics must use // resolveModelDimensionExpression instead. @@ -352,6 +390,10 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -369,7 +411,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -790,6 +832,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -803,7 +849,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage created_at ) AS (VALUES `) - args := make([]any, 0, len(keys)*46) + args := make([]any, 0, len(keys)*50) argPos := 1 for idx, key := range keys { if idx > 0 { @@ -867,6 +913,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -915,6 +965,10 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -1003,6 +1057,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -1016,7 +1074,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*46) + args := make([]any, 0, len(preparedList)*50) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1077,6 +1135,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -1125,6 +1187,10 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -1181,6 +1247,10 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared ip_address, image_count, image_size, + image_input_size, + image_output_size, + image_size_source, + image_size_breakdown, service_tier, reasoning_effort, inbound_endpoint, @@ -1198,7 +1268,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46, $47, $48, $49, $50 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1225,6 +1295,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + imageInputSize := nullString(log.ImageInputSize) + imageOutputSize := nullString(log.ImageOutputSize) + imageSizeSource := nullString(log.ImageSizeSource) + imageSizeBreakdown := nullStringIntMapJSON(log.ImageSizeBreakdown) serviceTier := nullString(log.ServiceTier) reasoningEffort := nullString(log.ReasoningEffort) inboundEndpoint := nullString(log.InboundEndpoint) @@ -1285,6 +1359,10 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { ipAddress, log.ImageCount, imageSize, + imageInputSize, + imageOutputSize, + imageSizeSource, + imageSizeBreakdown, serviceTier, reasoningEffort, inboundEndpoint, @@ -2352,6 +2430,9 @@ func (r *usageLogRepository) GetUserSpendingRanking(ctx context.Context, startTi // UserDashboardStats 用户仪表盘统计 type UserDashboardStats = usagestats.UserDashboardStats +// PlatformDashboardStats 单平台用量明细 +type PlatformDashboardStats = usagestats.PlatformDashboardStats + // GetUserDashboardStats 获取用户专属的仪表盘统计 func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID int64) (*UserDashboardStats, error) { stats := &UserDashboardStats{} @@ -2447,6 +2528,57 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i stats.Rpm = rpm stats.Tpm = tpm + // 按"有效平台"维度拆分(group.platform 优先,否则 account.platform)。 + // 与 ops 路径口径一致;HAVING 过滤掉无法确定平台的行(避免出现空字符串平台)。 + // 与上面 totalStatsQuery/todayStatsQuery 的总值可能略微差异,原因有二: + // 1) 无平台归属的极少数行(group/account 都没 platform)会被 HAVING 排除; + // 2) usageLogSuccessFilterUL 会把 actual_cost = 0 的失败 placeholder 行排除, + // 而 totalStatsQuery/todayStatsQuery 没有这层过滤、会把这些行的 request 计数算进去。 + platformQuery := ` + SELECT + ` + usageLogEffectivePlatformExpr + ` as platform, + COUNT(*) as total_requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, + COALESCE(SUM(ul.actual_cost), 0) as total_actual_cost, + COUNT(*) FILTER (WHERE ul.created_at >= $2) as today_requests, + COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) FILTER (WHERE ul.created_at >= $2), 0) as today_tokens, + COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2), 0) as today_actual_cost + FROM usage_logs ul + LEFT JOIN groups g ON g.id = ul.group_id + LEFT JOIN accounts a ON a.id = ul.account_id + WHERE ul.user_id = $1 + AND ` + usageLogSuccessFilterUL + ` + GROUP BY ` + usageLogEffectivePlatformExpr + ` + HAVING ` + usageLogEffectivePlatformExpr + ` IS NOT NULL AND ` + usageLogEffectivePlatformExpr + ` <> '' + ORDER BY total_actual_cost DESC + ` + rows, err := r.sql.QueryContext(ctx, platformQuery, userID, today) + if err != nil { + return nil, err + } + for rows.Next() { + var p PlatformDashboardStats + if err := rows.Scan( + &p.Platform, + &p.TotalRequests, + &p.TotalTokens, + &p.TotalActualCost, + &p.TodayRequests, + &p.TodayTokens, + &p.TodayActualCost, + ); err != nil { + _ = rows.Close() + return nil, err + } + stats.ByPlatform = append(stats.ByPlatform, p) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return stats, nil } @@ -2662,10 +2794,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } - if filters.BillingMode != "" { - conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) - args = append(args, filters.BillingMode) - } + conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode) if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -2710,6 +2839,9 @@ type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats +// PlatformUsage represents per-platform usage breakdown +type PlatformUsage = usagestats.PlatformUsage + func normalizePositiveInt64IDs(ids []int64) []int64 { if len(ids) == 0 { return nil @@ -2750,15 +2882,21 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs result[id] = &BatchUserUsageStats{UserID: id} } + // GROUP BY (user_id, effective_platform) 一次查询同时得到总值与按平台拆分。 + // 应用层把同一 user_id 的多行累加为总值,并把非空 platform 行收集到 ByPlatform。 query := ` SELECT - user_id, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $2 AND created_at < $3), 0) as total_cost, - COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $4), 0) as today_cost - FROM usage_logs - WHERE user_id = ANY($1) - AND created_at >= LEAST($2, $4) - GROUP BY user_id + ul.user_id, + ` + usageLogEffectivePlatformExpr + ` as platform, + COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $2 AND ul.created_at < $3), 0) as total_cost, + COALESCE(SUM(ul.actual_cost) FILTER (WHERE ul.created_at >= $4), 0) as today_cost + FROM usage_logs ul + LEFT JOIN groups g ON g.id = ul.group_id + LEFT JOIN accounts a ON a.id = ul.account_id + WHERE ul.user_id = ANY($1) + AND ul.created_at >= LEAST($2, $4) + AND ` + usageLogSuccessFilterUL + ` + GROUP BY ul.user_id, ` + usageLogEffectivePlatformExpr + ` ` today := timezone.Today() rows, err := r.sql.QueryContext(ctx, query, pq.Array(normalizedUserIDs), startTime, endTime, today) @@ -2767,15 +2905,25 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs } for rows.Next() { var userID int64 + var platform sql.NullString var total float64 var todayTotal float64 - if err := rows.Scan(&userID, &total, &todayTotal); err != nil { + if err := rows.Scan(&userID, &platform, &total, &todayTotal); err != nil { _ = rows.Close() return nil, err } - if stats, ok := result[userID]; ok { - stats.TotalActualCost = total - stats.TodayActualCost = todayTotal + stats, ok := result[userID] + if !ok { + continue + } + stats.TotalActualCost += total + stats.TodayActualCost += todayTotal + if platform.Valid && platform.String != "" { + stats.ByPlatform = append(stats.ByPlatform, PlatformUsage{ + Platform: platform.String, + TotalActualCost: total, + TodayActualCost: todayTotal, + }) } } if err := rows.Close(); err != nil { @@ -3363,10 +3511,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) } - if filters.BillingMode != "" { - conditions = append(conditions, fmt.Sprintf("billing_mode = $%d", len(args)+1)) - args = append(args, filters.BillingMode) - } + conditions, args = appendUsageLogBillingModeWhereCondition(conditions, args, filters.BillingMode) if filters.StartTime != nil { conditions = append(conditions, fmt.Sprintf("created_at >= $%d", len(args)+1)) args = append(args, *filters.StartTime) @@ -4084,6 +4229,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString + imageInputSize sql.NullString + imageOutputSize sql.NullString + imageSizeSource sql.NullString + imageSizeBreakdown sql.NullString serviceTier sql.NullString reasoningEffort sql.NullString inboundEndpoint sql.NullString @@ -4134,6 +4283,10 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, + &imageInputSize, + &imageOutputSize, + &imageSizeSource, + &imageSizeBreakdown, &serviceTier, &reasoningEffort, &inboundEndpoint, @@ -4212,6 +4365,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if imageInputSize.Valid { + log.ImageInputSize = &imageInputSize.String + } + if imageOutputSize.Valid { + log.ImageOutputSize = &imageOutputSize.String + } + if imageSizeSource.Valid { + log.ImageSizeSource = &imageSizeSource.String + } + log.ImageSizeBreakdown = stringIntMapFromNullJSON(imageSizeBreakdown) if serviceTier.Valid { log.ServiceTier = &serviceTier.String } @@ -4378,6 +4541,31 @@ func nullString(v *string) sql.NullString { return sql.NullString{String: *v, Valid: true} } +func nullStringIntMapJSON(v map[string]int) any { + if len(v) == 0 { + return nil + } + payload, err := json.Marshal(v) + if err != nil { + return nil + } + return string(payload) +} + +func stringIntMapFromNullJSON(v sql.NullString) map[string]int { + if !v.Valid || strings.TrimSpace(v.String) == "" { + return nil + } + var out map[string]int + if err := json.Unmarshal([]byte(v.String), &out); err != nil { + return nil + } + if len(out) == 0 { + return nil + } + return out +} + func coalesceTrimmedString(v sql.NullString, fallback string) string { if v.Valid && strings.TrimSpace(v.String) != "" { return v.String diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index a5ff4bc177a..597c95971d8 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -76,6 +76,10 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // ip_address log.ImageCount, sqlmock.AnyArg(), // image_size + sqlmock.AnyArg(), // image_input_size + sqlmock.AnyArg(), // image_output_size + sqlmock.AnyArg(), // image_size_source + sqlmock.AnyArg(), // image_size_breakdown sqlmock.AnyArg(), // service_tier sqlmock.AnyArg(), // reasoning_effort sqlmock.AnyArg(), // inbound_endpoint @@ -155,6 +159,10 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), log.ImageCount, sqlmock.AnyArg(), + sqlmock.AnyArg(), // image_input_size + sqlmock.AnyArg(), // image_output_size + sqlmock.AnyArg(), // image_size_source + sqlmock.AnyArg(), // image_size_breakdown serviceTier, sqlmock.AnyArg(), sqlmock.AnyArg(), @@ -230,12 +238,74 @@ func TestPrepareUsageLogInsert_ArgCountMatchesTypes(t *testing.T) { require.Len(t, prepared.args, len(usageLogInsertArgTypes)) } +func TestPrepareUsageLogInsert_PersistsImageSizeMetadata(t *testing.T) { + imageSize := "4K" + inputSize := "1024x1024" + outputSize := "3840x2160" + source := "output" + prepared := prepareUsageLogInsert(&service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-image-metadata", + Model: "gpt-image-2", + RequestedModel: "gpt-image-2", + ImageCount: 2, + ImageSize: &imageSize, + ImageInputSize: &inputSize, + ImageOutputSize: &outputSize, + ImageSizeSource: &source, + ImageSizeBreakdown: map[string]int{"1K": 1, "4K": 1}, + CreatedAt: time.Date(2025, 1, 6, 12, 0, 0, 0, time.UTC), + }) + + require.Equal(t, sql.NullString{String: imageSize, Valid: true}, prepared.args[34]) + require.Equal(t, sql.NullString{String: inputSize, Valid: true}, prepared.args[35]) + require.Equal(t, sql.NullString{String: outputSize, Valid: true}, prepared.args[36]) + require.Equal(t, sql.NullString{String: source, Valid: true}, prepared.args[37]) + breakdownJSON, ok := prepared.args[38].(string) + require.True(t, ok) + require.JSONEq(t, `{"1K":1,"4K":1}`, breakdownJSON) +} + func TestCoalesceTrimmedString(t *testing.T) { require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{}, "fallback")) require.Equal(t, "fallback", coalesceTrimmedString(sql.NullString{Valid: true, String: " "}, "fallback")) require.Equal(t, "value", coalesceTrimmedString(sql.NullString{Valid: true, String: "value"}, "fallback")) } +func TestAppendUsageLogBillingModeWhereCondition(t *testing.T) { + tests := []struct { + name string + billingMode string + wantCondition string + }{ + { + name: "image includes legacy image rows", + billingMode: string(service.BillingModeImage), + wantCondition: "(billing_mode = $1 OR COALESCE(image_count, 0) > 0)", + }, + { + name: "token includes legacy non-image rows", + billingMode: string(service.BillingModeToken), + wantCondition: "(billing_mode = $1 OR ((billing_mode IS NULL OR billing_mode = '') AND COALESCE(image_count, 0) <= 0))", + }, + { + name: "per request remains exact", + billingMode: string(service.BillingModePerRequest), + wantCondition: "billing_mode = $1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conditions, args := appendUsageLogBillingModeWhereCondition(nil, nil, tt.billingMode) + require.Equal(t, []string{tt.wantCondition}, conditions) + require.Equal(t, []any{tt.billingMode}, args) + }) + } +} + func anySliceToDriverValues(values []any) []driver.Value { out := make([]driver.Value, 0, len(values)) for _, value := range values { @@ -528,6 +598,63 @@ func (s usageLogScannerStub) Scan(dest ...any) error { } func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { + t.Run("image_size_metadata_is_scanned", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(4), + int64(13), + int64(23), + int64(33), + sql.NullString{Valid: true, String: "req-image-metadata"}, + "gpt-image-2", + sql.NullString{Valid: true, String: "gpt-image-2"}, + sql.NullString{}, + sql.NullInt64{}, + sql.NullInt64{}, + 0, 0, 0, 0, 0, 0, + 0, 0.0, // image_output_tokens, image_output_cost + 0.0, 0.0, 0.0, 0.0, 0.8, 0.8, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeSync), + false, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 2, + sql.NullString{Valid: true, String: "4K"}, + sql.NullString{Valid: true, String: "1024x1024"}, + sql.NullString{Valid: true, String: "3840x2160"}, + sql.NullString{Valid: true, String: "output"}, + sql.NullString{Valid: true, String: `{"4K":2}`}, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + sql.NullFloat64{}, + now, + }}) + require.NoError(t, err) + require.Equal(t, 2, log.ImageCount) + require.NotNil(t, log.ImageSize) + require.Equal(t, "4K", *log.ImageSize) + require.NotNil(t, log.ImageInputSize) + require.Equal(t, "1024x1024", *log.ImageInputSize) + require.NotNil(t, log.ImageOutputSize) + require.Equal(t, "3840x2160", *log.ImageOutputSize) + require.NotNil(t, log.ImageSizeSource) + require.Equal(t, "output", *log.ImageSizeSource) + require.Equal(t, map[string]int{"4K": 2}, log.ImageSizeBreakdown) + }) + t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) { now := time.Now().UTC() log, err := scanUsageLog(usageLogScannerStub{values: []any{ @@ -567,6 +694,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, + sql.NullString{}, // image_input_size + sql.NullString{}, // image_output_size + sql.NullString{}, // image_size_source + sql.NullString{}, // image_size_breakdown sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, @@ -615,6 +746,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, + sql.NullString{}, // image_input_size + sql.NullString{}, // image_output_size + sql.NullString{}, // image_size_source + sql.NullString{}, // image_size_breakdown sql.NullString{Valid: true, String: "flex"}, sql.NullString{}, sql.NullString{}, @@ -663,6 +798,10 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, 0, sql.NullString{}, + sql.NullString{}, // image_input_size + sql.NullString{}, // image_output_size + sql.NullString{}, // image_size_source + sql.NullString{}, // image_size_breakdown sql.NullString{Valid: true, String: "priority"}, sql.NullString{}, sql.NullString{}, diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 1566756d2bd..610d9a7b990 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -334,7 +334,8 @@ func normalizeEmailAuthIdentitySubject(email string) string { } if strings.HasSuffix(normalized, service.LinuxDoConnectSyntheticEmailDomain) || strings.HasSuffix(normalized, service.OIDCConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) { + strings.HasSuffix(normalized, service.WeChatConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, service.DingTalkConnectSyntheticEmailDomain) { return "" } return normalized @@ -956,7 +957,7 @@ func userSignupSourceOrDefault(signupSource string) string { switch strings.TrimSpace(strings.ToLower(signupSource)) { case "", "email": return "email" - case "linuxdo", "wechat", "oidc": + case "linuxdo", "wechat", "oidc", "dingtalk": return strings.TrimSpace(strings.ToLower(signupSource)) default: return "email" diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 39869d4dffc..0d60ac9dcbc 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -68,6 +68,7 @@ func TestAPIContracts(t *testing.T) { "linuxdo_bound": false, "oidc_bound": false, "wechat_bound": false, + "dingtalk_bound": false, "identities": { "email": { "provider": "email", @@ -104,6 +105,14 @@ func TestAPIContracts(t *testing.T) { "can_bind": true, "can_unbind": false, "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "dingtalk": { + "provider": "dingtalk", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "identity_bindings": { @@ -142,6 +151,14 @@ func TestAPIContracts(t *testing.T) { "can_bind": true, "can_unbind": false, "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "dingtalk": { + "provider": "dingtalk", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "auth_bindings": { @@ -180,6 +197,14 @@ func TestAPIContracts(t *testing.T) { "can_bind": true, "can_unbind": false, "bind_start_path": "/api/v1/auth/oauth/wechat/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" + }, + "dingtalk": { + "provider": "dingtalk", + "bound": false, + "bound_count": 0, + "can_bind": true, + "can_unbind": false, + "bind_start_path": "/api/v1/auth/oauth/dingtalk/bind/start?intent=bind_current_user&redirect=%2Fsettings%2Fprofile" } }, "run_mode": "standard" @@ -554,6 +579,10 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "image_input_size": null, + "image_output_size": null, + "image_size_source": null, + "image_size_breakdown": null, "media_type": null, "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", @@ -672,6 +701,22 @@ func TestAPIContracts(t *testing.T) { "linuxdo_connect_client_id": "", "linuxdo_connect_client_secret_configured": false, "linuxdo_connect_redirect_url": "", + "dingtalk_connect_enabled": false, + "dingtalk_connect_bypass_registration": false, + "dingtalk_connect_client_id": "", + "dingtalk_connect_client_secret_configured": false, + "dingtalk_connect_redirect_url": "", + "dingtalk_connect_internal_corp_id": "", + "dingtalk_connect_corp_restriction_policy": "", + "dingtalk_connect_sync_corp_email": false, + "dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email", + "dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱", + "dingtalk_connect_sync_dept": false, + "dingtalk_connect_sync_dept_attr_key": "dingtalk_department", + "dingtalk_connect_sync_dept_attr_name": "钉钉部门", + "dingtalk_connect_sync_display_name": false, + "dingtalk_connect_sync_display_name_attr_key": "dingtalk_name", + "dingtalk_connect_sync_display_name_attr_name": "钉钉姓名", "oidc_connect_enabled": false, "oidc_connect_provider_name": "OIDC", "oidc_connect_client_id": "", @@ -744,6 +789,11 @@ func TestAPIContracts(t *testing.T) { "auth_source_default_wechat_subscriptions": [], "auth_source_default_wechat_grant_on_signup": false, "auth_source_default_wechat_grant_on_first_bind": false, + "auth_source_default_dingtalk_balance": 0, + "auth_source_default_dingtalk_concurrency": 5, + "auth_source_default_dingtalk_subscriptions": [], + "auth_source_default_dingtalk_grant_on_signup": false, + "auth_source_default_dingtalk_grant_on_first_bind": false, "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, @@ -784,14 +834,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": true, "openai_fast_policy_settings": { - "rules": [ - { - "service_tier": "priority", - "action": "filter", - "scope": "all", - "fallback_action": "pass" - } - ] + "rules": [] }, "custom_menu_items": [], "custom_endpoints": [], @@ -917,6 +960,22 @@ func TestAPIContracts(t *testing.T) { "linuxdo_connect_client_id": "", "linuxdo_connect_client_secret_configured": false, "linuxdo_connect_redirect_url": "", + "dingtalk_connect_enabled": false, + "dingtalk_connect_bypass_registration": false, + "dingtalk_connect_client_id": "", + "dingtalk_connect_client_secret_configured": false, + "dingtalk_connect_redirect_url": "", + "dingtalk_connect_internal_corp_id": "", + "dingtalk_connect_corp_restriction_policy": "", + "dingtalk_connect_sync_corp_email": false, + "dingtalk_connect_sync_corp_email_attr_key": "dingtalk_email", + "dingtalk_connect_sync_corp_email_attr_name": "钉钉企业邮箱", + "dingtalk_connect_sync_dept": false, + "dingtalk_connect_sync_dept_attr_key": "dingtalk_department", + "dingtalk_connect_sync_dept_attr_name": "钉钉部门", + "dingtalk_connect_sync_display_name": false, + "dingtalk_connect_sync_display_name_attr_key": "dingtalk_name", + "dingtalk_connect_sync_display_name_attr_name": "钉钉姓名", "oidc_connect_enabled": true, "oidc_connect_provider_name": "ConfigOIDC", "oidc_connect_client_id": "oidc-config-client", @@ -999,14 +1058,7 @@ func TestAPIContracts(t *testing.T) { "payment_visible_method_wxpay_enabled": false, "openai_advanced_scheduler_enabled": false, "openai_fast_policy_settings": { - "rules": [ - { - "service_tier": "priority", - "action": "filter", - "scope": "all", - "fallback_action": "pass" - } - ] + "rules": [] }, "payment_enabled": false, "payment_min_amount": 0, @@ -1084,6 +1136,11 @@ func TestAPIContracts(t *testing.T) { "auth_source_default_wechat_subscriptions": [], "auth_source_default_wechat_grant_on_signup": false, "auth_source_default_wechat_grant_on_first_bind": false, + "auth_source_default_dingtalk_balance": 0, + "auth_source_default_dingtalk_concurrency": 5, + "auth_source_default_dingtalk_subscriptions": [], + "auth_source_default_dingtalk_grant_on_signup": false, + "auth_source_default_dingtalk_grant_on_first_bind": false, "force_email_on_third_party_signup": false } }`, @@ -1194,10 +1251,10 @@ func newContractDeps(t *testing.T) *contractDeps { settingService := service.NewSettingService(settingRepo, cfg) adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) - authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 972c1eafa5e..c15f534e372 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -92,6 +92,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti clientIP := ip.GetTrustedClientIP(c) allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { + service.MarkOpsClientBusinessLimited(c, service.OpsClientBusinessLimitedReasonIPRestriction) AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") return } diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 4a4ab0f9817..d6760d8d099 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -333,6 +333,15 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() require.NoError(t, router.SetTrustedProxies(nil)) + var markedBusinessLimited bool + var businessLimitedReason string + router.Use(func(c *gin.Context) { + c.Next() + markedBusinessLimited = service.HasOpsClientBusinessLimited(c) + if v, ok := c.Get(service.OpsClientBusinessLimitedReasonKey); ok { + businessLimitedReason, _ = v.(string) + } + }) router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.GET("/t", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"ok": true}) @@ -349,6 +358,8 @@ func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) require.Equal(t, http.StatusForbidden, w.Code) require.Contains(t, w.Body.String(), "ACCESS_DENIED") + require.True(t, markedBusinessLimited) + require.Equal(t, service.OpsClientBusinessLimitedReasonIPRestriction, businessLimitedReason) } func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) { diff --git a/backend/internal/server/middleware/backend_mode_guard.go b/backend/internal/server/middleware/backend_mode_guard.go index 157f06b0983..050e3bc669b 100644 --- a/backend/internal/server/middleware/backend_mode_guard.go +++ b/backend/internal/server/middleware/backend_mode_guard.go @@ -42,15 +42,19 @@ func backendModeAllowsAuthPath(path string) bool { "/auth/oauth/oidc/callback", "/auth/oauth/github/callback", "/auth/oauth/google/callback", + "/auth/oauth/dingtalk/callback", "/auth/oauth/linuxdo/complete-registration", "/auth/oauth/wechat/complete-registration", "/auth/oauth/oidc/complete-registration", + "/auth/oauth/dingtalk/complete-registration", "/auth/oauth/linuxdo/create-account", "/auth/oauth/wechat/create-account", "/auth/oauth/oidc/create-account", + "/auth/oauth/dingtalk/create-account", "/auth/oauth/linuxdo/bind-login", "/auth/oauth/wechat/bind-login", "/auth/oauth/oidc/bind-login", + "/auth/oauth/dingtalk/bind-login", } { if strings.HasSuffix(path, suffix) { return true diff --git a/backend/internal/server/middleware/backend_mode_guard_test.go b/backend/internal/server/middleware/backend_mode_guard_test.go index de9c9ec9dbd..df2edde6a87 100644 --- a/backend/internal/server/middleware/backend_mode_guard_test.go +++ b/backend/internal/server/middleware/backend_mode_guard_test.go @@ -270,6 +270,36 @@ func TestBackendModeAuthGuard(t *testing.T) { path: "/api/v1/auth/oauth/google/callback", wantStatus: http.StatusOK, }, + { + name: "enabled_blocks_dingtalk_oauth_start", + enabled: "true", + path: "/api/v1/auth/oauth/dingtalk/start", + wantStatus: http.StatusForbidden, + }, + { + name: "enabled_allows_dingtalk_oauth_callback", + enabled: "true", + path: "/api/v1/auth/oauth/dingtalk/callback", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_dingtalk_complete_registration", + enabled: "true", + path: "/api/v1/auth/oauth/dingtalk/complete-registration", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_dingtalk_create_account", + enabled: "true", + path: "/api/v1/auth/oauth/dingtalk/create-account", + wantStatus: http.StatusOK, + }, + { + name: "enabled_allows_dingtalk_bind_login", + enabled: "true", + path: "/api/v1/auth/oauth/dingtalk/bind-login", + wantStatus: http.StatusOK, + }, { name: "enabled_allows_oauth_pending_exchange", enabled: "true", diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 6e1059bc829..92e2f5b63d4 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -303,6 +303,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) accounts.GET("/:id/models", h.Admin.Account.GetAvailableModels) + accounts.POST("/:id/models/sync-upstream", h.Admin.Account.SyncUpstreamModels) accounts.POST("/batch", h.Admin.Account.BatchCreate) accounts.GET("/data", h.Admin.Account.ExportData) accounts.POST("/data", h.Admin.Account.ImportData) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 54d40e921b3..19d0fd2a774 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -182,6 +182,32 @@ func RegisterAuthRoutes( }), h.Auth.CreateOIDCOAuthAccount, ) + auth.GET("/oauth/dingtalk/start", h.Auth.DingTalkOAuthStart) + auth.GET("/oauth/dingtalk/bind/start", func(c *gin.Context) { + query := c.Request.URL.Query() + query.Set("intent", "bind_current_user") + c.Request.URL.RawQuery = query.Encode() + h.Auth.DingTalkOAuthStart(c) + }) + auth.GET("/oauth/dingtalk/callback", h.Auth.DingTalkOAuthCallback) + auth.POST("/oauth/dingtalk/complete-registration", + rateLimiter.LimitWithOptions("oauth-dingtalk-complete", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CompleteDingTalkOAuthRegistration, + ) + auth.POST("/oauth/dingtalk/bind-login", + rateLimiter.LimitWithOptions("oauth-dingtalk-bind-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.BindDingTalkOAuthLogin, + ) + auth.POST("/oauth/dingtalk/create-account", + rateLimiter.LimitWithOptions("oauth-dingtalk-create-account", 10, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), + h.Auth.CreateDingTalkOAuthAccount, + ) } // 公开设置(无需认证) diff --git a/backend/internal/service/account_credentials_redact.go b/backend/internal/service/account_credentials_redact.go new file mode 100644 index 00000000000..76c2d1de5be --- /dev/null +++ b/backend/internal/service/account_credentials_redact.go @@ -0,0 +1,50 @@ +package service + +// SensitiveCredentialKeys 列出 Account.Credentials JSON map 中绝不允许返回到前端的子键。 +// dto 层做响应脱敏、service 层做更新合并都引用此清单——新增凭证类型时务必同步。 +var SensitiveCredentialKeys = []string{ + // OAuth + "access_token", "refresh_token", "id_token", + // API Key 类 + "api_key", "session_key", "cookie", + // 云服务凭据 + "aws_secret_access_key", "aws_session_token", + "service_account_json", "service_account", "private_key", +} + +var sensitiveCredentialKeySet = func() map[string]struct{} { + m := make(map[string]struct{}, len(SensitiveCredentialKeys)) + for _, k := range SensitiveCredentialKeys { + m[k] = struct{}{} + } + return m +}() + +// IsSensitiveCredentialKey 判断指定键是否为敏感凭证子键。 +func IsSensitiveCredentialKey(key string) bool { + _, ok := sensitiveCredentialKeySet[key] + return ok +} + +// MergePreservingSensitiveCreds 把 incoming 写入 existing 之上,但敏感子键采用"incoming 没提供就保留 existing" +// 的语义。返回新的 map,不修改入参。 +// +// 用途:前端编辑账号通常采用"全对象 PUT"模式;脱敏后前端 spread 旧 credentials 时不会带上敏感键, +// 直接覆盖会清空已有 token。此函数保证: +// - 非敏感键:完全由 incoming 决定(用户可以编辑、删除非敏感字段)。 +// - 敏感键:incoming 显式提供则覆盖(用户主动旋转 token),否则保留 existing。 +func MergePreservingSensitiveCreds(existing, incoming map[string]any) map[string]any { + out := make(map[string]any, len(incoming)+len(SensitiveCredentialKeys)) + for k, v := range incoming { + out[k] = v + } + for _, key := range SensitiveCredentialKeys { + if _, hasIncoming := incoming[key]; hasIncoming { + continue + } + if existingVal, ok := existing[key]; ok { + out[key] = existingVal + } + } + return out +} diff --git a/backend/internal/service/account_credentials_redact_test.go b/backend/internal/service/account_credentials_redact_test.go new file mode 100644 index 00000000000..05f37da9da2 --- /dev/null +++ b/backend/internal/service/account_credentials_redact_test.go @@ -0,0 +1,90 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMergePreservingSensitiveCreds_PreservesSensitiveWhenIncomingMissing(t *testing.T) { + existing := map[string]any{ + "refresh_token": "rt-old", + "access_token": "at-old", + "api_key": "sk-old", + "base_url": "https://old.example.com", + } + incoming := map[string]any{ + "base_url": "https://new.example.com", + "model_mapping": map[string]any{"foo": "bar"}, + } + + out := MergePreservingSensitiveCreds(existing, incoming) + + require.Equal(t, "rt-old", out["refresh_token"], "incoming 没传 refresh_token,应保留 existing") + require.Equal(t, "at-old", out["access_token"]) + require.Equal(t, "sk-old", out["api_key"]) + require.Equal(t, "https://new.example.com", out["base_url"], "非敏感键由 incoming 决定") + require.Equal(t, map[string]any{"foo": "bar"}, out["model_mapping"]) +} + +func TestMergePreservingSensitiveCreds_OverwritesWhenIncomingProvidesSensitive(t *testing.T) { + existing := map[string]any{ + "refresh_token": "rt-old", + "api_key": "sk-old", + } + incoming := map[string]any{ + "refresh_token": "rt-new", + // 显式没传 api_key —— 应保留 + } + out := MergePreservingSensitiveCreds(existing, incoming) + require.Equal(t, "rt-new", out["refresh_token"], "incoming 显式传入应覆盖") + require.Equal(t, "sk-old", out["api_key"], "incoming 没传应保留") +} + +func TestMergePreservingSensitiveCreds_DoesNotMutateInputs(t *testing.T) { + existing := map[string]any{"refresh_token": "rt"} + incoming := map[string]any{"base_url": "x"} + + _ = MergePreservingSensitiveCreds(existing, incoming) + + require.Equal(t, "rt", existing["refresh_token"]) + require.NotContains(t, existing, "base_url") + require.Equal(t, "x", incoming["base_url"]) + require.NotContains(t, incoming, "refresh_token") +} + +func TestMergePreservingSensitiveCreds_NilInputs(t *testing.T) { + out := MergePreservingSensitiveCreds(nil, map[string]any{"base_url": "x"}) + require.Equal(t, "x", out["base_url"]) + require.NotContains(t, out, "refresh_token") + + out2 := MergePreservingSensitiveCreds(map[string]any{"refresh_token": "rt"}, nil) + require.Equal(t, "rt", out2["refresh_token"]) +} + +func TestMergePreservingSensitiveCreds_NonSensitiveDeletionAllowed(t *testing.T) { + existing := map[string]any{ + "refresh_token": "rt", + "base_url": "https://old", + "project_id": "p1", + } + incoming := map[string]any{ + "base_url": "https://new", + // 不带 project_id —— 等同删除(非敏感键由 incoming 决定) + } + out := MergePreservingSensitiveCreds(existing, incoming) + require.Equal(t, "rt", out["refresh_token"], "敏感键保留") + require.Equal(t, "https://new", out["base_url"]) + require.NotContains(t, out, "project_id", "非敏感键 incoming 不传 = 删除") +} + +func TestIsSensitiveCredentialKey(t *testing.T) { + require.True(t, IsSensitiveCredentialKey("refresh_token")) + require.True(t, IsSensitiveCredentialKey("api_key")) + require.True(t, IsSensitiveCredentialKey("private_key")) + require.False(t, IsSensitiveCredentialKey("base_url")) + require.False(t, IsSensitiveCredentialKey("")) + require.False(t, IsSensitiveCredentialKey("model_mapping")) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 68ba8f8ce98..1c871d2ba48 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -295,14 +295,16 @@ func NewAccountUsageService( // OAuth账号: 调用Anthropic API获取真实数据(需要profile scope),API响应缓存10分钟,窗口统计缓存1分钟 // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) // API Key账号: 不支持usage查询 -func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { +func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64, force ...bool) (*UsageInfo, error) { + forceProbe := len(force) > 0 && force[0] + account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { return nil, fmt.Errorf("get account failed: %w", err) } if account.Platform == PlatformOpenAI && account.Type == AccountTypeOAuth { - usage, err := s.getOpenAIUsage(ctx, account) + usage, err := s.getOpenAIUsage(ctx, account, forceProbe) if err == nil { s.tryClearRecoverableAccountError(ctx, account) } @@ -492,7 +494,7 @@ func (s *AccountUsageService) syncActiveToPassive(ctx context.Context, accountID } } -func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account) (*UsageInfo, error) { +func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Account, force bool) (*UsageInfo, error) { now := time.Now() usage := &UsageInfo{UpdatedAt: &now} @@ -507,7 +509,7 @@ func (s *AccountUsageService) getOpenAIUsage(ctx context.Context, account *Accou usage.SevenDay = progress } - if shouldRefreshOpenAICodexSnapshot(account, usage, now) && s.shouldProbeOpenAICodexSnapshot(account.ID, now) { + if (force || shouldRefreshOpenAICodexSnapshot(account, usage, now)) && s.shouldProbeOpenAICodexSnapshot(account.ID, now, force) { if updates, err := s.probeOpenAICodexSnapshot(ctx, account); err == nil && len(updates) > 0 { mergeAccountExtra(account, updates) if usage.UpdatedAt == nil { @@ -577,13 +579,16 @@ func isOpenAICodexSnapshotStale(account *Account, now time.Time) bool { return now.Sub(ts) >= openAIProbeCacheTTL } -func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time) bool { +func (s *AccountUsageService) shouldProbeOpenAICodexSnapshot(accountID int64, now time.Time, force ...bool) bool { if s == nil || s.cache == nil || accountID <= 0 { return true } - if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok { - if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL { - return false + forceProbe := len(force) > 0 && force[0] + if !forceProbe { + if cached, ok := s.cache.openAIProbeCache.Load(accountID); ok { + if ts, ok := cached.(time.Time); ok && now.Sub(ts) < openAIProbeCacheTTL { + return false + } } } s.cache.openAIProbeCache.Store(accountID, now) diff --git a/backend/internal/service/account_usage_service_test.go b/backend/internal/service/account_usage_service_test.go index 28b49838a5b..e0390c4c2a7 100644 --- a/backend/internal/service/account_usage_service_test.go +++ b/backend/internal/service/account_usage_service_test.go @@ -140,7 +140,7 @@ func TestAccountUsageService_GetOpenAIUsage_DoesNotPromoteCodexExtraToRateLimit( }, } - usage, err := svc.getOpenAIUsage(context.Background(), account) + usage, err := svc.getOpenAIUsage(context.Background(), account, false) if err != nil { t.Fatalf("getOpenAIUsage() error = %v", err) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index eb5994d5498..843af01358b 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -397,6 +397,7 @@ type GenerateRedeemCodesInput struct { Value float64 GroupID *int64 // 订阅类型专用:关联的分组ID ValidityDays int // 订阅类型专用:有效天数 + ExpiresAt *time.Time } type ProxyBatchDeleteResult struct { @@ -1238,7 +1239,7 @@ func (s *adminServiceImpl) BindUserAuthIdentity(ctx context.Context, userID int6 providerKey := strings.TrimSpace(input.ProviderKey) providerSubject := strings.TrimSpace(input.ProviderSubject) if providerType == "" { - return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, or wechat") + return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type must be one of email, linuxdo, oidc, wechat, or dingtalk") } if providerKey == "" || providerSubject == "" { return nil, infraerrors.BadRequest("INVALID_INPUT", "provider_type, provider_key, and provider_subject are required") @@ -1493,6 +1494,8 @@ func normalizeAdminAuthIdentityProviderType(input string) string { return "oidc" case "wechat": return "wechat" + case "dingtalk": + return "dingtalk" default: return "" } @@ -2470,7 +2473,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Notes = normalizeAccountNotes(input.Notes) } if len(input.Credentials) > 0 { - account.Credentials = input.Credentials + // 敏感子键采用"incoming 没提供就保留"的合并语义:前端响应已脱敏, + // 全对象 PUT 编辑时不会再带回 token,避免覆盖时清空已有凭证。 + account.Credentials = MergePreservingSensitiveCreds(account.Credentials, input.Credentials) } // Extra 使用 map:需要区分“未提供(nil)”与“显式清空({})”。 // 关闭配额限制时前端会删除 quota_* 键并提交 extra:{},此时也必须落库。 @@ -2966,6 +2971,10 @@ func (s *adminServiceImpl) GetRedeemCode(ctx context.Context, id int64) (*Redeem } func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *GenerateRedeemCodesInput) ([]RedeemCode, error) { + if input.ExpiresAt != nil && !input.ExpiresAt.After(time.Now()) { + return nil, ErrRedeemCodeExpired + } + // 如果是订阅类型,验证必须有 GroupID if input.Type == RedeemTypeSubscription { if input.GroupID == nil { @@ -2988,10 +2997,11 @@ func (s *adminServiceImpl) GenerateRedeemCodes(ctx context.Context, input *Gener return nil, err } code := RedeemCode{ - Code: codeValue, - Type: input.Type, - Value: input.Value, - Status: StatusUnused, + Code: codeValue, + Type: input.Type, + Value: input.Value, + Status: StatusUnused, + ExpiresAt: input.ExpiresAt, } // 订阅类型专用字段 if input.Type == RedeemTypeSubscription { diff --git a/backend/internal/service/admin_service_credentials_merge_test.go b/backend/internal/service/admin_service_credentials_merge_test.go new file mode 100644 index 00000000000..8250db281cb --- /dev/null +++ b/backend/internal/service/admin_service_credentials_merge_test.go @@ -0,0 +1,117 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type updateAccountCredsRepoStub struct { + mockAccountRepoForGemini + account *Account + updateCalls int +} + +func (r *updateAccountCredsRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) { + return r.account, nil +} + +func (r *updateAccountCredsRepoStub) Update(ctx context.Context, account *Account) error { + r.updateCalls++ + r.account = account + return nil +} + +func TestUpdateAccount_PreservesSensitiveCredsWhenIncomingOmits(t *testing.T) { + accountID := int64(202) + repo := &updateAccountCredsRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Status: StatusActive, + Credentials: map[string]any{ + "refresh_token": "rt-existing", + "access_token": "at-existing", + "id_token": "id-existing", + "base_url": "https://old.example.com", + }, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + // 模拟前端编辑:仅修改 base_url,没有传 token(脱敏后前端 spread 拿不到敏感键) + updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + Credentials: map[string]any{ + "base_url": "https://new.example.com", + }, + }) + + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 1, repo.updateCalls) + + // 敏感键应保留 + require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"]) + require.Equal(t, "at-existing", repo.account.Credentials["access_token"]) + require.Equal(t, "id-existing", repo.account.Credentials["id_token"]) + // 非敏感键被替换 + require.Equal(t, "https://new.example.com", repo.account.Credentials["base_url"]) +} + +func TestUpdateAccount_ExplicitNewTokenOverwrites(t *testing.T) { + accountID := int64(203) + repo := &updateAccountCredsRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Status: StatusActive, + Credentials: map[string]any{ + "refresh_token": "rt-old", + "api_key": "sk-old", + }, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + Credentials: map[string]any{ + "refresh_token": "rt-new", + // api_key 没传 → 应保留旧值 + }, + }) + require.NoError(t, err) + require.NotNil(t, updated) + + require.Equal(t, "rt-new", repo.account.Credentials["refresh_token"]) + require.Equal(t, "sk-old", repo.account.Credentials["api_key"]) +} + +func TestUpdateAccount_EmptyCredentialsSkipsUpdate(t *testing.T) { + accountID := int64(204) + repo := &updateAccountCredsRepoStub{ + account: &Account{ + ID: accountID, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Status: StatusActive, + Credentials: map[string]any{ + "refresh_token": "rt-existing", + }, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + _, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{ + Credentials: map[string]any{}, // len == 0 → 闸门跳过 + Name: "renamed", + }) + require.NoError(t, err) + + require.Equal(t, "rt-existing", repo.account.Credentials["refresh_token"], "空 credentials 不应触碰已有 token") + require.Equal(t, "renamed", repo.account.Name) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index a76e59fbea9..cfae171da25 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -2094,7 +2094,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co } // 解析请求以获取 image_size(用于图片计费) - imageSize := s.extractImageSize(body) + imageInputSize := s.extractImageInputSize(body) + imageSize := normalizeOpenAIImageSizeTier(imageInputSize) switch action { case "generateContent", "streamGenerateContent": @@ -2465,6 +2466,7 @@ handleSuccess: ClientDisconnect: clientDisconnect, ImageCount: imageCount, ImageSize: imageSize, + ImageInputSize: imageInputSize, }, nil } @@ -4063,21 +4065,17 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } } -// extractImageSize 从 Gemini 请求中提取 image_size 参数 -func (s *AntigravityGatewayService) extractImageSize(body []byte) string { +func (s *AntigravityGatewayService) extractImageInputSize(body []byte) string { var req antigravity.GeminiRequest if err := json.Unmarshal(body, &req); err != nil { - return "2K" // 默认 2K + return "" } if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil { - size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)) - if size == "1K" || size == "2K" || size == "4K" { - return size - } + return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize) } - return "2K" // 默认 2K + return "" } // isImageGenerationModel 判断模型是否为图片生成模型 diff --git a/backend/internal/service/antigravity_image_test.go b/backend/internal/service/antigravity_image_test.go index 7fd2f84301b..76269dd3c33 100644 --- a/backend/internal/service/antigravity_image_test.go +++ b/backend/internal/service/antigravity_image_test.go @@ -46,15 +46,15 @@ func TestExtractImageSize_ValidSizes(t *testing.T) { // 1K body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`) - require.Equal(t, "1K", svc.extractImageSize(body)) + require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) // 2K body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) // 4K body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`) - require.Equal(t, "4K", svc.extractImageSize(body)) + require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } // TestExtractImageSize_CaseInsensitive 测试大小写不敏感 @@ -62,10 +62,10 @@ func TestExtractImageSize_CaseInsensitive(t *testing.T) { svc := &AntigravityGatewayService{} body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`) - require.Equal(t, "1K", svc.extractImageSize(body)) + require.Equal(t, "1K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`) - require.Equal(t, "4K", svc.extractImageSize(body)) + require.Equal(t, "4K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } // TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K @@ -74,15 +74,15 @@ func TestExtractImageSize_Default(t *testing.T) { // 无 generationConfig body := []byte(`{"contents":[]}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) // 有 generationConfig 但无 imageConfig body = []byte(`{"generationConfig":{"temperature":0.7}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) // 有 imageConfig 但无 imageSize body = []byte(`{"generationConfig":{"imageConfig":{}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } // TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K @@ -90,10 +90,10 @@ func TestExtractImageSize_InvalidJSON(t *testing.T) { svc := &AntigravityGatewayService{} body := []byte(`not valid json`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) body = []byte(`{"broken":`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } // TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K @@ -101,11 +101,11 @@ func TestExtractImageSize_EmptySize(t *testing.T) { svc := &AntigravityGatewayService{} body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) // 空格 body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } // TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K @@ -113,11 +113,11 @@ func TestExtractImageSize_InvalidSize(t *testing.T) { svc := &AntigravityGatewayService{} body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`) - require.Equal(t, "2K", svc.extractImageSize(body)) + require.Equal(t, "2K", NormalizeImageBillingTierOrDefault(svc.extractImageInputSize(body))) } diff --git a/backend/internal/service/auth_oauth_email_flow.go b/backend/internal/service/auth_oauth_email_flow.go index e3c8298c299..3478fda5045 100644 --- a/backend/internal/service/auth_oauth_email_flow.go +++ b/backend/internal/service/auth_oauth_email_flow.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "net/mail" "strings" "time" @@ -18,7 +19,7 @@ func normalizeOAuthSignupSource(signupSource string) string { switch signupSource { case "", "email": return "email" - case "linuxdo", "wechat", "oidc", "github", "google": + case "linuxdo", "wechat", "oidc", "github", "google", "dingtalk": return signupSource default: return "email" @@ -71,7 +72,7 @@ func (s *AuthService) validateOAuthRegistrationInvitation(ctx context.Context, i if err != nil { return nil, ErrInvitationCodeInvalid } - if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() { return nil, ErrInvitationCodeInvalid } return redeemCode, nil @@ -109,7 +110,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( if s == nil { return nil, nil, ErrServiceUnavailable } - if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) { return nil, nil, ErrRegDisabled } @@ -118,18 +119,22 @@ func (s *AuthService) RegisterOAuthEmailAccount( return nil, nil, ErrEmailReserved } if err := s.validateRegistrationEmailPolicy(ctx, email); err != nil { + slog.Error("oauth email register: policy rejected", "email", email, "error", err.Error()) return nil, nil, err } if err := s.VerifyOAuthEmailCode(ctx, email, verifyCode); err != nil { + slog.Error("oauth email register: verify code failed", "email", email, "error", err.Error()) return nil, nil, err } if _, err := s.validateOAuthRegistrationInvitation(ctx, invitationCode); err != nil { + slog.Error("oauth email register: invitation failed", "email", email, "error", err.Error()) return nil, nil, err } existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { + slog.Error("oauth email register: ExistsByEmail failed", "email", email, "error", err.Error()) return nil, nil, ErrServiceUnavailable } if existsEmail { @@ -158,6 +163,7 @@ func (s *AuthService) RegisterOAuthEmailAccount( if errors.Is(err, ErrEmailExists) { return nil, nil, ErrEmailExists } + slog.Error("oauth email register: userRepo.Create failed", "email", email, "signup_source", signupSource, "error", err.Error()) return nil, nil, ErrServiceUnavailable } @@ -181,7 +187,7 @@ func (s *AuthService) RegisterVerifiedOAuthEmailAccount( if s == nil { return nil, nil, ErrServiceUnavailable } - if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) { return nil, nil, ErrRegDisabled } @@ -358,6 +364,7 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit UsedAt: entity.UsedAt, Notes: oauthEmailFlowStringValue(entity.Notes), CreatedAt: entity.CreatedAt, + ExpiresAt: entity.ExpiresAt, GroupID: entity.GroupID, ValidityDays: entity.ValidityDays, }, nil @@ -368,7 +375,11 @@ func (s *AuthService) loadOAuthRegistrationInvitation(ctx context.Context, invit func (s *AuthService) useOAuthRegistrationInvitation(ctx context.Context, invitationID, userID int64) error { if client := s.oauthEmailFlowClient(ctx); client != nil { affected, err := client.RedeemCode.Update(). - Where(redeemcode.IDEQ(invitationID), redeemcode.StatusEQ(StatusUnused)). + Where( + redeemcode.IDEQ(invitationID), + redeemcode.StatusEQ(StatusUnused), + redeemcode.Or(redeemcode.ExpiresAtIsNil(), redeemcode.ExpiresAtGT(time.Now().UTC())), + ). SetStatus(StatusUsed). SetUsedBy(userID). SetUsedAt(time.Now().UTC()). @@ -396,6 +407,11 @@ func (s *AuthService) updateOAuthRegistrationInvitation(ctx context.Context, cod SetStatus(code.Status). SetNotes(code.Notes). SetValidityDays(code.ValidityDays) + if code.ExpiresAt != nil { + update = update.SetExpiresAt(*code.ExpiresAt) + } else { + update = update.ClearExpiresAt() + } if code.UsedBy != nil { update = update.SetUsedBy(*code.UsedBy) } else { diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index e01e8217ab7..ce2b3fa3779 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -157,7 +157,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrInvitationCodeInvalid } // 检查类型和状态 - if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() { logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) return "", nil, ErrInvitationCodeInvalid } @@ -560,11 +560,25 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } +// canBypassRegistrationDisabledForOAuth 在钉钉企业模式(internal_only)且 +// dingtalk_connect_bypass_registration=true 时,允许跳过全局 registration_enabled 检查。 +func (s *AuthService) canBypassRegistrationDisabledForOAuth(ctx context.Context, signupSource string) bool { + if signupSource != "dingtalk" { + return false + } + cfg, err := s.settingService.GetDingTalkConnectOAuthConfig(ctx) + if err != nil || !cfg.Enabled || !cfg.BypassRegistration { + return false + } + return cfg.CorpRestrictionPolicy == "internal_only" +} + // LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair。 // 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token。 // invitationCode 仅在邀请码注册模式下新用户注册时使用;已有账号登录时忽略。 // affiliateCode 用于邀请返利绑定,仅在新用户注册时使用。 -func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode string) (*TokenPair, *User, error) { +// signupSource 标识来源渠道("dingtalk"/"linuxdo"/"wechat"/"oidc" 等),仅用于豁免检查。 +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username, invitationCode, affiliateCode, signupSource string) (*TokenPair, *User, error) { // 检查 refreshTokenCache 是否可用 if s.refreshTokenCache == nil { return nil, nil, errors.New("refresh token cache not configured") @@ -587,7 +601,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if err != nil { if errors.Is(err, ErrUserNotFound) { // OAuth 首次登录视为注册 - if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + if s.settingService == nil || (!s.settingService.IsRegistrationEnabled(ctx) && !s.canBypassRegistrationDisabledForOAuth(ctx, signupSource)) { return nil, nil, ErrRegDisabled } @@ -601,7 +615,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if err != nil { return nil, nil, ErrInvitationCodeInvalid } - if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { + if redeemCode.Type != RedeemTypeInvitation || !redeemCode.CanUse() { return nil, nil, ErrInvitationCodeInvalid } invitationRedeemCode = redeemCode @@ -617,7 +631,11 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return nil, nil, fmt.Errorf("hash password: %w", err) } - signupSource := inferLegacySignupSource(email) + // 优先用 caller 显式传入的 signupSource(如 "dingtalk" / "linuxdo" / "oidc" / "wechat"), + // 否则才按邮箱后缀推断——避免有真实邮箱的 OAuth 用户被推断为 "email" 渠道,导致渠道授权错读。 + if strings.TrimSpace(signupSource) == "" { + signupSource = inferLegacySignupSource(email) + } grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) var defaultRPMLimit int if s.settingService != nil { @@ -779,6 +797,8 @@ func authSourceSignupSettings(defaults *AuthSourceDefaultSettings, signupSource return defaults.GitHub, true case "google": return defaults.Google, true + case "dingtalk": + return defaults.DingTalk, true default: return ProviderDefaultGrantSettings{}, false } @@ -992,6 +1012,8 @@ func (s *AuthService) ensureEmailAuthIdentity(ctx context.Context, user *User, s func inferLegacySignupSource(email string) string { normalized := strings.ToLower(strings.TrimSpace(email)) switch { + case strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain): + return "dingtalk" case strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain): return "linuxdo" case strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain): @@ -1086,7 +1108,8 @@ func isReservedEmail(email string) bool { normalized := strings.ToLower(strings.TrimSpace(email)) return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) || strings.HasSuffix(normalized, OIDCConnectSyntheticEmailDomain) || - strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) + strings.HasSuffix(normalized, WeChatConnectSyntheticEmailDomain) || + strings.HasSuffix(normalized, DingTalkConnectSyntheticEmailDomain) } // GenerateToken 生成JWT access token diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index acc44a381a8..ece0247415c 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -602,7 +602,7 @@ func TestAuthService_Register_GrantOnSignupMergesSourceOverridesWithGlobalDefaul require.NoError(t, err) require.NotNil(t, user) require.Equal(t, 9.5, user.Balance) - require.Equal(t, 2, user.Concurrency) + require.Equal(t, 5, user.Concurrency) require.Len(t, assigner.calls, 1) require.Equal(t, int64(31), assigner.calls[0].GroupID) require.Equal(t, 5, assigner.calls[0].ValidityDays) @@ -622,7 +622,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_UsesLinuxDoAuthSourceDefa service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "linuxdo_user", "", "", "linuxdo") require.NoError(t, err) require.NotNil(t, tokenPair) require.NotNil(t, user) @@ -658,7 +658,7 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA service.defaultSubAssigner = assigner service.refreshTokenCache = &refreshTokenCacheStub{} - tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "") + tokenPair, user, err := service.LoginOrRegisterOAuthWithTokenPair(context.Background(), existing.Email, "linuxdo_user", "", "", "linuxdo") require.NoError(t, err) require.NotNil(t, tokenPair) require.Equal(t, existing.ID, user.ID) @@ -667,3 +667,99 @@ func TestAuthService_LoginOrRegisterOAuthWithTokenPair_ExistingUserDoesNotGrantA require.Empty(t, repo.created) require.Empty(t, assigner.calls) } + +// newAuthServiceWithDingTalkCfg 构建一个含完整 DingTalk config 的 AuthService, +// 用于测试 canBypassRegistrationDisabledForOAuth。 +func newAuthServiceWithDingTalkCfg(settings map[string]string, dtCfg config.DingTalkConnectConfig) *AuthService { + cfg := &config.Config{ + JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}, + Default: config.DefaultConfig{UserBalance: 3.5, UserConcurrency: 2}, + DingTalk: dtCfg, + } + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + return NewAuthService(nil, nil, nil, nil, cfg, settingService, nil, nil, nil, nil, nil, nil) +} + +// minDingTalkURLs 返回一个包含必填字段的基础 DingTalkConnectConfig(不设 Enabled/BypassRegistration/Policy)。 +func minDingTalkURLs() config.DingTalkConnectConfig { + return config.DingTalkConnectConfig{ + ClientID: "test-client", + ClientSecret: "test-secret", + AuthorizeURL: "https://example.com/oauth2/auth", + TokenURL: "https://example.com/oauth2/token", + UserInfoURL: "https://example.com/oauth2/userinfo", + RedirectURL: "https://example.com/callback", + FrontendRedirectURL: "https://example.com/auth/callback", + DingTalkAppKind: "internal_app", + AppType: "internal", + } +} + +func TestCanBypassRegistrationDisabledForOAuth(t *testing.T) { + cases := []struct { + name string + signupSource string + settings map[string]string + dtCfg config.DingTalkConnectConfig + want bool + }{ + { + name: "non-dingtalk source → false", + signupSource: "linuxdo", + settings: map[string]string{}, + dtCfg: minDingTalkURLs(), + want: false, + }, + { + name: "dingtalk but cfg.Enabled=false → false", + signupSource: "dingtalk", + settings: map[string]string{ + SettingKeyDingTalkConnectEnabled: "false", + SettingKeyDingTalkConnectBypassRegistration: "true", + SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only", + }, + dtCfg: minDingTalkURLs(), + want: false, + }, + { + name: "dingtalk enabled but BypassRegistration=false → false", + signupSource: "dingtalk", + settings: map[string]string{ + SettingKeyDingTalkConnectEnabled: "true", + SettingKeyDingTalkConnectBypassRegistration: "false", + SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only", + }, + dtCfg: minDingTalkURLs(), + want: false, + }, + { + name: "dingtalk enabled + bypass=true but policy=none → false", + signupSource: "dingtalk", + settings: map[string]string{ + SettingKeyDingTalkConnectEnabled: "true", + SettingKeyDingTalkConnectBypassRegistration: "true", + SettingKeyDingTalkConnectCorpRestrictionPolicy: "none", + }, + dtCfg: minDingTalkURLs(), + want: false, + }, + { + name: "dingtalk enabled + bypass=true + policy=internal_only → true", + signupSource: "dingtalk", + settings: map[string]string{ + SettingKeyDingTalkConnectEnabled: "true", + SettingKeyDingTalkConnectBypassRegistration: "true", + SettingKeyDingTalkConnectCorpRestrictionPolicy: "internal_only", + }, + dtCfg: minDingTalkURLs(), + want: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc := newAuthServiceWithDingTalkCfg(tc.settings, tc.dtCfg) + got := svc.canBypassRegistrationDisabledForOAuth(context.Background(), tc.signupSource) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/backend/internal/service/auth_service_test.go b/backend/internal/service/auth_service_test.go new file mode 100644 index 00000000000..2aeb620548a --- /dev/null +++ b/backend/internal/service/auth_service_test.go @@ -0,0 +1,13 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsReservedEmail_DingTalkDomain(t *testing.T) { + require.True(t, isReservedEmail("dingtalk-123@dingtalk-connect.invalid")) + require.True(t, isReservedEmail("DINGTALK-456@DINGTALK-CONNECT.INVALID")) // case-insensitive + require.False(t, isReservedEmail("real@dingtalk.com")) +} diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 45025fe6bf6..47975c8cf7a 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -809,6 +809,7 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag if imageCount <= 0 { return &CostBreakdown{} } + imageSize = NormalizeImageBillingTierOrDefault(imageSize) // 获取单价 unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig) diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 8d3ca987787..0232a2580b2 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -48,6 +48,21 @@ func TestCalculateImageCost_GroupCustomPricing(t *testing.T) { require.InDelta(t, 0.30, cost.TotalCost, 0.0001) } +func TestCalculateImageCost_NormalizesInvalidSizeTo2K(t *testing.T) { + svc := &BillingService{} + + price2K := 0.25 + groupConfig := &ImagePriceConfig{Price2K: &price2K} + + for _, imageSize := range []string{"", "auto", "not-a-size"} { + t.Run(imageSize, func(t *testing.T) { + cost := svc.CalculateImageCost("gemini-3-pro-image", imageSize, 2, groupConfig, 1.0) + require.InDelta(t, 0.50, cost.TotalCost, 0.0001) + require.InDelta(t, 0.50, cost.ActualCost, 0.0001) + }) + } +} + // TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍 func TestCalculateImageCost_4KDoublePrice(t *testing.T) { svc := &BillingService{} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 158bf8a31bf..760f688d70e 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -262,10 +262,17 @@ func deepCopyFeaturesConfig(src map[string]any) map[string]any { } // ValidateIntervals 校验区间列表的合法性。 -// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; -// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); -// 无界区间(MaxTokens=nil)必须是最后一个。间隙允许(回退默认价格)。 -func ValidateIntervals(intervals []PricingInterval) error { +// +// mode 决定区间语义: +// - BillingModeToken(含空值):区间是上下文 token 数分段 (min, max], +// 按 MinTokens 排序后无重叠,无界区间(MaxTokens=nil)必须是最后一个。 +// - BillingModePerRequest / BillingModeImage:区间是按 tier_label +// (1K/2K/4K 等) 分层,匹配走 label 不依赖 min/max,因此跳过区间重叠 +// 与 last-unlimited 校验,仅做单条字段自洽(min/max/价格非负)检查。 +// +// 通用规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; +// 所有价格字段 >= 0。 +func ValidateIntervals(intervals []PricingInterval, mode BillingMode) error { if len(intervals) == 0 { return nil } @@ -280,6 +287,11 @@ func ValidateIntervals(intervals []PricingInterval) error { return err } } + + // per_request / image 模式按 tier_label 匹配,不做 token 区间重叠校验 + if mode == BillingModePerRequest || mode == BillingModeImage { + return nil + } return validateIntervalOverlap(sorted) } diff --git a/backend/internal/service/channel_available.go b/backend/internal/service/channel_available.go index 815730e3207..d2d24659a1a 100644 --- a/backend/internal/service/channel_available.go +++ b/backend/internal/service/channel_available.go @@ -103,7 +103,11 @@ func (s *ChannelService) ListAvailable(ctx context.Context) ([]AvailableChannel, } // fillGlobalPricingFallback 对未命中渠道定价的支持模型,从全局 LiteLLM 数据合成一份 -// 展示用定价(按 token 计费)。仅用于「可用渠道」展示,不影响真实计费链路。 +// 展示用定价。仅用于「可用渠道」展示,不影响真实计费链路。 +// +// 触发条件: +// 1. Pricing == nil(渠道完全没声明该模型的定价条目) +// 2. Pricing 非 nil 但所有价格字段为空(admin UI 建了条目但没填价格) // // 当 s.pricingService 为 nil(测试场景),跳过回落。 func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) { @@ -111,28 +115,72 @@ func (s *ChannelService) fillGlobalPricingFallback(models []SupportedModel) { return } for i := range models { - if models[i].Pricing != nil { + if !pricingNeedsFallback(models[i].Pricing) { continue } lp := s.pricingService.GetModelPricing(models[i].Name) if lp == nil { continue } - models[i].Pricing = synthesizePricingFromLiteLLM(lp) + models[i].Pricing = synthesizePricingFromLiteLLM(lp, models[i].Pricing) + } +} + +// pricingNeedsFallback 判定一个 ChannelModelPricing 是否需要走全局回落。 +// 价格全部缺失(无 flat 字段且无任何带价 interval)即视为未配置。 +func pricingNeedsFallback(p *ChannelModelPricing) bool { + if p == nil { + return true + } + if p.InputPrice != nil || p.OutputPrice != nil || + p.CacheWritePrice != nil || p.CacheReadPrice != nil || + p.ImageOutputPrice != nil || p.PerRequestPrice != nil { + return false + } + for _, iv := range p.Intervals { + if iv.InputPrice != nil || iv.OutputPrice != nil || + iv.CacheWritePrice != nil || iv.CacheReadPrice != nil || + iv.PerRequestPrice != nil { + return false + } } + return true } // synthesizePricingFromLiteLLM 把 LiteLLM 的定价数据转成 ChannelModelPricing 形态, -// 仅用于展示。BillingMode 固定为 token;图片场景的 OutputCostPerImageToken 也归到 -// ImageOutputPrice 字段(与渠道侧"图片输出按 token 计价"语义一致)。 +// 仅用于展示。 +// +// 计费模式优先级: +// 1. 渠道已选 BillingMode(admin 在 UI 里选了 image / per_request 但没填价的场景, +// 按选定模式合成对应字段) +// 2. LiteLLM mode="image_generation" → image +// 3. 默认 token // // LiteLLM 中字段 0 视为未配置,不带入展示。 -func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing) *ChannelModelPricing { +func synthesizePricingFromLiteLLM(lp *LiteLLMModelPricing, existing *ChannelModelPricing) *ChannelModelPricing { if lp == nil { - return nil + return existing + } + + mode := BillingModeToken + switch { + case existing != nil && existing.BillingMode != "": + mode = existing.BillingMode + case lp.Mode == "image_generation": + mode = BillingModeImage + } + + if mode == BillingModeImage || mode == BillingModePerRequest { + return &ChannelModelPricing{ + BillingMode: mode, + PerRequestPrice: nonZeroPtr(lp.OutputCostPerImage), + ImageOutputPrice: nonZeroPtr(lp.OutputCostPerImageToken), + InputPrice: nonZeroPtr(lp.InputCostPerToken), + OutputPrice: nonZeroPtr(lp.OutputCostPerToken), + } } return &ChannelModelPricing{ - BillingMode: BillingModeToken, + BillingMode: mode, InputPrice: nonZeroPtr(lp.InputCostPerToken), OutputPrice: nonZeroPtr(lp.OutputCostPerToken), CacheWritePrice: nonZeroPtr(lp.CacheCreationInputTokenCost), diff --git a/backend/internal/service/channel_available_test.go b/backend/internal/service/channel_available_test.go index 8be70ceb62e..d59e587ecd5 100644 --- a/backend/internal/service/channel_available_test.go +++ b/backend/internal/service/channel_available_test.go @@ -175,3 +175,137 @@ func TestListAvailable_DefaultsEmptyBillingModelSource(t *testing.T) { require.Equal(t, BillingModelSourceChannelMapped, byName["empty"]) require.Equal(t, BillingModelSourceUpstream, byName["explicit"]) } + +func TestPricingNeedsFallback(t *testing.T) { + tests := []struct { + name string + in *ChannelModelPricing + want bool + }{ + {"nil", nil, true}, + {"empty struct", &ChannelModelPricing{BillingMode: BillingModeToken}, true}, + {"all-empty intervals", &ChannelModelPricing{ + BillingMode: BillingModeImage, + Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}}, + }, true}, + {"flat input set", &ChannelModelPricing{InputPrice: testPtrFloat64(3e-6)}, false}, + {"flat per_request set", &ChannelModelPricing{PerRequestPrice: testPtrFloat64(0.04)}, false}, + {"interval with price", &ChannelModelPricing{ + Intervals: []PricingInterval{{TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}}, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, pricingNeedsFallback(tt.in)) + }) + } +} + +func TestSynthesizePricingFromLiteLLM_TokenMode(t *testing.T) { + lp := &LiteLLMModelPricing{ + Mode: "chat", + InputCostPerToken: 3e-6, + OutputCostPerToken: 1.5e-5, + CacheCreationInputTokenCost: 3.75e-6, + CacheReadInputTokenCost: 3e-7, + } + got := synthesizePricingFromLiteLLM(lp, nil) + require.NotNil(t, got) + require.Equal(t, BillingModeToken, got.BillingMode) + require.NotNil(t, got.InputPrice) + require.InDelta(t, 3e-6, *got.InputPrice, 1e-12) + require.NotNil(t, got.CacheReadPrice) +} + +func TestSynthesizePricingFromLiteLLM_ImageGenerationMode(t *testing.T) { + // LiteLLM mode=image_generation 且渠道未声明模式时,按 image 合成。 + lp := &LiteLLMModelPricing{ + Mode: "image_generation", + OutputCostPerImageToken: 4e-5, + } + got := synthesizePricingFromLiteLLM(lp, nil) + require.NotNil(t, got) + require.Equal(t, BillingModeImage, got.BillingMode) + require.Nil(t, got.PerRequestPrice) + require.NotNil(t, got.ImageOutputPrice) +} + +func TestSynthesizePricingFromLiteLLM_RespectsExistingChannelMode(t *testing.T) { + // admin UI 选了 per_request 但没填价:LiteLLM 数据按 per_request 合成, + // 即便 LiteLLM 标的是 chat 模式也尊重渠道选择。 + lp := &LiteLLMModelPricing{ + Mode: "chat", + InputCostPerToken: 5e-6, + OutputCostPerImage: 0.04, + } + existing := &ChannelModelPricing{BillingMode: BillingModePerRequest} + got := synthesizePricingFromLiteLLM(lp, existing) + require.NotNil(t, got) + require.Equal(t, BillingModePerRequest, got.BillingMode) + require.NotNil(t, got.PerRequestPrice) + require.InDelta(t, 0.04, *got.PerRequestPrice, 1e-12) +} + +func TestFillGlobalPricingFallback_NilPricing(t *testing.T) { + pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{ + "claude-opus-4-5": {Mode: "chat", InputCostPerToken: 5e-6}, + }) + svc := &ChannelService{pricingService: pricingSvc} + + models := []SupportedModel{ + {Name: "claude-opus-4-5", Platform: "anthropic"}, + } + svc.fillGlobalPricingFallback(models) + require.NotNil(t, models[0].Pricing) + require.NotNil(t, models[0].Pricing.InputPrice) + require.InDelta(t, 5e-6, *models[0].Pricing.InputPrice, 1e-12) +} + +func TestFillGlobalPricingFallback_EmptyPricingFillsFromLiteLLM(t *testing.T) { + // 核心场景:admin UI 建了 pricing 条目(image 模式)但没填价,应走 LiteLLM 兜底。 + pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{ + "gpt-image-1": { + Mode: "image_generation", + OutputCostPerImageToken: 4e-5, + }, + }) + svc := &ChannelService{pricingService: pricingSvc} + + models := []SupportedModel{ + { + Name: "gpt-image-1", + Platform: "openai", + Pricing: &ChannelModelPricing{ + BillingMode: BillingModeImage, + Intervals: []PricingInterval{{TierLabel: "1K"}, {TierLabel: "2K"}}, + }, + }, + } + svc.fillGlobalPricingFallback(models) + require.NotNil(t, models[0].Pricing) + require.Equal(t, BillingModeImage, models[0].Pricing.BillingMode) + require.NotNil(t, models[0].Pricing.ImageOutputPrice) + require.InDelta(t, 4e-5, *models[0].Pricing.ImageOutputPrice, 1e-12) +} + +func TestFillGlobalPricingFallback_KeepsExistingPrice(t *testing.T) { + // 渠道已经填了价格的条目不应被回落覆盖。 + pricingSvc := newStubPricingServiceFromMap(map[string]*LiteLLMModelPricing{ + "served-model": {Mode: "chat", InputCostPerToken: 1e-6}, + }) + svc := &ChannelService{pricingService: pricingSvc} + + existing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(9e-9), + } + models := []SupportedModel{ + {Name: "served-model", Platform: "anthropic", Pricing: existing}, + } + svc.fillGlobalPricingFallback(models) + require.Same(t, existing, models[0].Pricing) +} + +func newStubPricingServiceFromMap(data map[string]*LiteLLMModelPricing) *PricingService { + return &PricingService{pricingData: data} +} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 4e08df4a569..4bf0147f38c 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -951,7 +951,7 @@ func validateNoConflictingMappings(mapping map[string]map[string]string) error { func validatePricingIntervals(pricingList []ChannelModelPricing) error { for _, pricing := range pricingList { - if err := ValidateIntervals(pricing.Intervals); err != nil { + if err := ValidateIntervals(pricing.Intervals, pricing.BillingMode); err != nil { return infraerrors.BadRequest( "INVALID_PRICING_INTERVALS", fmt.Sprintf("invalid pricing intervals for platform '%s' models %v: %v", diff --git a/backend/internal/service/channel_test.go b/backend/internal/service/channel_test.go index 164861fb93d..2f371f8a1c8 100644 --- a/backend/internal/service/channel_test.go +++ b/backend/internal/service/channel_test.go @@ -311,8 +311,8 @@ func TestChannelClone_EdgeCases(t *testing.T) { // --- ValidateIntervals --- func TestValidateIntervals_Empty(t *testing.T) { - require.NoError(t, ValidateIntervals(nil)) - require.NoError(t, ValidateIntervals([]PricingInterval{})) + require.NoError(t, ValidateIntervals(nil, BillingModeToken)) + require.NoError(t, ValidateIntervals([]PricingInterval{}, BillingModeToken)) } func TestValidateIntervals_ValidIntervals(t *testing.T) { @@ -357,7 +357,7 @@ func TestValidateIntervals_ValidIntervals(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - require.NoError(t, ValidateIntervals(tt.intervals)) + require.NoError(t, ValidateIntervals(tt.intervals, BillingModeToken)) }) } } @@ -366,7 +366,7 @@ func TestValidateIntervals_NegativeMinTokens(t *testing.T) { intervals := []PricingInterval{ {MinTokens: -1, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "min_tokens") require.Contains(t, err.Error(), ">= 0") @@ -376,7 +376,7 @@ func TestValidateIntervals_MaxTokensZero(t *testing.T) { intervals := []PricingInterval{ {MinTokens: 0, MaxTokens: testPtrInt(0), InputPrice: testPtrFloat64(1e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "max_tokens") require.Contains(t, err.Error(), "> 0") @@ -386,7 +386,7 @@ func TestValidateIntervals_MaxLessThanMin(t *testing.T) { intervals := []PricingInterval{ {MinTokens: 100, MaxTokens: testPtrInt(50), InputPrice: testPtrFloat64(1e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "max_tokens") require.Contains(t, err.Error(), "> min_tokens") @@ -396,7 +396,7 @@ func TestValidateIntervals_MaxEqualsMin(t *testing.T) { intervals := []PricingInterval{ {MinTokens: 100, MaxTokens: testPtrInt(100), InputPrice: testPtrFloat64(1e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "max_tokens") require.Contains(t, err.Error(), "> min_tokens") @@ -407,7 +407,7 @@ func TestValidateIntervals_NegativePrice(t *testing.T) { intervals := []PricingInterval{ {MinTokens: 0, MaxTokens: testPtrInt(100), InputPrice: &negPrice}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "input_price") require.Contains(t, err.Error(), ">= 0") @@ -418,7 +418,7 @@ func TestValidateIntervals_OverlappingIntervals(t *testing.T) { {MinTokens: 0, MaxTokens: testPtrInt(200), InputPrice: testPtrFloat64(1e-6)}, {MinTokens: 100, MaxTokens: testPtrInt(300), InputPrice: testPtrFloat64(2e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "overlap") } @@ -428,12 +428,43 @@ func TestValidateIntervals_UnboundedNotLast(t *testing.T) { {MinTokens: 0, MaxTokens: nil, InputPrice: testPtrFloat64(1e-6)}, {MinTokens: 128000, MaxTokens: testPtrInt(256000), InputPrice: testPtrFloat64(2e-6)}, } - err := ValidateIntervals(intervals) + err := ValidateIntervals(intervals, BillingModeToken) require.Error(t, err) require.Contains(t, err.Error(), "unbounded") require.Contains(t, err.Error(), "last") } +func TestValidateIntervals_ImageModeAllowsMultipleUnboundedTiers(t *testing.T) { + // image / per_request 按 tier_label 匹配,多条 min=0/max=nil 是合法形态。 + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + {MinTokens: 0, MaxTokens: nil, TierLabel: "2K", PerRequestPrice: testPtrFloat64(0.06)}, + {MinTokens: 0, MaxTokens: nil, TierLabel: "4K", PerRequestPrice: testPtrFloat64(0.08)}, + } + require.NoError(t, ValidateIntervals(intervals, BillingModeImage)) + require.NoError(t, ValidateIntervals(intervals, BillingModePerRequest)) +} + +func TestValidateIntervals_ImageModeStillRejectsNegativePrice(t *testing.T) { + // image 模式只跳过区间重叠校验,单条字段自洽(价格非负)仍要校验。 + intervals := []PricingInterval{ + {MinTokens: 0, MaxTokens: nil, TierLabel: "1K", PerRequestPrice: testPtrFloat64(-1)}, + } + err := ValidateIntervals(intervals, BillingModeImage) + require.Error(t, err) + require.Contains(t, err.Error(), "must be >= 0") +} + +func TestValidateIntervals_ImageModeStillRejectsBadMaxTokens(t *testing.T) { + // image 模式仍校验 max <= min 这种单条不合法。 + intervals := []PricingInterval{ + {MinTokens: 100, MaxTokens: testPtrInt(50), TierLabel: "1K", PerRequestPrice: testPtrFloat64(0.04)}, + } + err := ValidateIntervals(intervals, BillingModeImage) + require.Error(t, err) + require.Contains(t, err.Error(), "must be > min_tokens") +} + func TestSupportedModels_ExactKeysAndPricing(t *testing.T) { ch := &Channel{ ModelPricing: []ChannelModelPricing{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 17c40ba1dc3..f39c5d7e617 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -92,6 +92,9 @@ const OIDCConnectSyntheticEmailDomain = "@oidc-connect.invalid" // WeChatConnectSyntheticEmailDomain 是 WeChat Connect 用户的合成邮箱后缀(RFC 保留域名)。 const WeChatConnectSyntheticEmailDomain = "@wechat-connect.invalid" +// DingTalkConnectSyntheticEmailDomain 是 DingTalk Connect 用户的合成邮箱后缀(RFC 保留域名)。 +const DingTalkConnectSyntheticEmailDomain = "@dingtalk-connect.invalid" + // Setting keys const ( // 注册设置 @@ -137,6 +140,24 @@ const ( SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" + // DingTalk Connect OAuth 登录设置 + SettingKeyDingTalkConnectEnabled = "dingtalk_connect_enabled" + SettingKeyDingTalkConnectClientID = "dingtalk_connect_client_id" + SettingKeyDingTalkConnectClientSecret = "dingtalk_connect_client_secret" + SettingKeyDingTalkConnectRedirectURL = "dingtalk_connect_redirect_url" + SettingKeyDingTalkConnectCorpRestrictionPolicy = "dingtalk_connect_corp_restriction_policy" + SettingKeyDingTalkConnectInternalCorpID = "dingtalk_connect_internal_corp_id" + SettingKeyDingTalkConnectBypassRegistration = "dingtalk_connect_bypass_registration" + SettingKeyDingTalkConnectSyncCorpEmail = "dingtalk_connect_sync_corp_email" + SettingKeyDingTalkConnectSyncDisplayName = "dingtalk_connect_sync_display_name" + SettingKeyDingTalkConnectSyncDept = "dingtalk_connect_sync_dept" + SettingKeyDingTalkConnectSyncCorpEmailAttrKey = "dingtalk_connect_sync_corp_email_attr_key" + SettingKeyDingTalkConnectSyncDisplayNameAttrKey = "dingtalk_connect_sync_display_name_attr_key" + SettingKeyDingTalkConnectSyncDeptAttrKey = "dingtalk_connect_sync_dept_attr_key" + SettingKeyDingTalkConnectSyncCorpEmailAttrName = "dingtalk_connect_sync_corp_email_attr_name" + SettingKeyDingTalkConnectSyncDisplayNameAttrName = "dingtalk_connect_sync_display_name_attr_name" + SettingKeyDingTalkConnectSyncDeptAttrName = "dingtalk_connect_sync_dept_attr_name" + // WeChat Connect OAuth 登录设置 SettingKeyWeChatConnectEnabled = "wechat_connect_enabled" SettingKeyWeChatConnectAppID = "wechat_connect_app_id" @@ -214,37 +235,42 @@ const ( SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制) // 第三方认证来源默认授予配置 - SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" - SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" - SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" - SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" - SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" - SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" - SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" - SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" - SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" - SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" - SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" - SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" - SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" - SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" - SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" - SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" - SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" - SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" - SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" - SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" - SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance" - SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency" - SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions" - SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup" - SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind" - SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance" - SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency" - SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions" - SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup" - SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind" - SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" + SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" + SettingKeyAuthSourceDefaultEmailConcurrency = "auth_source_default_email_concurrency" + SettingKeyAuthSourceDefaultEmailSubscriptions = "auth_source_default_email_subscriptions" + SettingKeyAuthSourceDefaultEmailGrantOnSignup = "auth_source_default_email_grant_on_signup" + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind = "auth_source_default_email_grant_on_first_bind" + SettingKeyAuthSourceDefaultLinuxDoBalance = "auth_source_default_linuxdo_balance" + SettingKeyAuthSourceDefaultLinuxDoConcurrency = "auth_source_default_linuxdo_concurrency" + SettingKeyAuthSourceDefaultLinuxDoSubscriptions = "auth_source_default_linuxdo_subscriptions" + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup = "auth_source_default_linuxdo_grant_on_signup" + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind = "auth_source_default_linuxdo_grant_on_first_bind" + SettingKeyAuthSourceDefaultOIDCBalance = "auth_source_default_oidc_balance" + SettingKeyAuthSourceDefaultOIDCConcurrency = "auth_source_default_oidc_concurrency" + SettingKeyAuthSourceDefaultOIDCSubscriptions = "auth_source_default_oidc_subscriptions" + SettingKeyAuthSourceDefaultOIDCGrantOnSignup = "auth_source_default_oidc_grant_on_signup" + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind = "auth_source_default_oidc_grant_on_first_bind" + SettingKeyAuthSourceDefaultWeChatBalance = "auth_source_default_wechat_balance" + SettingKeyAuthSourceDefaultWeChatConcurrency = "auth_source_default_wechat_concurrency" + SettingKeyAuthSourceDefaultWeChatSubscriptions = "auth_source_default_wechat_subscriptions" + SettingKeyAuthSourceDefaultWeChatGrantOnSignup = "auth_source_default_wechat_grant_on_signup" + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind = "auth_source_default_wechat_grant_on_first_bind" + SettingKeyAuthSourceDefaultGitHubBalance = "auth_source_default_github_balance" + SettingKeyAuthSourceDefaultGitHubConcurrency = "auth_source_default_github_concurrency" + SettingKeyAuthSourceDefaultGitHubSubscriptions = "auth_source_default_github_subscriptions" + SettingKeyAuthSourceDefaultGitHubGrantOnSignup = "auth_source_default_github_grant_on_signup" + SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind = "auth_source_default_github_grant_on_first_bind" + SettingKeyAuthSourceDefaultGoogleBalance = "auth_source_default_google_balance" + SettingKeyAuthSourceDefaultGoogleConcurrency = "auth_source_default_google_concurrency" + SettingKeyAuthSourceDefaultGoogleSubscriptions = "auth_source_default_google_subscriptions" + SettingKeyAuthSourceDefaultGoogleGrantOnSignup = "auth_source_default_google_grant_on_signup" + SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind = "auth_source_default_google_grant_on_first_bind" + SettingKeyAuthSourceDefaultDingTalkBalance = "auth_source_default_dingtalk_balance" + SettingKeyAuthSourceDefaultDingTalkConcurrency = "auth_source_default_dingtalk_concurrency" + SettingKeyAuthSourceDefaultDingTalkSubscriptions = "auth_source_default_dingtalk_subscriptions" + SettingKeyAuthSourceDefaultDingTalkGrantOnSignup = "auth_source_default_dingtalk_grant_on_signup" + SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind = "auth_source_default_dingtalk_grant_on_first_bind" + SettingKeyForceEmailOnThirdPartySignup = "force_email_on_third_party_signup" // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) diff --git a/backend/internal/service/gateway_record_usage_test.go b/backend/internal/service/gateway_record_usage_test.go index 140bdc67e29..09b07a5e67e 100644 --- a/backend/internal/service/gateway_record_usage_test.go +++ b/backend/internal/service/gateway_record_usage_test.go @@ -192,6 +192,46 @@ func TestGatewayServiceRecordUsage_PreservesRequestedAndUpstreamModels(t *testin require.Equal(t, mappedModel, *usageRepo.lastLog.UpstreamModel) } +func TestGatewayServiceRecordUsage_EmptyImageSizeDefaultsBeforeBillingAndPersistence(t *testing.T) { + imagePrice2K := 0.19 + groupID := int64(901) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newGatewayRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}) + + err := svc.RecordUsage(context.Background(), &RecordUsageInput{ + Result: &ForwardResult{ + RequestID: "gateway_image_default_size", + Model: "gemini-image", + ImageCount: 1, + ImageInputSize: "auto", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 801, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice2K: &imagePrice2K, + }, + }, + User: &User{ID: 601}, + Account: &Account{ID: 701}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 1, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.ImageSize) + require.Equal(t, ImageBillingSize2K, *usageRepo.lastLog.ImageSize) + require.NotNil(t, usageRepo.lastLog.ImageInputSize) + require.Equal(t, "auto", *usageRepo.lastLog.ImageInputSize) + require.NotNil(t, usageRepo.lastLog.ImageSizeSource) + require.Equal(t, ImageSizeSourceDefault, *usageRepo.lastLog.ImageSizeSource) + require.InDelta(t, 0.19, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.19, usageRepo.lastLog.ActualCost, 1e-12) +} + func TestGatewayServiceRecordUsage_UsageLogWriteErrorDoesNotSkipBilling(t *testing.T) { usageRepo := &openAIRecordUsageLogRepoStub{inserted: false, err: MarkUsageLogCreateNotPersisted(context.Canceled)} userRepo := &openAIRecordUsageUserRepoStub{} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b8cbf715881..8180e321efa 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -501,8 +501,13 @@ type ForwardResult struct { ReasoningEffort *string // 图片生成计费字段(图片生成模型使用) - ImageCount int // 生成的图片数量 - ImageSize string // 图片尺寸 "1K", "2K", "4K" + ImageCount int // 生成的图片数量 + ImageSize string // 最终计费尺寸 "1K", "2K", "4K" + ImageInputSize string // 请求中的原始图片尺寸 + ImageOutputSize string // 上游响应中的图片尺寸 + ImageOutputSizes []string + ImageSizeSource string + ImageSizeBreakdown map[string]int } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1397,7 +1402,6 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// 调度流程文档见 docs/ACCOUNT_SCHEDULING_FLOW.md 。 // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { @@ -8404,6 +8408,7 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage user := input.User account := input.Account subscription := input.Subscription + ApplyForwardImageBillingResolution(result) // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 @@ -8549,6 +8554,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { + sizeTier := NormalizeImageBillingTierOrDefault(result.ImageSize) if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, @@ -8562,7 +8568,7 @@ func (s *GatewayService) calculateImageCost( GroupID: &gid, Tokens: tokens, RequestCount: result.ImageCount, - SizeTier: result.ImageSize, + SizeTier: sizeTier, RateMultiplier: multiplier, Resolver: s.resolver, Resolved: resolved, @@ -8582,7 +8588,7 @@ func (s *GatewayService) calculateImageCost( Price4K: apiKey.Group.ImagePrice4K, } } - return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) + return s.billingService.CalculateImageCost(billingModel, sizeTier, result.ImageCount, groupConfig, multiplier) } // calculateTokenCost 计算 Token 计费:根据 opts 决定走普通/长上下文/渠道统一计费。 @@ -8683,6 +8689,10 @@ func (s *GatewayService) buildRecordUsageLog( FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), + ImageInputSize: optionalTrimmedStringPtr(result.ImageInputSize), + ImageOutputSize: optionalTrimmedStringPtr(result.ImageOutputSize), + ImageSizeSource: optionalTrimmedStringPtr(result.ImageSizeSource), + ImageSizeBreakdown: result.ImageSizeBreakdown, CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), diff --git a/backend/internal/service/gemini_chat_completions_compat_service.go b/backend/internal/service/gemini_chat_completions_compat_service.go new file mode 100644 index 00000000000..5ea02df5297 --- /dev/null +++ b/backend/internal/service/gemini_chat_completions_compat_service.go @@ -0,0 +1,890 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" +) + +// ForwardAsChatCompletions serves OpenAI Chat Completions clients through +// Gemini accounts. It keeps the client-facing response in Chat Completions +// format while routing the upstream call through Gemini native endpoints. +func (s *GeminiMessagesCompatService) ForwardAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, +) (*ForwardResult, error) { + startTime := time.Now() + + var ccReq apicompat.ChatCompletionsRequest + if err := json.Unmarshal(body, &ccReq); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + } + if strings.TrimSpace(ccReq.Model) == "" { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + } + + originalModel := ccReq.Model + clientStream := ccReq.Stream + includeUsage := ccReq.StreamOptions != nil && ccReq.StreamOptions.IncludeUsage + + responsesReq, err := apicompat.ChatCompletionsToResponses(&ccReq) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + + anthropicReq, err := apicompat.ResponsesToAnthropicRequest(responsesReq) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + anthropicReq.Stream = clientStream + + claudeBody, err := json.Marshal(anthropicReq) + if err != nil { + return nil, fmt.Errorf("marshal chat completions compat request: %w", err) + } + + return s.forwardClaudeBodyAsChatCompletions(ctx, c, account, claudeBody, originalModel, clientStream, includeUsage, startTime, body) +} + +func (s *GeminiMessagesCompatService) forwardClaudeBodyAsChatCompletions( + ctx context.Context, + c *gin.Context, + account *Account, + claudeBody []byte, + originalModel string, + clientStream bool, + includeUsage bool, + startTime time.Time, + originalChatBody []byte, +) (*ForwardResult, error) { + var req struct { + Model string `json:"model"` + Stream bool `json:"stream"` + } + if err := json.Unmarshal(claudeBody, &req); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + } + if strings.TrimSpace(req.Model) == "" { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", "model is required") + } + + mappedModel := req.Model + if account.Type == AccountTypeAPIKey || account.Type == AccountTypeServiceAccount { + mappedModel = account.GetMappedModel(req.Model) + } + + geminiReq, err := convertClaudeMessagesToGeminiGenerateContent(claudeBody) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + } + geminiReq = ensureGeminiFunctionCallThoughtSignatures(geminiReq) + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + useUpstreamStream := clientStream + if account.Type == AccountTypeOAuth && !clientStream && strings.TrimSpace(account.GetCredential("project_id")) != "" { + useUpstreamStream = true + } + + buildReq, requestIDHeader := s.buildGeminiChatCompletionsUpstreamRequestFunc( + account, + mappedModel, + geminiReq, + clientStream, + useUpstreamStream, + ) + + var resp *http.Response + for attempt := 1; attempt <= geminiMaxRetries; attempt++ { + upstreamReq, idHeader, err := buildReq(ctx) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return nil, err + } + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", err.Error()) + } + requestIDHeader = idHeader + + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(geminiReq)) + } + + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if attempt < geminiMaxRetries { + logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + sleepGeminiBackoff(attempt) + continue + } + setOpsUpstreamError(c, 0, safeErr, "") + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries: "+safeErr) + } + + if matched, rebuilt := s.checkErrorPolicyInLoop(ctx, account, resp); matched { + resp = rebuilt + break + } else { + resp = rebuilt + } + + if resp.StatusCode >= 400 && s.shouldRetryGeminiUpstreamError(account, resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if resp.StatusCode == http.StatusForbidden && isGeminiInsufficientScope(resp.Header, respBody) { + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + if resp.StatusCode == http.StatusTooManyRequests { + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + if attempt < geminiMaxRetries { + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "retry", + Message: upstreamMsg, + }) + logger.LegacyPrintf("service.gemini_chat_completions", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + sleepGeminiBackoff(attempt) + continue + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break + } + + break + } + defer func() { _ = resp.Body.Close() }() + + requestID := resp.Header.Get(requestIDHeader) + if requestID == "" { + requestID = resp.Header.Get("x-goog-request-id") + } + if requestID != "" { + c.Header("x-request-id", requestID) + } + + reasoningEffort := extractCCReasoningEffortFromBody(originalChatBody) + + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(account.Type == AccountTypeOAuth, respBody) + + if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(evBody))) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} + } + + return nil, s.writeGeminiChatCompletionsMappedError(c, account, resp.StatusCode, requestID, evBody) + } + + var usage *ClaudeUsage + var firstTokenMs *int + if clientStream { + streamRes, err := s.handleChatCompletionsStreamingResponseFromGemini(c, resp, startTime, originalModel, account.Type == AccountTypeOAuth, includeUsage) + if err != nil { + return nil, err + } + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + } else if useUpstreamStream { + collected, usageObj, err := collectGeminiSSE(resp.Body, account.Type == AccountTypeOAuth) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") + } + collectedBytes, _ := json.Marshal(collected) + chatResp, usageObj2, err := geminiResponseToChatCompletions(collected, originalModel, collectedBytes, usageObj) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + c.JSON(http.StatusOK, chatResp) + usage = usageObj2 + } else { + usageResp, err := s.handleChatCompletionsNonStreamingResponseFromGemini(c, resp, originalModel, account.Type == AccountTypeOAuth) + if err != nil { + return nil, err + } + usage = usageResp + } + + if usage == nil { + usage = &ClaudeUsage{} + } + + imageCount := 0 + imageInputSize := s.extractImageInputSize(claudeBody) + imageSize := normalizeOpenAIImageSizeTier(imageInputSize) + if isImageGenerationModel(originalModel) { + imageCount = 1 + } + + return &ForwardResult{ + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ReasoningEffort: reasoningEffort, + ImageCount: imageCount, + ImageSize: imageSize, + ImageInputSize: imageInputSize, + ClientDisconnect: false, + }, nil +} + +func (s *GeminiMessagesCompatService) buildGeminiChatCompletionsUpstreamRequestFunc( + account *Account, + mappedModel string, + geminiReq []byte, + clientStream bool, + useUpstreamStream bool, +) (func(context.Context) (*http.Request, string, error), string) { + switch account.Type { + case AccountTypeAPIKey: + return func(ctx context.Context) (*http.Request, string, error) { + apiKey := account.GetCredential("api_key") + if strings.TrimSpace(apiKey) == "" { + return nil, "", errors.New("gemini api_key not configured") + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if clientStream { + action = "streamGenerateContent" + } + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if clientStream { + fullURL += "?alt=sse" + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("x-goog-api-key", apiKey) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + case AccountTypeOAuth: + return func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + projectID := strings.TrimSpace(account.GetCredential("project_id")) + action := "generateContent" + if useUpstreamStream { + action = "streamGenerateContent" + } + + if projectID != "" { + baseURL, err := s.validateUpstreamBaseURL(geminicli.GeminiCliBaseURL) + if err != nil { + return nil, "", err + } + fullURL := fmt.Sprintf("%s/v1internal:%s", strings.TrimRight(baseURL, "/"), action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + var inner any + if err := json.Unmarshal(geminiReq, &inner); err != nil { + return nil, "", fmt.Errorf("failed to parse gemini request: %w", err) + } + wrappedBytes, _ := json.Marshal(map[string]any{ + "model": mappedModel, + "project": projectID, + "request": inner, + }) + + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(wrappedBytes)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + upstreamReq.Header.Set("User-Agent", geminicli.GeminiCLIUserAgent) + return upstreamReq, "x-request-id", nil + } + + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, "", err + } + + fullURL := fmt.Sprintf("%s/v1beta/models/%s:%s", strings.TrimRight(normalizedBaseURL, "/"), mappedModel, action) + if useUpstreamStream { + fullURL += "?alt=sse" + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + case AccountTypeServiceAccount: + return func(ctx context.Context) (*http.Request, string, error) { + if s.tokenProvider == nil { + return nil, "", errors.New("gemini token provider not configured") + } + accessToken, err := s.tokenProvider.GetAccessToken(ctx, account) + if err != nil { + return nil, "", err + } + + action := "generateContent" + if clientStream { + action = "streamGenerateContent" + } + fullURL, err := buildVertexGeminiURL(account.VertexProjectID(), account.VertexLocation(mappedModel), mappedModel, action, clientStream) + if err != nil { + return nil, "", err + } + + restGeminiReq := normalizeGeminiRequestForAIStudio(geminiReq) + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewReader(restGeminiReq)) + if err != nil { + return nil, "", err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+accessToken) + return upstreamReq, "x-request-id", nil + }, "x-request-id" + + default: + return func(context.Context) (*http.Request, string, error) { + return nil, "", fmt.Errorf("unsupported account type: %s", account.Type) + }, "x-request-id" + } +} + +func (s *GeminiMessagesCompatService) handleChatCompletionsNonStreamingResponseFromGemini( + c *gin.Context, + resp *http.Response, + originalModel string, + isOAuth bool, +) (*ClaudeUsage, error) { + respBody, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return nil, err + } + if isOAuth { + if unwrappedBody, uwErr := unwrapGeminiResponse(respBody); uwErr == nil { + respBody = unwrappedBody + } + } + + var geminiResp map[string]any + if err := json.Unmarshal(respBody, &geminiResp); err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + chatResp, usage, err := geminiResponseToChatCompletions(geminiResp, originalModel, respBody, nil) + if err != nil { + return nil, s.writeChatCompletionsError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.JSON(http.StatusOK, chatResp) + return usage, nil +} + +func geminiResponseToChatCompletions( + geminiResp map[string]any, + originalModel string, + rawData []byte, + usageOverride *ClaudeUsage, +) (*apicompat.ChatCompletionsResponse, *ClaudeUsage, error) { + claudeRespMap, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, rawData) + if usageOverride != nil && (usageOverride.InputTokens > 0 || usageOverride.OutputTokens > 0 || usageOverride.CacheReadInputTokens > 0) { + usage = usageOverride + if usageMap, ok := claudeRespMap["usage"].(map[string]any); ok { + usageMap["input_tokens"] = usage.InputTokens + usageMap["output_tokens"] = usage.OutputTokens + usageMap["cache_read_input_tokens"] = usage.CacheReadInputTokens + } + } + + claudeBytes, err := json.Marshal(claudeRespMap) + if err != nil { + return nil, nil, err + } + var anthropicResp apicompat.AnthropicResponse + if err := json.Unmarshal(claudeBytes, &anthropicResp); err != nil { + return nil, nil, err + } + responsesResp := apicompat.AnthropicToResponsesResponse(&anthropicResp) + return apicompat.ResponsesToChatCompletions(responsesResp, originalModel), usage, nil +} + +func (s *GeminiMessagesCompatService) handleChatCompletionsStreamingResponseFromGemini( + c *gin.Context, + resp *http.Response, + startTime time.Time, + originalModel string, + isOAuth bool, + includeUsage bool, +) (*geminiStreamResult, error) { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + anthState := apicompat.NewAnthropicEventToResponsesState() + anthState.Model = originalModel + ccState := apicompat.NewResponsesEventToChatState() + ccState.Model = originalModel + ccState.IncludeUsage = includeUsage + + var usage ClaudeUsage + var firstTokenMs *int + firstChunk := true + + writeChatChunk := func(chunk apicompat.ChatCompletionsChunk) bool { + sse, err := apicompat.ChatChunkToSSE(chunk) + if err != nil { + return false + } + if _, err := io.WriteString(c.Writer, sse); err != nil { + return true + } + return false + } + + emitAnthropicEvent := func(evt *apicompat.AnthropicStreamEvent) bool { + responsesEvents := apicompat.AnthropicEventToResponsesEvents(evt, anthState) + for _, resEvt := range responsesEvents { + chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range chunks { + if disconnected := writeChatChunk(chunk); disconnected { + return true + } + } + } + flusher.Flush() + return false + } + + messageID := "msg_" + randomHex(12) + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "message_start", + Message: &apicompat.AnthropicResponse{ + ID: messageID, + Type: "message", + Role: "assistant", + Model: originalModel, + Content: []apicompat.AnthropicContentBlock{}, + Usage: apicompat.AnthropicUsage{}, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + finishReason := "" + sawToolUse := false + nextBlockIndex := 0 + openBlockIndex := -1 + openBlockType := "" + seenText := "" + openToolIndex := -1 + openToolName := "" + seenToolJSON := "" + + closeOpenBlock := func() bool { + if openBlockIndex < 0 { + return false + } + disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"}) + openBlockIndex = -1 + openBlockType = "" + return disconnected + } + closeOpenTool := func() bool { + if openToolIndex < 0 { + return false + } + disconnected := emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "content_block_stop"}) + openToolIndex = -1 + openToolName = "" + seenToolJSON = "" + return disconnected + } + + reader := bufio.NewReader(resp.Body) + for { + line, err := reader.ReadString('\n') + if len(line) > 0 { + trimmed := strings.TrimRight(line, "\r\n") + if strings.HasPrefix(trimmed, "data:") { + payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) + if payload != "" && payload != "[DONE]" { + rawBytes := []byte(payload) + if isOAuth { + if innerBytes, uwErr := unwrapGeminiResponse(rawBytes); uwErr == nil { + rawBytes = innerBytes + } + } + + var geminiResp map[string]any + if err := json.Unmarshal(rawBytes, &geminiResp); err == nil { + if firstChunk { + firstChunk = false + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if fr := extractGeminiFinishReason(geminiResp); fr != "" { + finishReason = fr + } + if u := extractGeminiUsage(rawBytes); u != nil { + usage = *u + } + + for _, part := range extractGeminiParts(geminiResp) { + if text, ok := part["text"].(string); ok && text != "" { + if openToolIndex >= 0 { + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + delta, newSeen := computeGeminiTextDelta(seenText, text) + seenText = newSeen + if delta == "" { + continue + } + if openBlockType != "text" { + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + idx := nextBlockIndex + nextBlockIndex++ + openBlockIndex = idx + openBlockType = "text" + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &apicompat.AnthropicContentBlock{ + Type: "text", + Text: "", + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "text_delta", + Text: delta, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + continue + } + + if fc, ok := part["functionCall"].(map[string]any); ok && fc != nil { + name, _ := fc["name"].(string) + if strings.TrimSpace(name) == "" { + name = "tool" + } + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if openToolIndex >= 0 && openToolName != name { + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + if openToolIndex < 0 { + idx := nextBlockIndex + nextBlockIndex++ + openToolIndex = idx + openToolName = name + sawToolUse = true + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_start", + Index: &idx, + ContentBlock: &apicompat.AnthropicContentBlock{ + Type: "tool_use", + ID: "toolu_" + randomHex(8), + Name: name, + Input: json.RawMessage(`{}`), + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + + argsJSONText := "{}" + switch v := fc["args"].(type) { + case nil: + case string: + if strings.TrimSpace(v) != "" { + argsJSONText = v + } + default: + if b, err := json.Marshal(v); err == nil && len(b) > 0 { + argsJSONText = string(b) + } + } + delta, newSeen := computeGeminiTextDelta(seenToolJSON, argsJSONText) + seenToolJSON = newSeen + if delta != "" { + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "content_block_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "input_json_delta", + PartialJSON: delta, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + } + } + } + } + } + } + + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("stream read error: %w", err) + } + } + + if closeOpenBlock() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if closeOpenTool() { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + stopReason := mapGeminiFinishReasonToClaudeStopReason(finishReason) + if sawToolUse { + stopReason = "tool_use" + } + anthState.InputTokens = usage.InputTokens + anthState.CacheReadInputTokens = usage.CacheReadInputTokens + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{ + Type: "message_delta", + Delta: &apicompat.AnthropicDelta{ + Type: "message_delta", + StopReason: stopReason, + }, + Usage: &apicompat.AnthropicUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + }, + }) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + if emitAnthropicEvent(&apicompat.AnthropicStreamEvent{Type: "message_stop"}) { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + + for _, resEvt := range apicompat.FinalizeAnthropicResponsesStream(anthState) { + chunks := apicompat.ResponsesEventToChatChunks(&resEvt, ccState) + for _, chunk := range chunks { + if disconnected := writeChatChunk(chunk); disconnected { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + } + for _, chunk := range apicompat.FinalizeResponsesChatStream(ccState) { + if disconnected := writeChatChunk(chunk); disconnected { + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil + } + } + + _, _ = io.WriteString(c.Writer, "data: [DONE]\n\n") + flusher.Flush() + + return &geminiStreamResult{usage: &usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *GeminiMessagesCompatService) writeGeminiChatCompletionsMappedError( + c *gin.Context, + account *Account, + upstreamStatus int, + upstreamRequestID string, + body []byte, +) error { + upstreamMsg := sanitizeUpstreamErrorMessage(strings.TrimSpace(extractUpstreamErrorMessage(body))) + setOpsUpstreamError(c, upstreamStatus, upstreamMsg, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: upstreamStatus, + UpstreamRequestID: upstreamRequestID, + Kind: "http_error", + Message: upstreamMsg, + }) + } + + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + return s.writeChatCompletionsError(c, status, errType, errMsg) + } + + statusCode := http.StatusBadGateway + errType := "upstream_error" + errMsg := "Upstream request failed" + if mapped := mapGeminiErrorBodyToClaudeError(body); mapped != nil { + if mapped.Type != "" { + errType = mapped.Type + } + if mapped.Message != "" { + errMsg = mapped.Message + } + if mapped.StatusCode > 0 { + statusCode = mapped.StatusCode + } + } + + switch upstreamStatus { + case http.StatusBadRequest: + if statusCode == http.StatusBadGateway { + statusCode = http.StatusBadRequest + } + if errType == "upstream_error" { + errType = "invalid_request_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Invalid request" + } + case http.StatusNotFound: + statusCode = http.StatusNotFound + if errType == "upstream_error" { + errType = "not_found_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Resource not found" + } + case http.StatusTooManyRequests: + statusCode = http.StatusTooManyRequests + if errType == "upstream_error" { + errType = "rate_limit_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Upstream rate limit exceeded, please retry later" + } + case 529: + statusCode = http.StatusServiceUnavailable + if errType == "upstream_error" { + errType = "overloaded_error" + } + if errMsg == "Upstream request failed" { + errMsg = "Upstream service overloaded, please retry later" + } + } + + if upstreamMsg != "" && errMsg == "Upstream request failed" { + errMsg = upstreamMsg + } + return s.writeChatCompletionsError(c, statusCode, errType, errMsg) +} + +func (s *GeminiMessagesCompatService) writeChatCompletionsError(c *gin.Context, status int, errType, message string) error { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) + return fmt.Errorf("%s", message) +} diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go index 4bd1ced77c5..84f9a706b7a 100644 --- a/backend/internal/service/gemini_error_policy_test.go +++ b/backend/internal/service/gemini_error_policy_test.go @@ -383,6 +383,37 @@ func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) { // policy tests. Embeds mockAccountRepoForGemini and adds tracking. // --------------------------------------------------------------------------- +func TestHandleGeminiUpstreamError_GoogleOneCapacityExhaustedUsesTierCooldown(t *testing.T) { + repo := &rateLimit429AccountRepoStub{} + quotaSvc := NewGeminiQuotaService(&config.Config{}, nil) + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, quotaSvc, nil) + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + rateLimitService: rlSvc, + } + + account := &Account{ + ID: 511, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "oauth_type": "google_one", + "tier_id": "google_ai_pro", + }, + } + body := []byte(`{"error":{"code":429,"details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","domain":"cloudcode-pa.googleapis.com","metadata":{"model":"gemini-3.1-pro-preview"},"reason":"MODEL_CAPACITY_EXHAUSTED"}],"message":"No capacity available for model gemini-3.1-pro-preview on the server","status":"RESOURCE_EXHAUSTED"}}`) + + before := time.Now() + svc.handleGeminiUpstreamError(context.Background(), account, http.StatusTooManyRequests, http.Header{}, body) + after := time.Now() + + require.Equal(t, 1, repo.rateLimitCalls) + require.Equal(t, int64(511), repo.lastRateLimitID) + require.WithinDuration(t, before.Add(5*time.Minute), repo.lastRateLimitReset, 2*time.Second) + require.True(t, repo.lastRateLimitReset.After(before)) + require.True(t, repo.lastRateLimitReset.Before(after.Add(5*time.Minute).Add(2*time.Second))) +} + type geminiErrorPolicyRepo struct { mockAccountRepoForGemini setErrorCalls int diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index ea0c0d7dd39..4d6fa47ba72 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -1072,21 +1072,23 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex // 图片生成计费 imageCount := 0 - imageSize := s.extractImageSize(body) + imageInputSize := s.extractImageInputSize(body) + imageSize := normalizeOpenAIImageSizeTier(imageInputSize) if isImageGenerationModel(originalModel) { imageCount = 1 } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: req.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: req.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + ImageInputSize: imageInputSize, }, nil } @@ -1600,21 +1602,23 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. // 图片生成计费 imageCount := 0 - imageSize := s.extractImageSize(body) + imageInputSize := s.extractImageInputSize(body) + imageSize := normalizeOpenAIImageSizeTier(imageInputSize) if isImageGenerationModel(originalModel) { imageCount = 1 } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - UpstreamModel: mappedModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: imageSize, + ImageInputSize: imageInputSize, }, nil } @@ -2822,14 +2826,18 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont if resetAt == nil { // 根据账号类型使用不同的默认重置时间 var ra time.Time - if isCodeAssist { - // Code Assist: fallback cooldown by tier + if isCodeAssist || oauthType == "google_one" { + // Gemini CLI / Google One: fallback cooldown by tier cooldown := geminiCooldownForTier(tierID) if s.rateLimitService != nil { cooldown = s.rateLimitService.GeminiCooldown(ctx, account) } ra = time.Now().Add(cooldown) - logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + if isCodeAssist { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + } else { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Google One OAuth, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + } } else { // API Key / AI Studio OAuth: PST 午夜 if ts := nextGeminiDailyResetUnix(); ts != nil { @@ -3430,8 +3438,7 @@ func convertClaudeGenerationConfig(req map[string]any) map[string]any { return out } -// extractImageSize 从 Gemini 请求中提取 image_size 参数 -func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string { +func (s *GeminiMessagesCompatService) extractImageInputSize(body []byte) string { var req struct { GenerationConfig *struct { ImageConfig *struct { @@ -3440,15 +3447,12 @@ func (s *GeminiMessagesCompatService) extractImageSize(body []byte) string { } `json:"generationConfig"` } if err := json.Unmarshal(body, &req); err != nil { - return "2K" + return "" } if req.GenerationConfig != nil && req.GenerationConfig.ImageConfig != nil { - size := strings.ToUpper(strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize)) - if size == "1K" || size == "2K" || size == "4K" { - return size - } + return strings.TrimSpace(req.GenerationConfig.ImageConfig.ImageSize) } - return "2K" + return "" } diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index c2adf45dedc..d0560344dc4 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -1,6 +1,7 @@ package service import ( + "bytes" "context" "encoding/json" "fmt" @@ -41,6 +42,134 @@ func (s *geminiCompatHTTPUpstreamStub) DoWithTLS(req *http.Request, proxyURL str return s.Do(req, proxyURL, accountID, accountConcurrency) } +func TestGeminiForwardAsChatCompletions_OAuthRoutesToGeminiAndReturnsChatFormat(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamBody := `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello from gemini"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":7,"candidatesTokenCount":3}}}` + "\n\n" + + "data: [DONE]\n\n" + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }, + } + svc := &GeminiMessagesCompatService{ + tokenProvider: &GeminiTokenProvider{}, + httpUpstream: httpStub, + cfg: &config.Config{}, + } + account := &Account{ + ID: 101, + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ya29.test-token", + "project_id": "project-1", + }, + Concurrency: 1, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gemini-2.5-flash","messages":[{"role":"user","content":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gemini-2.5-flash", result.Model) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/v1internal:streamGenerateContent?alt=sse") + require.Equal(t, "Bearer ya29.test-token", httpStub.lastReq.Header.Get("Authorization")) + require.Empty(t, httpStub.lastReq.Header.Get("x-api-key")) + require.Empty(t, httpStub.lastReq.Header.Get("anthropic-version")) + + var sent map[string]any + sentBody, err := io.ReadAll(httpStub.lastReq.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(sentBody, &sent)) + require.Equal(t, "gemini-2.5-flash", sent["model"]) + require.Equal(t, "project-1", sent["project"]) + require.Contains(t, fmt.Sprint(sent["request"]), "hi") + + var got map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &got)) + require.Equal(t, "chat.completion", got["object"]) + require.Equal(t, "gemini-2.5-flash", got["model"]) + choices, ok := got["choices"].([]any) + require.True(t, ok) + require.NotEmpty(t, choices) + choice, ok := choices[0].(map[string]any) + require.True(t, ok) + message, ok := choice["message"].(map[string]any) + require.True(t, ok) + require.Equal(t, "assistant", message["role"]) + require.Equal(t, "hello from gemini", message["content"]) + usage, ok := got["usage"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(7), usage["prompt_tokens"]) + require.Equal(t, float64(3), usage["completion_tokens"]) + require.Equal(t, float64(10), usage["total_tokens"]) +} + +func TestGeminiForwardAsChatCompletions_StreamsOpenAIChunksFromGeminiSSE(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstreamBody := `data: {"candidates":[{"content":{"parts":[{"text":"hel"}]}}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":1}}` + "\n\n" + + `data: {"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":2,"candidatesTokenCount":2}}` + "\n\n" + + "data: [DONE]\n\n" + httpStub := &geminiCompatHTTPUpstreamStub{ + response: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }, + } + svc := &GeminiMessagesCompatService{ + httpUpstream: httpStub, + cfg: &config.Config{}, + } + account := &Account{ + ID: 102, + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "gemini-api-key", + }, + Concurrency: 1, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"gemini-2.5-flash","stream":true,"stream_options":{"include_usage":true},"messages":[{"role":"user","content":"hi"}]}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, http.StatusOK, rec.Code) + require.True(t, result.Stream) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + + require.NotNil(t, httpStub.lastReq) + require.Contains(t, httpStub.lastReq.URL.String(), "/v1beta/models/gemini-2.5-flash:streamGenerateContent?alt=sse") + require.Equal(t, "gemini-api-key", httpStub.lastReq.Header.Get("x-goog-api-key")) + + out := rec.Body.String() + require.Contains(t, out, `"object":"chat.completion.chunk"`) + require.Contains(t, out, `"role":"assistant"`) + require.Contains(t, out, `"content":"hel"`) + require.Contains(t, out, `"content":"lo"`) + require.Contains(t, out, `"usage":{"prompt_tokens":2,"completion_tokens":2,"total_tokens":4}`) + require.Contains(t, out, "data: [DONE]") +} + // TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换 func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { tests := []struct { diff --git a/backend/internal/service/image_billing_size.go b/backend/internal/service/image_billing_size.go new file mode 100644 index 00000000000..0ca69ac4bbf --- /dev/null +++ b/backend/internal/service/image_billing_size.go @@ -0,0 +1,260 @@ +package service + +import ( + "sort" + "strconv" + "strings" +) + +const ( + ImageBillingSize1K = "1K" + ImageBillingSize2K = "2K" + ImageBillingSize4K = "4K" + + ImageSizeSourceOutput = "output" + ImageSizeSourceInput = "input" + ImageSizeSourceDefault = "default" + ImageSizeSourceLegacy = "legacy" +) + +type ImageBillingSizeResolution struct { + BillingSize string + InputSize string + OutputSize string + Source string + Breakdown map[string]int +} + +func ClassifyImageBillingTier(size string) (string, bool) { + trimmed := strings.TrimSpace(size) + normalized := strings.ToLower(trimmed) + switch normalized { + case "", "auto": + return "", false + case "1k": + return ImageBillingSize1K, true + case "2k": + return ImageBillingSize2K, true + case "4k": + return ImageBillingSize4K, true + case "2048x2048", "2048x1152": + return ImageBillingSize2K, true + case "3840x2160", "2160x3840": + return ImageBillingSize4K, true + } + + width, height, ok := parseImageBillingDimensions(trimmed) + if !ok { + return "", false + } + maxEdge := width + if height > maxEdge { + maxEdge = height + } + switch { + case maxEdge <= 1024: + return ImageBillingSize1K, true + case maxEdge <= 2048: + return ImageBillingSize2K, true + default: + return ImageBillingSize4K, true + } +} + +func NormalizeImageBillingTierOrDefault(size string) string { + if tier, ok := ClassifyImageBillingTier(size); ok { + return tier + } + return ImageBillingSize2K +} + +func ResolveImageBillingSize(inputSize string, outputSizes []string) ImageBillingSizeResolution { + inputSize = strings.TrimSpace(inputSize) + outputSizes = compactTrimmedStrings(outputSizes) + + breakdown := map[string]int{} + outputSize := firstDisplayImageOutputSize(outputSizes) + outputTier := "" + for _, output := range outputSizes { + tier, ok := ClassifyImageBillingTier(output) + if !ok { + continue + } + breakdown[tier]++ + if imageTierRank(tier) > imageTierRank(outputTier) { + outputTier = tier + } + } + if outputTier != "" { + return ImageBillingSizeResolution{ + BillingSize: outputTier, + InputSize: inputSize, + OutputSize: outputSize, + Source: ImageSizeSourceOutput, + Breakdown: normalizeImageSizeBreakdown(breakdown), + } + } + + if tier, ok := ClassifyImageBillingTier(inputSize); ok { + return ImageBillingSizeResolution{ + BillingSize: tier, + InputSize: inputSize, + OutputSize: outputSize, + Source: ImageSizeSourceInput, + } + } + + return ImageBillingSizeResolution{ + BillingSize: ImageBillingSize2K, + InputSize: inputSize, + OutputSize: outputSize, + Source: ImageSizeSourceDefault, + } +} + +func ApplyOpenAIImageBillingResolution(result *OpenAIForwardResult) { + if result == nil || result.ImageCount <= 0 { + return + } + inputSize := strings.TrimSpace(result.ImageInputSize) + if inputSize == "" && strings.TrimSpace(result.ImageSize) != ImageBillingSize2K { + inputSize = strings.TrimSpace(result.ImageSize) + } + outputSizes := result.ImageOutputSizes + if len(outputSizes) == 0 && strings.TrimSpace(result.ImageOutputSize) != "" { + outputSizes = []string{result.ImageOutputSize} + } + resolved := ResolveImageBillingSize(inputSize, outputSizes) + applyImageBillingResolution( + &result.ImageSize, + &result.ImageInputSize, + &result.ImageOutputSize, + &result.ImageSizeSource, + &result.ImageSizeBreakdown, + resolved, + ) +} + +func ApplyForwardImageBillingResolution(result *ForwardResult) { + if result == nil || result.ImageCount <= 0 { + return + } + inputSize := strings.TrimSpace(result.ImageInputSize) + if inputSize == "" && strings.TrimSpace(result.ImageSize) != ImageBillingSize2K { + inputSize = strings.TrimSpace(result.ImageSize) + } + outputSizes := result.ImageOutputSizes + if len(outputSizes) == 0 && strings.TrimSpace(result.ImageOutputSize) != "" { + outputSizes = []string{result.ImageOutputSize} + } + resolved := ResolveImageBillingSize(inputSize, outputSizes) + applyImageBillingResolution( + &result.ImageSize, + &result.ImageInputSize, + &result.ImageOutputSize, + &result.ImageSizeSource, + &result.ImageSizeBreakdown, + resolved, + ) +} + +func applyImageBillingResolution( + billingSize *string, + inputSize *string, + outputSize *string, + source *string, + breakdown *map[string]int, + resolved ImageBillingSizeResolution, +) { + *billingSize = resolved.BillingSize + *inputSize = resolved.InputSize + *outputSize = resolved.OutputSize + *source = resolved.Source + *breakdown = resolved.Breakdown +} + +func parseImageBillingDimensions(size string) (int, int, bool) { + parts := strings.Split(strings.ToLower(strings.TrimSpace(size)), "x") + if len(parts) != 2 { + return 0, 0, false + } + width, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + return 0, 0, false + } + height, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + return 0, 0, false + } + if width <= 0 || height <= 0 { + return 0, 0, false + } + return width, height, true +} + +func compactTrimmedStrings(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + trimmed := strings.TrimSpace(value) + if trimmed != "" { + out = append(out, trimmed) + } + } + return out +} + +func firstDisplayImageOutputSize(outputSizes []string) string { + for _, output := range outputSizes { + if trimmed := strings.TrimSpace(output); trimmed != "" { + return trimmed + } + } + return "" +} + +func imageTierRank(tier string) int { + switch strings.ToUpper(strings.TrimSpace(tier)) { + case ImageBillingSize1K: + return 1 + case ImageBillingSize2K: + return 2 + case ImageBillingSize4K: + return 3 + default: + return 0 + } +} + +func normalizeImageSizeBreakdown(in map[string]int) map[string]int { + if len(in) == 0 { + return nil + } + out := make(map[string]int, len(in)) + for _, tier := range []string{ImageBillingSize1K, ImageBillingSize2K, ImageBillingSize4K} { + if count := in[tier]; count > 0 { + out[tier] = count + } + } + if len(out) == 0 { + return nil + } + return out +} + +func SortedImageBillingBreakdownKeys(breakdown map[string]int) []string { + keys := make([]string, 0, len(breakdown)) + for key := range breakdown { + keys = append(keys, key) + } + sort.Slice(keys, func(i, j int) bool { + left, right := imageTierRank(keys[i]), imageTierRank(keys[j]) + if left == right { + return keys[i] < keys[j] + } + return left < right + }) + return keys +} diff --git a/backend/internal/service/image_billing_size_test.go b/backend/internal/service/image_billing_size_test.go new file mode 100644 index 00000000000..48c9ac340e7 --- /dev/null +++ b/backend/internal/service/image_billing_size_test.go @@ -0,0 +1,110 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestClassifyImageBillingTier(t *testing.T) { + tests := []struct { + name string + size string + wantTier string + wantOK bool + }{ + {name: "explicit 2k square", size: "2048x2048", wantTier: "2K", wantOK: true}, + {name: "explicit 2k landscape", size: "2048x1152", wantTier: "2K", wantOK: true}, + {name: "explicit 4k landscape", size: "3840x2160", wantTier: "4K", wantOK: true}, + {name: "explicit 4k portrait", size: "2160x3840", wantTier: "4K", wantOK: true}, + {name: "long edge 1k", size: "1024X768", wantTier: "1K", wantOK: true}, + {name: "long edge 2k", size: "1280x768", wantTier: "2K", wantOK: true}, + {name: "long edge 4k", size: "2560x1600", wantTier: "4K", wantOK: true}, + {name: "tier string 1k", size: "1k", wantTier: "1K", wantOK: true}, + {name: "empty", size: "", wantOK: false}, + {name: "auto", size: "auto", wantOK: false}, + {name: "invalid", size: "not-a-size", wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotTier, gotOK := ClassifyImageBillingTier(tt.size) + require.Equal(t, tt.wantOK, gotOK) + require.Equal(t, tt.wantTier, gotTier) + }) + } +} + +func TestResolveImageBillingSize(t *testing.T) { + tests := []struct { + name string + inputSize string + outputSizes []string + wantBilling string + wantOutput string + wantSource string + wantBreakdown map[string]int + }{ + { + name: "output wins over input", + inputSize: "1024x1024", + outputSizes: []string{"3840x2160"}, + wantBilling: "4K", + wantOutput: "3840x2160", + wantSource: ImageSizeSourceOutput, + wantBreakdown: map[string]int{"4K": 1}, + }, + { + name: "input fallback", + inputSize: "1024x1024", + wantBilling: "1K", + wantSource: ImageSizeSourceInput, + }, + { + name: "auto defaults", + inputSize: "auto", + wantBilling: "2K", + wantSource: ImageSizeSourceDefault, + }, + { + name: "empty defaults", + inputSize: "", + wantBilling: "2K", + wantSource: ImageSizeSourceDefault, + }, + { + name: "invalid defaults", + inputSize: "largest", + wantBilling: "2K", + wantSource: ImageSizeSourceDefault, + }, + { + name: "mixed output chooses highest tier", + inputSize: "1024x1024", + outputSizes: []string{"1024x1024", "3840x2160", "1280x720"}, + wantBilling: "4K", + wantOutput: "1024x1024", + wantSource: ImageSizeSourceOutput, + wantBreakdown: map[string]int{"1K": 1, "2K": 1, "4K": 1}, + }, + { + name: "unparseable output falls back to parseable input", + inputSize: "2048x1152", + outputSizes: []string{"auto"}, + wantBilling: "2K", + wantOutput: "auto", + wantSource: ImageSizeSourceInput, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ResolveImageBillingSize(tt.inputSize, tt.outputSizes) + require.Equal(t, tt.wantBilling, got.BillingSize) + require.Equal(t, tt.inputSize, got.InputSize) + require.Equal(t, tt.wantOutput, got.OutputSize) + require.Equal(t, tt.wantSource, got.Source) + require.Equal(t, tt.wantBreakdown, got.Breakdown) + }) + } +} diff --git a/backend/internal/service/image_generation_intent.go b/backend/internal/service/image_generation_intent.go index b6ef106509d..4aca1239ff9 100644 --- a/backend/internal/service/image_generation_intent.go +++ b/backend/internal/service/image_generation_intent.go @@ -170,7 +170,13 @@ func cloneRequestMapForImageIntent(body []byte) map[string]any { return out } -func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackModel string) (string, string, error) { +type OpenAIResponsesImageBillingConfig struct { + Model string + SizeTier string + InputSize string +} + +func resolveOpenAIResponsesImageBillingConfigDetailed(reqBody map[string]any, fallbackModel string) (OpenAIResponsesImageBillingConfig, error) { imageModel := "" imageSize := "" hasImageTool := false @@ -203,12 +209,24 @@ func resolveOpenAIResponsesImageBillingConfig(reqBody map[string]any, fallbackMo imageModel = strings.TrimSpace(fallbackModel) } sizeTier := normalizeOpenAIImageSizeTier(imageSize) - return imageModel, sizeTier, nil + return OpenAIResponsesImageBillingConfig{ + Model: imageModel, + SizeTier: sizeTier, + InputSize: imageSize, + }, nil } func resolveOpenAIResponsesImageBillingConfigFromBody(body []byte, fallbackModel string) (string, string, error) { + cfg, err := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body, fallbackModel) + if err != nil { + return "", "", err + } + return cfg.Model, cfg.SizeTier, nil +} + +func resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body []byte, fallbackModel string) (OpenAIResponsesImageBillingConfig, error) { reqBody := cloneRequestMapForImageIntent(body) - return resolveOpenAIResponsesImageBillingConfig(reqBody, fallbackModel) + return resolveOpenAIResponsesImageBillingConfigDetailed(reqBody, fallbackModel) } func isOpenAIImageBillingModelAlias(model string) bool { diff --git a/backend/internal/service/image_generation_intent_test.go b/backend/internal/service/image_generation_intent_test.go index 5e7bec79b7f..4621e9d9857 100644 --- a/backend/internal/service/image_generation_intent_test.go +++ b/backend/internal/service/image_generation_intent_test.go @@ -140,9 +140,10 @@ func TestResolveOpenAIResponsesImageBillingConfigDoesNotRejectUnknownSizes(t *te func TestOpenAIImageOutputCounterDeduplicatesFinalImages(t *testing.T) { counter := newOpenAIImageOutputCounter() counter.AddSSEData([]byte(`{"type":"response.image_generation_call.partial_image","partial_image_b64":"abc"}`)) - counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a"}}`)) - counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b"}]}}`)) + counter.AddSSEData([]byte(`{"type":"response.output_item.done","item":{"id":"ig_1","type":"image_generation_call","result":"final-a","size":"1024x1024"}}`)) + counter.AddSSEData([]byte(`{"type":"response.completed","response":{"output":[{"id":"ig_1","type":"image_generation_call","result":"final-a"},{"id":"ig_2","type":"image_generation_call","result":"final-b","size":"3840x2160"}]}}`)) require.Equal(t, 2, counter.Count()) + require.Equal(t, []string{"1024x1024", "3840x2160"}, counter.Sizes()) } func TestOpenAIImageOutputCounterCountsImagesAPIStreamShapes(t *testing.T) { @@ -182,3 +183,36 @@ func TestOpenAIImageOutputCounterFallsBackForInvalidMultilineSSEBody(t *testing. ) require.Equal(t, 2, counter.Count()) } + +func TestCollectOpenAIResponseImageOutputSizesFromJSONBytes(t *testing.T) { + body := []byte(`{ + "output": [ + {"id":"ig_1","type":"image_generation_call","result":"final-a","size":"3840x2160"}, + {"id":"ig_2","type":"image_generation_call","result":"final-b","size":"1024x1024"} + ] + }`) + + require.Equal(t, 2, countOpenAIResponseImageOutputsFromJSONBytes(body)) + require.Equal(t, []string{"3840x2160", "1024x1024"}, collectOpenAIResponseImageOutputSizesFromJSONBytes(body)) +} + +func TestCollectOpenAIResponseImageOutputSizesFromImagesAPIData(t *testing.T) { + body := []byte(`{ + "data": [ + {"b64_json":"final-a","size":"2048x1152"}, + {"b64_json":"final-b","size":"2048x1152"} + ] + }`) + + require.Equal(t, 2, countOpenAIResponseImageOutputsFromJSONBytes(body)) + require.Equal(t, []string{"2048x1152", "2048x1152"}, collectOpenAIResponseImageOutputSizesFromJSONBytes(body)) +} + +func TestCollectOpenAIImageOutputSizesFromSSEBody(t *testing.T) { + body := "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_1\",\"type\":\"image_generation_call\",\"result\":\"final-a\",\"size\":\"3840x2160\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"output\":[{\"id\":\"ig_1\",\"type\":\"image_generation_call\",\"result\":\"final-a\"},{\"id\":\"ig_2\",\"type\":\"image_generation_call\",\"result\":\"final-b\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: [DONE]\n\n" + + require.Equal(t, 2, countOpenAIImageOutputsFromSSEBody(body)) + require.Equal(t, []string{"3840x2160", "1024x1024"}, collectOpenAIImageOutputSizesFromSSEBody(body)) +} diff --git a/backend/internal/service/image_output_accounting.go b/backend/internal/service/image_output_accounting.go index 219c0c59609..2f2bd6ae840 100644 --- a/backend/internal/service/image_output_accounting.go +++ b/backend/internal/service/image_output_accounting.go @@ -10,12 +10,18 @@ import ( type openAIImageOutputCounter struct { seen map[string]struct{} + seenSizes map[string]string + seenOrder []string + dataSizes []string count int maxDataCount int } func newOpenAIImageOutputCounter() *openAIImageOutputCounter { - return &openAIImageOutputCounter{seen: make(map[string]struct{})} + return &openAIImageOutputCounter{ + seen: make(map[string]struct{}), + seenSizes: make(map[string]string), + } } func (c *openAIImageOutputCounter) Count() int { @@ -28,6 +34,25 @@ func (c *openAIImageOutputCounter) Count() int { return c.count } +func (c *openAIImageOutputCounter) Sizes() []string { + if c == nil { + return nil + } + sizes := make([]string, 0, len(c.seenOrder)+len(c.dataSizes)) + for _, key := range c.seenOrder { + if size := strings.TrimSpace(c.seenSizes[key]); size != "" { + sizes = append(sizes, size) + } + } + if len(sizes) == 0 && len(c.dataSizes) > 0 { + sizes = append(sizes, c.dataSizes...) + } + if len(sizes) == 0 { + return nil + } + return sizes +} + func (c *openAIImageOutputCounter) AddJSONResponse(body []byte) { if c == nil || len(body) == 0 || !gjson.ValidBytes(body) { return @@ -73,10 +98,20 @@ func (c *openAIImageOutputCounter) addDataArray(data gjson.Result) { if !data.IsArray() { return } - count := len(data.Array()) + items := data.Array() + count := len(items) if count > c.maxDataCount { c.maxDataCount = count } + sizes := make([]string, 0, len(items)) + for _, item := range items { + if size := strings.TrimSpace(item.Get("size").String()); size != "" { + sizes = append(sizes, size) + } + } + if len(sizes) > 0 { + c.dataSizes = sizes + } } func (c *openAIImageOutputCounter) addOutputArray(output gjson.Result) { @@ -120,10 +155,18 @@ func (c *openAIImageOutputCounter) addImageOutputItem(item gjson.Result) { if key == "" { return } + size := strings.TrimSpace(item.Get("size").String()) if _, exists := c.seen[key]; exists { + if size != "" && strings.TrimSpace(c.seenSizes[key]) == "" { + c.seenSizes[key] = size + } return } c.seen[key] = struct{}{} + c.seenOrder = append(c.seenOrder, key) + if size != "" { + c.seenSizes[key] = size + } c.count++ } @@ -142,8 +185,20 @@ func countOpenAIResponseImageOutputsFromJSONBytes(body []byte) int { return counter.Count() } +func collectOpenAIResponseImageOutputSizesFromJSONBytes(body []byte) []string { + counter := newOpenAIImageOutputCounter() + counter.AddJSONResponse(body) + return counter.Sizes() +} + func countOpenAIImageOutputsFromSSEBody(body string) int { counter := newOpenAIImageOutputCounter() counter.AddSSEBody(body) return counter.Count() } + +func collectOpenAIImageOutputSizesFromSSEBody(body string) []string { + counter := newOpenAIImageOutputCounter() + counter.AddSSEBody(body) + return counter.Sizes() +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a3b69dee5fd..1c2e3cb34c3 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -8,6 +8,7 @@ import ( var codexModelMap = map[string]string{ "gpt-5.5": "gpt-5.5", + "codex-auto-review": "codex-auto-review", "gpt-5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-none": "gpt-5.4", @@ -1030,7 +1031,7 @@ func filterCodexInputWithOptions(input []any, opts codexInputFilterOptions) []an return id } if strings.HasPrefix(id, "call_") { - return "fc" + strings.TrimPrefix(id, "call_") + return "fc_" + strings.TrimPrefix(id, "call_") } return "fc_" + id } diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 9c72760aa2b..4c182b8eed9 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -41,7 +41,7 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { second, ok := input[1].(map[string]any) require.True(t, ok) require.Equal(t, "o1", second["id"]) - require.Equal(t, "fc1", second["call_id"]) + require.Equal(t, "fc_1", second["call_id"]) } func TestApplyCodexOAuthTransform_MessagesBridgePromptCacheKeyIsHeaderOnly(t *testing.T) { @@ -120,11 +120,11 @@ func TestApplyCodexOAuthTransform_ToolContinuationNormalizesToolReferenceIDsOnly first, ok := input[0].(map[string]any) require.True(t, ok) - require.Equal(t, "fc1", first["id"]) + require.Equal(t, "fc_1", first["id"]) second, ok := input[1].(map[string]any) require.True(t, ok) - require.Equal(t, "fc1", second["call_id"]) + require.Equal(t, "fc_1", second["call_id"]) } func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) { @@ -144,7 +144,7 @@ func TestApplyCodexOAuthTransform_ToolSearchOutputPreservesCallID(t *testing.T) first, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "tool_search_output", first["type"]) - require.Equal(t, "fc1", first["call_id"]) + require.Equal(t, "fc_1", first["call_id"]) } func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testing.T) { @@ -164,11 +164,11 @@ func TestApplyCodexOAuthTransform_CustomAndMCPToolOutputsPreserveCallID(t *testi first, ok := input[0].(map[string]any) require.True(t, ok) - require.Equal(t, "fccustom", first["call_id"]) + require.Equal(t, "fc_custom", first["call_id"]) second, ok := input[1].(map[string]any) require.True(t, ok) - require.Equal(t, "fcmcp", second["call_id"]) + require.Equal(t, "fc_mcp", second["call_id"]) } func TestApplyCodexOAuthTransform_ImageAndWebSearchCallsDoNotGainCallID(t *testing.T) { @@ -221,7 +221,7 @@ func TestApplyCodexOAuthTransform_ConvertsToolRoleMessageToFunctionCallOutput(t item, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "function_call_output", item["type"]) - require.Equal(t, "fc1", item["call_id"]) + require.Equal(t, "fc_1", item["call_id"]) require.Equal(t, "ok", item["output"]) _, hasRole := item["role"] require.False(t, hasRole) @@ -340,7 +340,7 @@ func TestApplyCodexOAuthTransform_AddsFallbackNameForFunctionCallInput(t *testin require.True(t, ok) require.Equal(t, "function_call", item["type"]) require.Equal(t, "tool", item["name"]) - require.Equal(t, "fc1", item["call_id"]) + require.Equal(t, "fc_1", item["call_id"]) } func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) { @@ -359,7 +359,7 @@ func TestApplyCodexOAuthTransform_PreservesFunctionCallInputName(t *testing.T) { item, ok := input[0].(map[string]any) require.True(t, ok) require.Equal(t, "shell", item["name"]) - require.Equal(t, "fc1", item["call_id"]) + require.Equal(t, "fc_1", item["call_id"]) } func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) { @@ -384,7 +384,7 @@ func TestApplyCodexOAuthTransform_PreservesMCPToolCallIDAndName(t *testing.T) { require.True(t, ok) require.Equal(t, "mcp_tool_call", item["type"]) require.Equal(t, "remote_tool", item["name"]) - require.Equal(t, "fcabc", item["call_id"]) + require.Equal(t, "fc_abc", item["call_id"]) } func TestCodexInputItemRequiresNameTypesAllowCallID(t *testing.T) { @@ -839,6 +839,7 @@ func TestNormalizeCodexModel_Gpt53(t *testing.T) { "gpt-5.4": "gpt-5.4", "gpt5.5": "gpt-5.5", "openai/gpt5.5": "gpt-5.5", + "codex-auto-review": "codex-auto-review", "gpt5.4": "gpt-5.4", "gpt-5.4-high": "gpt-5.4", "gpt-5.4-chat-latest": "gpt-5.4", diff --git a/backend/internal/service/openai_compat_model_test.go b/backend/internal/service/openai_compat_model_test.go index 0ba2a63f725..f8b9d360330 100644 --- a/backend/internal/service/openai_compat_model_test.go +++ b/backend/internal/service/openai_compat_model_test.go @@ -183,6 +183,63 @@ func TestForwardAsAnthropic_NormalizesRoutingAndEffortForGpt54XHigh(t *testing.T t.Logf("response body: %s", rec.Body.String()) } +func TestForwardAsAnthropic_MappedClaudeModelAcceptsChatUsageShape(t *testing.T) { + t.Parallel() + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + body := []byte(`{"model":"claude-opus-4-7","max_tokens":16,"messages":[{"role":"user","content":"compact this"}],"stream":true}`) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_compact","model":"gpt-5.5","status":"in_progress","output":[]}}`, + "", + `data: {"type":"response.output_text.delta","delta":"ok"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_compact","object":"response","model":"gpt-5.5","status":"completed","output":[{"type":"message","id":"msg_1","role":"assistant","status":"completed","content":[{"type":"output_text","text":"ok"}]}],"usage":{"prompt_tokens":31,"completion_tokens":9,"total_tokens":40,"prompt_tokens_details":{"cached_tokens":11}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_compact_usage"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + httpUpstream: upstream, + cfg: &config.Config{Security: config.SecurityConfig{URLAllowlist: config.URLAllowlistConfig{Enabled: false}}}, + } + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1", + "model_mapping": map[string]any{ + "gpt-5.5": "gpt-5.5", + }, + }, + } + + result, err := svc.ForwardAsAnthropic(context.Background(), c, account, body, "", "gpt-5.5") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "claude-opus-4-7", result.Model) + require.Equal(t, "gpt-5.5", result.BillingModel) + require.Equal(t, "gpt-5.5", result.UpstreamModel) + require.Equal(t, 31, result.Usage.InputTokens) + require.Equal(t, 9, result.Usage.OutputTokens) + require.Equal(t, 11, result.Usage.CacheReadInputTokens) + require.Equal(t, "gpt-5.5", gjson.GetBytes(upstream.lastBody, "model").String()) +} + func TestForwardAsAnthropic_InjectsPromptCacheKeyForAPIKeyMessagesDispatch(t *testing.T) { t.Parallel() gin.SetMode(gin.TestMode) diff --git a/backend/internal/service/openai_endpoint_url.go b/backend/internal/service/openai_endpoint_url.go new file mode 100644 index 00000000000..93ae9b952f2 --- /dev/null +++ b/backend/internal/service/openai_endpoint_url.go @@ -0,0 +1,78 @@ +package service + +import ( + "net/url" + "strings" +) + +func buildOpenAIEndpointURL(base string, endpoint string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + endpoint = "/" + strings.TrimLeft(strings.TrimSpace(endpoint), "/") + relative := strings.TrimPrefix(endpoint, "/v1") + if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) { + return normalized + } + if openAIBaseURLHasVersionSuffix(normalized) { + return normalized + relative + } + return normalized + endpoint +} + +func openAIBaseURLHasVersionSuffix(raw string) bool { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return false + } + + pathValue := "" + if parsed, err := url.Parse(trimmed); err == nil && parsed.Scheme != "" && parsed.Host != "" { + pathValue = parsed.Path + } else if slash := strings.Index(trimmed, "/"); slash >= 0 { + pathValue = trimmed[slash:] + } + + pathValue = strings.TrimRight(pathValue, "/") + if pathValue == "" { + return false + } + lastSlash := strings.LastIndex(pathValue, "/") + segment := pathValue + if lastSlash >= 0 { + segment = pathValue[lastSlash+1:] + } + return isOpenAIAPIVersionSegment(segment) +} + +func isOpenAIAPIVersionSegment(segment string) bool { + s := strings.ToLower(strings.TrimSpace(segment)) + if len(s) < 2 || s[0] != 'v' || !isASCIIDigit(s[1]) { + return false + } + + i := 1 + for i < len(s) && isASCIIDigit(s[i]) { + i++ + } + if i == len(s) { + return true + } + if s[i] == '.' { + i++ + if i == len(s) || !isASCIIDigit(s[i]) { + return false + } + for i < len(s) && isASCIIDigit(s[i]) { + i++ + } + return i == len(s) + } + + suffix := s[i:] + return strings.HasPrefix(suffix, "alpha") || + strings.HasPrefix(suffix, "beta") || + strings.HasPrefix(suffix, "preview") +} + +func isASCIIDigit(b byte) bool { + return b >= '0' && b <= '9' +} diff --git a/backend/internal/service/openai_fast_policy_test.go b/backend/internal/service/openai_fast_policy_test.go index b52da6149d4..70fcaffad77 100644 --- a/backend/internal/service/openai_fast_policy_test.go +++ b/backend/internal/service/openai_fast_policy_test.go @@ -8,6 +8,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) type openAIFastPolicyRepoStub struct { @@ -62,25 +63,33 @@ func newOpenAIGatewayServiceWithSettings(t *testing.T, settings *OpenAIFastPolic } } -func TestEvaluateOpenAIFastPolicy_DefaultFiltersAllModelsPriority(t *testing.T) { +func openAIFastFilterPriorityPolicy() *OpenAIFastPolicySettings { + return &OpenAIFastPolicySettings{ + Rules: []OpenAIFastPolicyRule{{ + ServiceTier: OpenAIFastTierPriority, + Action: BetaPolicyActionFilter, + Scope: BetaPolicyScopeAll, + ModelWhitelist: []string{}, + FallbackAction: BetaPolicyActionPass, + }}, + } +} + +func TestEvaluateOpenAIFastPolicy_DefaultPassesKnownTiers(t *testing.T) { + require.Empty(t, DefaultOpenAIFastPolicySettings().Rules, "default policy must not rewrite service_tier unless admin configured rules") + svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - // 默认策略对所有模型生效(whitelist 为空),因为 codex 的 service_tier=fast - // 是用户级开关,与 model 正交。 - // gpt-5.5 + priority → filter action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierPriority) - require.Equal(t, BetaPolicyActionFilter, action) + require.Equal(t, BetaPolicyActionPass, action) - // gpt-5.5-turbo → filter action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5-turbo", OpenAIFastTierPriority) - require.Equal(t, BetaPolicyActionFilter, action) + require.Equal(t, BetaPolicyActionPass, action) - // gpt-4 + priority → filter(默认策略覆盖所有模型) action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-4", OpenAIFastTierPriority) - require.Equal(t, BetaPolicyActionFilter, action) + require.Equal(t, BetaPolicyActionPass, action) - // gpt-5.5 + flex → pass (tier doesn't match) action, _ = svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", OpenAIFastTierFlex) require.Equal(t, BetaPolicyActionPass, action) @@ -129,27 +138,24 @@ func TestEvaluateOpenAIFastPolicy_ScopeFiltersOAuth(t *testing.T) { require.Equal(t, BetaPolicyActionPass, action) } -func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) { +func TestApplyOpenAIFastPolicyToBody_DefaultPassesPriorityAndFast(t *testing.T) { svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - // gpt-5.5 fast → service_tier stripped body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`) updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) require.NoError(t, err) - require.NotContains(t, string(updated), `"service_tier"`) + require.Equal(t, string(body), string(updated)) - // Client sending "fast" (alias for priority) also filtered body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`) updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) require.NoError(t, err) - require.NotContains(t, string(updated), `"service_tier"`) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String()) - // gpt-4 priority → 默认策略对所有模型 filter,service_tier 被移除 body = []byte(`{"model":"gpt-4","service_tier":"priority"}`) updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-4", body) require.NoError(t, err) - require.NotContains(t, string(updated), `"service_tier"`) + require.Equal(t, string(body), string(updated)) // No service_tier → no-op body = []byte(`{"model":"gpt-5.5"}`) @@ -158,9 +164,23 @@ func TestApplyOpenAIFastPolicyToBody_FilterRemovesField(t *testing.T) { require.Equal(t, string(body), string(updated)) } -// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证扩展白名单后 -// 客户端显式发送的 OpenAI 官方合法 tier(auto/default/scale)能透传到上游而不被 -// 静默剥离。默认策略只针对 priority,所以这些 tier 落在 fall-through pass 分支。 +func TestApplyOpenAIFastPolicyToBody_ExplicitFilterRemovesField(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy()) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + body := []byte(`{"model":"gpt-5.5","service_tier":"priority","messages":[]}`) + updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) + + body = []byte(`{"model":"gpt-5.5","service_tier":"fast"}`) + updated, err = svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) + require.NoError(t, err) + require.NotContains(t, string(updated), `"service_tier"`) +} + +// TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule 验证默认配置 +// 下客户端显式发送的 OpenAI 官方合法 tier 能透传到上游而不被静默剥离。 func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T) { svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} @@ -170,10 +190,10 @@ func TestApplyOpenAIFastPolicyToBody_OfficialTiersBypassDefaultRule(t *testing.T updated, err := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", body) require.NoError(t, err, "tier %q should pass without error", tier) require.Contains(t, string(updated), `"service_tier":"`+tier+`"`, - "tier %q should be preserved in body under default rule", tier) + "tier %q should be preserved in body under default policy", tier) } - // evaluate 层也应判定为 pass(默认规则 ServiceTier=priority 与 auto/default/scale 不匹配) + // evaluate 层也应判定为 pass(默认配置没有内置规则)。 for _, tier := range []string{"auto", "default", "scale"} { action, _ := svc.evaluateOpenAIFastPolicy(context.Background(), account, "gpt-5.5", tier) require.Equal(t, BetaPolicyActionPass, action, "tier %q should evaluate to pass", tier) diff --git a/backend/internal/service/openai_fast_policy_ws_test.go b/backend/internal/service/openai_fast_policy_ws_test.go index 7c8341b2a5f..4624e7a5c17 100644 --- a/backend/internal/service/openai_fast_policy_ws_test.go +++ b/backend/internal/service/openai_fast_policy_ws_test.go @@ -22,7 +22,7 @@ import ( // --- Helper-level (unit) tests for applyOpenAIFastPolicyToWSResponseCreate --- -func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) { +func TestWSResponseCreate_DefaultPassesPriorityAndNormalizesFast(t *testing.T) { svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} @@ -30,26 +30,37 @@ func TestWSResponseCreate_FilterStripsServiceTier(t *testing.T) { updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) require.NoError(t, err) require.Nil(t, blocked) - require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier") + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), "default policy should preserve priority tier") // Other fields preserved. require.Equal(t, "response.create", gjson.GetBytes(updated, "type").String()) require.Equal(t, "gpt-5.5", gjson.GetBytes(updated, "model").String()) require.Equal(t, "hi", gjson.GetBytes(updated, "input.0.text").String()) + + frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String(), "fast alias should normalize before reaching upstream") + + // Mixed-case + whitespace variant should also normalize. + frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`) + updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) + require.NoError(t, err) + require.Nil(t, blocked) + require.Equal(t, "priority", gjson.GetBytes(updated, "service_tier").String()) } -func TestWSResponseCreate_FastNormalizedToPriorityThenFiltered(t *testing.T) { - svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) +func TestWSResponseCreate_ExplicitFilterStripsServiceTier(t *testing.T) { + svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - // Verbatim "fast" → normalized to "priority" → matches default rule → filter. - frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`) + frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority","input":[{"type":"input_text","text":"hi"}]}`) updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) require.NoError(t, err) require.Nil(t, blocked) - require.NotContains(t, string(updated), `"service_tier"`) + require.NotContains(t, string(updated), `"service_tier"`, "filter action should strip service_tier") - // Mixed-case + whitespace variant should also normalize and filter. - frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":" Fast "}`) + frame = []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"fast"}`) updated, blocked, err = svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) require.NoError(t, err) require.Nil(t, blocked) @@ -60,7 +71,7 @@ func TestWSResponseCreate_FlexPassThrough(t *testing.T) { svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - // Default policy targets priority only; flex is left untouched. + // Default policy has no rules; flex is left untouched. frame := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"flex"}`) updated, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", frame) require.NoError(t, err) @@ -220,8 +231,8 @@ func (f *fakePassthroughFrameConn) Close() error { } // gpt55WhitelistFastPolicy 返回一份强制带 model whitelist 的策略,用于 -// 验证 capturedSessionModel fallback 的语义(默认策略 whitelist 为空时 -// fallback 路径无法被观察到)。 +// 验证 capturedSessionModel fallback 的语义(默认配置没有规则,fallback +// 路径无法被观察到)。 func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings { return &OpenAIFastPolicySettings{ Rules: []OpenAIFastPolicyRule{{ @@ -242,7 +253,7 @@ func gpt55WhitelistFastPolicy() *OpenAIFastPolicySettings { // through to the upstream. func TestPolicyEnforcingFrameConn_FollowupFrameWithoutModelUsesCapturedModel(t *testing.T) { // 此处特意使用带 whitelist 的策略,以便观察 capturedSessionModel - // fallback 是否生效(默认策略 whitelist 为空,fallback 与否结果一致, + // fallback 是否生效(默认配置没有规则,fallback 与否结果一致, // 不能用来覆盖此回归)。 svc := newOpenAIGatewayServiceWithSettings(t, gpt55WhitelistFastPolicy()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} @@ -310,13 +321,13 @@ func TestPolicyEnforcingFrameConn_WithoutCapturedFallbackPolicyMisses(t *testing "sanity: without capturedSessionModel fallback the leak (D5) reproduces — confirms the fix is load-bearing") } -// --- Ingress end-to-end test (filter path) --- +// --- Ingress end-to-end test (explicit filter path) --- // TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream wires up the // real ProxyResponsesWebSocketFromClient ingress session pipeline against a // captureConn upstream and asserts that a client frame with service_tier=fast -// is normalized + filtered out before being written upstream. This is the -// integration flavour of TestWSResponseCreate_FilterStripsServiceTier. +// is normalized + filtered out by an explicit admin policy before being +// written upstream. func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) { gin.SetMode(gin.TestMode) @@ -345,9 +356,9 @@ func TestWSResponseCreate_IngressFiltersServiceTierBeforeUpstream(t *testing.T) pool.setClientDialerForTest(captureDialer) repo := &openAIFastPolicyRepoStub{values: map[string]string{}} - defaultJSON, err := json.Marshal(DefaultOpenAIFastPolicySettings()) + filterPolicyJSON, err := json.Marshal(openAIFastFilterPriorityPolicy()) require.NoError(t, err) - repo.values[SettingKeyOpenAIFastPolicySettings] = string(defaultJSON) + repo.values[SettingKeyOpenAIFastPolicySettings] = string(filterPolicyJSON) svc := &OpenAIGatewayService{ cfg: cfg, @@ -631,13 +642,13 @@ func TestApplyOpenAIFastPolicyToBody_BlockShortCircuitsUpstream(t *testing.T) { require.Equal(t, string(body), string(updated), "block must not mutate body") } -// TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy verifies -// the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → BetaFastMode -// detection → ServiceTier="priority" injection (openai_gateway_messages.go:60) -// → applyOpenAIFastPolicyToBody filter on default policy → upstream body has -// no service_tier. We exercise the same internal pipeline (Anthropic→Responses -// + BetaFastMode + policy) without spinning up a real upstream HTTP server. -func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *testing.T) { +// TestForwardAsAnthropicMessages_BetaFastModePassesOpenAIFastPolicyByDefault +// verifies the Anthropic-compat entrypoint chain: anthropic-beta: fast-mode → +// BetaFastMode detection → ServiceTier="priority" injection +// (openai_gateway_messages.go:60) → default OpenAI fast policy pass. We +// exercise the same internal pipeline (Anthropic→Responses + BetaFastMode + +// policy) without spinning up a real upstream HTTP server. +func TestForwardAsAnthropicMessages_BetaFastModePassesOpenAIFastPolicyByDefault(t *testing.T) { svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} @@ -663,8 +674,9 @@ func TestForwardAsAnthropicMessages_BetaFastModeTriggersOpenAIFastPolicy(t *test upstreamBody, policyErr := svc.applyOpenAIFastPolicyToBody(context.Background(), account, "gpt-5.5", responsesBody) require.NoError(t, policyErr) - // Step 4: assert that policy filtered the field before the upstream HTTP request. - require.NotContains(t, string(upstreamBody), `"service_tier"`, "default policy 命中 gpt-5.5 priority 应当 filter 掉 service_tier") + // Step 4: default policy must preserve the explicit fast/priority request. + require.Equal(t, "priority", gjson.GetBytes(upstreamBody, "service_tier").String(), + "default policy should pass service_tier=priority through to upstream") } // --- Fix1: passthrough capturedSessionModel must follow session.update --- @@ -808,7 +820,7 @@ func TestApplyOpenAIFastPolicyToBody_PassNormalizesFastAlias(t *testing.T) { // tier) instead of the user-requested "priority". This test pins the // contract those two helpers must uphold for the adapter's billing path. func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) { - svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} raw := []byte(`{"type":"response.create","model":"gpt-5.5","service_tier":"priority"}`) @@ -821,7 +833,7 @@ func TestPassthroughBilling_PostFilterServiceTier(t *testing.T) { require.Equal(t, "priority", *pre, "sanity: raw first frame carries priority that pre-fix billing would have reported") - // Apply policy filter (default rule: gpt-5.5 + priority → filter). + // Apply explicit policy filter (gpt-5.5 + priority → filter). filtered, blocked, err := svc.applyOpenAIFastPolicyToWSResponseCreate(context.Background(), account, "gpt-5.5", raw) require.NoError(t, err) require.Nil(t, blocked) @@ -890,9 +902,9 @@ func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) { // atomic.Pointer[string] on every successful response.create frame. // // This test pins the four legs of the semantic contract: -// - turn 1: service_tier=priority hits the default whitelist filter, so +// - turn 1: service_tier=priority hits the explicit filter rule, so // after filter the upstream sees no tier → billing is nil. -// - turn 2: service_tier=flex passes (default rule targets priority only), +// - turn 2: service_tier=flex passes (the filter rule targets priority only), // billing should now reflect "flex". // - turn 3: response.create without any service_tier — the upstream will // treat it as default; we choose to mirror that and overwrite billing @@ -900,7 +912,7 @@ func TestApplyOpenAIFastPolicyToBody_NonStringServiceTier(t *testing.T) { // - non-response.create frame (response.cancel here) carrying a stray // service_tier-shaped field must NOT clobber the billing pointer. func TestPassthroughBilling_MultiTurnServiceTierFollowsFilteredFrames(t *testing.T) { - svc := newOpenAIGatewayServiceWithSettings(t, DefaultOpenAIFastPolicySettings()) + svc := newOpenAIGatewayServiceWithSettings(t, openAIFastFilterPriorityPolicy()) account := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey} // Mirror the production filter closure (openai_ws_v2_passthrough_adapter.go diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 5b3c0e6ff49..f8b23a287ae 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -48,10 +48,10 @@ var cursorResponsesUnsupportedFields = []string{ // 正确的,但 sub2api 接入 DeepSeek/Kimi/GLM 等第三方 OpenAI 兼容上游后假设破裂: // 这些上游普遍只支持 /v1/chat/completions,无 /v1/responses 端点。 // -// 当前路由策略(基于账号探测标记,详见 openai_compat.ShouldUseResponsesAPI): -// - APIKey 账号 + 探测确认不支持 Responses → 走 forwardAsRawChatCompletions +// 当前路由策略(基于账号覆盖模式/探测标记,详见 openai_compat.ShouldUseResponsesAPI): +// - APIKey 账号 + 强制或探测确认不支持 Responses → 走 forwardAsRawChatCompletions // 直转上游 /v1/chat/completions,不做协议转换 -// - 其他所有情况(OAuth、APIKey 探测确认支持、未探测)→ 走原有 CC→Responses +// - 其他所有情况(OAuth、APIKey 强制/探测确认支持、未探测)→ 走原有 CC→Responses // 转换路径(保留旧行为,存量未探测账号零兼容破坏) func (s *OpenAIGatewayService) ForwardAsChatCompletions( ctx context.Context, @@ -61,8 +61,8 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( promptCacheKey string, defaultMappedModel string, ) (*OpenAIForwardResult, error) { - // 入口分流:APIKey 账号 + 已探测且确认上游不支持 Responses,走 CC 直转。 - // 标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。 + // 入口分流:APIKey 账号 + 强制或已探测确认上游不支持 Responses,走 CC 直转。 + // 自动模式下标记缺失(未探测)按"现状即证据"原则继续走下方原 Responses 转换路径。 if account.Type == AccountTypeAPIKey && !openai_compat.ShouldUseResponsesAPI(account.Extra) { return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel) } @@ -247,6 +247,16 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if account.Type == AccountTypeAPIKey && + openai_compat.ResolveResponsesSupport(account.Extra) == openai_compat.ResponsesSupportUnknown && + !isResponsesEndpointSupportedByStatus(resp.StatusCode) { + logger.L().Info("openai chat_completions: /responses unsupported, falling back to raw chat completions", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", resp.StatusCode), + zap.String("upstream_message", upstreamMsg), + ) + return s.forwardAsRawChatCompletions(ctx, c, account, body, defaultMappedModel) + } if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { @@ -282,7 +292,7 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( var result *OpenAIForwardResult var handleErr error if clientStream { - result, handleErr = s.handleChatStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, includeUsage, startTime) + result, handleErr = s.handleChatStreamingResponse(resp, c, account, originalModel, billingModel, upstreamModel, includeUsage, startTime, len(body)) } else { result, handleErr = s.handleChatBufferedStreamingResponse(resp, c, originalModel, billingModel, upstreamModel, startTime) } @@ -404,22 +414,31 @@ func (s *OpenAIGatewayService) handleChatBufferedStreamingResponse( func (s *OpenAIGatewayService) handleChatStreamingResponse( resp *http.Response, c *gin.Context, + account *Account, originalModel string, billingModel string, upstreamModel string, includeUsage bool, startTime time.Time, + requestBodyLen int, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) } - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) state := apicompat.NewResponsesEventToChatState() state.Model = originalModel @@ -429,6 +448,9 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( var firstTokenMs *int firstChunk := true clientDisconnected := false + clientOutputStarted := false + pendingSSE := make([]string, 0, 4) + refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen) scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -479,6 +501,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) return false } + refusalDetector.ObservePayload([]byte(payload)) // 仅按兼容转换器支持的终止事件提取 usage,避免无意扩大事件语义。 isTerminalEvent := isOpenAICompatResponsesTerminalEvent(event.Type) @@ -489,6 +512,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( chunks := apicompat.ResponsesEventToChatChunks(&event, state) if !clientDisconnected { for _, chunk := range chunks { + refusalDetector.ObserveChatChunk(chunk) sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { logger.L().Warn("openai chat_completions stream: failed to marshal chunk", @@ -497,6 +521,27 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( ) continue } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingSSE = append(pendingSSE, sse) + continue + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected while flushing pending chunks", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + if clientDisconnected { + break + } + } if _, err := fmt.Fprint(c.Writer, sse); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected, continuing to drain upstream for billing", @@ -506,7 +551,7 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } } } - if len(chunks) > 0 && !clientDisconnected { + if len(chunks) > 0 && !clientDisconnected && clientOutputStarted { c.Writer.Flush() } return isTerminalEvent @@ -515,10 +560,32 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( finalizeStream := func() (*OpenAIForwardResult, error) { if finalChunks := apicompat.FinalizeResponsesChatStream(state); len(finalChunks) > 0 && !clientDisconnected { for _, chunk := range finalChunks { + refusalDetector.ObserveChatChunk(chunk) sse, err := apicompat.ChatChunkToSSE(chunk) if err != nil { continue } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingSSE = append(pendingSSE, sse) + continue + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during pending final flush", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + if clientDisconnected { + break + } + } if _, err := fmt.Fprint(c.Writer, sse); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected during final flush", @@ -528,14 +595,35 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( } } } + if !clientDisconnected && !clientOutputStarted { + if refusalDetector.IsSilentRefusal() { + return nil, newOpenAISilentRefusalFailoverError(c, account, requestID) + } + if len(pendingSSE) > 0 { + writeStreamHeaders() + for _, pending := range pendingSSE { + if _, err := fmt.Fprint(c.Writer, pending); err != nil { + clientDisconnected = true + logger.L().Info("openai chat_completions stream: client disconnected during final pending flush", + zap.String("request_id", requestID), + ) + break + } + } + pendingSSE = pendingSSE[:0] + clientOutputStarted = !clientDisconnected + } + } // Send [DONE] sentinel if !clientDisconnected { + writeStreamHeaders() if _, err := fmt.Fprint(c.Writer, "data: [DONE]\n\n"); err != nil { clientDisconnected = true logger.L().Info("openai chat_completions stream: client disconnected during done flush", zap.String("request_id", requestID), ) } + clientOutputStarted = !clientDisconnected } if !clientDisconnected { c.Writer.Flush() @@ -692,10 +780,14 @@ func (s *OpenAIGatewayService) handleChatStreamingResponse( if clientDisconnected { continue } + if refusalDetector.Enabled() && !clientOutputStarted { + continue + } if time.Since(lastDataAt) < keepaliveInterval { continue } // Send SSE comment as keepalive + writeStreamHeaders() if _, err := fmt.Fprint(c.Writer, ":\n\n"); err != nil { logger.L().Info("openai chat_completions stream: client disconnected during keepalive", zap.String("request_id", requestID), diff --git a/backend/internal/service/openai_gateway_chat_completions_raw.go b/backend/internal/service/openai_gateway_chat_completions_raw.go index 3be765a2827..c585290e8d1 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw.go @@ -220,7 +220,7 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( // 8. Forward response if clientStream { - return s.streamRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) + return s.streamRawChatCompletions(c, resp, account, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime, len(body)) } return s.bufferRawChatCompletions(c, resp, originalModel, billingModel, upstreamModel, reasoningEffort, serviceTier, startTime) } @@ -234,23 +234,32 @@ func (s *OpenAIGatewayService) forwardAsRawChatCompletions( func (s *OpenAIGatewayService) streamRawChatCompletions( c *gin.Context, resp *http.Response, + account *Account, originalModel string, billingModel string, upstreamModel string, reasoningEffort *string, serviceTier *string, startTime time.Time, + requestBodyLen int, ) (*OpenAIForwardResult, error) { requestID := resp.Header.Get("x-request-id") - if s.responseHeaderFilter != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + headersWritten := false + writeStreamHeaders := func() { + if headersWritten { + return + } + headersWritten = true + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + } + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.WriteHeader(http.StatusOK) } - c.Writer.Header().Set("Content-Type", "text/event-stream") - c.Writer.Header().Set("Cache-Control", "no-cache") - c.Writer.Header().Set("Connection", "keep-alive") - c.Writer.Header().Set("X-Accel-Buffering", "no") - c.Writer.WriteHeader(http.StatusOK) scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -262,9 +271,45 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( var usage OpenAIUsage var firstTokenMs *int clientDisconnected := false + clientOutputStarted := false + pendingLines := make([]string, 0, 8) + refusalDetector := newOpenAIChatSilentRefusalDetector(requestBodyLen) + + writeLine := func(line string) { + if clientDisconnected { + return + } + if !clientOutputStarted && !refusalDetector.ShouldReleaseClientOutput() { + pendingLines = append(pendingLines, line) + return + } + if !clientOutputStarted { + writeStreamHeaders() + for _, pending := range pendingLines { + if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + return + } + } + pendingLines = pendingLines[:0] + clientOutputStarted = true + } + if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", + zap.Error(werr), + zap.String("request_id", requestID), + ) + } + } for scanner.Scan() { line := scanner.Text() + refusalDetector.ObserveSSELine(line) if payload, ok := extractOpenAISSEDataLine(line); ok { trimmedPayload := strings.TrimSpace(payload) if trimmedPayload != "[DONE]" { @@ -279,22 +324,14 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( } } - if !clientDisconnected { - if _, werr := c.Writer.WriteString(line + "\n"); werr != nil { - clientDisconnected = true - logger.L().Debug("openai chat_completions raw: client disconnected, continuing to drain upstream for billing", - zap.Error(werr), - zap.String("request_id", requestID), - ) - } - } + writeLine(line) if line == "" { - if !clientDisconnected { + if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } continue } - if !clientDisconnected { + if !clientDisconnected && clientOutputStarted { c.Writer.Flush() } } @@ -306,6 +343,27 @@ func (s *OpenAIGatewayService) streamRawChatCompletions( zap.String("request_id", requestID), ) } + } else if !clientDisconnected && !clientOutputStarted { + if refusalDetector.IsSilentRefusal() { + return nil, newOpenAISilentRefusalFailoverError(c, account, requestID) + } + if len(pendingLines) > 0 { + writeStreamHeaders() + for _, pending := range pendingLines { + if _, werr := c.Writer.WriteString(pending + "\n"); werr != nil { + clientDisconnected = true + logger.L().Debug("openai chat_completions raw: client disconnected during final flush", + zap.Error(werr), + zap.String("request_id", requestID), + ) + break + } + } + if !clientDisconnected { + c.Writer.Flush() + clientOutputStarted = true + } + } } return &OpenAIForwardResult{ @@ -422,16 +480,10 @@ func (s *OpenAIGatewayService) bufferRawChatCompletions( // // - base 已是 /chat/completions:原样返回 // - base 以 /v1 结尾:追加 /chat/completions +// - base 以其他版本段结尾(如 /v4):追加 /chat/completions // - 其他情况:追加 /v1/chat/completions // // 与 buildOpenAIResponsesURL 是姐妹函数。 func buildOpenAIChatCompletionsURL(base string) string { - normalized := strings.TrimRight(strings.TrimSpace(base), "/") - if strings.HasSuffix(normalized, "/chat/completions") { - return normalized - } - if strings.HasSuffix(normalized, "/v1") { - return normalized + "/chat/completions" - } - return normalized + "/v1/chat/completions" + return buildOpenAIEndpointURL(base, "/v1/chat/completions") } diff --git a/backend/internal/service/openai_gateway_chat_completions_raw_test.go b/backend/internal/service/openai_gateway_chat_completions_raw_test.go index 1be07fd7170..64449636965 100644 --- a/backend/internal/service/openai_gateway_chat_completions_raw_test.go +++ b/backend/internal/service/openai_gateway_chat_completions_raw_test.go @@ -5,6 +5,7 @@ package service import ( "bytes" "context" + "errors" "io" "net/http" "net/http/httptest" @@ -36,6 +37,7 @@ func TestBuildOpenAIChatCompletionsURL(t *testing.T) { // 第三方上游常见形式 {"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/chat/completions"}, {"third-party with path prefix", "https://api.gptgod.online/api", "https://api.gptgod.online/api/v1/chat/completions"}, + {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/chat/completions"}, // 带空白字符 {"whitespace trimmed", " https://api.openai.com/v1 ", "https://api.openai.com/v1/chat/completions"}, } @@ -64,6 +66,7 @@ func TestBuildOpenAIResponsesURL_ProbeURL(t *testing.T) { {"bare /v1", "https://api.openai.com/v1", "https://api.openai.com/v1/responses"}, {"already /responses", "https://api.openai.com/v1/responses", "https://api.openai.com/v1/responses"}, {"third-party bare domain", "https://api.deepseek.com", "https://api.deepseek.com/v1/responses"}, + {"third-party versioned path", "https://open.bigmodel.cn/api/paas/v4", "https://open.bigmodel.cn/api/paas/v4/responses"}, {"only domain, no scheme", "api.gptgod.online", "api.gptgod.online/v1/responses"}, } @@ -118,6 +121,259 @@ func TestForwardAsRawChatCompletions_ForcesStreamUsageUpstreamAndPassesUsageDown require.Contains(t, rec.Body.String(), "data: [DONE]") } +func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentNonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"deepseek-reasoner","messages":[{"role":"user","content":"hello"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamJSON := `{"id":"chatcmpl_reasoning","object":"chat.completion","model":"deepseek-reasoner","choices":[{"index":0,"message":{"role":"assistant","reasoning_content":"think first","content":"final answer"},"finish_reason":"stop"}],"usage":{"prompt_tokens":3,"completion_tokens":5,"total_tokens":8}}` + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_deepseek_reasoning_json"}}, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, "think first", gjson.Get(rec.Body.String(), "choices.0.message.reasoning_content").String()) + require.Equal(t, "final answer", gjson.Get(rec.Body.String(), "choices.0.message.content").String()) +} + +func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"deepseek-reasoner","messages":[{"role":"user","content":"hello"}],"stream":true}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"reasoning_content":"think first"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[{"index":0,"delta":{"content":"final answer"},"finish_reason":null}]}`, + "", + `data: {"id":"chatcmpl_reasoning","object":"chat.completion.chunk","model":"deepseek-reasoner","choices":[],"usage":{"prompt_tokens":3,"completion_tokens":5,"total_tokens":8}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_deepseek_reasoning_stream"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Contains(t, rec.Body.String(), `"reasoning_content":"think first"`) + require.Contains(t, rec.Body.String(), `"content":"final answer"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestForwardAsRawChatCompletions_PreservesDeepSeekReasoningContentInRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"deepseek-v4-pro","messages":[{"role":"user","content":"weather"},{"role":"assistant","reasoning_content":"need tool","content":"","tool_calls":[{"id":"call_1","type":"function","function":{"name":"get_weather","arguments":"{}"}}]},{"role":"tool","tool_call_id":"call_1","content":"cloudy"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_deepseek_reasoning_request"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"chatcmpl_request","object":"chat.completion","model":"deepseek-v4-pro","choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}],"usage":{"prompt_tokens":4,"completion_tokens":2,"total_tokens":6}}`)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, account, body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "need tool", gjson.GetBytes(upstream.lastBody, "messages.1.reasoning_content").String()) + require.Equal(t, "get_weather", gjson.GetBytes(upstream.lastBody, "messages.1.tool_calls.0.function.name").String()) +} + +func TestForwardAsRawChatCompletions_SilentRefusalTriggersFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_silent","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_silent"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.Nil(t, result) + var failoverErr *UpstreamFailoverError + require.True(t, errors.As(err, &failoverErr)) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.True(t, IsOpenAISilentRefusalErrorBody(failoverErr.ResponseBody)) + require.False(t, c.Writer.Written(), "silent refusal must not commit a 200 response before failover") + require.Empty(t, rec.Body.String()) +} + +func TestForwardAsRawChatCompletions_SilentRefusalToolCallsExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_1","type":"function","function":{"name":"lookup","arguments":""}}]}}]}`, + "", + `data: {"id":"chatcmpl_tool","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_tool"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"tool_calls"`) + require.Contains(t, rec.Body.String(), `"finish_reason":"tool_calls"`) +} + +func TestHandleChatStreamingResponse_SilentRefusalReasoningSummaryExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) + + upstreamBody := strings.Join([]string{ + `data: {"type":"response.created","response":{"id":"resp_reasoning","model":"gpt-5.5"}}`, + "", + `data: {"type":"response.reasoning_summary_text.delta","delta":"thinking only"}`, + "", + `data: {"type":"response.completed","response":{"id":"resp_reasoning","model":"gpt-5.5","status":"completed"}}`, + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_reasoning"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + } + svc := &OpenAIGatewayService{cfg: rawChatCompletionsTestConfig()} + + result, err := svc.handleChatStreamingResponse( + resp, + c, + rawChatCompletionsTestAccount(), + "gpt-5.5", + "gpt-5.5", + "gpt-5.5", + false, + time.Now(), + openAISilentRefusalMinRequestBodyBytes, + ) + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"reasoning_content":"thinking only"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + +func TestForwardAsRawChatCompletions_SilentRefusalNormalContentExempt(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := largeRawChatCompletionsBody() + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstreamBody := strings.Join([]string{ + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"role":"assistant"}}]}`, + "", + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":"ok"}}]}`, + "", + `data: {"id":"chatcmpl_ok","object":"chat.completion.chunk","model":"gpt-5.5","choices":[{"index":0,"delta":{"content":""},"finish_reason":"stop"}]}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid_ok"}}, + Body: io.NopCloser(strings.NewReader(upstreamBody)), + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + + result, err := svc.forwardAsRawChatCompletions(context.Background(), c, rawChatCompletionsTestAccount(), body, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Contains(t, rec.Body.String(), `"content":"ok"`) + require.Contains(t, rec.Body.String(), "data: [DONE]") +} + func TestForwardAsRawChatCompletions_ClientDisconnectDrainsUsage(t *testing.T) { gin.SetMode(gin.TestMode) @@ -193,6 +449,49 @@ func TestForwardAsRawChatCompletions_UpstreamRequestIgnoresClientCancel(t *testi require.NoError(t, upstream.lastReq.Context().Err()) } +func TestForwardAsChatCompletions_UnknownResponsesSupportFallbackUsesVersionedChatURL(t *testing.T) { + gin.SetMode(gin.TestMode) + + body := []byte(`{"model":"glm-4.5-air","messages":[{"role":"user","content":"hello"}],"stream":false}`) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{responses: []*http.Response{ + { + StatusCode: http.StatusNotFound, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"not found"}}`)), + }, + { + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid_raw_fallback"}}, + Body: io.NopCloser(strings.NewReader( + `{"id":"chatcmpl_1","object":"chat.completion","model":"glm-4.5-air","choices":[{"index":0,"message":{"role":"assistant","content":"ok"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3}}`, + )), + }, + }} + + svc := &OpenAIGatewayService{ + cfg: rawChatCompletionsTestConfig(), + httpUpstream: upstream, + } + account := rawChatCompletionsTestAccount() + account.Credentials["base_url"] = "https://open.bigmodel.cn/api/paas/v4" + + result, err := svc.ForwardAsChatCompletions(context.Background(), c, account, body, "", "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.Usage.InputTokens) + require.Equal(t, 2, result.Usage.OutputTokens) + require.Len(t, upstream.requests, 2) + require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/responses", upstream.requests[0].URL.String()) + require.Equal(t, "https://open.bigmodel.cn/api/paas/v4/chat/completions", upstream.requests[1].URL.String()) + require.Equal(t, http.StatusOK, rec.Code) + require.Contains(t, rec.Body.String(), `"content":"ok"`) +} + func TestIsOpenAIChatUsageOnlyStreamChunk(t *testing.T) { t.Parallel() @@ -258,3 +557,9 @@ func rawChatCompletionsTestAccount() *Account { }, } } + +func largeRawChatCompletionsBody() []byte { + return []byte(`{"model":"gpt-5.5","messages":[{"role":"user","content":"` + + strings.Repeat("x", openAISilentRefusalMinRequestBodyBytes) + + `"}],"stream":true}`) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index cf909ec98e9..096f5b1079b 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1320,6 +1320,93 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) } +func TestOpenAIGatewayServiceRecordUsage_EmptyImageSizeDefaultsBeforeBillingAndPersistence(t *testing.T) { + imagePrice2K := 0.31 + groupID := int64(1201) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_default_size", + Model: "gpt-image-2", + ImageCount: 2, + ImageSize: "", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 11201, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice2K: &imagePrice2K, + }, + }, + User: &User{ID: 21201}, + Account: &Account{ID: 31201}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.NotNil(t, usageRepo.lastLog.ImageSize) + require.Equal(t, ImageBillingSize2K, *usageRepo.lastLog.ImageSize) + require.NotNil(t, usageRepo.lastLog.ImageSizeSource) + require.Equal(t, ImageSizeSourceDefault, *usageRepo.lastLog.ImageSizeSource) + require.Nil(t, usageRepo.lastLog.ImageInputSize) + require.Nil(t, usageRepo.lastLog.ImageOutputSize) + require.InDelta(t, 0.62, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.62, usageRepo.lastLog.ActualCost, 1e-12) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) +} + +func TestOpenAIGatewayServiceRecordUsage_OutputImageSizeWinsBeforeBillingAndPersistence(t *testing.T) { + imagePrice1K := 0.11 + imagePrice4K := 0.44 + groupID := int64(1202) + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, &openAIRecordUsageUserRepoStub{}, &openAIRecordUsageSubRepoStub{}, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_output_size", + Model: "gpt-image-2", + ImageCount: 1, + ImageInputSize: "1024x1024", + ImageOutputSizes: []string{"3840x2160"}, + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 11202, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice1K: &imagePrice1K, + ImagePrice4K: &imagePrice4K, + }, + }, + User: &User{ID: 21202}, + Account: &Account{ID: 31202}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.ImageSize) + require.Equal(t, ImageBillingSize4K, *usageRepo.lastLog.ImageSize) + require.NotNil(t, usageRepo.lastLog.ImageInputSize) + require.Equal(t, "1024x1024", *usageRepo.lastLog.ImageInputSize) + require.NotNil(t, usageRepo.lastLog.ImageOutputSize) + require.Equal(t, "3840x2160", *usageRepo.lastLog.ImageOutputSize) + require.NotNil(t, usageRepo.lastLog.ImageSizeSource) + require.Equal(t, ImageSizeSourceOutput, *usageRepo.lastLog.ImageSizeSource) + require.Equal(t, map[string]int{ImageBillingSize4K: 1}, usageRepo.lastLog.ImageSizeBreakdown) + require.InDelta(t, 0.44, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.44, usageRepo.lastLog.ActualCost, 1e-12) +} + func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) { imagePrice := 0.02 groupID := int64(12) @@ -1641,3 +1728,42 @@ func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingUsesSizeTier( require.InDelta(t, 0.80, cost.TotalCost, 1e-12) require.InDelta(t, 0.80, cost.ActualCost, 1e-12) } + +func TestGatewayServiceCalculateRecordUsageCost_ChannelImageBillingNormalizesMissingSizeTier(t *testing.T) { + groupID := int64(128) + defaultPrice := 0.10 + price2K := 0.22 + cache := newEmptyChannelCache() + cache.pricingByGroupModel[channelModelKey{groupID: groupID, model: "gemini-image"}] = &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: &defaultPrice, + Intervals: []PricingInterval{{ + TierLabel: "2K", + PerRequestPrice: &price2K, + }}, + } + cache.channelByGroupID[groupID] = &Channel{ID: groupID, Status: StatusActive} + cache.loadedAt = time.Now() + channelService := &ChannelService{} + channelService.cache.Store(cache) + + svc := &GatewayService{ + billingService: NewBillingService(&config.Config{}, nil), + resolver: NewModelPricingResolver(channelService, NewBillingService(&config.Config{}, nil)), + } + + cost := svc.calculateRecordUsageCost( + context.Background(), + &ForwardResult{Model: "gemini-image", ImageCount: 2, ImageSize: ""}, + &APIKey{GroupID: i64p(groupID), Group: &Group{ID: groupID}}, + "gemini-image", + 1.0, + 1.0, + nil, + ) + + require.NotNil(t, cost) + require.Equal(t, string(BillingModeImage), cost.BillingMode) + require.InDelta(t, 0.44, cost.TotalCost, 1e-12) + require.InDelta(t, 0.44, cost.ActualCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index a2276353343..3e09b33e3ec 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -228,14 +228,19 @@ type OpenAIForwardResult struct { ServiceTier *string // ReasoningEffort is extracted from request body (reasoning.effort) or derived from model suffix. // Stored for usage records display; nil means not provided / not applicable. - ReasoningEffort *string - Stream bool - OpenAIWSMode bool - ResponseHeaders http.Header - Duration time.Duration - FirstTokenMs *int - ImageCount int - ImageSize string + ReasoningEffort *string + Stream bool + OpenAIWSMode bool + ResponseHeaders http.Header + Duration time.Duration + FirstTokenMs *int + ImageCount int + ImageSize string + ImageInputSize string + ImageOutputSize string + ImageOutputSizes []string + ImageSizeSource string + ImageSizeBreakdown map[string]int } type OpenAIWSRetryMetricsSnapshot struct { @@ -1113,6 +1118,9 @@ func isOpenAITransientProcessingError(upstreamStatusCode int, upstreamMsg string if strings.Contains(lower, "an error occurred while processing your request") { return true } + if strings.Contains(lower, "selected model is at capacity") { + return true + } return strings.Contains(lower, "you can retry your request") && strings.Contains(lower, "help.openai.com") && strings.Contains(lower, "request id") @@ -2416,9 +2424,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } imageBillingModel := "" imageSizeTier := "" + imageInputSize := "" if IsImageGenerationIntentMap(openAIResponsesEndpoint, reqModel, reqBody) { var imageCfgErr error - imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfig(reqBody, billingModel) + imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailed(reqBody, billingModel) if imageCfgErr != nil { setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") c.JSON(http.StatusBadRequest, gin.H{ @@ -2430,6 +2439,9 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) return nil, imageCfgErr } + imageBillingModel = imageCfg.Model + imageSizeTier = imageCfg.SizeTier + imageInputSize = imageCfg.InputSize } // Re-serialize body only if modified @@ -2671,6 +2683,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco wsResult.UpstreamModel = upstreamModel if wsResult.ImageCount > 0 { wsResult.ImageSize = imageSizeTier + wsResult.ImageInputSize = imageInputSize wsResult.BillingModel = imageBillingModel } return wsResult, nil @@ -2777,6 +2790,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco var usage *OpenAIUsage var firstTokenMs *int imageCount := 0 + var imageOutputSizes []string if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, upstreamModel) if err != nil { @@ -2785,6 +2799,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs imageCount = streamResult.imageCount + imageOutputSizes = streamResult.imageOutputSizes } else { nonStreamResult, err := s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, upstreamModel) if err != nil { @@ -2792,6 +2807,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } usage = nonStreamResult.usage imageCount = nonStreamResult.imageCount + imageOutputSizes = nonStreamResult.imageOutputSizes } // Extract and save Codex usage snapshot from response headers (for OAuth accounts) @@ -2823,6 +2839,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if imageCount > 0 { forwardResult.ImageCount = imageCount forwardResult.ImageSize = imageSizeTier + forwardResult.ImageInputSize = imageInputSize + forwardResult.ImageOutputSizes = imageOutputSizes forwardResult.BillingModel = imageBillingModel } return forwardResult, nil @@ -2927,9 +2945,10 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } imageBillingModel := "" imageSizeTier := "" + imageInputSize := "" if IsImageGenerationIntent(openAIResponsesEndpoint, reqModel, body) { var imageCfgErr error - imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(body, reqModel) + imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(body, reqModel) if imageCfgErr != nil { setOpsUpstreamError(c, http.StatusBadRequest, imageCfgErr.Error(), "") c.JSON(http.StatusBadRequest, gin.H{ @@ -2941,6 +2960,9 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( }) return nil, imageCfgErr } + imageBillingModel = imageCfg.Model + imageSizeTier = imageCfg.SizeTier + imageInputSize = imageCfg.InputSize } logger.LegacyPrintf("service.openai_gateway", @@ -3026,6 +3048,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( var usage *OpenAIUsage var firstTokenMs *int imageCount := 0 + var imageOutputSizes []string if reqStream { result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime, reqModel, upstreamPassthroughModel) if err != nil { @@ -3034,6 +3057,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( usage = result.usage firstTokenMs = result.firstTokenMs imageCount = result.imageCount + imageOutputSizes = result.imageOutputSizes } else { result, err := s.handleNonStreamingResponsePassthrough(ctx, resp, c, reqModel, upstreamPassthroughModel) if err != nil { @@ -3041,6 +3065,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } usage = result.usage imageCount = result.imageCount + imageOutputSizes = result.imageOutputSizes } if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { @@ -3066,6 +3091,8 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( if imageCount > 0 { forwardResult.ImageCount = imageCount forwardResult.ImageSize = imageSizeTier + forwardResult.ImageInputSize = imageInputSize + forwardResult.ImageOutputSizes = imageOutputSizes forwardResult.BillingModel = imageBillingModel } return forwardResult, nil @@ -3361,15 +3388,17 @@ func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { } type openaiStreamingResultPassthrough struct { - usage *OpenAIUsage - firstTokenMs *int - imageCount int + usage *OpenAIUsage + firstTokenMs *int + imageCount int + imageOutputSizes []string } type openaiNonStreamingResultPassthrough struct { *OpenAIUsage - usage *OpenAIUsage - imageCount int + usage *OpenAIUsage + imageCount int + imageOutputSizes []string } func openAIStreamClientOutputStarted(c *gin.Context, localStarted bool) bool { @@ -3400,6 +3429,9 @@ func openAIStreamDataStartsClientOutput(data, eventType string) bool { } func openAIStreamFailedEventShouldFailover(payload []byte, message string) bool { + if isOpenAITransientProcessingError(http.StatusBadRequest, message, payload) { + return true + } code := strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "response.error.code").String())) if code == "" { code = strings.ToLower(strings.TrimSpace(gjson.GetBytes(payload, "error.code").String())) @@ -3539,7 +3571,12 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( needModelReplace := strings.TrimSpace(originalModel) != "" && strings.TrimSpace(mappedModel) != "" && strings.TrimSpace(originalModel) != strings.TrimSpace(mappedModel) resultWithUsage := func() *openaiStreamingResultPassthrough { - return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} + return &openaiStreamingResultPassthrough{ + usage: usage, + firstTokenMs: firstTokenMs, + imageCount: imageCounter.Count(), + imageOutputSizes: imageCounter.Sizes(), + } } for scanner.Scan() { @@ -3696,9 +3733,10 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( } c.Data(resp.StatusCode, contentType, body) return &openaiNonStreamingResultPassthrough{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + imageOutputSizes: collectOpenAIResponseImageOutputSizesFromJSONBytes(body), }, nil } @@ -3758,9 +3796,10 @@ func (s *OpenAIGatewayService) handlePassthroughSSEToJSON(resp *http.Response, c c.Data(resp.StatusCode, contentType, body) return &openaiNonStreamingResultPassthrough{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + imageOutputSizes: collectOpenAIImageOutputSizesFromSSEBody(bodyText), }, nil } @@ -4182,15 +4221,17 @@ func (s *OpenAIGatewayService) handleCompatErrorResponse( // openaiStreamingResult streaming response result type openaiStreamingResult struct { - usage *OpenAIUsage - firstTokenMs *int - imageCount int + usage *OpenAIUsage + firstTokenMs *int + imageCount int + imageOutputSizes []string } type openaiNonStreamingResult struct { *OpenAIUsage - usage *OpenAIUsage - imageCount int + usage *OpenAIUsage + imageCount int + imageOutputSizes []string } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { @@ -4303,7 +4344,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp needModelReplace := originalModel != mappedModel resultWithUsage := func() *openaiStreamingResult { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs, imageCount: imageCounter.Count()} + return &openaiStreamingResult{ + usage: usage, + firstTokenMs: firstTokenMs, + imageCount: imageCounter.Count(), + imageOutputSizes: imageCounter.Sizes(), + } } finalizeStream := func() (*openaiStreamingResult, error) { if !sawTerminalEvent { @@ -4709,28 +4755,47 @@ func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsag return } - usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) - usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) - usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) - usage.ImageOutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens_details.image_tokens").Int()) + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(data); ok { + *usage = parsedUsage + } } func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { if len(body) == 0 || !gjson.ValidBytes(body) { return OpenAIUsage{}, false } - values := gjson.GetManyBytes( - body, - "usage.input_tokens", - "usage.output_tokens", - "usage.input_tokens_details.cached_tokens", - "usage.output_tokens_details.image_tokens", - ) + if usage, ok := openAIUsageFromGJSON(gjson.GetBytes(body, "usage")); ok { + return usage, true + } + return openAIUsageFromGJSON(gjson.GetBytes(body, "response.usage")) +} + +func openAIUsageFromGJSON(value gjson.Result) (OpenAIUsage, bool) { + if !value.Exists() || !value.IsObject() { + return OpenAIUsage{}, false + } + inputTokens := value.Get("input_tokens").Int() + if inputTokens == 0 { + inputTokens = value.Get("prompt_tokens").Int() + } + outputTokens := value.Get("output_tokens").Int() + if outputTokens == 0 { + outputTokens = value.Get("completion_tokens").Int() + } + cacheReadTokens := value.Get("input_tokens_details.cached_tokens").Int() + if cacheReadTokens == 0 { + cacheReadTokens = value.Get("prompt_tokens_details.cached_tokens").Int() + } + imageOutputTokens := value.Get("output_tokens_details.image_tokens").Int() + if imageOutputTokens == 0 { + imageOutputTokens = value.Get("completion_tokens_details.image_tokens").Int() + } return OpenAIUsage{ - InputTokens: int(values[0].Int()), - OutputTokens: int(values[1].Int()), - CacheReadInputTokens: int(values[2].Int()), - ImageOutputTokens: int(values[3].Int()), + InputTokens: int(inputTokens), + OutputTokens: int(outputTokens), + CacheCreationInputTokens: int(value.Get("cache_creation_input_tokens").Int()), + CacheReadInputTokens: int(cacheReadTokens), + ImageOutputTokens: int(imageOutputTokens), }, true } @@ -4781,9 +4846,10 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r c.Data(resp.StatusCode, contentType, body) return &openaiNonStreamingResult{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIResponseImageOutputsFromJSONBytes(body), + imageOutputSizes: collectOpenAIResponseImageOutputSizesFromJSONBytes(body), }, nil } @@ -4845,9 +4911,10 @@ func (s *OpenAIGatewayService) handleSSEToJSON(resp *http.Response, c *gin.Conte c.Data(resp.StatusCode, contentType, body) return &openaiNonStreamingResult{ - OpenAIUsage: usage, - usage: usage, - imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + OpenAIUsage: usage, + usage: usage, + imageCount: countOpenAIImageOutputsFromSSEBody(bodyText), + imageOutputSizes: collectOpenAIImageOutputSizesFromSSEBody(bodyText), }, nil } @@ -5025,17 +5092,11 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro // buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 // - base 以 /v1 结尾:追加 /responses +// - base 以其他版本段结尾(如 /v4):追加 /responses // - base 已是 /responses:原样返回 // - 其他情况:追加 /v1/responses func buildOpenAIResponsesURL(base string) string { - normalized := strings.TrimRight(strings.TrimSpace(base), "/") - if strings.HasSuffix(normalized, "/responses") { - return normalized - } - if strings.HasSuffix(normalized, "/v1") { - return normalized + "/responses" - } - return normalized + "/v1/responses" + return buildOpenAIEndpointURL(base, "/v1/responses") } func trimOpenAIEncryptedReasoningItems(reqBody map[string]any) bool { @@ -5286,6 +5347,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec user := input.User account := input.Account subscription := input.Subscription + ApplyOpenAIImageBillingResolution(result) // 计算实际的新输入token(减去缓存读取的token) // 因为 input_tokens 包含了 cache_read_tokens,而缓存读取的token不应按输入价格计费 @@ -5395,6 +5457,10 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec ImageOutputTokens: result.Usage.ImageOutputTokens, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), + ImageInputSize: optionalTrimmedStringPtr(result.ImageInputSize), + ImageOutputSize: optionalTrimmedStringPtr(result.ImageOutputSize), + ImageSizeSource: optionalTrimmedStringPtr(result.ImageSizeSource), + ImageSizeBreakdown: result.ImageSizeBreakdown, } if cost != nil { usageLog.InputCost = cost.InputCost @@ -5563,6 +5629,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( result *OpenAIForwardResult, multiplier float64, ) *CostBreakdown { + sizeTier := NormalizeImageBillingTierOrDefault(result.ImageSize) if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { gid := apiKey.Group.ID @@ -5571,7 +5638,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( Model: billingModel, GroupID: &gid, RequestCount: result.ImageCount, - SizeTier: result.ImageSize, + SizeTier: sizeTier, RateMultiplier: multiplier, Resolver: s.resolver, Resolved: resolved, @@ -5590,7 +5657,7 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( Price4K: apiKey.Group.ImagePrice4K, } } - return s.billingService.CalculateImageCost(billingModel, result.ImageSize, result.ImageCount, groupConfig, multiplier) + return s.billingService.CalculateImageCost(billingModel, sizeTier, result.ImageCount, groupConfig, multiplier) } func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -6168,7 +6235,7 @@ func writeOpenAIFastPolicyBlockedResponse(c *gin.Context, err *OpenAIFastBlocked // applyOpenAIFastPolicyToBody contract but operates on a Realtime/Responses // WS payload: // -// - pass: returns frame unchanged (newBytes == frame, blocked == nil) +// - pass: keeps service_tier, normalizing aliases such as "fast" to "priority" // - filter: returns a copy with top-level service_tier removed // - block: returns (frame, *OpenAIFastBlockedError) // @@ -6232,7 +6299,14 @@ func (s *OpenAIGatewayService) applyOpenAIFastPolicyToWSResponseCreate( } return trimmed, nil, nil default: - return frame, nil, nil + if normTier == rawTier { + return frame, nil, nil + } + updated, err := sjson.SetBytes(frame, "service_tier", normTier) + if err != nil { + return frame, nil, fmt.Errorf("normalize service_tier in ws frame: %w", err) + } + return updated, nil, nil } } diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go index fe58e92f69f..951860cdefc 100644 --- a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -218,6 +218,12 @@ func TestIsOpenAITransientProcessingError(t *testing.T) { nil, )) + require.True(t, isOpenAITransientProcessingError( + http.StatusBadRequest, + "Selected model is at capacity. Please try a different model.", + []byte(`{"error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`), + )) + require.True(t, isOpenAITransientProcessingError( http.StatusBadRequest, "", @@ -332,3 +338,55 @@ func TestOpenAIGatewayService_Forward_TransientProcessingErrorTriggersFailover(t require.Contains(t, string(failoverErr.ResponseBody), "An error occurred while processing your request") require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层换号,而不是直接向客户端写响应") } + +func TestOpenAIGatewayService_Forward_ModelCapacityErrorTriggersFailoverAndSameAccountRetry(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-capacity-400"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "pool_mode": true, + }, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.4","stream":false,"input":[{"type":"text","text":"hello"}]}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadRequest, failoverErr.StatusCode) + require.True(t, failoverErr.RetryableOnSameAccount) + require.Contains(t, string(failoverErr.ResponseBody), "Selected model is at capacity") + require.False(t, c.Writer.Written(), "service 层应返回 failover 错误给上层重试/换号,而不是直接向客户端写响应") +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index d636cf27a55..013d7a08629 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1116,6 +1116,47 @@ func TestOpenAIStreamingResponseFailedBeforeOutputReturnsFailover(t *testing.T) require.Empty(t, rec.Body.String()) } +func TestOpenAIStreamingResponseFailedBeforeOutputCapacityErrorReturnsFailover(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + "event: response.created", + `data: {"type":"response.created","response":{"id":"resp_1"}}`, + "", + "event: response.in_progress", + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + "", + "event: response.failed", + `data: {"type":"response.failed","error":{"message":"Selected model is at capacity. Please try a different model.","type":"invalid_request_error"}}`, + "", + }, "\n"))), + Header: http.Header{"X-Request-Id": []string{"rid-capacity-failed"}}, + } + + _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1, Platform: PlatformOpenAI, Name: "acc"}, time.Now(), "model", "model") + require.Error(t, err) + var failoverErr *UpstreamFailoverError + require.ErrorAs(t, err, &failoverErr) + require.Equal(t, http.StatusBadGateway, failoverErr.StatusCode) + require.Contains(t, string(failoverErr.ResponseBody), "Selected model is at capacity") + require.False(t, c.Writer.Written()) + require.Empty(t, rec.Body.String()) +} + func TestOpenAIStreamingPreambleOnlyMissingTerminalReturnsFailover(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -2174,6 +2215,25 @@ func TestParseSSEUsage_SelectiveParsing(t *testing.T) { require.Equal(t, 13, usage.InputTokens) require.Equal(t, 15, usage.OutputTokens) require.Equal(t, 4, usage.CacheReadInputTokens) + + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"prompt_tokens":21,"completion_tokens":8,"prompt_tokens_details":{"cached_tokens":6}}}}`, usage) + require.Equal(t, 21, usage.InputTokens) + require.Equal(t, 8, usage.OutputTokens) + require.Equal(t, 6, usage.CacheReadInputTokens) +} + +func TestExtractOpenAIUsageFromJSONBytes_AcceptsResponseAndChatUsageShapes(t *testing.T) { + usage, ok := extractOpenAIUsageFromJSONBytes([]byte(`{"id":"resp_1","usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}`)) + require.True(t, ok) + require.Equal(t, 3, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) + + usage, ok = extractOpenAIUsageFromJSONBytes([]byte(`{"type":"response.completed","response":{"usage":{"prompt_tokens":13,"completion_tokens":7,"prompt_tokens_details":{"cached_tokens":4}}}}`)) + require.True(t, ok) + require.Equal(t, 13, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index afa94156867..783a44e9685 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -532,54 +532,7 @@ func isOpenAINativeImageOption(name string) bool { } func normalizeOpenAIImageSizeTier(size string) string { - trimmed := strings.TrimSpace(size) - normalized := strings.ToLower(trimmed) - switch normalized { - case "", "auto": - return "2K" - case "1024x1024": - return "1K" - case "1536x1024", "1024x1536", "1792x1024", "1024x1792", "2048x2048", "2048x1152", "1152x2048": - return "2K" - case "3840x2160", "2160x3840": - return "4K" - } - width, height, ok := parseOpenAIImageSizeDimensions(trimmed) - if !ok { - return "2K" - } - return classifyUnknownOpenAIImageSizeTier(width, height) -} - -const ( - openAIImage2KMaxPixels = 2560 * 1440 -) - -func parseOpenAIImageSizeDimensions(size string) (int, int, bool) { - trimmed := strings.TrimSpace(size) - parts := strings.Split(strings.ToLower(trimmed), "x") - if len(parts) != 2 { - return 0, 0, false - } - width, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err != nil { - return 0, 0, false - } - height, err := strconv.Atoi(strings.TrimSpace(parts[1])) - if err != nil { - return 0, 0, false - } - if width <= 0 || height <= 0 { - return 0, 0, false - } - return width, height, true -} - -func classifyUnknownOpenAIImageSizeTier(width int, height int) string { - if height > 0 && width > openAIImage2KMaxPixels/height { - return "4K" - } - return "2K" + return NormalizeImageBillingTierOrDefault(size) } func (s *OpenAIGatewayService) ForwardImages( @@ -639,7 +592,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( setOpsUpstreamRequestBody(c, forwardBody) } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) defer releaseUpstreamCtx() token, _, err := s.GetAccessToken(upstreamCtx, account) @@ -704,29 +657,46 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( imageCount := parsed.N var firstTokenMs *int if parsed.Stream && isEventStreamResponse(resp.Header) { - streamUsage, streamCount, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) + streamUsage, streamCount, streamSizes, ttft, err := s.handleOpenAIImagesStreamingResponse(resp, c, startTime) if err != nil { if streamCount > 0 { return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: streamUsage, - Model: requestModel, - UpstreamModel: upstreamModel, - Stream: parsed.Stream, - ResponseHeaders: resp.Header.Clone(), - Duration: time.Since(startTime), - FirstTokenMs: ttft, - ImageCount: streamCount, - ImageSize: parsed.SizeTier, + RequestID: resp.Header.Get("x-request-id"), + Usage: streamUsage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: ttft, + ImageCount: streamCount, + ImageSize: parsed.SizeTier, + ImageInputSize: parsed.Size, + ImageOutputSizes: streamSizes, }, err } return nil, err } usage = streamUsage imageCount = streamCount + imageOutputSizes := streamSizes firstTokenMs = ttft + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + ImageInputSize: parsed.Size, + ImageOutputSizes: imageOutputSizes, + }, nil } else { - nonStreamUsage, nonStreamCount, err := s.handleOpenAIImagesNonStreamingResponse(resp, c) + nonStreamUsage, nonStreamCount, nonStreamSizes, err := s.handleOpenAIImagesNonStreamingResponse(resp, c) if err != nil { return nil, err } @@ -734,19 +704,21 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( if nonStreamCount > 0 { imageCount = nonStreamCount } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: upstreamModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + ImageInputSize: parsed.Size, + ImageOutputSizes: nonStreamSizes, + }, nil } - return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: usage, - Model: requestModel, - UpstreamModel: upstreamModel, - Stream: parsed.Stream, - ResponseHeaders: resp.Header.Clone(), - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: parsed.SizeTier, - }, nil } func (s *OpenAIGatewayService) buildOpenAIImagesRequest( @@ -795,15 +767,7 @@ func (s *OpenAIGatewayService) buildOpenAIImagesRequest( } func buildOpenAIImagesURL(base string, endpoint string) string { - normalized := strings.TrimRight(strings.TrimSpace(base), "/") - relative := strings.TrimPrefix(strings.TrimSpace(endpoint), "/v1") - if strings.HasSuffix(normalized, endpoint) || strings.HasSuffix(normalized, relative) { - return normalized - } - if strings.HasSuffix(normalized, "/v1") { - return normalized + relative - } - return normalized + endpoint + return buildOpenAIEndpointURL(base, endpoint) } func rewriteOpenAIImagesModel(body []byte, contentType string, model string) ([]byte, string, error) { @@ -892,10 +856,10 @@ func cloneMultipartHeader(src textproto.MIMEHeader) textproto.MIMEHeader { return dst } -func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, error) { +func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http.Response, c *gin.Context) (OpenAIUsage, int, []string, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - return OpenAIUsage{}, 0, err + return OpenAIUsage{}, 0, nil, err } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" @@ -907,14 +871,14 @@ func (s *OpenAIGatewayService) handleOpenAIImagesNonStreamingResponse(resp *http c.Data(resp.StatusCode, contentType, body) usage, _ := extractOpenAIUsageFromJSONBytes(body) - return usage, extractOpenAIImageCountFromJSONBytes(body), nil + return usage, extractOpenAIImageCountFromJSONBytes(body), collectOpenAIResponseImageOutputSizesFromJSONBytes(body), nil } func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( resp *http.Response, c *gin.Context, startTime time.Time, -) (OpenAIUsage, int, *int, error) { +) (OpenAIUsage, int, []string, *int, error) { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { @@ -925,7 +889,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( flusher, ok := c.Writer.(http.Flusher) if !ok { - return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + return OpenAIUsage{}, 0, nil, nil, fmt.Errorf("streaming is not supported by response writer") } usage := OpenAIUsage{} @@ -1010,12 +974,12 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( } if err != nil { flushSSEEvent() - return usage, imageCounter.Count(), firstTokenMs, err + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, err } } flushSSEEvent() finalizeFallbackBody() - return usage, imageCounter.Count(), firstTokenMs, nil + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, nil } type readEvent struct { @@ -1082,11 +1046,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( if !ok { flushSSEEvent() finalizeFallbackBody() - return usage, imageCounter.Count(), firstTokenMs, nil + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, nil } if ev.err != nil { flushSSEEvent() - return usage, imageCounter.Count(), firstTokenMs, ev.err + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, ev.err } processLine(ev.line) case <-intervalCh: @@ -1095,11 +1059,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesStreamingResponse( continue } if clientDisconnected { - return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout") + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, fmt.Errorf("image stream incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images stream data interval timeout: interval=%s", streamInterval) _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) - return usage, imageCounter.Count(), firstTokenMs, fmt.Errorf("image stream data interval timeout") + return usage, imageCounter.Count(), imageCounter.Sizes(), firstTokenMs, fmt.Errorf("image stream data interval timeout") case <-keepaliveCh: if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { continue diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go index 25cd8228a83..96c2e0d892f 100644 --- a/backend/internal/service/openai_images_responses.go +++ b/backend/internal/service/openai_images_responses.go @@ -72,6 +72,22 @@ func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIRe } } +func openAIResponsesImageResultSizes(results []openAIResponsesImageResult) []string { + if len(results) == 0 { + return nil + } + sizes := make([]string, 0, len(results)) + for _, result := range results { + if size := strings.TrimSpace(result.Size); size != "" { + sizes = append(sizes, size) + } + } + if len(sizes) == 0 { + return nil + } + return sizes +} + func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) { switch gjson.GetBytes(payload, "type").String() { case "response.created", "response.in_progress", "response.completed": @@ -547,10 +563,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( c *gin.Context, responseFormat string, fallbackModel string, -) (OpenAIUsage, int, error) { +) (OpenAIUsage, int, []string, error) { body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) if err != nil { - return OpenAIUsage{}, 0, err + return OpenAIUsage{}, 0, nil, err } var usage OpenAIUsage @@ -559,10 +575,10 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( }) results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) if err != nil { - return OpenAIUsage{}, 0, err + return OpenAIUsage{}, 0, nil, err } if len(results) == 0 { - return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output") + return OpenAIUsage{}, 0, nil, fmt.Errorf("upstream did not return image output") } if strings.TrimSpace(firstMeta.Model) == "" { firstMeta.Model = strings.TrimSpace(fallbackModel) @@ -570,11 +586,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) if err != nil { - return OpenAIUsage{}, 0, err + return OpenAIUsage{}, 0, nil, err } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody) - return usage, len(results), nil + return usage, len(results), openAIResponsesImageResultSizes(results), nil } func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( @@ -584,7 +600,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( responseFormat string, streamPrefix string, fallbackModel string, -) (OpenAIUsage, int, *int, error) { +) (OpenAIUsage, int, []string, *int, error) { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") @@ -593,7 +609,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( flusher, ok := c.Writer.(http.Flusher) if !ok { - return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + return OpenAIUsage{}, 0, nil, nil, fmt.Errorf("streaming is not supported by response writer") } format := strings.ToLower(strings.TrimSpace(responseFormat)) @@ -603,6 +619,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( usage := OpenAIUsage{} imageCount := 0 + var imageOutputSizes []string var firstTokenMs *int emitted := make(map[string]struct{}) pendingResults := make([]openAIResponsesImageResult, 0, 1) @@ -713,6 +730,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) } imageCount = len(emitted) + imageOutputSizes = openAIResponsesImageResultSizes(finalResults) processDataDone = true } } @@ -753,6 +771,7 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, eventName, payload) } imageCount = len(emitted) + imageOutputSizes = openAIResponsesImageResultSizes(pendingResults) return nil } @@ -769,33 +788,33 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( line, err := reader.ReadBytes('\n') done, processErr := processLine(line) if processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } if err == io.EOF { break } if err != nil { if done, processErr := flushData(); processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } else if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(err.Error())) - return usage, imageCount, firstTokenMs, err + return usage, imageCount, imageOutputSizes, firstTokenMs, err } } if done, processErr := flushData(); processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } else if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } if err := finalizePending(); err != nil { - return usage, imageCount, firstTokenMs, err + return usage, imageCount, imageOutputSizes, firstTokenMs, err } - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } type readEvent struct { @@ -861,30 +880,30 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( case ev, ok := <-events: if !ok { if done, processErr := flushData(); processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } else if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } if err := finalizePending(); err != nil { - return usage, imageCount, firstTokenMs, err + return usage, imageCount, imageOutputSizes, firstTokenMs, err } - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } if ev.err != nil { if done, processErr := flushData(); processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } else if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(ev.err.Error())) - return usage, imageCount, firstTokenMs, ev.err + return usage, imageCount, imageOutputSizes, firstTokenMs, ev.err } done, processErr := processLine(ev.line) if processErr != nil { - return usage, imageCount, firstTokenMs, processErr + return usage, imageCount, imageOutputSizes, firstTokenMs, processErr } if done { - return usage, imageCount, firstTokenMs, nil + return usage, imageCount, imageOutputSizes, firstTokenMs, nil } case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -892,11 +911,11 @@ func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( continue } if clientDisconnected { - return usage, imageCount, firstTokenMs, fmt.Errorf("image stream incomplete after timeout") + return usage, imageCount, imageOutputSizes, firstTokenMs, fmt.Errorf("image stream incomplete after timeout") } logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Images responses stream data interval timeout: interval=%s", streamInterval) s.tryWriteOpenAIImagesStreamEvent(c, flusher, &clientDisconnected, &lastDownstreamWriteAt, "error", buildOpenAIImagesStreamErrorBody(fmt.Sprintf("upstream image stream idle for %s", streamInterval))) - return usage, imageCount, firstTokenMs, fmt.Errorf("image stream data interval timeout") + return usage, imageCount, imageOutputSizes, firstTokenMs, fmt.Errorf("image stream data interval timeout") case <-keepaliveCh: if clientDisconnected || time.Since(lastDownstreamWriteAt) < keepaliveInterval { continue @@ -948,7 +967,7 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( ) } - upstreamCtx, releaseUpstreamCtx := detachStreamUpstreamContext(ctx, parsed.Stream) + upstreamCtx, releaseUpstreamCtx := detachUpstreamContext(ctx) defer releaseUpstreamCtx() token, _, err := s.GetAccessToken(upstreamCtx, account) @@ -1019,31 +1038,34 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( defer func() { _ = resp.Body.Close() }() var ( - usage OpenAIUsage - imageCount int - firstTokenMs *int + usage OpenAIUsage + imageCount int + imageOutputSizes []string + firstTokenMs *int ) if parsed.Stream { - usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) + usage, imageCount, imageOutputSizes, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) if err != nil { if imageCount > 0 { return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: usage, - Model: requestModel, - UpstreamModel: requestModel, - Stream: parsed.Stream, - ResponseHeaders: resp.Header.Clone(), - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: parsed.SizeTier, + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + ImageInputSize: parsed.Size, + ImageOutputSizes: imageOutputSizes, }, err } return nil, err } } else { - usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) + usage, imageCount, imageOutputSizes, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) if err != nil { return nil, err } @@ -1052,15 +1074,17 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( imageCount = parsed.N } return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: usage, - Model: requestModel, - UpstreamModel: requestModel, - Stream: parsed.Stream, - ResponseHeaders: resp.Header.Clone(), - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: parsed.SizeTier, + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + ImageInputSize: parsed.Size, + ImageOutputSizes: imageOutputSizes, }, nil } diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 45fb24e975e..35789d2166e 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -149,9 +149,9 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_NormalizesOfficialAndCusto {size: "2048x1152", wantTier: "2K"}, {size: "3840x2160", wantTier: "4K"}, {size: "2160x3840", wantTier: "4K"}, - {size: "1024X768", wantTier: "2K"}, + {size: "1024X768", wantTier: "1K"}, {size: "1280x768", wantTier: "2K"}, - {size: "2560x1440", wantTier: "2K"}, + {size: "2560x1440", wantTier: "4K"}, {size: "2560x1600", wantTier: "4K"}, {size: "auto", wantTier: "2K"}, } @@ -186,7 +186,7 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_UnknownSizesDoNotBlockPass {size: "2048x1153", wantTier: "2K"}, {size: "4096x1024", wantTier: "4K"}, {size: "3840x1024", wantTier: "4K"}, - {size: "512x512", wantTier: "2K"}, + {size: "512x512", wantTier: "1K"}, {size: "invalid", wantTier: "2K"}, {size: "999999999999999999999999999x2", wantTier: "2K"}, } @@ -418,6 +418,10 @@ func TestBuildOpenAIImagesURL_HandlesVersionedBaseURL(t *testing.T) { "https://image-upstream.example/v1/images/generations", buildOpenAIImagesURL("https://image-upstream.example/v1", openAIImagesGenerationsEndpoint), ) + require.Equal(t, + "https://open.bigmodel.cn/api/paas/v4/images/generations", + buildOpenAIImagesURL("https://open.bigmodel.cn/api/paas/v4", openAIImagesGenerationsEndpoint), + ) require.Equal(t, "https://image-upstream.example/v1/images/edits", buildOpenAIImagesURL("https://image-upstream.example/v1/", openAIImagesEditsEndpoint), diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index f087ac32beb..020e887528e 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -261,6 +261,12 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { model: "gpt-5.4-high", want: "gpt-5.4", }, + { + name: "oauth preserves codex auto review model", + account: &Account{Type: AccountTypeOAuth}, + model: "codex-auto-review", + want: "codex-auto-review", + }, { name: "apikey preserves custom compatible model", account: &Account{Type: AccountTypeAPIKey}, @@ -283,3 +289,17 @@ func TestNormalizeOpenAIModelForUpstream(t *testing.T) { }) } } + +func TestUsageBillingModelCandidatesPreserveCodexAutoReviewModel(t *testing.T) { + candidates := usageBillingModelCandidates("codex-auto-review") + + expected := []string{"codex-auto-review"} + if len(candidates) != len(expected) { + t.Fatalf("usageBillingModelCandidates(codex-auto-review) = %#v, want %#v", candidates, expected) + } + for i := range expected { + if candidates[i] != expected[i] { + t.Fatalf("usageBillingModelCandidates(codex-auto-review) = %#v, want %#v", candidates, expected) + } + } +} diff --git a/backend/internal/service/openai_silent_refusal.go b/backend/internal/service/openai_silent_refusal.go new file mode 100644 index 00000000000..27b71b75716 --- /dev/null +++ b/backend/internal/service/openai_silent_refusal.go @@ -0,0 +1,293 @@ +package service + +import ( + "bytes" + "encoding/json" + "net/http" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/apicompat" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +const ( + openAISilentRefusalMinRequestBodyBytes = 64 * 1024 + openAISilentRefusalErrorCode = "openai_silent_refusal" + openAISilentRefusalUpstreamMessage = "OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage" + openAISilentRefusalClientMessage = "Upstream returned an empty completion without usage; no fallback account was available" +) + +type openAIChatSilentRefusalDetector struct { + enabled bool + sawContent bool + sawToolCall bool + sawFunctionCall bool + sawUsage bool + sawError bool + sawReasoning bool + sawFinish bool + finishReason string +} + +func newOpenAIChatSilentRefusalDetector(requestBodyLen int) *openAIChatSilentRefusalDetector { + return &openAIChatSilentRefusalDetector{ + enabled: requestBodyLen >= openAISilentRefusalMinRequestBodyBytes, + } +} + +func (d *openAIChatSilentRefusalDetector) Enabled() bool { + return d != nil && d.enabled +} + +func (d *openAIChatSilentRefusalDetector) ObserveSSELine(line string) { + if d == nil || !d.enabled { + return + } + if eventType, ok := extractOpenAISSEEventLine(line); ok { + d.observeEventType(eventType) + return + } + if payload, ok := extractOpenAISSEDataLine(line); ok { + d.ObservePayload([]byte(payload)) + } +} + +func (d *openAIChatSilentRefusalDetector) ObservePayload(payload []byte) { + if d == nil || !d.enabled { + return + } + payload = bytes.TrimSpace(payload) + if len(payload) == 0 || bytes.Equal(payload, []byte("[DONE]")) { + return + } + if !gjson.ValidBytes(payload) { + return + } + + eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + d.observeEventType(eventType) + + if gjson.GetBytes(payload, "error").Exists() { + d.sawError = true + } + if usage := gjson.GetBytes(payload, "usage"); usage.Exists() && usage.IsObject() { + d.sawUsage = true + } + if usage := gjson.GetBytes(payload, "response.usage"); usage.Exists() && usage.IsObject() { + d.sawUsage = true + } + + d.observeChatChoicesPayload(payload) + d.observeResponsesPayload(payload, eventType) +} + +func (d *openAIChatSilentRefusalDetector) ObserveChatChunk(chunk apicompat.ChatCompletionsChunk) { + if d == nil || !d.enabled { + return + } + if chunk.Usage != nil { + d.sawUsage = true + } + for _, choice := range chunk.Choices { + if choice.FinishReason != nil { + d.observeFinishReason(*choice.FinishReason) + } + delta := choice.Delta + if delta.Content != nil && *delta.Content != "" { + d.sawContent = true + } + if delta.ReasoningContent != nil { + d.sawReasoning = true + } + if len(delta.ToolCalls) > 0 { + d.sawToolCall = true + } + } +} + +func (d *openAIChatSilentRefusalDetector) ShouldReleaseClientOutput() bool { + if d == nil || !d.enabled { + return true + } + if d.sawContent || d.sawToolCall || d.sawFunctionCall || d.sawUsage || d.sawError || d.sawReasoning { + return true + } + return d.sawFinish && d.finishReason != "" && d.finishReason != "stop" +} + +func (d *openAIChatSilentRefusalDetector) IsSilentRefusal() bool { + if d == nil || !d.enabled { + return false + } + return !d.sawContent && + !d.sawToolCall && + !d.sawFunctionCall && + !d.sawUsage && + !d.sawError && + !d.sawReasoning && + d.sawFinish && + d.finishReason == "stop" +} + +func (d *openAIChatSilentRefusalDetector) observeEventType(eventType string) { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return + } + if eventType == "error" || eventType == "response.failed" { + d.sawError = true + } + if strings.Contains(eventType, "reasoning") || strings.Contains(eventType, "reasoning_summary") { + d.sawReasoning = true + } +} + +func (d *openAIChatSilentRefusalDetector) observeFinishReason(reason string) { + reason = strings.TrimSpace(reason) + if reason == "" { + return + } + d.sawFinish = true + d.finishReason = reason +} + +func (d *openAIChatSilentRefusalDetector) observeChatChoicesPayload(payload []byte) { + choices := gjson.GetBytes(payload, "choices") + if !choices.Exists() || !choices.IsArray() { + return + } + for _, choice := range choices.Array() { + if finish := choice.Get("finish_reason"); finish.Exists() { + d.observeFinishReason(finish.String()) + } + delta := choice.Get("delta") + if !delta.Exists() { + continue + } + if content := delta.Get("content"); content.Exists() && content.String() != "" { + d.sawContent = true + } + if delta.Get("tool_calls").Exists() { + d.sawToolCall = true + } + if delta.Get("function_call").Exists() { + d.sawFunctionCall = true + } + if delta.Get("reasoning").Exists() || + delta.Get("reasoning_content").Exists() || + delta.Get("reasoning_summary").Exists() { + d.sawReasoning = true + } + } +} + +func (d *openAIChatSilentRefusalDetector) observeResponsesPayload(payload []byte, eventType string) { + switch eventType { + case "response.output_text.delta": + if gjson.GetBytes(payload, "delta").String() != "" { + d.sawContent = true + } + case "response.output_item.added": + switch strings.TrimSpace(gjson.GetBytes(payload, "item.type").String()) { + case "function_call": + d.sawToolCall = true + case "reasoning": + d.sawReasoning = true + } + case "response.function_call_arguments.delta": + d.sawToolCall = true + case "response.reasoning_summary_text.delta", "response.reasoning_summary_text.done": + d.sawReasoning = true + case "response.completed", "response.done": + d.observeFinishReason("stop") + case "response.incomplete": + d.observeFinishReason("length") + case "response.failed": + d.sawError = true + } + + if output := gjson.GetBytes(payload, "response.output"); output.Exists() && output.IsArray() { + for _, item := range output.Array() { + switch strings.TrimSpace(item.Get("type").String()) { + case "function_call": + d.sawToolCall = true + case "reasoning": + d.sawReasoning = true + case "message": + d.observeResponseMessageItem(item) + } + } + } +} + +func (d *openAIChatSilentRefusalDetector) observeResponseMessageItem(item gjson.Result) { + content := item.Get("content") + if !content.Exists() || !content.IsArray() { + return + } + for _, part := range content.Array() { + if part.Get("text").String() != "" { + d.sawContent = true + return + } + } +} + +func newOpenAISilentRefusalFailoverError(c *gin.Context, account *Account, upstreamRequestID string) *UpstreamFailoverError { + accountID := int64(0) + accountName := "" + platform := PlatformOpenAI + if account != nil { + accountID = account.ID + accountName = account.Name + platform = account.Platform + } + + setOpsUpstreamError(c, http.StatusBadGateway, openAISilentRefusalUpstreamMessage, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: platform, + AccountID: accountID, + AccountName: accountName, + UpstreamStatusCode: http.StatusBadGateway, + UpstreamRequestID: upstreamRequestID, + Kind: "failover", + Message: openAISilentRefusalUpstreamMessage, + }) + + headers := http.Header{} + if strings.TrimSpace(upstreamRequestID) != "" { + headers.Set("x-request-id", strings.TrimSpace(upstreamRequestID)) + } + return &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + ResponseBody: openAISilentRefusalErrorBody(), + ResponseHeaders: headers, + } +} + +func openAISilentRefusalErrorBody() []byte { + body, err := json.Marshal(map[string]any{ + "error": map[string]any{ + "type": "upstream_error", + "code": openAISilentRefusalErrorCode, + "message": openAISilentRefusalUpstreamMessage, + }, + }) + if err != nil { + return []byte(`{"error":{"type":"upstream_error","code":"openai_silent_refusal","message":"OpenAI upstream returned an empty completion stream with finish_reason=stop and no usage"}}`) + } + return body +} + +// IsOpenAISilentRefusalErrorBody reports whether a failover body was produced +// by the OpenAI silent-refusal detector. +func IsOpenAISilentRefusalErrorBody(body []byte) bool { + return strings.TrimSpace(gjson.GetBytes(body, "error.code").String()) == openAISilentRefusalErrorCode +} + +// OpenAISilentRefusalClientMessage returns the exhausted-failover client message +// for OpenAI silent refusals. +func OpenAISilentRefusalClientMessage() string { + return openAISilentRefusalClientMessage +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index a680d45103e..5b55d200737 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -154,7 +154,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew if needsRefresh && strings.TrimSpace(account.GetOpenAIRefreshToken()) == "" { if expiresAt != nil && !time.Now().Before(*expiresAt) { - return "", errors.New("openai access_token expired and refresh_token is missing") + const reason = "openai access_token expired and refresh_token is missing" + // 永久故障:缺失 refresh_token 时账号无法自愈,必须立即从调度池剔除, + // 否则会被反复选中、每次都在 token 阶段直接返回错误,对用户呈现持续 502。 + p.disableAccountMissingRefreshToken(account, reason) + return "", errors.New(reason) } needsRefresh = false } @@ -261,6 +265,39 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } +// disableAccountMissingRefreshToken 在请求路径上发现 OpenAI OAuth 账号 +// 凭证已过期且 refresh_token 缺失时,将账号标记为 error 状态。 +// 这是一种永久性故障:仅靠后续请求或 TokenRefreshService 不会自愈 +// (NeedsRefresh 也会因 refresh_token 为空直接跳过), +// 必须主动剔除以避免账号被持续选中导致用户端反复 502。 +// 使用 background context 是因为请求 context 可能很快结束。 +func (p *OpenAITokenProvider) disableAccountMissingRefreshToken(account *Account, reason string) { + if p == nil || p.accountRepo == nil || account == nil { + return + } + bgCtx := context.Background() + if err := p.accountRepo.SetError(bgCtx, account.ID, reason); err != nil { + slog.Warn("openai_token_provider.set_error_failed", + "account_id", account.ID, + "error", err, + ) + return + } + if p.tokenCache != nil { + cacheKey := OpenAITokenCacheKey(account) + if err := p.tokenCache.DeleteAccessToken(bgCtx, cacheKey); err != nil { + slog.Warn("openai_token_provider.cache_delete_failed", + "account_id", account.ID, + "error", err, + ) + } + } + slog.Warn("openai_token_provider.account_disabled_missing_refresh_token", + "account_id", account.ID, + "reason", reason, + ) +} + func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) { wait := openAILockInitialWait totalWaitMs := int64(0) diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index 4b69db8ae45..df2f0f3e6b7 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -930,3 +930,34 @@ func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1)) require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) } + +func TestOpenAITokenProvider_NoRefreshTokenExpired_DisablesAccount(t *testing.T) { + cache := newOpenAITokenCacheStub() + repo := &rateLimitAccountRepoStub{} + + expiresAt := time.Now().Add(-time.Minute).UTC().Format(time.RFC3339) + account := &Account{ + ID: 2881, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "expired-access-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + cache.tokens[cacheKey] = "stale-cached-token" + // Force the provider past the cache hit branch. + cache.getErr = errors.New("simulated cache miss") + + provider := NewOpenAITokenProvider(repo, cache, nil) + + token, err := provider.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Empty(t, token) + require.Contains(t, err.Error(), "refresh_token is missing") + + require.Equal(t, 1, repo.setErrorCalls, "account should be disabled via SetError exactly once") + require.Contains(t, repo.lastErrorMsg, "refresh_token is missing") +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 77cf7d95a3f..920a2239ece 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -399,15 +399,9 @@ func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIU if usage == nil || len(message) == 0 { return } - values := gjson.GetManyBytes( - message, - "response.usage.input_tokens", - "response.usage.output_tokens", - "response.usage.input_tokens_details.cached_tokens", - ) - usage.InputTokens = int(values[0].Int()) - usage.OutputTokens = int(values[1].Int()) - usage.CacheReadInputTokens = int(values[2].Int()) + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(message); ok { + *usage = parsedUsage + } } func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { @@ -2351,18 +2345,19 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( ) return &OpenAIForwardResult{ - RequestID: responseID, - Usage: *usage, - Model: originalModel, - UpstreamModel: mappedModel, - ImageCount: imageCounter.Count(), - ServiceTier: extractOpenAIServiceTier(reqBody), - ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), - Stream: reqStream, - OpenAIWSMode: true, - ResponseHeaders: lease.HandshakeHeaders(), - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: responseID, + Usage: *usage, + Model: originalModel, + UpstreamModel: mappedModel, + ImageCount: imageCounter.Count(), + ImageOutputSizes: imageCounter.Sizes(), + ServiceTier: extractOpenAIServiceTier(reqBody), + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + ResponseHeaders: lease.HandshakeHeaders(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, }, nil } @@ -2464,6 +2459,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( originalModel string imageBillingModel string imageSizeTier string + imageInputSize string payloadBytes int } @@ -2567,12 +2563,16 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } imageBillingModel := "" imageSizeTier := "" + imageInputSize := "" if imageIntent { var imageCfgErr error - imageBillingModel, imageSizeTier, imageCfgErr = resolveOpenAIResponsesImageBillingConfigFromBody(normalized, originalModel) + imageCfg, imageCfgErr := resolveOpenAIResponsesImageBillingConfigDetailedFromBody(normalized, originalModel) if imageCfgErr != nil { return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, imageCfgErr.Error(), imageCfgErr) } + imageBillingModel = imageCfg.Model + imageSizeTier = imageCfg.SizeTier + imageInputSize = imageCfg.InputSize } // Apply OpenAI Fast Policy on the response.create frame using the same @@ -2621,6 +2621,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( originalModel: originalModel, imageBillingModel: imageBillingModel, imageSizeTier: imageSizeTier, + imageInputSize: imageInputSize, payloadBytes: len(normalized), }, nil } @@ -2822,7 +2823,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return payload, nil } - sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string) (*OpenAIForwardResult, error) { + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string, imageBillingModel string, imageSizeTier string, imageInputSize string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") } @@ -3046,6 +3047,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if imageCount > 0 { result.ImageCount = imageCount result.ImageSize = imageSizeTier + result.ImageInputSize = imageInputSize + result.ImageOutputSizes = imageCounter.Sizes() result.BillingModel = imageBillingModel } return result, nil @@ -3057,6 +3060,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentOriginalModel := firstPayload.originalModel currentImageBillingModel := firstPayload.imageBillingModel currentImageSizeTier := firstPayload.imageSizeTier + currentImageInputSize := firstPayload.imageInputSize currentPayloadBytes := firstPayload.payloadBytes isStrictAffinityTurn := func(payload []byte) bool { if !storeDisabled { @@ -3534,7 +3538,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } - result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier) + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, currentImageBillingModel, currentImageSizeTier, currentImageInputSize) if relayErr != nil { lastTurnClean = false if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { @@ -3658,6 +3662,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentOriginalModel = nextPayload.originalModel currentImageBillingModel = nextPayload.imageBillingModel currentImageSizeTier = nextPayload.imageSizeTier + currentImageInputSize = nextPayload.imageInputSize currentPayloadBytes = nextPayload.payloadBytes storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) if !storeDisabled { diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go index 761676038d1..0350bde9868 100644 --- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -29,6 +29,14 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { require.Equal(t, 11, usage.InputTokens) require.Equal(t, 7, usage.OutputTokens) require.Equal(t, 3, usage.CacheReadInputTokens) + + parseOpenAIWSResponseUsageFromCompletedEvent( + []byte(`{"type":"response.completed","response":{"usage":{"prompt_tokens":19,"completion_tokens":5,"prompt_tokens_details":{"cached_tokens":4}}}}`), + usage, + ) + require.Equal(t, 19, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 4, usage.CacheReadInputTokens) } func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go index af8ee195680..2b7e2add750 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -82,6 +82,7 @@ type relayState struct { terminalEventType string firstTokenMs *int turnTimingByID map[string]*relayTurnTiming + activeTurn *relayTurnTiming } type relayExitSignal struct { @@ -550,6 +551,12 @@ func observeUpstreamMessage( if ms >= 0 { state.firstTokenMs = &ms } + if state.activeTurn != nil && state.activeTurn.firstTokenMs == nil { + tms := int(now.Sub(state.activeTurn.startAt).Milliseconds()) + if tms >= 0 { + state.activeTurn.firstTokenMs = &tms + } + } } parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) observed := observedUpstreamEvent{ @@ -622,6 +629,7 @@ func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now if !ok || timing == nil || timing.startAt.IsZero() { timing = &relayTurnTiming{startAt: now} state.turnTimingByID[responseID] = timing + state.activeTurn = timing return timing } return timing @@ -636,6 +644,9 @@ func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayT return relayTurnTiming{}, false } delete(state.turnTimingByID, responseID) + if state.activeTurn == timing { + state.activeTurn = nil + } return *timing, true } diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go index ff9b73111d4..cdd41a058ec 100644 --- a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -750,3 +750,67 @@ func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageT func (c *errorOnWriteFrameConn) Close() error { return nil } + +func TestRelay_OnTurnComplete_RealOpenAIStream_FirstTokenMs(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.created","response":{"id":"resp_real"}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"He"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"llo"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_real","usage":{"input_tokens":2,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 10 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_real", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + + require.NotNil(t, turn.FirstTokenMs, "per-turn FirstTokenMs must be captured for real OpenAI streams") + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + + require.Less(t, + int64(*turn.FirstTokenMs), + turn.Duration.Milliseconds(), + "per-turn FirstTokenMs (%dms) should be strictly less than Duration (%dms); "+ + "equality indicates the bug where first_token is mistakenly stamped on the terminal event", + *turn.FirstTokenMs, turn.Duration.Milliseconds(), + ) + + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, *result.FirstTokenMs, 0) +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go index e27607259b4..0a89e2dde1a 100644 --- a/backend/internal/service/openai_ws_v2_passthrough_adapter.go +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -267,9 +267,8 @@ func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( // omits "model" — Realtime clients are allowed to send response.create // without re-stating the model, in which case the upstream uses the model // negotiated at session.update time. Without this fallback, an empty - // model would miss the default ["gpt-5.5","gpt-5.5*"] whitelist and be - // silently passed through, defeating the policy on every frame after - // the first. + // model would miss any admin-configured model whitelist and be silently + // passed through, defeating that policy on every frame after the first. capturedSessionModel := openAIWSPassthroughPolicyModelForFrame(account, firstClientMessage) initialRequestModel := "" if hooks != nil { diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 05d444e186c..5c8ac5a6818 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -36,6 +36,12 @@ const ( // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 OpsSkipPassthroughKey = "ops_skip_passthrough" + + // Client-side configuration denials should remain visible in ops_error_logs, + // but should be excluded from SLA/error-rate calculations. + OpsClientBusinessLimitedKey = "ops_client_business_limited" + OpsClientBusinessLimitedReasonKey = "ops_client_business_limited_reason" + OpsClientBusinessLimitedReasonIPRestriction = "api_key_ip_restriction" ) func setOpsUpstreamRequestBody(c *gin.Context, body []byte) { @@ -53,6 +59,28 @@ func SetOpsLatencyMs(c *gin.Context, key string, value int64) { c.Set(key, value) } +func MarkOpsClientBusinessLimited(c *gin.Context, reason string) { + if c == nil { + return + } + c.Set(OpsClientBusinessLimitedKey, true) + if reason = strings.TrimSpace(reason); reason != "" { + c.Set(OpsClientBusinessLimitedReasonKey, reason) + } +} + +func HasOpsClientBusinessLimited(c *gin.Context) bool { + if c == nil { + return false + } + v, ok := c.Get(OpsClientBusinessLimitedKey) + if !ok { + return false + } + marked, _ := v.(bool) + return marked +} + // SetOpsUpstreamError is the exported wrapper for setOpsUpstreamError, used by // handler-layer code (e.g. failover-exhausted paths) that needs to record the // original upstream status code before mapping it to a client-facing code. diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index 056967f0388..e6cc4b3ca6b 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -499,22 +499,39 @@ func selectedInstanceSupportedTypes(sel *payment.InstanceSelection) string { func (s *PaymentService) buildPaymentSubject(plan *dbent.SubscriptionPlan, limitAmount float64, cfg *PaymentConfig, sel *payment.InstanceSelection) string { if plan != nil { - if plan.ProductName != "" { - return plan.ProductName + productName := plan.ProductName + if productName == "" { + productName = "Sub2API Subscription " + plan.Name } - return "Sub2API Subscription " + plan.Name + return applyPaymentProductNameAffix(productName, cfg) } currency := payment.DefaultPaymentCurrency if sel != nil { currency = paymentProviderConfigCurrency(sel.ProviderKey, sel.Config) } amountStr := payment.FormatAmountForCurrency(limitAmount, currency) + if hasPaymentProductNameAffix(cfg) { + return applyPaymentProductNameAffix(amountStr, cfg) + } + return "Sub2API " + amountStr + " " + currency +} + +func hasPaymentProductNameAffix(cfg *PaymentConfig) bool { + if cfg == nil { + return false + } pf := strings.TrimSpace(cfg.ProductNamePrefix) sf := strings.TrimSpace(cfg.ProductNameSuffix) - if pf != "" || sf != "" { - return strings.TrimSpace(pf + " " + amountStr + " " + sf) + return pf != "" || sf != "" +} + +func applyPaymentProductNameAffix(productName string, cfg *PaymentConfig) string { + if !hasPaymentProductNameAffix(cfg) { + return productName } - return "Sub2API " + amountStr + " " + currency + pf := strings.TrimSpace(cfg.ProductNamePrefix) + sf := strings.TrimSpace(cfg.ProductNameSuffix) + return strings.TrimSpace(pf + " " + productName + " " + sf) } func (s *PaymentService) maybeBuildWeChatOAuthRequiredResponse(ctx context.Context, req CreateOrderRequest, amount, payAmount, feeRate float64) (*CreateOrderResponse, error) { diff --git a/backend/internal/service/payment_order_expiry_service.go b/backend/internal/service/payment_order_expiry_service.go index b0cda3e5914..32e51d7fb26 100644 --- a/backend/internal/service/payment_order_expiry_service.go +++ b/backend/internal/service/payment_order_expiry_service.go @@ -59,10 +59,18 @@ func (s *PaymentOrderExpiryService) Stop() { } func (s *PaymentOrderExpiryService) runOnce() { - ctx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout) - defer cancel() + reconcileCtx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout) + recovered, err := s.paymentSvc.ReconcilePendingWxpayOrders(reconcileCtx) + cancel() + if err != nil { + slog.Warn("[PaymentOrderExpiry] failed to reconcile pending wxpay orders", "error", err) + } else if recovered > 0 { + slog.Info("[PaymentOrderExpiry] reconciled paid wxpay orders", "count", recovered) + } - expired, err := s.paymentSvc.ExpireTimedOutOrders(ctx) + expireCtx, cancel := context.WithTimeout(context.Background(), expiryCheckTimeout) + defer cancel() + expired, err := s.paymentSvc.ExpireTimedOutOrders(expireCtx) if err != nil { slog.Error("[PaymentOrderExpiry] failed to expire orders", "error", err) return diff --git a/backend/internal/service/payment_order_lifecycle.go b/backend/internal/service/payment_order_lifecycle.go index b627ced4ecc..ffe120d0166 100644 --- a/backend/internal/service/payment_order_lifecycle.go +++ b/backend/internal/service/payment_order_lifecycle.go @@ -26,8 +26,14 @@ const ( rateLimitModeFixed = "fixed" checkPaidResultAlreadyPaid = "already_paid" checkPaidResultCancelled = "cancelled" + + pendingWxpayReconcileLimit = 20 ) +type checkPaidOptions struct { + cancelIfUnpaid bool +} + func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error { if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 { return nil @@ -136,6 +142,14 @@ func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, } func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { + return s.checkPaidWithOptions(ctx, o, checkPaidOptions{cancelIfUnpaid: true}) +} + +func (s *PaymentService) reconcilePaid(ctx context.Context, o *dbent.PaymentOrder) string { + return s.checkPaidWithOptions(ctx, o, checkPaidOptions{}) +} + +func (s *PaymentService) checkPaidWithOptions(ctx context.Context, o *dbent.PaymentOrder, opts checkPaidOptions) string { prov, err := s.getOrderProvider(ctx, o) if err != nil { return "" @@ -182,6 +196,9 @@ func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) s } return checkPaidResultAlreadyPaid } + if !opts.cancelIfUnpaid { + return "" + } if cp, ok := prov.(payment.CancelableProvider); ok { _ = cp.CancelPayment(ctx, queryRef) } @@ -268,7 +285,7 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo } // Only verify orders that are still pending or recently expired if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) + result := s.reconcilePaid(ctx, o) if result == checkPaidResultAlreadyPaid { // Reload order to get updated status o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) @@ -280,6 +297,37 @@ func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo return o, nil } +// ReconcilePendingWxpayOrders actively checks recent pending WeChat orders so +// missed provider notifications do not wait until order expiry to fulfill. +func (s *PaymentService) ReconcilePendingWxpayOrders(ctx context.Context) (int, error) { + now := time.Now() + orders, err := s.entClient.PaymentOrder.Query(). + Where( + paymentorder.StatusEQ(OrderStatusPending), + paymentorder.ExpiresAtGT(now), + paymentorder.Or( + paymentorder.PaymentTypeEQ(payment.TypeWxpay), + paymentorder.PaymentTypeHasPrefix(payment.TypeWxpay+"_"), + paymentorder.ProviderKeyEQ(payment.TypeWxpay), + paymentorder.ProviderKeyHasPrefix(payment.TypeWxpay+"_"), + ), + ). + Order(dbent.Asc(paymentorder.FieldCreatedAt)). + Limit(pendingWxpayReconcileLimit). + All(ctx) + if err != nil { + return 0, fmt.Errorf("query pending wxpay orders: %w", err) + } + + recovered := 0 + for _, order := range orders { + if s.reconcilePaid(ctx, order) == checkPaidResultAlreadyPaid { + recovered++ + } + } + return recovered, nil +} + // VerifyOrderPublic returns the currently persisted public order state without // triggering any upstream reconciliation. Signed resume-token recovery is the // only public recovery path allowed to query upstream state. diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index d8595715a53..1964cdf6f90 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -20,10 +20,13 @@ import ( ) type paymentOrderLifecycleQueryProvider struct { - lastQueryTradeNo string - queryCalls int - responses []*payment.QueryOrderResponse - resp *payment.QueryOrderResponse + key string + lastQueryTradeNo string + lastCancelTradeNo string + queryCalls int + cancelCalls int + responses []*payment.QueryOrderResponse + resp *payment.QueryOrderResponse } type paymentOrderLifecycleRedeemRepo struct { @@ -38,10 +41,15 @@ func (p *paymentOrderLifecycleQueryProvider) Name() string { return "payment-order-lifecycle-query-provider" } -func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { return payment.TypeAlipay } +func (p *paymentOrderLifecycleQueryProvider) ProviderKey() string { + if p.key != "" { + return p.key + } + return payment.TypeAlipay +} func (p *paymentOrderLifecycleQueryProvider) SupportedTypes() []payment.PaymentType { - return []payment.PaymentType{payment.TypeAlipay} + return []payment.PaymentType{p.ProviderKey()} } func (p *paymentOrderLifecycleQueryProvider) CreatePayment(context.Context, payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { @@ -69,6 +77,12 @@ func (p *paymentOrderLifecycleQueryProvider) Refund(context.Context, payment.Ref panic("unexpected call") } +func (p *paymentOrderLifecycleQueryProvider) CancelPayment(_ context.Context, tradeNo string) error { + p.lastCancelTradeNo = tradeNo + p.cancelCalls++ + return nil +} + func (r *paymentOrderLifecycleRedeemRepo) Create(context.Context, *RedeemCode) error { panic("unexpected call") } @@ -435,6 +449,222 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { require.Empty(t, redeemRepo.useCalls) } +func TestVerifyOrderByOutTradeNoDoesNotCancelUnpaidUpstreamOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("checkpaid-pending@example.com"). + SetPasswordHash("hash"). + SetUsername("checkpaid-pending-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CHECKPAID-PENDING"). + SetOutTradeNo("sub2_checkpaid_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: order.OutTradeNo, + Status: payment.ProviderStatusPending, + Amount: 0, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + got, err := svc.VerifyOrderByOutTradeNo(ctx, order.OutTradeNo, user.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusPending, got.Status) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Zero(t, provider.cancelCalls) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusPending, reloaded.Status) +} + +func TestCancelOrderStillClosesUnpaidUpstreamOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("cancel-pending@example.com"). + SetPasswordHash("hash"). + SetUsername("cancel-pending-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(88). + SetPayAmount(88). + SetFeeRate(0). + SetRechargeCode("CANCEL-PENDING"). + SetOutTradeNo("sub2_cancel_pending"). + SetPaymentType(payment.TypeAlipay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + resp: &payment.QueryOrderResponse{ + TradeNo: order.OutTradeNo, + Status: payment.ProviderStatusPending, + Amount: 0, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + providersLoaded: true, + } + + outcome, err := svc.CancelOrder(ctx, order.ID, user.ID) + require.NoError(t, err) + require.Equal(t, checkPaidResultCancelled, outcome) + require.Equal(t, order.OutTradeNo, provider.lastCancelTradeNo) + require.Equal(t, 1, provider.cancelCalls) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCancelled, reloaded.Status) +} + +func TestReconcilePendingWxpayOrdersBackfillsPaidOrder(t *testing.T) { + ctx := context.Background() + client := newPaymentOrderLifecycleTestClient(t) + + user, err := client.User.Create(). + SetEmail("wxpay-reconcile@example.com"). + SetPasswordHash("hash"). + SetUsername("wxpay-reconcile-user"). + Save(ctx) + require.NoError(t, err) + + order, err := client.PaymentOrder.Create(). + SetUserID(user.ID). + SetUserEmail(user.Email). + SetUserName(user.Username). + SetAmount(50). + SetPayAmount(50). + SetFeeRate(0). + SetRechargeCode("WXPAY-RECONCILE"). + SetOutTradeNo("sub2_wxpay_reconcile"). + SetPaymentType(payment.TypeWxpay). + SetPaymentTradeNo(""). + SetOrderType(payment.OrderTypeBalance). + SetStatus(OrderStatusPending). + SetExpiresAt(time.Now().Add(time.Hour)). + SetClientIP("127.0.0.1"). + SetSrcHost("api.example.com"). + Save(ctx) + require.NoError(t, err) + + userRepo := &mockUserRepo{ + getByIDUser: &User{ + ID: user.ID, + Email: user.Email, + Username: user.Username, + Balance: 0, + }, + } + userRepo.updateBalanceFn = func(ctx context.Context, id int64, amount float64) error { + require.Equal(t, user.ID, id) + if userRepo.getByIDUser != nil { + userRepo.getByIDUser.Balance += amount + } + return nil + } + redeemRepo := &paymentOrderLifecycleRedeemRepo{ + codesByCode: map[string]*RedeemCode{ + order.RechargeCode: { + ID: 1, + Code: order.RechargeCode, + Type: RedeemTypeBalance, + Value: order.Amount, + Status: StatusUnused, + }, + }, + } + redeemService := NewRedeemService( + redeemRepo, + userRepo, + nil, + nil, + nil, + client, + nil, + nil, + ) + registry := payment.NewRegistry() + provider := &paymentOrderLifecycleQueryProvider{ + key: payment.TypeWxpay, + resp: &payment.QueryOrderResponse{ + TradeNo: "wxpay-upstream-trade-123", + Status: payment.ProviderStatusPaid, + Amount: 50, + Metadata: map[string]string{ + "trade_state": "SUCCESS", + }, + }, + } + registry.Register(provider) + + svc := &PaymentService{ + entClient: client, + registry: registry, + redeemService: redeemService, + userRepo: userRepo, + providersLoaded: true, + } + + recovered, err := svc.ReconcilePendingWxpayOrders(ctx) + require.NoError(t, err) + require.Equal(t, 1, recovered) + require.Equal(t, order.OutTradeNo, provider.lastQueryTradeNo) + require.Zero(t, provider.cancelCalls) + + reloaded, err := client.PaymentOrder.Get(ctx, order.ID) + require.NoError(t, err) + require.Equal(t, OrderStatusCompleted, reloaded.Status) + require.Equal(t, "wxpay-upstream-trade-123", reloaded.PaymentTradeNo) + require.Equal(t, 50.0, userRepo.getByIDUser.Balance) + require.Len(t, redeemRepo.useCalls, 1) +} + func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsForAlipay(t *testing.T) { ctx := context.Background() client := newPaymentOrderLifecycleTestClient(t) diff --git a/backend/internal/service/payment_order_result_test.go b/backend/internal/service/payment_order_result_test.go index f78d6b37fec..bfe275481cf 100644 --- a/backend/internal/service/payment_order_result_test.go +++ b/backend/internal/service/payment_order_result_test.go @@ -138,6 +138,41 @@ func TestCalculateCreateOrderPayAmountRejectsFractionalZeroDecimal(t *testing.T) } } +func TestBuildPaymentSubjectAppliesAffixToSubscriptionPlanProductName(t *testing.T) { + t.Parallel() + + svc := &PaymentService{} + cfg := &PaymentConfig{ + ProductNamePrefix: "PRE", + ProductNameSuffix: "SUF", + } + plan := &dbent.SubscriptionPlan{ + Name: "Pro Monthly", + ProductName: "Claude Pro", + } + + got := svc.buildPaymentSubject(plan, 0, cfg, nil) + if got != "PRE Claude Pro SUF" { + t.Fatalf("buildPaymentSubject() = %q, want %q", got, "PRE Claude Pro SUF") + } +} + +func TestBuildPaymentSubjectAppliesAffixToSubscriptionPlanDefaultName(t *testing.T) { + t.Parallel() + + svc := &PaymentService{} + cfg := &PaymentConfig{ + ProductNamePrefix: "PRE", + ProductNameSuffix: "SUF", + } + plan := &dbent.SubscriptionPlan{Name: "Team Monthly"} + + got := svc.buildPaymentSubject(plan, 0, cfg, nil) + if got != "PRE Sub2API Subscription Team Monthly SUF" { + t.Fatalf("buildPaymentSubject() = %q, want %q", got, "PRE Sub2API Subscription Team Monthly SUF") + } +} + func TestMaybeBuildWeChatOAuthRequiredResponse(t *testing.T) { t.Setenv("PAYMENT_RESUME_SIGNING_KEY", "0123456789abcdef0123456789abcdef") diff --git a/backend/internal/service/payment_resume_lookup.go b/backend/internal/service/payment_resume_lookup.go index 1ff061e8cac..fb41ced4ed3 100644 --- a/backend/internal/service/payment_resume_lookup.go +++ b/backend/internal/service/payment_resume_lookup.go @@ -46,7 +46,7 @@ func (s *PaymentService) GetPublicOrderByResumeToken(ctx context.Context, token return nil, invalidResumeTokenMatchError() } if order.Status == OrderStatusPending || order.Status == OrderStatusExpired { - result := s.checkPaid(ctx, order) + result := s.reconcilePaid(ctx, order) if result == checkPaidResultAlreadyPaid { order, err = s.entClient.PaymentOrder.Get(ctx, order.ID) if err != nil { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 3c3e2c5ba6c..d0b46886103 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -2,6 +2,8 @@ package service import ( "encoding/json" + "os" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -111,6 +113,22 @@ func TestGetModelPricing_OpenAICompactAliasUsesStaticFallback(t *testing.T) { require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) } +func TestDefaultPricingIncludesCodexAutoReview(t *testing.T) { + data, err := os.ReadFile(filepath.Join("..", "..", "resources", "model-pricing", "model_prices_and_context_window.json")) + require.NoError(t, err) + + svc := &PricingService{} + pricingData, err := svc.parsePricingData(data) + require.NoError(t, err) + svc.pricingData = pricingData + + got := svc.GetModelPricing("codex-auto-review") + require.NotNil(t, got) + require.InDelta(t, 2.5e-6, got.InputCostPerToken, 1e-12) + require.InDelta(t, 1.5e-5, got.OutputCostPerToken, 1e-12) + require.InDelta(t, 2.5e-7, got.CacheReadInputTokenCost, 1e-12) +} + func TestGetModelPricing_Gpt54MiniUsesDedicatedStaticFallbackWhenRemoteMissing(t *testing.T) { svc := &PricingService{ pricingData: map[string]*LiteLLMModelPricing{ diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 19c45a5a7eb..892d9aca86b 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -209,6 +209,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err) } } + // 缺少 refresh_token 的 OAuth 账号无法在冷却期内自愈(后台刷新服务也会跳过), + // 直接走 SetError 永久禁用,避免冷却结束后再被选中产生一发无意义的 502。 + if strings.TrimSpace(account.GetCredential("refresh_token")) == "" { + msg := "Authentication failed (401): refresh_token missing, cannot recover" + if upstreamMsg != "" { + msg = "OAuth 401 (no refresh_token): " + upstreamMsg + } + s.handleAuthError(ctx, account, msg) + shouldDisable = true + break + } // 2. 设置 expires_at 为当前时间,强制下次请求刷新 token if account.Credentials == nil { account.Credentials = make(map[string]any) diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 73b7849fbac..a964775eaca 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -85,6 +85,7 @@ func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *t Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{ + "refresh_token": "rt-100", "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ @@ -138,6 +139,9 @@ func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testin ID: 101, Platform: PlatformOpenAI, Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt-101", + }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) @@ -175,7 +179,8 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t * Platform: PlatformOpenAI, Type: AccountTypeOAuth, Credentials: map[string]any{ - "access_token": "token", + "access_token": "token", + "refresh_token": "rt-103", }, } @@ -185,3 +190,52 @@ func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t * require.Equal(t, 1, repo.updateCredentialsCalls) require.NotEmpty(t, repo.lastCredentials["expires_at"]) } + +// 缺少 refresh_token 的 OAuth 账号 401 应直接 SetError 永久禁用, +// 不再走 10 分钟冷却(冷却期内无人能刷新它,结束后还会被选中再 502 一次)。 +func TestRateLimitService_HandleUpstreamError_OAuth401NoRefreshTokenSetsError(t *testing.T) { + t.Run("openai_no_refresh_token", func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + invalidator := &tokenCacheInvalidatorRecorder{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetTokenCacheInvalidator(invalidator) + account := &Account{ + ID: 2881, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "expired-at", + // no refresh_token + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls, "AT-only OAuth 401 must SetError") + require.Equal(t, 0, repo.tempCalls, "AT-only OAuth 401 must NOT temp-unschedule") + require.Equal(t, 0, repo.updateCredentialsCalls, "no point forcing expires_at when refresh is impossible") + require.Contains(t, repo.lastErrorMsg, "refresh_token missing") + require.Len(t, invalidator.accounts, 1, "cache should still be invalidated") + }) + + t.Run("openai_blank_refresh_token_treated_as_missing", func(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 2882, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "expired-at", + "refresh_token": " ", + }, + } + + shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + }) +} diff --git a/backend/internal/service/redeem_code.go b/backend/internal/service/redeem_code.go index a66b53bad37..55abcfb3e1d 100644 --- a/backend/internal/service/redeem_code.go +++ b/backend/internal/service/redeem_code.go @@ -16,6 +16,7 @@ type RedeemCode struct { UsedAt *time.Time Notes string CreatedAt time.Time + ExpiresAt *time.Time GroupID *int64 ValidityDays int @@ -28,8 +29,22 @@ func (r *RedeemCode) IsUsed() bool { return r.Status == StatusUsed } +func (r *RedeemCode) IsExpired() bool { + return r.IsExpiredAt(time.Now()) +} + +func (r *RedeemCode) IsExpiredAt(now time.Time) bool { + if r == nil { + return false + } + if r.Status == StatusExpired { + return true + } + return r.Status == StatusUnused && r.ExpiresAt != nil && !r.ExpiresAt.After(now) +} + func (r *RedeemCode) CanUse() bool { - return r.Status == StatusUnused + return r.Status == StatusUnused && !r.IsExpired() } func GenerateRedeemCode() (string, error) { diff --git a/backend/internal/service/redeem_code_test.go b/backend/internal/service/redeem_code_test.go new file mode 100644 index 00000000000..ba5c7e7cc24 --- /dev/null +++ b/backend/internal/service/redeem_code_test.go @@ -0,0 +1,59 @@ +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestRedeemCodeExpiry(t *testing.T) { + now := time.Now().UTC() + past := now.Add(-time.Hour) + future := now.Add(time.Hour) + + tests := []struct { + name string + code RedeemCode + wantExpired bool + wantCanUse bool + }{ + { + name: "unused without expiry can be used", + code: RedeemCode{Status: StatusUnused}, + wantExpired: false, + wantCanUse: true, + }, + { + name: "unused before expiry can be used", + code: RedeemCode{Status: StatusUnused, ExpiresAt: &future}, + wantExpired: false, + wantCanUse: true, + }, + { + name: "unused after expiry cannot be used", + code: RedeemCode{Status: StatusUnused, ExpiresAt: &past}, + wantExpired: true, + wantCanUse: false, + }, + { + name: "explicit expired status is expired", + code: RedeemCode{Status: StatusExpired}, + wantExpired: true, + wantCanUse: false, + }, + { + name: "used code remains used even after expiry time", + code: RedeemCode{Status: StatusUsed, ExpiresAt: &past}, + wantExpired: false, + wantCanUse: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.wantExpired, tt.code.IsExpiredAt(now)) + require.Equal(t, tt.wantCanUse, tt.code.CanUse()) + }) + } +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index dcf293c565a..73aa02b136e 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -18,6 +18,7 @@ import ( var ( ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") + ErrRedeemCodeExpired = infraerrors.Conflict("REDEEM_CODE_EXPIRED", "redeem code expired") ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") @@ -207,6 +208,9 @@ func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error if code.Status == "" { code.Status = StatusUnused } + if code.IsExpired() { + return ErrRedeemCodeExpired + } if err := s.redeemRepo.Create(ctx, code); err != nil { return fmt.Errorf("create redeem code: %w", err) @@ -289,7 +293,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( return nil, fmt.Errorf("get redeem code: %w", err) } - // 检查兑换码状态 + // 检查兑换码状态和码本身的过期时间 + if redeemCode.IsExpired() { + s.incrementRedeemErrorCount(ctx, userID) + return nil, ErrRedeemCodeExpired + } if !redeemCode.CanUse() { s.incrementRedeemErrorCount(ctx, userID) return nil, ErrRedeemCodeUsed diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 86978eecc4a..a5c16b1fe5d 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -24,6 +24,25 @@ import ( "golang.org/x/sync/singleflight" ) +// CoerceDingTalkCorpPolicyForWrite 是 coerceDeprecatedDingTalkCorpPolicy 的导出版本, +// 用于 admin handler 在写入路径上对客户端直传的入参做防御性 coerce(前端 UI 虽已无 whitelist 选项, +// 但 API 可被直接调用)。 +func CoerceDingTalkCorpPolicyForWrite(policy string) string { + return coerceDeprecatedDingTalkCorpPolicy(policy) +} + +// coerceDeprecatedDingTalkCorpPolicy 把已废弃的 corp_restriction_policy 值替换成安全的等价值。 +// 升级前残留在 DB 中的 "whitelist" 会导致 callback 链路在 default case 静默 fail-closed +// (所有钉钉登录被拒)。这里统一退化为 "none" 让服务保持可用,并 warn 日志提醒 admin 重新保存设置。 +func coerceDeprecatedDingTalkCorpPolicy(policy string) string { + if policy == "whitelist" { + slog.Warn("dingtalk: corp_restriction_policy=whitelist is deprecated and unsupported, coercing to none", + "hint", "re-save DingTalk settings in admin UI to clear this warning") + return "none" + } + return policy +} + var ( ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") @@ -146,6 +165,7 @@ type AuthSourceDefaultSettings struct { WeChat ProviderDefaultGrantSettings GitHub ProviderDefaultGrantSettings Google ProviderDefaultGrantSettings + DingTalk ProviderDefaultGrantSettings ForceEmailOnThirdPartySignup bool } @@ -200,6 +220,13 @@ var ( grantOnSignup: SettingKeyAuthSourceDefaultGoogleGrantOnSignup, grantOnFirstBind: SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, } + dingTalkAuthSourceDefaultKeys = authSourceDefaultKeySet{ + balance: SettingKeyAuthSourceDefaultDingTalkBalance, + concurrency: SettingKeyAuthSourceDefaultDingTalkConcurrency, + subscriptions: SettingKeyAuthSourceDefaultDingTalkSubscriptions, + grantOnSignup: SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, + grantOnFirstBind: SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, + } ) const ( @@ -606,6 +633,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyCustomMenuItems, SettingKeyCustomEndpoints, SettingKeyLinuxDoConnectEnabled, + SettingKeyDingTalkConnectEnabled, SettingKeyWeChatConnectEnabled, SettingKeyWeChatConnectAppID, SettingKeyWeChatConnectAppSecret, @@ -654,6 +682,12 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings } else { linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled } + dingTalkEnabled := false + if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok { + dingTalkEnabled = raw == "true" + } else { + dingTalkEnabled = s.cfg != nil && s.cfg.DingTalk.Enabled + } oidcEnabled := false if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok { oidcEnabled = raw == "true" @@ -723,6 +757,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings CustomMenuItems: settings[SettingKeyCustomMenuItems], CustomEndpoints: settings[SettingKeyCustomEndpoints], LinuxDoOAuthEnabled: linuxDoEnabled, + DingTalkOAuthEnabled: dingTalkEnabled, WeChatOAuthEnabled: weChatEnabled, WeChatOAuthOpenEnabled: weChatOpenEnabled, WeChatOAuthMPEnabled: weChatMPEnabled, @@ -926,6 +961,7 @@ type PublicSettingsInjectionPayload struct { CustomMenuItems json.RawMessage `json:"custom_menu_items"` CustomEndpoints json.RawMessage `json:"custom_endpoints"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + DingTalkOAuthEnabled bool `json:"dingtalk_oauth_enabled"` WeChatOAuthEnabled bool `json:"wechat_oauth_enabled"` WeChatOAuthOpenEnabled bool `json:"wechat_oauth_open_enabled"` WeChatOAuthMPEnabled bool `json:"wechat_oauth_mp_enabled"` @@ -990,6 +1026,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), CustomEndpoints: safeRawJSONArray(settings.CustomEndpoints), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + DingTalkOAuthEnabled: settings.DingTalkOAuthEnabled, WeChatOAuthEnabled: settings.WeChatOAuthEnabled, WeChatOAuthOpenEnabled: settings.WeChatOAuthOpenEnabled, WeChatOAuthMPEnabled: settings.WeChatOAuthMPEnabled, @@ -1476,6 +1513,26 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret } + // DingTalk Connect OAuth 登录 + updates[SettingKeyDingTalkConnectEnabled] = strconv.FormatBool(settings.DingTalkConnectEnabled) + updates[SettingKeyDingTalkConnectClientID] = settings.DingTalkConnectClientID + updates[SettingKeyDingTalkConnectRedirectURL] = settings.DingTalkConnectRedirectURL + if settings.DingTalkConnectClientSecret != "" { + updates[SettingKeyDingTalkConnectClientSecret] = settings.DingTalkConnectClientSecret + } + updates[SettingKeyDingTalkConnectCorpRestrictionPolicy] = settings.DingTalkConnectCorpRestrictionPolicy + updates[SettingKeyDingTalkConnectInternalCorpID] = settings.DingTalkConnectInternalCorpID + updates[SettingKeyDingTalkConnectBypassRegistration] = strconv.FormatBool(settings.DingTalkConnectBypassRegistration) + updates[SettingKeyDingTalkConnectSyncCorpEmail] = strconv.FormatBool(settings.DingTalkConnectSyncCorpEmail) + updates[SettingKeyDingTalkConnectSyncDisplayName] = strconv.FormatBool(settings.DingTalkConnectSyncDisplayName) + updates[SettingKeyDingTalkConnectSyncDept] = strconv.FormatBool(settings.DingTalkConnectSyncDept) + updates[SettingKeyDingTalkConnectSyncCorpEmailAttrKey] = settings.DingTalkConnectSyncCorpEmailAttrKey + updates[SettingKeyDingTalkConnectSyncDisplayNameAttrKey] = settings.DingTalkConnectSyncDisplayNameAttrKey + updates[SettingKeyDingTalkConnectSyncDeptAttrKey] = settings.DingTalkConnectSyncDeptAttrKey + updates[SettingKeyDingTalkConnectSyncCorpEmailAttrName] = settings.DingTalkConnectSyncCorpEmailAttrName + updates[SettingKeyDingTalkConnectSyncDisplayNameAttrName] = settings.DingTalkConnectSyncDisplayNameAttrName + updates[SettingKeyDingTalkConnectSyncDeptAttrName] = settings.DingTalkConnectSyncDeptAttrName + // Generic OIDC OAuth 登录 updates[SettingKeyOIDCConnectEnabled] = strconv.FormatBool(settings.OIDCConnectEnabled) updates[SettingKeyOIDCConnectProviderName] = settings.OIDCConnectProviderName @@ -1677,19 +1734,21 @@ func (s *SettingService) buildAuthSourceDefaultUpdates(ctx context.Context, sett settings.WeChat.Subscriptions, settings.GitHub.Subscriptions, settings.Google.Subscriptions, + settings.DingTalk.Subscriptions, } { if err := s.validateDefaultSubscriptionGroups(ctx, subscriptions); err != nil { return nil, err } } - updates := make(map[string]string, 31) + updates := make(map[string]string, 36) writeProviderDefaultGrantUpdates(updates, emailAuthSourceDefaultKeys, settings.Email) writeProviderDefaultGrantUpdates(updates, linuxDoAuthSourceDefaultKeys, settings.LinuxDo) writeProviderDefaultGrantUpdates(updates, oidcAuthSourceDefaultKeys, settings.OIDC) writeProviderDefaultGrantUpdates(updates, weChatAuthSourceDefaultKeys, settings.WeChat) writeProviderDefaultGrantUpdates(updates, gitHubAuthSourceDefaultKeys, settings.GitHub) writeProviderDefaultGrantUpdates(updates, googleAuthSourceDefaultKeys, settings.Google) + writeProviderDefaultGrantUpdates(updates, dingTalkAuthSourceDefaultKeys, settings.DingTalk) updates[SettingKeyForceEmailOnThirdPartySignup] = strconv.FormatBool(settings.ForceEmailOnThirdPartySignup) return updates, nil } @@ -2225,6 +2284,11 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut SettingKeyAuthSourceDefaultGoogleSubscriptions, SettingKeyAuthSourceDefaultGoogleGrantOnSignup, SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind, + SettingKeyAuthSourceDefaultDingTalkBalance, + SettingKeyAuthSourceDefaultDingTalkConcurrency, + SettingKeyAuthSourceDefaultDingTalkSubscriptions, + SettingKeyAuthSourceDefaultDingTalkGrantOnSignup, + SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind, SettingKeyForceEmailOnThirdPartySignup, } @@ -2240,6 +2304,7 @@ func (s *SettingService) GetAuthSourceDefaultSettings(ctx context.Context) (*Aut WeChat: parseProviderDefaultGrantSettings(settings, weChatAuthSourceDefaultKeys), GitHub: parseProviderDefaultGrantSettings(settings, gitHubAuthSourceDefaultKeys), Google: parseProviderDefaultGrantSettings(settings, googleAuthSourceDefaultKeys), + DingTalk: parseProviderDefaultGrantSettings(settings, dingTalkAuthSourceDefaultKeys), ForceEmailOnThirdPartySignup: settings[SettingKeyForceEmailOnThirdPartySignup] == "true", }, nil } @@ -2316,111 +2381,116 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 初始化默认设置 defaults := map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyEmailVerifyEnabled: "false", - SettingKeyRegistrationEmailSuffixWhitelist: "[]", - SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 - SettingKeyLoginAgreementEnabled: "false", - SettingKeyLoginAgreementMode: defaultLoginAgreementMode, - SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate, - SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON, - SettingKeySiteName: "Sub2API", - SettingKeySiteLogo: "", - SettingKeyPurchaseSubscriptionEnabled: "false", - SettingKeyPurchaseSubscriptionURL: "", - SettingKeyTableDefaultPageSize: "20", - SettingKeyTablePageSizeOptions: "[10,20,50,100]", - SettingKeyCustomMenuItems: "[]", - SettingKeyCustomEndpoints: "[]", - SettingKeyWeChatConnectEnabled: "false", - SettingKeyWeChatConnectAppID: "", - SettingKeyWeChatConnectAppSecret: "", - SettingKeyWeChatConnectOpenAppID: "", - SettingKeyWeChatConnectOpenAppSecret: "", - SettingKeyWeChatConnectMPAppID: "", - SettingKeyWeChatConnectMPAppSecret: "", - SettingKeyWeChatConnectMobileAppID: "", - SettingKeyWeChatConnectMobileAppSecret: "", - SettingKeyWeChatConnectOpenEnabled: "false", - SettingKeyWeChatConnectMPEnabled: "false", - SettingKeyWeChatConnectMobileEnabled: "false", - SettingKeyWeChatConnectMode: "open", - SettingKeyWeChatConnectScopes: "snsapi_login", - SettingKeyWeChatConnectRedirectURL: "", - SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, - SettingKeyGitHubOAuthEnabled: "false", - SettingKeyGitHubOAuthClientID: "", - SettingKeyGitHubOAuthClientSecret: "", - SettingKeyGitHubOAuthRedirectURL: "", - SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend, - SettingKeyGoogleOAuthEnabled: "false", - SettingKeyGoogleOAuthClientID: "", - SettingKeyGoogleOAuthClientSecret: "", - SettingKeyGoogleOAuthRedirectURL: "", - SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend, - SettingKeyOIDCConnectEnabled: "false", - SettingKeyOIDCConnectProviderName: "OIDC", - SettingKeyOIDCConnectClientID: "", - SettingKeyOIDCConnectClientSecret: "", - SettingKeyOIDCConnectIssuerURL: "", - SettingKeyOIDCConnectDiscoveryURL: "", - SettingKeyOIDCConnectAuthorizeURL: "", - SettingKeyOIDCConnectTokenURL: "", - SettingKeyOIDCConnectUserInfoURL: "", - SettingKeyOIDCConnectJWKSURL: "", - SettingKeyOIDCConnectScopes: "openid email profile", - SettingKeyOIDCConnectRedirectURL: "", - SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", - SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", - SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault), - SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault), - SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", - SettingKeyOIDCConnectClockSkewSeconds: "120", - SettingKeyOIDCConnectRequireEmailVerified: "false", - SettingKeyOIDCConnectUserInfoEmailPath: "", - SettingKeyOIDCConnectUserInfoIDPath: "", - SettingKeyOIDCConnectUserInfoUsernamePath: "", - SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), - SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), - SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault), - SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault), - SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64), - SettingKeyDefaultUserRPMLimit: "0", - SettingKeyDefaultSubscriptions: "[]", - SettingKeyAuthSourceDefaultEmailBalance: "0", - SettingKeyAuthSourceDefaultEmailConcurrency: "5", - SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", - SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", - SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", - SettingKeyAuthSourceDefaultLinuxDoBalance: "0", - SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", - SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", - SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false", - SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", - SettingKeyAuthSourceDefaultOIDCBalance: "0", - SettingKeyAuthSourceDefaultOIDCConcurrency: "5", - SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", - SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false", - SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", - SettingKeyAuthSourceDefaultWeChatBalance: "0", - SettingKeyAuthSourceDefaultWeChatConcurrency: "5", - SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", - SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false", - SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", - SettingKeyAuthSourceDefaultGitHubBalance: "0", - SettingKeyAuthSourceDefaultGitHubConcurrency: "5", - SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]", - SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false", - SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false", - SettingKeyAuthSourceDefaultGoogleBalance: "0", - SettingKeyAuthSourceDefaultGoogleConcurrency: "5", - SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]", - SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false", - SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false", - SettingKeyForceEmailOnThirdPartySignup: "false", - SettingKeySMTPPort: "587", - SettingKeySMTPUseTLS: "false", + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "false", + SettingKeyRegistrationEmailSuffixWhitelist: "[]", + SettingKeyPromoCodeEnabled: "true", // 默认启用优惠码功能 + SettingKeyLoginAgreementEnabled: "false", + SettingKeyLoginAgreementMode: defaultLoginAgreementMode, + SettingKeyLoginAgreementUpdatedAt: defaultLoginAgreementDate, + SettingKeyLoginAgreementDocuments: loginAgreementDocumentsJSON, + SettingKeySiteName: "Sub2API", + SettingKeySiteLogo: "", + SettingKeyPurchaseSubscriptionEnabled: "false", + SettingKeyPurchaseSubscriptionURL: "", + SettingKeyTableDefaultPageSize: "20", + SettingKeyTablePageSizeOptions: "[10,20,50,100]", + SettingKeyCustomMenuItems: "[]", + SettingKeyCustomEndpoints: "[]", + SettingKeyWeChatConnectEnabled: "false", + SettingKeyWeChatConnectAppID: "", + SettingKeyWeChatConnectAppSecret: "", + SettingKeyWeChatConnectOpenAppID: "", + SettingKeyWeChatConnectOpenAppSecret: "", + SettingKeyWeChatConnectMPAppID: "", + SettingKeyWeChatConnectMPAppSecret: "", + SettingKeyWeChatConnectMobileAppID: "", + SettingKeyWeChatConnectMobileAppSecret: "", + SettingKeyWeChatConnectOpenEnabled: "false", + SettingKeyWeChatConnectMPEnabled: "false", + SettingKeyWeChatConnectMobileEnabled: "false", + SettingKeyWeChatConnectMode: "open", + SettingKeyWeChatConnectScopes: "snsapi_login", + SettingKeyWeChatConnectRedirectURL: "", + SettingKeyWeChatConnectFrontendRedirectURL: defaultWeChatConnectFrontend, + SettingKeyGitHubOAuthEnabled: "false", + SettingKeyGitHubOAuthClientID: "", + SettingKeyGitHubOAuthClientSecret: "", + SettingKeyGitHubOAuthRedirectURL: "", + SettingKeyGitHubOAuthFrontendRedirectURL: defaultGitHubOAuthFrontend, + SettingKeyGoogleOAuthEnabled: "false", + SettingKeyGoogleOAuthClientID: "", + SettingKeyGoogleOAuthClientSecret: "", + SettingKeyGoogleOAuthRedirectURL: "", + SettingKeyGoogleOAuthFrontendRedirectURL: defaultGoogleOAuthFrontend, + SettingKeyOIDCConnectEnabled: "false", + SettingKeyOIDCConnectProviderName: "OIDC", + SettingKeyOIDCConnectClientID: "", + SettingKeyOIDCConnectClientSecret: "", + SettingKeyOIDCConnectIssuerURL: "", + SettingKeyOIDCConnectDiscoveryURL: "", + SettingKeyOIDCConnectAuthorizeURL: "", + SettingKeyOIDCConnectTokenURL: "", + SettingKeyOIDCConnectUserInfoURL: "", + SettingKeyOIDCConnectJWKSURL: "", + SettingKeyOIDCConnectScopes: "openid email profile", + SettingKeyOIDCConnectRedirectURL: "", + SettingKeyOIDCConnectFrontendRedirectURL: "/auth/oidc/callback", + SettingKeyOIDCConnectTokenAuthMethod: "client_secret_post", + SettingKeyOIDCConnectUsePKCE: strconv.FormatBool(oidcUsePKCEDefault), + SettingKeyOIDCConnectValidateIDToken: strconv.FormatBool(oidcValidateIDTokenDefault), + SettingKeyOIDCConnectAllowedSigningAlgs: "RS256,ES256,PS256", + SettingKeyOIDCConnectClockSkewSeconds: "120", + SettingKeyOIDCConnectRequireEmailVerified: "false", + SettingKeyOIDCConnectUserInfoEmailPath: "", + SettingKeyOIDCConnectUserInfoIDPath: "", + SettingKeyOIDCConnectUserInfoUsernamePath: "", + SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), + SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyAffiliateRebateRate: strconv.FormatFloat(AffiliateRebateRateDefault, 'f', 8, 64), + SettingKeyAffiliateRebateFreezeHours: strconv.Itoa(AffiliateRebateFreezeHoursDefault), + SettingKeyAffiliateRebateDurationDays: strconv.Itoa(AffiliateRebateDurationDaysDefault), + SettingKeyAffiliateRebatePerInviteeCap: strconv.FormatFloat(AffiliateRebatePerInviteeCapDefault, 'f', 2, 64), + SettingKeyDefaultUserRPMLimit: "0", + SettingKeyDefaultSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailBalance: "0", + SettingKeyAuthSourceDefaultEmailConcurrency: "5", + SettingKeyAuthSourceDefaultEmailSubscriptions: "[]", + SettingKeyAuthSourceDefaultEmailGrantOnSignup: "false", + SettingKeyAuthSourceDefaultEmailGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultLinuxDoBalance: "0", + SettingKeyAuthSourceDefaultLinuxDoConcurrency: "5", + SettingKeyAuthSourceDefaultLinuxDoSubscriptions: "[]", + SettingKeyAuthSourceDefaultLinuxDoGrantOnSignup: "false", + SettingKeyAuthSourceDefaultLinuxDoGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultOIDCBalance: "0", + SettingKeyAuthSourceDefaultOIDCConcurrency: "5", + SettingKeyAuthSourceDefaultOIDCSubscriptions: "[]", + SettingKeyAuthSourceDefaultOIDCGrantOnSignup: "false", + SettingKeyAuthSourceDefaultOIDCGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultWeChatBalance: "0", + SettingKeyAuthSourceDefaultWeChatConcurrency: "5", + SettingKeyAuthSourceDefaultWeChatSubscriptions: "[]", + SettingKeyAuthSourceDefaultWeChatGrantOnSignup: "false", + SettingKeyAuthSourceDefaultWeChatGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultGitHubBalance: "0", + SettingKeyAuthSourceDefaultGitHubConcurrency: "5", + SettingKeyAuthSourceDefaultGitHubSubscriptions: "[]", + SettingKeyAuthSourceDefaultGitHubGrantOnSignup: "false", + SettingKeyAuthSourceDefaultGitHubGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultGoogleBalance: "0", + SettingKeyAuthSourceDefaultGoogleConcurrency: "5", + SettingKeyAuthSourceDefaultGoogleSubscriptions: "[]", + SettingKeyAuthSourceDefaultGoogleGrantOnSignup: "false", + SettingKeyAuthSourceDefaultGoogleGrantOnFirstBind: "false", + SettingKeyAuthSourceDefaultDingTalkBalance: "0", + SettingKeyAuthSourceDefaultDingTalkConcurrency: "5", + SettingKeyAuthSourceDefaultDingTalkSubscriptions: "[]", + SettingKeyAuthSourceDefaultDingTalkGrantOnSignup: "false", + SettingKeyAuthSourceDefaultDingTalkGrantOnFirstBind: "false", + SettingKeyForceEmailOnThirdPartySignup: "false", + SettingKeySMTPPort: "587", + SettingKeySMTPUseTLS: "false", // Model fallback defaults SettingKeyEnableModelFallback: "false", SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022", @@ -2599,6 +2669,136 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + // DingTalk Connect 设置: + // - 兼容 config.yaml/env + // - 支持后台系统设置覆盖并持久化(存储于 DB) + dingTalkBase := config.DingTalkConnectConfig{} + if s.cfg != nil { + dingTalkBase = s.cfg.DingTalk + } + + if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok { + result.DingTalkConnectEnabled = raw == "true" + } else { + result.DingTalkConnectEnabled = dingTalkBase.Enabled + } + + if v, ok := settings[SettingKeyDingTalkConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectClientID = strings.TrimSpace(v) + } else { + result.DingTalkConnectClientID = dingTalkBase.ClientID + } + + if v, ok := settings[SettingKeyDingTalkConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectRedirectURL = strings.TrimSpace(v) + } else { + result.DingTalkConnectRedirectURL = dingTalkBase.RedirectURL + } + + result.DingTalkConnectClientSecret = strings.TrimSpace(settings[SettingKeyDingTalkConnectClientSecret]) + if result.DingTalkConnectClientSecret == "" { + result.DingTalkConnectClientSecret = strings.TrimSpace(dingTalkBase.ClientSecret) + } + result.DingTalkConnectClientSecretConfigured = result.DingTalkConnectClientSecret != "" + + if v, ok := settings[SettingKeyDingTalkConnectCorpRestrictionPolicy]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectCorpRestrictionPolicy = strings.TrimSpace(v) + } else { + result.DingTalkConnectCorpRestrictionPolicy = dingTalkBase.CorpRestrictionPolicy + } + result.DingTalkConnectCorpRestrictionPolicy = coerceDeprecatedDingTalkCorpPolicy(result.DingTalkConnectCorpRestrictionPolicy) + + if v, ok := settings[SettingKeyDingTalkConnectInternalCorpID]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectInternalCorpID = strings.TrimSpace(v) + } else { + result.DingTalkConnectInternalCorpID = dingTalkBase.InternalCorpID + } + + if v, ok := settings[SettingKeyDingTalkConnectBypassRegistration]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectBypassRegistration = strings.EqualFold(strings.TrimSpace(v), "true") + } else { + result.DingTalkConnectBypassRegistration = dingTalkBase.BypassRegistration + } + // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制 false, + // 以保证加载出的 effective config 永远是一致状态。 + if result.DingTalkConnectCorpRestrictionPolicy != "internal_only" { + result.DingTalkConnectBypassRegistration = false + } + + if v, ok := settings[SettingKeyDingTalkConnectSyncCorpEmail]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectSyncCorpEmail = strings.EqualFold(strings.TrimSpace(v), "true") + } else { + result.DingTalkConnectSyncCorpEmail = dingTalkBase.SyncCorpEmail + } + if v, ok := settings[SettingKeyDingTalkConnectSyncDisplayName]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectSyncDisplayName = strings.EqualFold(strings.TrimSpace(v), "true") + } else { + result.DingTalkConnectSyncDisplayName = dingTalkBase.SyncDisplayName + } + if v, ok := settings[SettingKeyDingTalkConnectSyncDept]; ok && strings.TrimSpace(v) != "" { + result.DingTalkConnectSyncDept = strings.EqualFold(strings.TrimSpace(v), "true") + } else { + result.DingTalkConnectSyncDept = dingTalkBase.SyncDept + } + // 身份同步三开关仅在 internal_only 模式下有意义;其它策略强制 false。 + if result.DingTalkConnectCorpRestrictionPolicy != "internal_only" { + result.DingTalkConnectSyncCorpEmail = false + result.DingTalkConnectSyncDisplayName = false + result.DingTalkConnectSyncDept = false + } + + // 身份同步目标 attr key(DB 空 → fallback 默认值) + result.DingTalkConnectSyncCorpEmailAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrKey]) + if result.DingTalkConnectSyncCorpEmailAttrKey == "" { + if v := strings.TrimSpace(dingTalkBase.SyncCorpEmailAttrKey); v != "" { + result.DingTalkConnectSyncCorpEmailAttrKey = v + } else { + result.DingTalkConnectSyncCorpEmailAttrKey = "dingtalk_email" + } + } + result.DingTalkConnectSyncDisplayNameAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrKey]) + if result.DingTalkConnectSyncDisplayNameAttrKey == "" { + if v := strings.TrimSpace(dingTalkBase.SyncDisplayNameAttrKey); v != "" { + result.DingTalkConnectSyncDisplayNameAttrKey = v + } else { + result.DingTalkConnectSyncDisplayNameAttrKey = "dingtalk_name" + } + } + result.DingTalkConnectSyncDeptAttrKey = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrKey]) + if result.DingTalkConnectSyncDeptAttrKey == "" { + if v := strings.TrimSpace(dingTalkBase.SyncDeptAttrKey); v != "" { + result.DingTalkConnectSyncDeptAttrKey = v + } else { + result.DingTalkConnectSyncDeptAttrKey = "dingtalk_department" + } + } + + // 身份同步目标 attr 显示名称(DB 空 → fallback 默认中文) + result.DingTalkConnectSyncCorpEmailAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrName]) + if result.DingTalkConnectSyncCorpEmailAttrName == "" { + if v := strings.TrimSpace(dingTalkBase.SyncCorpEmailAttrName); v != "" { + result.DingTalkConnectSyncCorpEmailAttrName = v + } else { + result.DingTalkConnectSyncCorpEmailAttrName = "钉钉企业邮箱" + } + } + result.DingTalkConnectSyncDisplayNameAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrName]) + if result.DingTalkConnectSyncDisplayNameAttrName == "" { + if v := strings.TrimSpace(dingTalkBase.SyncDisplayNameAttrName); v != "" { + result.DingTalkConnectSyncDisplayNameAttrName = v + } else { + result.DingTalkConnectSyncDisplayNameAttrName = "钉钉姓名" + } + } + result.DingTalkConnectSyncDeptAttrName = strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrName]) + if result.DingTalkConnectSyncDeptAttrName == "" { + if v := strings.TrimSpace(dingTalkBase.SyncDeptAttrName); v != "" { + result.DingTalkConnectSyncDeptAttrName = v + } else { + result.DingTalkConnectSyncDeptAttrName = "钉钉部门" + } + } + // Generic OIDC 设置: // - 兼容 config.yaml/env // - 支持后台系统设置覆盖并持久化(存储于 DB) @@ -2992,10 +3192,14 @@ func mergeProviderDefaultGrantSettings(globalDefaults ProviderDefaultGrantSettin GrantOnFirstBind: providerDefaults.GrantOnFirstBind, } - if providerDefaults.Balance != defaultAuthSourceBalance { + // 注意:不能把 parse 默认值 (defaultAuthSourceBalance / defaultAuthSourceConcurrency) + // 当作"未配置"哨兵——admin 完全有权显式设成相同的值,那时仍应覆盖 globalDefaults。 + // 旧实现的 `!= defaultAuthSourceConcurrency` 会把 admin 设的 5 与 fallback 5 混淆, + // 导致渠道发放退回到全局默认(如 1),表现为"管理员设 5、新用户实际拿 1"。 + if providerDefaults.Balance >= 0 { result.Balance = providerDefaults.Balance } - if providerDefaults.Concurrency > 0 && providerDefaults.Concurrency != defaultAuthSourceConcurrency { + if providerDefaults.Concurrency > 0 { result.Concurrency = providerDefaults.Concurrency } if len(providerDefaults.Subscriptions) > 0 { @@ -3281,6 +3485,157 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf return effective, nil } +// GetDingTalkConnectOAuthConfig 返回用于登录的"最终生效" DingTalk Connect 配置。 +// +// 优先级: +// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值 +// - 否则回退到 config.yaml/env 的值 +func (s *SettingService) GetDingTalkConnectOAuthConfig(ctx context.Context) (config.DingTalkConnectConfig, error) { + if s == nil || s.cfg == nil { + return config.DingTalkConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + + effective := s.cfg.DingTalk + + keys := []string{ + SettingKeyDingTalkConnectEnabled, + SettingKeyDingTalkConnectClientID, + SettingKeyDingTalkConnectClientSecret, + SettingKeyDingTalkConnectRedirectURL, + SettingKeyDingTalkConnectCorpRestrictionPolicy, + SettingKeyDingTalkConnectInternalCorpID, + SettingKeyDingTalkConnectBypassRegistration, + SettingKeyDingTalkConnectSyncCorpEmail, + SettingKeyDingTalkConnectSyncDisplayName, + SettingKeyDingTalkConnectSyncDept, + SettingKeyDingTalkConnectSyncCorpEmailAttrKey, + SettingKeyDingTalkConnectSyncDisplayNameAttrKey, + SettingKeyDingTalkConnectSyncDeptAttrKey, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.DingTalkConnectConfig{}, fmt.Errorf("get dingtalk connect settings: %w", err) + } + + if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok { + effective.Enabled = raw == "true" + } + if v, ok := settings[SettingKeyDingTalkConnectClientID]; ok && strings.TrimSpace(v) != "" { + effective.ClientID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyDingTalkConnectClientSecret]; ok && strings.TrimSpace(v) != "" { + effective.ClientSecret = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyDingTalkConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.RedirectURL = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyDingTalkConnectCorpRestrictionPolicy]; ok && strings.TrimSpace(v) != "" { + effective.CorpRestrictionPolicy = strings.TrimSpace(v) + } + effective.CorpRestrictionPolicy = coerceDeprecatedDingTalkCorpPolicy(effective.CorpRestrictionPolicy) + if v, ok := settings[SettingKeyDingTalkConnectInternalCorpID]; ok && strings.TrimSpace(v) != "" { + effective.InternalCorpID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyDingTalkConnectBypassRegistration]; ok && strings.TrimSpace(v) != "" { + effective.BypassRegistration = strings.EqualFold(strings.TrimSpace(v), "true") + } + // bypass_registration 仅在 internal_only 模式下有意义;其它策略下强制 false, + // 以保证 OAuth callback 看到的 effective config 永远是一致状态。 + if effective.CorpRestrictionPolicy != "internal_only" { + effective.BypassRegistration = false + } + + if v, ok := settings[SettingKeyDingTalkConnectSyncCorpEmail]; ok && strings.TrimSpace(v) != "" { + effective.SyncCorpEmail = strings.EqualFold(strings.TrimSpace(v), "true") + } + if v, ok := settings[SettingKeyDingTalkConnectSyncDisplayName]; ok && strings.TrimSpace(v) != "" { + effective.SyncDisplayName = strings.EqualFold(strings.TrimSpace(v), "true") + } + if v, ok := settings[SettingKeyDingTalkConnectSyncDept]; ok && strings.TrimSpace(v) != "" { + effective.SyncDept = strings.EqualFold(strings.TrimSpace(v), "true") + } + // 身份同步三开关仅在 internal_only 模式下有意义;其它策略强制 false。 + if effective.CorpRestrictionPolicy != "internal_only" { + effective.SyncCorpEmail = false + effective.SyncDisplayName = false + effective.SyncDept = false + } + + // 身份同步目标 attr key(DB 空 → fallback 默认值) + if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncCorpEmailAttrKey]); v != "" { + effective.SyncCorpEmailAttrKey = v + } + if effective.SyncCorpEmailAttrKey == "" { + effective.SyncCorpEmailAttrKey = "dingtalk_email" + } + if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDisplayNameAttrKey]); v != "" { + effective.SyncDisplayNameAttrKey = v + } + if effective.SyncDisplayNameAttrKey == "" { + effective.SyncDisplayNameAttrKey = "dingtalk_name" + } + if v := strings.TrimSpace(settings[SettingKeyDingTalkConnectSyncDeptAttrKey]); v != "" { + effective.SyncDeptAttrKey = v + } + if effective.SyncDeptAttrKey == "" { + effective.SyncDeptAttrKey = "dingtalk_department" + } + + if !effective.Enabled { + return config.DingTalkConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "dingtalk oauth login is disabled") + } + + // 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。 + if strings.TrimSpace(effective.ClientID) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth client id not configured") + } + if strings.TrimSpace(effective.AuthorizeURL) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth authorize url not configured") + } + if strings.TrimSpace(effective.TokenURL) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth token url not configured") + } + if strings.TrimSpace(effective.UserInfoURL) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth userinfo url not configured") + } + if strings.TrimSpace(effective.RedirectURL) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth redirect url not configured") + } + if strings.TrimSpace(effective.FrontendRedirectURL) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth frontend redirect url not configured") + } + + if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth authorize url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth token url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth userinfo url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth frontend redirect url invalid") + } + if strings.TrimSpace(effective.ClientSecret) == "" { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "dingtalk oauth client secret not configured") + } + + // 镜像 admin handler 行为:internal_only policy 隐式要求 AppType=internal + if effective.CorpRestrictionPolicy == "internal_only" { + effective.AppType = "internal" + } + + if err := config.ValidateDingTalkConfig(effective); err != nil { + return config.DingTalkConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", err.Error()) + } + + return effective, nil +} + // GetWeChatConnectOAuthConfig 返回用于登录的最终生效 WeChat Connect 配置。 // // WeChat Connect 已回归 DB 系统设置模型,不再回退到 config/env。 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index bfe859951a2..ea5fa57c424 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -46,6 +46,25 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string + // DingTalk Connect OAuth 登录 + DingTalkConnectEnabled bool + DingTalkConnectClientID string + DingTalkConnectClientSecret string + DingTalkConnectClientSecretConfigured bool + DingTalkConnectRedirectURL string + DingTalkConnectCorpRestrictionPolicy string + DingTalkConnectInternalCorpID string + DingTalkConnectBypassRegistration bool + DingTalkConnectSyncCorpEmail bool + DingTalkConnectSyncDisplayName bool + DingTalkConnectSyncDept bool + DingTalkConnectSyncCorpEmailAttrKey string + DingTalkConnectSyncDisplayNameAttrKey string + DingTalkConnectSyncDeptAttrKey string + DingTalkConnectSyncCorpEmailAttrName string + DingTalkConnectSyncDisplayNameAttrName string + DingTalkConnectSyncDeptAttrName string + // WeChat Connect OAuth 登录 WeChatConnectEnabled bool WeChatConnectAppID string @@ -235,6 +254,7 @@ type PublicSettings struct { CustomEndpoints string // JSON array of custom endpoints LinuxDoOAuthEnabled bool + DingTalkOAuthEnabled bool WeChatOAuthEnabled bool WeChatOAuthOpenEnabled bool WeChatOAuthMPEnabled bool @@ -491,25 +511,10 @@ type OpenAIFastPolicySettings struct { } // DefaultOpenAIFastPolicySettings 返回默认的 OpenAI fast 策略配置。 -// 默认对所有模型的 priority(fast)请求执行 filter,即剔除 service_tier 字段, -// 让上游按 normal 优先级处理。 -// -// 为什么 ModelWhitelist 为空(=对所有模型生效): -// codex 客户端的 service_tier=fast 是用户级开关,与 model 字段正交。即使 -// 用户使用 gpt-4 + fast,priority 配额仍会被消耗。如果默认规则只锁 -// gpt-5.5*,"用 gpt-4 + fast 透传 priority 上游" 这条路径就会绕过策略。 -// 与 codex 真实语义对齐,默认对所有模型生效;管理员若需要只针对特定 -// 模型,可在 admin UI 中显式配置 model_whitelist。 +// 默认不配置任何规则,保留 OpenAI 上游 service_tier 语义;管理员如需 +// 限制 priority/flex,可以在 admin UI 中显式配置 filter 或 block 规则。 func DefaultOpenAIFastPolicySettings() *OpenAIFastPolicySettings { return &OpenAIFastPolicySettings{ - Rules: []OpenAIFastPolicyRule{ - { - ServiceTier: OpenAIFastTierPriority, - Action: BetaPolicyActionFilter, - Scope: BetaPolicyScopeAll, - ModelWhitelist: []string{}, - FallbackAction: BetaPolicyActionPass, - }, - }, + Rules: []OpenAIFastPolicyRule{}, } } diff --git a/backend/internal/service/subscription_assign_idempotency_test.go b/backend/internal/service/subscription_assign_idempotency_test.go index 40bab206063..c8ace613301 100644 --- a/backend/internal/service/subscription_assign_idempotency_test.go +++ b/backend/internal/service/subscription_assign_idempotency_test.go @@ -199,6 +199,24 @@ func (s *subscriptionUserSubRepoStub) GetByID(_ context.Context, id int64) (*Use return &cp, nil } +func (s *subscriptionUserSubRepoStub) Update(_ context.Context, sub *UserSubscription) error { + if sub == nil { + return ErrSubscriptionNilInput + } + existing := s.byID[sub.ID] + if existing == nil { + return ErrSubscriptionNotFound + } + oldKey := s.key(existing.UserID, existing.GroupID) + cp := *sub + s.byID[cp.ID] = &cp + if oldKey != s.key(cp.UserID, cp.GroupID) { + delete(s.byUserGroup, oldKey) + } + s.byUserGroup[s.key(cp.UserID, cp.GroupID)] = &cp + return nil +} + func TestAssignSubscriptionReuseWhenSemanticsMatch(t *testing.T) { start := time.Date(2026, 2, 20, 10, 0, 0, 0, time.UTC) groupRepo := &subscriptionGroupRepoStub{ diff --git a/backend/internal/service/subscription_calculate_progress_test.go b/backend/internal/service/subscription_calculate_progress_test.go index 53e5c56806d..650522d54c5 100644 --- a/backend/internal/service/subscription_calculate_progress_test.go +++ b/backend/internal/service/subscription_calculate_progress_test.go @@ -66,6 +66,30 @@ func TestCalculateProgress_DailyUsage(t *testing.T) { assert.Equal(t, dailyStart, progress.Daily.WindowStart) } +func TestCalculateProgress_DailyCardUsesExpiryAsDailyResetTime(t *testing.T) { + svc := newTestSubscriptionService() + startsAt := time.Now().Add(-12 * time.Hour) + dailyStart := time.Date(startsAt.Year(), startsAt.Month(), startsAt.Day(), 0, 0, 0, 0, startsAt.Location()) + expiresAt := startsAt.Add(24 * time.Hour) + + sub := &UserSubscription{ + ID: 1, + StartsAt: startsAt, + ExpiresAt: expiresAt, + DailyUsageUSD: 3.0, + DailyWindowStart: ptrTime(dailyStart), + } + group := &Group{ + Name: "Daily", + DailyLimitUSD: ptrFloat64(10.0), + } + + progress := svc.calculateProgress(sub, group) + + require.NotNil(t, progress.Daily, "日卡有日限额和窗口时 Daily 不应为 nil") + assert.Equal(t, expiresAt, progress.Daily.ResetsAt, "日卡的一次性日额度结束时间应为订阅过期时间") +} + func TestCalculateProgress_WeeklyUsage(t *testing.T) { svc := newTestSubscriptionService() now := time.Now() diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index f0a5540e0aa..9905e6a1e29 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -196,7 +196,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in now := time.Now() var newExpiresAt time.Time - if existingSub.ExpiresAt.After(now) { + isExpired := !existingSub.ExpiresAt.After(now) + if !isExpired { // 未过期:从当前过期时间累加 newExpiresAt = existingSub.ExpiresAt.AddDate(0, 0, validityDays) } else { @@ -209,43 +210,8 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in newExpiresAt = MaxExpiresAt } - // 开启事务:ExtendExpiry + UpdateStatus + UpdateNotes 在同一事务中完成 - tx, err := s.entClient.Tx(ctx) - if err != nil { - return nil, false, fmt.Errorf("begin transaction: %w", err) - } - txCtx := dbent.NewTxContext(ctx, tx) - - // 更新过期时间 - if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("extend subscription: %w", err) - } - - // 如果订阅已过期或被暂停,恢复为active状态 - if existingSub.Status != SubscriptionStatusActive { - if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("update subscription status: %w", err) - } - } - - // 追加备注 - if input.Notes != "" { - newNotes := existingSub.Notes - if newNotes != "" { - newNotes += "\n" - } - newNotes += input.Notes - if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, newNotes); err != nil { - _ = tx.Rollback() - return nil, false, fmt.Errorf("update subscription notes: %w", err) - } - } - - // 提交事务 - if err := tx.Commit(); err != nil { - return nil, false, fmt.Errorf("commit transaction: %w", err) + if err := s.updateExistingSubscriptionTerm(ctx, existingSub, input.Notes, now, newExpiresAt, isExpired); err != nil { + return nil, false, err } // 失效订阅缓存 @@ -284,6 +250,94 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in return sub, false, nil // false 表示是新建 } +func (s *SubscriptionService) updateExistingSubscriptionTerm( + ctx context.Context, + existingSub *UserSubscription, + notes string, + startsAt time.Time, + newExpiresAt time.Time, + isExpired bool, +) error { + return s.withSubscriptionUpdateTx(ctx, func(txCtx context.Context) error { + if isExpired { + renewed := renewedSubscriptionTerm(existingSub, notes, startsAt, newExpiresAt) + if err := s.userSubRepo.Update(txCtx, renewed); err != nil { + return fmt.Errorf("renew expired subscription: %w", err) + } + return nil + } + + // 更新过期时间 + if err := s.userSubRepo.ExtendExpiry(txCtx, existingSub.ID, newExpiresAt); err != nil { + return fmt.Errorf("extend subscription: %w", err) + } + + // 如果订阅被暂停,恢复为 active 状态 + if existingSub.Status != SubscriptionStatusActive { + if err := s.userSubRepo.UpdateStatus(txCtx, existingSub.ID, SubscriptionStatusActive); err != nil { + return fmt.Errorf("update subscription status: %w", err) + } + } + + // 追加备注 + if notes != "" { + if err := s.userSubRepo.UpdateNotes(txCtx, existingSub.ID, appendSubscriptionNotes(existingSub.Notes, notes)); err != nil { + return fmt.Errorf("update subscription notes: %w", err) + } + } + + return nil + }) +} + +func (s *SubscriptionService) withSubscriptionUpdateTx(ctx context.Context, fn func(context.Context) error) error { + if s.entClient == nil { + return fn(ctx) + } + + tx, err := s.entClient.Tx(ctx) + if err != nil { + return fmt.Errorf("begin transaction: %w", err) + } + txCtx := dbent.NewTxContext(ctx, tx) + + if err := fn(txCtx); err != nil { + _ = tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("commit transaction: %w", err) + } + return nil +} + +func renewedSubscriptionTerm(existingSub *UserSubscription, notes string, startsAt, expiresAt time.Time) *UserSubscription { + renewed := *existingSub + windowStart := startOfDay(startsAt) + renewed.StartsAt = startsAt + renewed.ExpiresAt = expiresAt + renewed.Status = SubscriptionStatusActive + renewed.DailyWindowStart = &windowStart + renewed.WeeklyWindowStart = &windowStart + renewed.MonthlyWindowStart = &windowStart + renewed.DailyUsageUSD = 0 + renewed.WeeklyUsageUSD = 0 + renewed.MonthlyUsageUSD = 0 + renewed.Notes = appendSubscriptionNotes(existingSub.Notes, notes) + return &renewed +} + +func appendSubscriptionNotes(existingNotes, newNotes string) string { + if newNotes == "" { + return existingNotes + } + if existingNotes == "" { + return newNotes + } + return existingNotes + "\n" + newNotes +} + // createSubscription 创建新订阅(内部方法) func (s *SubscriptionService) createSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, error) { validityDays := input.ValidityDays @@ -945,6 +999,9 @@ func (s *SubscriptionService) calculateProgress(sub *UserSubscription, group *Gr if group.HasDailyLimit() && sub.DailyWindowStart != nil { limit := *group.DailyLimitUSD resetsAt := sub.DailyWindowStart.Add(24 * time.Hour) + if dailyResetTime := sub.DailyResetTime(); dailyResetTime != nil { + resetsAt = *dailyResetTime + } progress.Daily = &UsageWindowProgress{ LimitUSD: limit, UsedUSD: sub.DailyUsageUSD, diff --git a/backend/internal/service/upstream_models.go b/backend/internal/service/upstream_models.go new file mode 100644 index 00000000000..77e8d1e49ad --- /dev/null +++ b/backend/internal/service/upstream_models.go @@ -0,0 +1,474 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "sort" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" +) + +const upstreamModelsBodyLimit int64 = 8 << 20 + +// UpstreamModelSyncErrorKind classifies model sync failures for safe HTTP mapping. +type UpstreamModelSyncErrorKind string + +const ( + // UpstreamModelSyncErrorConfiguration means the account or server configuration cannot perform the sync. + UpstreamModelSyncErrorConfiguration UpstreamModelSyncErrorKind = "configuration" + // UpstreamModelSyncErrorUnsupported means the account format is intentionally unsupported for live model sync. + UpstreamModelSyncErrorUnsupported UpstreamModelSyncErrorKind = "unsupported" + // UpstreamModelSyncErrorUpstream means the configured upstream failed or returned an unusable response. + UpstreamModelSyncErrorUpstream UpstreamModelSyncErrorKind = "upstream" +) + +// UpstreamModelSyncError keeps internal failure details wrapped while exposing a safe client message. +type UpstreamModelSyncError struct { + Kind UpstreamModelSyncErrorKind + Message string + Err error +} + +func (e *UpstreamModelSyncError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return e.Message + } + return e.Message + ": " + e.Err.Error() +} + +func (e *UpstreamModelSyncError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +// SafeMessage returns the sanitized message that can be sent to API clients. +func (e *UpstreamModelSyncError) SafeMessage() string { + if e == nil || strings.TrimSpace(e.Message) == "" { + return "Failed to sync upstream models" + } + return e.Message +} + +func newUpstreamModelSyncConfigError(message string, err error) error { + return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorConfiguration, Message: message, Err: err} +} + +func newUpstreamModelSyncUnsupportedError(message string, err error) error { + return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUnsupported, Message: message, Err: err} +} + +func newUpstreamModelSyncUpstreamError(message string, err error) error { + return &UpstreamModelSyncError{Kind: UpstreamModelSyncErrorUpstream, Message: message, Err: err} +} + +// FetchUpstreamSupportedModels fetches the live model list from the account's upstream API format. +func (s *AccountTestService) FetchUpstreamSupportedModels(ctx context.Context, account *Account) ([]string, error) { + if s == nil { + return nil, newUpstreamModelSyncConfigError("Account test service is not configured", nil) + } + if account == nil { + return nil, newUpstreamModelSyncConfigError("Account is required", nil) + } + + if account.Platform == PlatformAntigravity && account.Type != AccountTypeAPIKey { + return s.fetchAntigravityOAuthUpstreamModels(ctx, account) + } + + if s.httpUpstream == nil { + return nil, newUpstreamModelSyncConfigError("Upstream HTTP client is not configured", nil) + } + + req, err := s.buildUpstreamModelsRequest(ctx, account) + if err != nil { + return nil, err + } + + proxyURL := upstreamModelsProxyURL(account) + resp, err := s.doUpstreamModelsRequest(req, proxyURL, account) + if err != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to request upstream model list", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, upstreamModelsBodyLimit+1)) + if err != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to read upstream model list", err) + } + if int64(len(body)) > upstreamModelsBodyLimit { + return nil, newUpstreamModelSyncUpstreamError("Upstream model list response is too large", fmt.Errorf("response exceeds %d bytes", upstreamModelsBodyLimit)) + } + + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, newUpstreamModelSyncUpstreamError( + fmt.Sprintf("Upstream model list request failed with HTTP %d", resp.StatusCode), + fmt.Errorf("upstream model list returned HTTP %d", resp.StatusCode), + ) + } + + models, err := extractUpstreamModelIDs(body) + if err != nil { + return nil, newUpstreamModelSyncUpstreamError("Upstream model list response was not valid JSON", err) + } + if len(models) == 0 { + return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil) + } + + return models, nil +} + +func (s *AccountTestService) buildUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) { + switch { + case account.Platform == PlatformAntigravity: + return s.buildAntigravityAPIKeyModelsRequest(ctx, account) + case account.IsOpenAI(): + return s.buildOpenAIUpstreamModelsRequest(ctx, account) + case account.IsGemini(): + return s.buildGeminiUpstreamModelsRequest(ctx, account) + case account.IsAnthropic(): + return s.buildAnthropicUpstreamModelsRequest(ctx, account) + default: + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported platform for upstream model sync: %s", account.Platform), nil, + ) + } +} + +func (s *AccountTestService) buildAnthropicUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) { + if account.IsBedrock() || account.Type == AccountTypeServiceAccount { + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil, + ) + } + + baseURL := "https://api.anthropic.com" + authHeaderName := "" + authHeaderValue := "" + betaHeader := "" + + if account.IsOAuth() { + accessToken := strings.TrimSpace(account.GetCredential("access_token")) + if accessToken == "" && s.claudeTokenProvider != nil { + token, tokenErr := s.claudeTokenProvider.GetAccessToken(ctx, account) + if tokenErr != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to get Anthropic access token", tokenErr) + } + accessToken = strings.TrimSpace(token) + } + if accessToken == "" { + return nil, newUpstreamModelSyncConfigError("No Anthropic access token is available", nil) + } + authHeaderName = "Authorization" + authHeaderValue = "Bearer " + accessToken + betaHeader = claude.DefaultBetaHeader + } else if account.Type == AccountTypeAPIKey { + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, newUpstreamModelSyncConfigError("No Anthropic API key is available", nil) + } + baseURL = account.GetBaseURL() + if strings.TrimSpace(baseURL) == "" { + baseURL = "https://api.anthropic.com" + } + authHeaderName = "x-api-key" + authHeaderValue = apiKey + betaHeader = claude.APIKeyBetaHeader + } else { + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported Anthropic account type for upstream model sync: %s", account.Type), nil, + ) + } + + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Anthropic base URL", err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Anthropic model list URL", err) + } + for key, value := range claude.DefaultHeaders { + req.Header.Set(key, value) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("anthropic-beta", betaHeader) + req.Header.Set(authHeaderName, authHeaderValue) + return req, nil +} + +func (s *AccountTestService) buildAntigravityAPIKeyModelsRequest(ctx context.Context, account *Account) (*http.Request, error) { + if account.Type != AccountTypeAPIKey { + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported Antigravity account type for upstream model sync: %s", account.Type), nil, + ) + } + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, newUpstreamModelSyncConfigError("No Antigravity API key is available", nil) + } + + baseURL := strings.TrimRight(strings.TrimSpace(account.GetCredential("base_url")), "/") + if baseURL == "" { + return nil, newUpstreamModelSyncConfigError("Antigravity API-key base URL is required for upstream model sync", nil) + } + if !strings.HasSuffix(strings.ToLower(baseURL), "/antigravity") { + return nil, newUpstreamModelSyncUnsupportedError( + "Antigravity API-key upstream model sync requires a compatible gateway base URL ending in /antigravity; use Antigravity OAuth for official Cloud Code upstreams", + nil, + ) + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Antigravity base URL", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildV1ModelsURL(normalizedBaseURL), nil) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Antigravity model list URL", err) + } + for key, value := range claude.DefaultHeaders { + req.Header.Set(key, value) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) + req.Header.Set("x-api-key", apiKey) + return req, nil +} + +func (s *AccountTestService) buildOpenAIUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) { + if account.Type != AccountTypeAPIKey { + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported OpenAI account type for upstream model sync: %s", account.Type), nil, + ) + } + apiKey := strings.TrimSpace(account.GetOpenAIApiKey()) + if apiKey == "" { + return nil, newUpstreamModelSyncConfigError("No OpenAI API key is available", nil) + } + + baseURL := account.GetOpenAIBaseURL() + if strings.TrimSpace(baseURL) == "" { + baseURL = "https://api.openai.com" + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid OpenAI base URL", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildOpenAIModelsURL(normalizedBaseURL), nil) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid OpenAI model list URL", err) + } + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + return req, nil +} + +func (s *AccountTestService) buildGeminiUpstreamModelsRequest(ctx context.Context, account *Account) (*http.Request, error) { + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) + if strings.TrimSpace(baseURL) == "" { + baseURL = geminicli.AIStudioBaseURL + } + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Gemini base URL", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, buildGeminiModelsURL(normalizedBaseURL), nil) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Invalid Gemini model list URL", err) + } + req.Header.Set("Accept", "application/json") + + switch account.Type { + case AccountTypeAPIKey: + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if apiKey == "" { + return nil, newUpstreamModelSyncConfigError("No Gemini API key is available", nil) + } + req.Header.Set("x-goog-api-key", apiKey) + case AccountTypeOAuth: + if strings.TrimSpace(account.GetCredential("project_id")) != "" { + return nil, newUpstreamModelSyncUnsupportedError("Gemini Code Assist model listing is not supported by this sync button", nil) + } + if s.geminiTokenProvider == nil { + return nil, newUpstreamModelSyncConfigError("Gemini token provider is not configured", nil) + } + accessToken, tokenErr := s.geminiTokenProvider.GetAccessToken(ctx, account) + if tokenErr != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to get Gemini access token", tokenErr) + } + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return nil, newUpstreamModelSyncConfigError("No Gemini access token is available", nil) + } + req.Header.Set("Authorization", "Bearer "+accessToken) + default: + return nil, newUpstreamModelSyncUnsupportedError( + fmt.Sprintf("Unsupported Gemini account type for upstream model sync: %s", account.Type), nil, + ) + } + + return req, nil +} + +func (s *AccountTestService) fetchAntigravityOAuthUpstreamModels(ctx context.Context, account *Account) ([]string, error) { + if s.antigravityGatewayService == nil || s.antigravityGatewayService.GetTokenProvider() == nil { + return nil, newUpstreamModelSyncConfigError("Antigravity token provider is not configured", nil) + } + + accessToken, err := s.antigravityGatewayService.GetTokenProvider().GetAccessToken(ctx, account) + if err != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to get Antigravity access token", err) + } + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return nil, newUpstreamModelSyncConfigError("No Antigravity access token is available", nil) + } + + client, err := antigravity.NewClient(upstreamModelsProxyURL(account)) + if err != nil { + return nil, newUpstreamModelSyncConfigError("Failed to configure Antigravity client", err) + } + modelsResp, _, err := client.FetchAvailableModels(ctx, accessToken, strings.TrimSpace(account.GetCredential("project_id"))) + if err != nil { + return nil, newUpstreamModelSyncUpstreamError("Failed to fetch Antigravity available models", err) + } + if modelsResp == nil || len(modelsResp.Models) == 0 { + return nil, newUpstreamModelSyncUpstreamError("Upstream returned no supported models", nil) + } + + models := make([]string, 0, len(modelsResp.Models)) + for modelID := range modelsResp.Models { + models = append(models, strings.TrimSpace(modelID)) + } + return dedupeAndSortModelIDs(models), nil +} + +func (s *AccountTestService) doUpstreamModelsRequest(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + if s.tlsFPProfileService == nil { + return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, nil) + } + return s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, s.tlsFPProfileService.ResolveTLSProfile(account)) +} + +func upstreamModelsProxyURL(account *Account) string { + if account != nil && account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" +} + +func buildV1ModelsURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/v1/models") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/models" + } + return normalized + "/v1/models" +} + +func buildOpenAIModelsURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/v1/models") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/models" + } + return normalized + "/v1/models" +} + +func buildGeminiModelsURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/v1beta/models") { + return normalized + } + if strings.HasSuffix(normalized, "/v1beta") { + return normalized + "/models" + } + return normalized + "/v1beta/models" +} + +type upstreamModelEntry struct { + ID string `json:"id"` + Name string `json:"name"` +} + +func extractUpstreamModelIDs(body []byte) ([]string, error) { + var response struct { + Data []upstreamModelEntry `json:"data"` + Models []upstreamModelEntry `json:"models"` + } + if err := json.Unmarshal(body, &response); err != nil { + var arrayResponse []upstreamModelEntry + if arrayErr := json.Unmarshal(body, &arrayResponse); arrayErr != nil { + return nil, fmt.Errorf("parse upstream model list: %w", err) + } + + models := make([]string, 0, len(arrayResponse)) + for _, entry := range arrayResponse { + models = append(models, upstreamModelEntryID(entry)) + } + return dedupeAndSortModelIDs(models), nil + } + + models := make([]string, 0, len(response.Data)+len(response.Models)) + for _, entry := range response.Data { + models = append(models, upstreamModelEntryID(entry)) + } + for _, entry := range response.Models { + models = append(models, upstreamModelEntryID(entry)) + } + + if len(models) == 0 { + var arrayResponse []upstreamModelEntry + if err := json.Unmarshal(body, &arrayResponse); err == nil { + for _, entry := range arrayResponse { + models = append(models, upstreamModelEntryID(entry)) + } + } + } + + return dedupeAndSortModelIDs(models), nil +} + +func upstreamModelEntryID(entry upstreamModelEntry) string { + modelID := strings.TrimSpace(entry.ID) + if modelID == "" { + modelID = strings.TrimSpace(entry.Name) + } + return strings.TrimPrefix(modelID, "models/") +} + +func dedupeAndSortModelIDs(models []string) []string { + seen := make(map[string]struct{}, len(models)) + result := make([]string, 0, len(models)) + for _, model := range models { + model = strings.TrimSpace(model) + if model == "" { + continue + } + if _, exists := seen[model]; exists { + continue + } + seen[model] = struct{}{} + result = append(result, model) + } + sort.Strings(result) + return result +} diff --git a/backend/internal/service/upstream_models_test.go b/backend/internal/service/upstream_models_test.go new file mode 100644 index 00000000000..6831e79187f --- /dev/null +++ b/backend/internal/service/upstream_models_test.go @@ -0,0 +1,226 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func upstreamModelSyncTestConfig() *config.Config { + return &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + } +} + +func TestBuildV1ModelsURL(t *testing.T) { + t.Parallel() + + require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com")) + require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1")) + require.Equal(t, "https://api.anthropic.com/v1/models", buildV1ModelsURL("https://api.anthropic.com/v1/models")) + require.Equal(t, "https://gateway.example.com/antigravity/v1/models", buildV1ModelsURL("https://gateway.example.com/antigravity/")) +} + +func TestBuildGeminiModelsURL(t *testing.T) { + t.Parallel() + + require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com")) + require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta")) + require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", buildGeminiModelsURL("https://generativelanguage.googleapis.com/v1beta/models")) +} + +func TestExtractUpstreamModelIDs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + body string + want []string + }{ + { + name: "openai and anthropic data array", + body: `{"data":[{"id":"claude-sonnet-4-5"},{"id":"gpt-5"},{"id":"gpt-5"},{"id":""}]}`, + want: []string{"claude-sonnet-4-5", "gpt-5"}, + }, + { + name: "gemini models array strips prefix", + body: `{"models":[{"name":"models/gemini-2.5-pro"},{"name":"gemini-2.5-flash"}]}`, + want: []string{"gemini-2.5-flash", "gemini-2.5-pro"}, + }, + { + name: "top level array", + body: `[{"id":"z-model"},{"name":"models/a-model"}]`, + want: []string{"a-model", "z-model"}, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got, err := extractUpstreamModelIDs([]byte(tt.body)) + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestBuildUpstreamModelsRequestsForAPIKeyAccounts(t *testing.T) { + t.Parallel() + + svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()} + ctx := context.Background() + + anthropicReq, err := svc.buildAnthropicUpstreamModelsRequest(ctx, &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "anthropic-key", + "base_url": "https://anthropic.example.com/v1", + }, + }) + require.NoError(t, err) + require.Equal(t, "https://anthropic.example.com/v1/models", anthropicReq.URL.String()) + require.Equal(t, "anthropic-key", anthropicReq.Header.Get("x-api-key")) + require.Equal(t, "2023-06-01", anthropicReq.Header.Get("anthropic-version")) + + openAIReq, err := svc.buildOpenAIUpstreamModelsRequest(ctx, &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "openai-key", + "base_url": "https://openai.example.com", + }, + }) + require.NoError(t, err) + require.Equal(t, "https://openai.example.com/v1/models", openAIReq.URL.String()) + require.Equal(t, "Bearer openai-key", openAIReq.Header.Get("Authorization")) + + geminiReq, err := svc.buildGeminiUpstreamModelsRequest(ctx, &Account{ + Platform: PlatformGemini, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "gemini-key", + "base_url": "https://generativelanguage.googleapis.com/v1beta", + }, + }) + require.NoError(t, err) + require.Equal(t, "https://generativelanguage.googleapis.com/v1beta/models", geminiReq.URL.String()) + require.Equal(t, "gemini-key", geminiReq.Header.Get("x-goog-api-key")) + + antigravityReq, err := svc.buildAntigravityAPIKeyModelsRequest(ctx, &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "antigravity-key", + "base_url": "https://gateway.example.com/antigravity", + }, + }) + require.NoError(t, err) + require.Equal(t, "https://gateway.example.com/antigravity/v1/models", antigravityReq.URL.String()) + require.Equal(t, "antigravity-key", antigravityReq.Header.Get("x-api-key")) +} + +func TestBuildAntigravityAPIKeyModelsRequestRejectsOfficialCloudCodeBase(t *testing.T) { + t.Parallel() + + svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()} + _, err := svc.buildAntigravityAPIKeyModelsRequest(context.Background(), &Account{ + Platform: PlatformAntigravity, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "antigravity-key", + "base_url": "https://cloudcode-pa.googleapis.com", + }, + }) + require.Error(t, err) + + var syncErr *UpstreamModelSyncError + require.True(t, errors.As(err, &syncErr)) + require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind) + require.Contains(t, syncErr.SafeMessage(), "compatible gateway") +} + +func TestBuildAnthropicUpstreamModelsRequestRejectsBedrock(t *testing.T) { + t.Parallel() + + svc := &AccountTestService{cfg: upstreamModelSyncTestConfig()} + _, err := svc.buildAnthropicUpstreamModelsRequest(context.Background(), &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeBedrock, + }) + require.Error(t, err) + + var syncErr *UpstreamModelSyncError + require.True(t, errors.As(err, &syncErr)) + require.Equal(t, UpstreamModelSyncErrorUnsupported, syncErr.Kind) +} + +func TestFetchUpstreamSupportedModelsParsesOpenAIResponse(t *testing.T) { + t.Parallel() + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"data":[{"id":"gpt-5"},{"id":"gpt-5"},{"name":"o3"}]}`)), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: upstreamModelSyncTestConfig(), + } + + models, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{ + ID: 7, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "openai-key", + "base_url": "https://openai.example.com/v1", + }, + }) + require.NoError(t, err) + require.Equal(t, []string{"gpt-5", "o3"}, models) + require.Equal(t, "https://openai.example.com/v1/models", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer openai-key", upstream.lastReq.Header.Get("Authorization")) +} + +func TestFetchUpstreamSupportedModelsDoesNotExposeUpstreamBody(t *testing.T) { + t.Parallel() + + upstream := &httpUpstreamRecorder{resp: &http.Response{ + StatusCode: http.StatusBadGateway, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"error":"SECRET_TOKEN should not be exposed"}`)), + }} + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: upstreamModelSyncTestConfig(), + } + + _, err := svc.FetchUpstreamSupportedModels(context.Background(), &Account{ + ID: 8, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "openai-key", + "base_url": "https://openai.example.com/v1", + }, + }) + require.Error(t, err) + require.NotContains(t, err.Error(), "SECRET_TOKEN") + + var syncErr *UpstreamModelSyncError + require.True(t, errors.As(err, &syncErr)) + require.Equal(t, UpstreamModelSyncErrorUpstream, syncErr.Kind) + require.NotContains(t, syncErr.SafeMessage(), "SECRET_TOKEN") + require.Contains(t, syncErr.SafeMessage(), "HTTP 502") +} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index e29d282eb80..d63f47ccf0f 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -162,9 +162,13 @@ type UsageLog struct { CacheTTLOverridden bool // 图片生成字段 - ImageCount int - ImageSize *string - MediaType *string + ImageCount int + ImageSize *string + ImageInputSize *string + ImageOutputSize *string + ImageSizeSource *string + ImageSizeBreakdown map[string]int + MediaType *string CreatedAt time.Time diff --git a/backend/internal/service/user_attribute_service.go b/backend/internal/service/user_attribute_service.go index 6c2f807767c..ef19e078b3b 100644 --- a/backend/internal/service/user_attribute_service.go +++ b/backend/internal/service/user_attribute_service.go @@ -72,6 +72,11 @@ func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*Us return s.defRepo.GetByID(ctx, id) } +// GetDefinitionByKey retrieves a definition by its unique key +func (s *UserAttributeService) GetDefinitionByKey(ctx context.Context, key string) (*UserAttributeDefinition, error) { + return s.defRepo.GetByKey(ctx, key) +} + // ListDefinitions lists all definitions func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) { return s.defRepo.List(ctx, enabledOnly) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index f84e6f0ab06..b346f6e705f 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -141,10 +141,11 @@ type UserIdentitySummary struct { } type UserIdentitySummarySet struct { - Email UserIdentitySummary `json:"email"` - LinuxDo UserIdentitySummary `json:"linuxdo"` - OIDC UserIdentitySummary `json:"oidc"` - WeChat UserIdentitySummary `json:"wechat"` + Email UserIdentitySummary `json:"email"` + LinuxDo UserIdentitySummary `json:"linuxdo"` + OIDC UserIdentitySummary `json:"oidc"` + WeChat UserIdentitySummary `json:"wechat"` + DingTalk UserIdentitySummary `json:"dingtalk"` } type StartUserIdentityBindingRequest struct { @@ -260,10 +261,11 @@ func (s *UserService) GetProfileIdentitySummaries(ctx context.Context, userID in } summaries := UserIdentitySummarySet{ - Email: s.buildEmailIdentitySummary(user, records), - LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records), - OIDC: s.buildProviderIdentitySummary("oidc", user, records), - WeChat: s.buildProviderIdentitySummary("wechat", user, records), + Email: s.buildEmailIdentitySummary(user, records), + LinuxDo: s.buildProviderIdentitySummary("linuxdo", user, records), + OIDC: s.buildProviderIdentitySummary("oidc", user, records), + WeChat: s.buildProviderIdentitySummary("wechat", user, records), + DingTalk: s.buildProviderIdentitySummary("dingtalk", user, records), } s.applyExplicitProviderAvailability(ctx, &summaries) @@ -283,6 +285,7 @@ func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, sum SettingKeyWeChatConnectMPEnabled, SettingKeyWeChatConnectMobileEnabled, SettingKeyWeChatConnectMode, + SettingKeyDingTalkConnectEnabled, }) if err != nil { return @@ -291,6 +294,9 @@ func (s *UserService) applyExplicitProviderAvailability(ctx context.Context, sum if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { disableIdentityBindAction(&summaries.LinuxDo) } + if raw, ok := settings[SettingKeyDingTalkConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { + disableIdentityBindAction(&summaries.DingTalk) + } if raw, ok := settings[SettingKeyOIDCConnectEnabled]; ok && strings.TrimSpace(raw) != "" && raw != "true" { disableIdentityBindAction(&summaries.OIDC) } @@ -696,7 +702,7 @@ func (s *UserService) canUnbindProvider(provider string, user *User, records []U return true } - for _, candidate := range []string{"linuxdo", "oidc", "wechat"} { + for _, candidate := range []string{"linuxdo", "oidc", "wechat", "dingtalk"} { if candidate == provider { continue } @@ -772,6 +778,8 @@ func buildUserIdentityBindAuthorizeURL(provider, redirectTo string) (string, err path = "/api/v1/auth/oauth/oidc/bind/start" case "wechat": path = "/api/v1/auth/oauth/wechat/bind/start" + case "dingtalk": + path = "/api/v1/auth/oauth/dingtalk/bind/start" default: return "", ErrIdentityProviderInvalid } @@ -790,6 +798,8 @@ func normalizeUserIdentityProvider(provider string) string { return "oidc" case "wechat": return "wechat" + case "dingtalk": + return "dingtalk" case "email": return "email" default: diff --git a/backend/internal/service/user_subscription.go b/backend/internal/service/user_subscription.go index ec547d81a5a..6303e6e3e36 100644 --- a/backend/internal/service/user_subscription.go +++ b/backend/internal/service/user_subscription.go @@ -50,11 +50,25 @@ func (s *UserSubscription) IsWindowActivated() bool { return s.DailyWindowStart != nil || s.WeeklyWindowStart != nil || s.MonthlyWindowStart != nil } +func (s *UserSubscription) HasOneTimeDailyQuota() bool { + if s == nil || s.StartsAt.IsZero() || s.ExpiresAt.IsZero() { + return false + } + return !s.ExpiresAt.After(s.StartsAt.AddDate(0, 0, 1)) +} + func (s *UserSubscription) NeedsDailyReset() bool { + return s.NeedsDailyResetAt(time.Now()) +} + +func (s *UserSubscription) NeedsDailyResetAt(now time.Time) bool { if s.DailyWindowStart == nil { return false } - return time.Since(*s.DailyWindowStart) >= 24*time.Hour + if s.HasOneTimeDailyQuota() { + return false + } + return !now.Before(s.DailyWindowStart.Add(24 * time.Hour)) } func (s *UserSubscription) NeedsWeeklyReset() bool { @@ -75,6 +89,10 @@ func (s *UserSubscription) DailyResetTime() *time.Time { if s.DailyWindowStart == nil { return nil } + if s.HasOneTimeDailyQuota() { + t := s.ExpiresAt + return &t + } t := s.DailyWindowStart.Add(24 * time.Hour) return &t } diff --git a/backend/internal/service/user_subscription_daily_quota_test.go b/backend/internal/service/user_subscription_daily_quota_test.go new file mode 100644 index 00000000000..3738bdd6980 --- /dev/null +++ b/backend/internal/service/user_subscription_daily_quota_test.go @@ -0,0 +1,178 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type dailyResetTrackingUserSubRepo struct { + userSubRepoNoop + + resetDailyCalled bool +} + +func (r *dailyResetTrackingUserSubRepo) ResetDailyUsage(context.Context, int64, time.Time) error { + r.resetDailyCalled = true + return nil +} + +func TestAssignOrExtendSubscription_ExpiredDailyCardStartsNewOneTimeQuota(t *testing.T) { + groupRepo := &subscriptionGroupRepoStub{ + group: &Group{ID: 1, SubscriptionType: SubscriptionTypeSubscription}, + } + subRepo := newSubscriptionUserSubRepoStub() + oldStart := time.Now().AddDate(0, 0, -3) + oldWindowStart := startOfDay(oldStart) + subRepo.seed(&UserSubscription{ + ID: 100, + UserID: 200, + GroupID: 1, + StartsAt: oldStart, + ExpiresAt: oldStart.AddDate(0, 0, 1), + Status: SubscriptionStatusExpired, + DailyWindowStart: &oldWindowStart, + WeeklyWindowStart: &oldWindowStart, + MonthlyWindowStart: &oldWindowStart, + DailyUsageUSD: 10, + WeeklyUsageUSD: 20, + MonthlyUsageUSD: 30, + Notes: "old", + }) + svc := NewSubscriptionService(groupRepo, subRepo, nil, nil, nil) + + renewed, reused, err := svc.AssignOrExtendSubscription(context.Background(), &AssignSubscriptionInput{ + UserID: 200, + GroupID: 1, + ValidityDays: 1, + Notes: "new", + }) + + require.NoError(t, err) + require.True(t, reused) + require.True(t, renewed.HasOneTimeDailyQuota(), "过期后重新购买 1 日卡仍应被识别为一次性日额度") + require.Equal(t, SubscriptionStatusActive, renewed.Status) + require.True(t, renewed.StartsAt.After(oldStart), "重新购买过期订阅时应重置当前周期 StartsAt") + require.False(t, renewed.ExpiresAt.After(renewed.StartsAt.AddDate(0, 0, 1))) + require.NotNil(t, renewed.DailyWindowStart) + require.Equal(t, startOfDay(renewed.StartsAt), *renewed.DailyWindowStart) + require.Equal(t, 0.0, renewed.DailyUsageUSD) + require.Equal(t, 0.0, renewed.WeeklyUsageUSD) + require.Equal(t, 0.0, renewed.MonthlyUsageUSD) + require.Equal(t, "old\nnew", renewed.Notes) +} + +func TestUserSubscriptionNeedsDailyReset_DailyCardKeepsOneTimeQuota(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: start.Add(24 * time.Hour), + DailyWindowStart: &dailyWindowStart, + DailyUsageUSD: 10, + } + + require.True(t, sub.HasOneTimeDailyQuota()) + require.False(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(25*time.Hour)), "日卡应作为一次性配额,跨 0 点后不再刷新日额度") +} + +func TestUserSubscriptionNeedsDailyReset_MultiDaySubscriptionStillRefreshes(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: start.AddDate(0, 0, 2), + DailyWindowStart: &dailyWindowStart, + } + + require.False(t, sub.HasOneTimeDailyQuota()) + require.True(t, sub.NeedsDailyResetAt(dailyWindowStart.Add(24*time.Hour)), "多日订阅仍应按 24 小时日窗口刷新") +} + +func TestUserSubscriptionDailyResetTime_DailyCardReturnsExpiry(t *testing.T) { + start := time.Date(2026, 5, 18, 12, 0, 0, 0, time.UTC) + dailyWindowStart := time.Date(2026, 5, 18, 0, 0, 0, 0, time.UTC) + expiresAt := start.Add(24 * time.Hour) + sub := &UserSubscription{ + StartsAt: start, + ExpiresAt: expiresAt, + DailyWindowStart: &dailyWindowStart, + } + + resetAt := sub.DailyResetTime() + require.NotNil(t, resetAt) + require.Equal(t, expiresAt, *resetAt, "日卡展示的日额度结束时间应为订阅过期时间") +} + +func TestCheckAndResetWindows_DailyCardDoesNotResetDailyUsage(t *testing.T) { + now := time.Now() + startsAt := now.Add(-23 * time.Hour) + dailyWindowStart := now.Add(-25 * time.Hour) + repo := &dailyResetTrackingUserSubRepo{} + svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil) + sub := &UserSubscription{ + ID: 1, + UserID: 10, + GroupID: 20, + StartsAt: startsAt, + ExpiresAt: startsAt.Add(24 * time.Hour), + DailyUsageUSD: 10, + DailyWindowStart: &dailyWindowStart, + } + + err := svc.CheckAndResetWindows(context.Background(), sub) + + require.NoError(t, err) + require.False(t, repo.resetDailyCalled, "日卡作为一次性配额,过了 24 小时日窗口也不应重置 daily usage") + require.Equal(t, 10.0, sub.DailyUsageUSD) +} + +func TestCheckAndResetWindows_MultiDaySubscriptionStillResetsDailyUsage(t *testing.T) { + now := time.Now() + startsAt := now.Add(-48 * time.Hour) + dailyWindowStart := now.Add(-25 * time.Hour) + repo := &dailyResetTrackingUserSubRepo{} + svc := NewSubscriptionService(groupRepoNoop{}, repo, nil, nil, nil) + sub := &UserSubscription{ + ID: 1, + UserID: 10, + GroupID: 20, + StartsAt: startsAt, + ExpiresAt: startsAt.AddDate(0, 0, 2), + DailyUsageUSD: 10, + DailyWindowStart: &dailyWindowStart, + } + + err := svc.CheckAndResetWindows(context.Background(), sub) + + require.NoError(t, err) + require.True(t, repo.resetDailyCalled, "多日订阅仍应重置过期 daily window") + require.Equal(t, 0.0, sub.DailyUsageUSD) +} + +func TestValidateAndCheckLimits_DailyCardDoesNotAllowSecondQuotaAfterMidnight(t *testing.T) { + start := time.Now().Add(-23 * time.Hour) + dailyWindowStart := time.Now().Add(-25 * time.Hour) + dailyLimit := 10.0 + sub := &UserSubscription{ + Status: SubscriptionStatusActive, + StartsAt: start, + ExpiresAt: start.Add(24 * time.Hour), + DailyWindowStart: &dailyWindowStart, + DailyUsageUSD: dailyLimit + 0.01, + } + group := &Group{ + SubscriptionType: SubscriptionTypeSubscription, + DailyLimitUSD: &dailyLimit, + } + svc := NewSubscriptionService(groupRepoNoop{}, userSubRepoNoop{}, nil, nil, nil) + + needsMaintenance, err := svc.ValidateAndCheckLimits(sub, group) + + require.False(t, needsMaintenance, "日卡跨过日窗口后不应触发 daily reset 维护") + require.True(t, errors.Is(err, ErrDailyLimitExceeded)) + require.Equal(t, dailyLimit+0.01, sub.DailyUsageUSD, "热路径不应清零日卡已用额度") +} diff --git a/backend/migrations/136_add_dingtalk_provider_type.sql b/backend/migrations/136_add_dingtalk_provider_type.sql new file mode 100644 index 00000000000..79c7ba05abd --- /dev/null +++ b/backend/migrations/136_add_dingtalk_provider_type.sql @@ -0,0 +1,27 @@ +ALTER TABLE users + DROP CONSTRAINT IF EXISTS users_signup_source_check; + +ALTER TABLE users + ADD CONSTRAINT users_signup_source_check + CHECK (signup_source IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk')); + +ALTER TABLE auth_identities + DROP CONSTRAINT IF EXISTS auth_identities_provider_type_check; + +ALTER TABLE auth_identities + ADD CONSTRAINT auth_identities_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk')); + +ALTER TABLE auth_identity_channels + DROP CONSTRAINT IF EXISTS auth_identity_channels_provider_type_check; + +ALTER TABLE auth_identity_channels + ADD CONSTRAINT auth_identity_channels_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk')); + +ALTER TABLE pending_auth_sessions + DROP CONSTRAINT IF EXISTS pending_auth_sessions_provider_type_check; + +ALTER TABLE pending_auth_sessions + ADD CONSTRAINT pending_auth_sessions_provider_type_check + CHECK (provider_type IN ('email', 'linuxdo', 'wechat', 'oidc', 'github', 'google', 'dingtalk')); diff --git a/backend/migrations/136_usage_log_image_size_metadata.sql b/backend/migrations/136_usage_log_image_size_metadata.sql new file mode 100644 index 00000000000..76bcb956496 --- /dev/null +++ b/backend/migrations/136_usage_log_image_size_metadata.sql @@ -0,0 +1,51 @@ +-- Add generated-image billing size audit metadata. +-- `image_size` remains the canonical billing tier used for cost calculation. + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS image_input_size VARCHAR(32); + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS image_output_size VARCHAR(32); + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS image_size_source VARCHAR(16); + +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS image_size_breakdown JSONB; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'usage_logs_image_size_source_check' + AND conrelid = 'usage_logs'::regclass + ) THEN + ALTER TABLE usage_logs + ADD CONSTRAINT usage_logs_image_size_source_check + CHECK ( + image_size_source IS NULL + OR image_size_source IN ('output', 'input', 'default', 'legacy') + ) NOT VALID; + END IF; +END $$; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'usage_logs_image_billing_size_check' + AND conrelid = 'usage_logs'::regclass + ) THEN + ALTER TABLE usage_logs + ADD CONSTRAINT usage_logs_image_billing_size_check + CHECK ( + image_count <= 0 + OR ( + image_size IS NOT NULL + AND image_size IN ('1K', '2K', '4K', 'mixed') + ) + ) NOT VALID; + END IF; +END $$; diff --git a/backend/migrations/137_redeem_code_expires_at.sql b/backend/migrations/137_redeem_code_expires_at.sql new file mode 100644 index 00000000000..4fa27927c82 --- /dev/null +++ b/backend/migrations/137_redeem_code_expires_at.sql @@ -0,0 +1,8 @@ +-- Add optional expiry time for redeem codes themselves. +-- `validity_days` remains the subscription duration granted after redeeming. + +ALTER TABLE redeem_codes + ADD COLUMN IF NOT EXISTS expires_at TIMESTAMPTZ; + +CREATE INDEX IF NOT EXISTS idx_redeem_codes_expires_at + ON redeem_codes (expires_at); diff --git a/backend/resources/model-pricing/model_prices_and_context_window.json b/backend/resources/model-pricing/model_prices_and_context_window.json index 0a096257d4a..3cae8c8b58d 100644 --- a/backend/resources/model-pricing/model_prices_and_context_window.json +++ b/backend/resources/model-pricing/model_prices_and_context_window.json @@ -5173,6 +5173,39 @@ "supports_tool_choice": true, "supports_vision": true }, + "codex-auto-review": { + "cache_read_input_token_cost": 2.5e-07, + "input_cost_per_token": 2.5e-06, + "litellm_provider": "openai", + "max_input_tokens": 1050000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supported_modalities": [ + "text", + "image" + ], + "supported_output_modalities": [ + "text", + "image" + ], + "supports_function_calling": true, + "supports_native_streaming": true, + "supports_parallel_function_calling": true, + "supports_pdf_input": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_response_schema": true, + "supports_service_tier": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "supports_vision": true + }, "gpt-5.4-mini": { "cache_read_input_token_cost": 7.5e-08, "input_cost_per_token": 7.5e-07, diff --git a/deploy/install.sh b/deploy/install.sh index 6dcf41238e7..1846dede353 100644 --- a/deploy/install.sh +++ b/deploy/install.sh @@ -7,6 +7,21 @@ set -e +# Bash 4+ is required for associative arrays used by the localized message table. +# Keep this guard before any Bash 4-only syntax so older shells fail with a clear hint. +if [ -z "${BASH_VERSION:-}" ]; then + echo "Error: This installer must be run with Bash 4.0 or later." >&2 + echo "Please install Bash 4+ and run it with that interpreter." >&2 + exit 1 +fi + +BASH_MAJOR_VERSION="${BASH_VERSION%%.*}" +if [ "$BASH_MAJOR_VERSION" -lt 4 ]; then + echo "Error: Bash 4.0 or later is required. Current version: $BASH_VERSION" >&2 + echo "Please install Bash 4+ and retry with that interpreter." >&2 + exit 1 +fi + # Colors RED='\033[0;31m' GREEN='\033[0;32m' diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 00ed40878c3..92b0abcab23 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -232,9 +232,12 @@ export async function clearError(id: number): Promise { * @param id - Account ID * @returns Account usage info */ -export async function getUsage(id: number, source?: 'passive' | 'active'): Promise { +export async function getUsage(id: number, source?: 'passive' | 'active', force?: boolean): Promise { + const params: Record = {} + if (source) params.source = source + if (force) params.force = 'true' const { data } = await apiClient.get(`/admin/accounts/${id}/usage`, { - params: source ? { source } : undefined + params: Object.keys(params).length > 0 ? params : undefined }) return data } @@ -446,6 +449,20 @@ export async function getAvailableModels(id: number): Promise { return data } +export interface SyncUpstreamModelsResult { + models: string[] +} + +/** + * Sync live supported models from the account's upstream model-list endpoint + * @param id - Account ID + * @returns List of model IDs returned by the upstream + */ +export async function syncUpstreamModels(id: number): Promise { + const { data } = await apiClient.post(`/admin/accounts/${id}/models/sync-upstream`) + return data +} + export interface CRSPreviewAccount { crs_account_id: string kind: string @@ -660,6 +677,7 @@ export const accountsAPI = { resetTempUnschedulable, setSchedulable, getAvailableModels, + syncUpstreamModels, generateAuthUrl, exchangeCode, refreshOpenAIToken, diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 49e487b744b..dda7d8927d0 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -266,10 +266,17 @@ export async function getUserSpendingRanking( return data } +export interface PlatformUsage { + platform: string + today_actual_cost: number + total_actual_cost: number +} + export interface BatchUserUsageStats { user_id: number today_actual_cost: number total_actual_cost: number + by_platform?: PlatformUsage[] } export interface BatchUsersUsageResponse { diff --git a/frontend/src/api/admin/redeem.ts b/frontend/src/api/admin/redeem.ts index 57626b1efd9..398d68a4e08 100644 --- a/frontend/src/api/admin/redeem.ts +++ b/frontend/src/api/admin/redeem.ts @@ -60,6 +60,7 @@ export async function getById(id: number): Promise { * @param value - Value of the code * @param groupId - Group ID (required for subscription type) * @param validityDays - Validity days (for subscription type) + * @param expiresInDays - Days before the code itself expires * @returns Array of generated redeem codes */ export async function generate( @@ -67,7 +68,8 @@ export async function generate( type: RedeemCodeType, value: number, groupId?: number | null, - validityDays?: number + validityDays?: number, + expiresInDays?: number | null ): Promise { const payload: GenerateRedeemCodesRequest = { count, @@ -82,6 +84,9 @@ export async function generate( payload.validity_days = validityDays } } + if (expiresInDays && expiresInDays > 0) { + payload.expires_in_days = expiresInDays + } const { data } = await apiClient.post('/admin/redeem-codes/generate', payload) return data diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 03e9e58fde5..4fc49de6227 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -22,7 +22,8 @@ export type AuthSourceType = | "oidc" | "wechat" | "github" - | "google"; + | "google" + | "dingtalk"; export interface AuthSourceDefaultsValue { balance: number; @@ -64,6 +65,7 @@ const AUTH_SOURCE_TYPES: AuthSourceType[] = [ "wechat", "github", "google", + "dingtalk", ]; const AUTH_SOURCE_DEFAULT_BALANCE = 0; const AUTH_SOURCE_DEFAULT_CONCURRENCY = 5; @@ -352,6 +354,11 @@ export interface SystemSettings { auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_wechat_grant_on_signup?: boolean; auth_source_default_wechat_grant_on_first_bind?: boolean; + auth_source_default_dingtalk_balance?: number; + auth_source_default_dingtalk_concurrency?: number; + auth_source_default_dingtalk_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_dingtalk_grant_on_signup?: boolean; + auth_source_default_dingtalk_grant_on_first_bind?: boolean; auth_source_default_github_balance?: number; auth_source_default_github_concurrency?: number; auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[]; @@ -396,6 +403,24 @@ export interface SystemSettings { linuxdo_connect_client_secret_configured: boolean; linuxdo_connect_redirect_url: string; + // DingTalk Connect OAuth settings + dingtalk_connect_enabled: boolean; + dingtalk_connect_client_id: string; + dingtalk_connect_client_secret_configured: boolean; + dingtalk_connect_redirect_url: string; + dingtalk_connect_corp_restriction_policy: string; + dingtalk_connect_internal_corp_id: string; + dingtalk_connect_bypass_registration: boolean; + dingtalk_connect_sync_corp_email: boolean; + dingtalk_connect_sync_display_name: boolean; + dingtalk_connect_sync_dept: boolean; + dingtalk_connect_sync_corp_email_attr_key: string; + dingtalk_connect_sync_display_name_attr_key: string; + dingtalk_connect_sync_dept_attr_key: string; + dingtalk_connect_sync_corp_email_attr_name: string; + dingtalk_connect_sync_display_name_attr_name: string; + dingtalk_connect_sync_dept_attr_name: string; + // WeChat Connect OAuth settings wechat_connect_enabled: boolean; wechat_connect_app_id: string; @@ -571,6 +596,11 @@ export interface UpdateSettingsRequest { auth_source_default_wechat_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_wechat_grant_on_signup?: boolean; auth_source_default_wechat_grant_on_first_bind?: boolean; + auth_source_default_dingtalk_balance?: number; + auth_source_default_dingtalk_concurrency?: number; + auth_source_default_dingtalk_subscriptions?: DefaultSubscriptionSetting[]; + auth_source_default_dingtalk_grant_on_signup?: boolean; + auth_source_default_dingtalk_grant_on_first_bind?: boolean; auth_source_default_github_balance?: number; auth_source_default_github_concurrency?: number; auth_source_default_github_subscriptions?: DefaultSubscriptionSetting[]; @@ -609,6 +639,22 @@ export interface UpdateSettingsRequest { linuxdo_connect_client_id?: string; linuxdo_connect_client_secret?: string; linuxdo_connect_redirect_url?: string; + dingtalk_connect_enabled?: boolean; + dingtalk_connect_client_id?: string; + dingtalk_connect_client_secret?: string; + dingtalk_connect_redirect_url?: string; + dingtalk_connect_corp_restriction_policy?: string; + dingtalk_connect_internal_corp_id?: string; + dingtalk_connect_bypass_registration?: boolean; + dingtalk_connect_sync_corp_email?: boolean; + dingtalk_connect_sync_display_name?: boolean; + dingtalk_connect_sync_dept?: boolean; + dingtalk_connect_sync_corp_email_attr_key?: string; + dingtalk_connect_sync_display_name_attr_key?: string; + dingtalk_connect_sync_dept_attr_key?: string; + dingtalk_connect_sync_corp_email_attr_name?: string; + dingtalk_connect_sync_display_name_attr_name?: string; + dingtalk_connect_sync_dept_attr_name?: string; wechat_connect_enabled?: boolean; wechat_connect_app_id?: string; wechat_connect_app_secret?: string; diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index bb990fc4af3..fd259230677 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -592,7 +592,7 @@ export async function completeWeChatOAuthRegistration( } async function createPendingOAuthAccount( - provider: 'linuxdo' | 'oidc' | 'wechat', + provider: 'linuxdo' | 'oidc' | 'wechat' | 'dingtalk', invitationCode: string, decision?: OAuthAdoptionDecision, affiliateCode?: string @@ -633,6 +633,14 @@ export async function createPendingWeChatOAuthAccount( return createPendingOAuthAccount('wechat', invitationCode, decision, affiliateCode) } +export async function createPendingDingTalkOAuthAccount( + invitationCode: string, + decision?: OAuthAdoptionDecision, + affiliateCode?: string +): Promise { + return createPendingOAuthAccount('dingtalk', invitationCode, decision, affiliateCode) +} + export async function completePendingOAuthBindLogin( decision?: OAuthAdoptionDecision ): Promise { @@ -683,7 +691,8 @@ export const authAPI = { exchangePendingOAuthCompletion, completeLinuxDoOAuthRegistration, completeOIDCOAuthRegistration, - completeWeChatOAuthRegistration + completeWeChatOAuthRegistration, + createPendingDingTalkOAuthAccount } export default authAPI diff --git a/frontend/src/api/usage.ts b/frontend/src/api/usage.ts index 802c428f80b..7169b698dfe 100644 --- a/frontend/src/api/usage.ts +++ b/frontend/src/api/usage.ts @@ -15,6 +15,16 @@ import type { // ==================== Dashboard Types ==================== +export interface PlatformDashboardStats { + platform: string + total_requests: number + total_tokens: number + total_actual_cost: number + today_requests: number + today_tokens: number + today_actual_cost: number +} + export interface UserDashboardStats { total_api_keys: number active_api_keys: number @@ -37,6 +47,7 @@ export interface UserDashboardStats { average_duration_ms: number rpm: number // 近5分钟平均每分钟请求数 tpm: number // 近5分钟平均每分钟Token数 + by_platform?: PlatformDashboardStats[] } export interface TrendParams { diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 2c04e6734b1..64f1366b3c0 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -126,6 +126,30 @@ :show-now-when-idle="true" color="emerald" /> +
+ +
@@ -1070,7 +1094,7 @@ const attachVisibilityObserver = () => { const loadActiveUsage = async () => { activeQueryLoading.value = true try { - usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active') + usageInfo.value = await adminAPI.accounts.getUsage(props.account.id, 'active', true) } catch (e: any) { console.error('Failed to load active usage:', e) } finally { diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 9ef6c9d2451..115aab22e6b 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -2649,6 +2649,28 @@
+ +
+
+
+ +

+ {{ t('admin.accounts.openai.responsesModeDesc') }} +

+
+
+ +
+
+
+ {{ t(openAIResponsesStatusKey) }} +
+
+
('whitelist') const antigravityWhitelistModels = ref([]) const antigravityModelMappings = ref([]) +const isSyncingAntigravityUpstream = ref(false) const tempUnschedEnabled = ref(false) const tempUnschedRules = ref([]) const getModelMappingKey = createStableObjectKeyResolver('edit-model-mapping') @@ -2332,6 +2370,7 @@ const customBaseUrl = ref('') // OpenAI 自动透传开关(OAuth/API Key) const openaiPassthroughEnabled = ref(false) const openAICompactMode = ref('auto') +const openAIResponsesMode = ref('auto') const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) @@ -2433,9 +2472,36 @@ const openAICompactModeOptions = computed(() => [ { value: 'force_on', label: t('admin.accounts.openai.compactModeForceOn') }, { value: 'force_off', label: t('admin.accounts.openai.compactModeForceOff') } ]) +const openAIResponsesModeOptions = computed(() => [ + { value: 'auto', label: t('admin.accounts.openai.responsesModeAuto') }, + { value: 'force_responses', label: t('admin.accounts.openai.responsesModeForceResponses') }, + { value: 'force_chat_completions', label: t('admin.accounts.openai.responsesModeForceChatCompletions') } +]) +const normalizeOpenAIResponsesMode = (mode: unknown): OpenAIResponsesMode => { + if (mode === 'force_responses' || mode === 'force_chat_completions') { + return mode + } + return 'auto' +} const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) +const openAIResponsesStatusKey = computed(() => { + if (openAIResponsesMode.value === 'force_responses') { + return 'admin.accounts.openai.responsesStatusForcedResponses' + } + if (openAIResponsesMode.value === 'force_chat_completions') { + return 'admin.accounts.openai.responsesStatusForcedChatCompletions' + } + const extra = props.account?.extra as Record | undefined + if (extra?.openai_responses_supported === true) { + return 'admin.accounts.openai.responsesStatusAutoSupported' + } + if (extra?.openai_responses_supported === false) { + return 'admin.accounts.openai.responsesStatusAutoUnsupported' + } + return 'admin.accounts.openai.responsesStatusAutoUnknown' +}) const openAICompactStatusKey = computed(() => { const extra = props.account?.extra as Record | undefined if (!props.account || props.account.platform !== 'openai') return '' @@ -2542,6 +2608,19 @@ const normalizePoolModeRetryCount = (value: number) => { return normalized } +const loadModelRestrictionFromMapping = (rawMapping?: Record) => { + const parsed = splitModelMappingObject(rawMapping) + allowedModels.value = parsed.allowedModels + modelMappings.value = parsed.modelMappings + modelRestrictionMode.value = + parsed.modelMappings.length > 0 && parsed.allowedModels.length === 0 + ? 'mapping' + : 'whitelist' +} + +const buildModelRestrictionMapping = () => + buildModelMappingObject('combined', allowedModels.value, modelMappings.value) + const syncFormFromAccount = (newAccount: Account | null) => { if (!newAccount) { return @@ -2582,6 +2661,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) openaiPassthroughEnabled.value = false openAICompactMode.value = 'auto' + openAIResponsesMode.value = 'auto' openAICompactModelMappings.value = [] openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF @@ -2592,6 +2672,9 @@ const syncFormFromAccount = (newAccount: Account | null) => { if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true openAICompactMode.value = (extra?.openai_compact_mode as OpenAICompactMode) || 'auto' + if (newAccount.type === 'apikey') { + openAIResponsesMode.value = normalizeOpenAIResponsesMode(extra?.openai_responses_mode) + } const codexImageGenerationBridgeValue = typeof extra?.codex_image_generation_bridge === 'boolean' ? extra.codex_image_generation_bridge : extra?.codex_image_generation_bridge_enabled @@ -2713,30 +2796,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { editBaseUrl.value = (credentials.base_url as string) || platformDefaultUrl // Load model mappings and detect mode - const existingMappings = credentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - - // Detect if this is whitelist mode (all from === to) or mapping mode - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - - if (isWhitelistMode) { - // Whitelist mode: populate allowedModels - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - // Mapping mode: populate modelMappings - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - // No mappings: default to whitelist mode with empty selection (allow all) - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } + loadModelRestrictionFromMapping(credentials.model_mapping as Record | undefined) // Load pool mode poolModeEnabled.value = credentials.pool_mode === true @@ -2780,24 +2840,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { loadQuotaNotifyFromExtra(bedrockExtra) // Load model mappings for bedrock - const existingMappings = bedrockCreds.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } + loadModelRestrictionFromMapping(bedrockCreds.model_mapping as Record | undefined) } else if (newAccount.type === 'upstream' && newAccount.credentials) { const credentials = newAccount.credentials as Record editBaseUrl.value = (credentials.base_url as string) || '' @@ -2808,24 +2851,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { editVertexLocation.value = (credentials.location as string) || (credentials.vertex_location as string) || 'us-central1' // Load model mappings for service_account - const existingMappings = credentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } + loadModelRestrictionFromMapping(credentials.model_mapping as Record | undefined) } else { const platformDefaultUrl = newAccount.platform === 'openai' @@ -2838,24 +2864,7 @@ const syncFormFromAccount = (newAccount: Account | null) => { // Load model mappings for OpenAI OAuth accounts if (newAccount.platform === 'openai' && newAccount.credentials) { const oauthCredentials = newAccount.credentials as Record - const existingMappings = oauthCredentials.model_mapping as Record | undefined - if (existingMappings && typeof existingMappings === 'object') { - const entries = Object.entries(existingMappings) - const isWhitelistMode = entries.length > 0 && entries.every(([from, to]) => from === to) - if (isWhitelistMode) { - modelRestrictionMode.value = 'whitelist' - allowedModels.value = entries.map(([from]) => from) - modelMappings.value = [] - } else { - modelRestrictionMode.value = 'mapping' - modelMappings.value = entries.map(([from, to]) => ({ from, to })) - allowedModels.value = [] - } - } else { - modelRestrictionMode.value = 'whitelist' - modelMappings.value = [] - allowedModels.value = [] - } + loadModelRestrictionFromMapping(oauthCredentials.model_mapping as Record | undefined) } else { modelRestrictionMode.value = 'whitelist' modelMappings.value = [] @@ -2935,6 +2944,40 @@ const addAntigravityPresetMapping = (from: string, to: string) => { antigravityModelMappings.value.push({ from, to }) } +const syncAntigravityUpstreamModels = async () => { + if (!props.account?.id || isSyncingAntigravityUpstream.value) return + + isSyncingAntigravityUpstream.value = true + try { + const result = await adminAPI.accounts.syncUpstreamModels(props.account.id) + const upstreamModels = result.models.map((model) => model.trim()).filter(Boolean) + if (upstreamModels.length === 0) { + appStore.showInfo(t('admin.accounts.syncUpstreamModelsEmpty')) + return + } + + let addedCount = 0 + for (const model of upstreamModels) { + const exists = antigravityModelMappings.value.some((mapping) => mapping.from === model) + if (!exists) { + antigravityModelMappings.value.push({ from: model, to: model }) + addedCount += 1 + } + } + + if (addedCount > 0) { + appStore.showSuccess(t('admin.accounts.syncUpstreamModelsSuccess', { count: addedCount, total: upstreamModels.length })) + } else { + appStore.showInfo(t('admin.accounts.syncUpstreamModelsNoChanges', { count: upstreamModels.length })) + } + } catch (error) { + const message = error instanceof Error ? error.message : t('admin.accounts.syncUpstreamModelsFailed') + appStore.showError(t('admin.accounts.syncUpstreamModelsError', { message })) + } finally { + isSyncingAntigravityUpstream.value = false + } +} + // Error code toggle helper const toggleErrorCode = (code: number) => { const index = selectedErrorCodes.value.indexOf(code) @@ -3343,20 +3386,22 @@ const handleSubmit = async () => { } // Handle API key + // 后端响应已脱敏:currentCredentials 不会再包含 api_key 原文。 + // 用户填入新值则覆盖;留空时优先看 credentials_status.has_api_key; + // 若后端尚未升级(无 credentials_status),回退读旧结构 currentCredentials.api_key。 + // 两者都无才报错。 + const hasExistingApiKey = + props.account.credentials_status?.has_api_key ?? Boolean(currentCredentials.api_key) if (editApiKey.value.trim()) { - // User provided a new API key newCredentials.api_key = editApiKey.value.trim() - } else if (currentCredentials.api_key) { - // Preserve existing api_key - newCredentials.api_key = currentCredentials.api_key - } else { + } else if (!hasExistingApiKey) { appStore.showError(t('admin.accounts.apiKeyIsRequired')) return } // Add model mapping if configured(OpenAI 开启自动透传时保留现有映射,不再编辑) if (shouldApplyModelMapping) { - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + const modelMapping = buildModelRestrictionMapping() if (modelMapping) { newCredentials.model_mapping = modelMapping } else { @@ -3434,7 +3479,15 @@ const handleSubmit = async () => { return } - if (!currentCredentials.service_account_json && !currentCredentials.service_account) { + // SA JSON 已脱敏不再随 credentials 返回,存在性优先读 credentials_status。 + // 若后端尚未升级(无 credentials_status),回退读旧结构 service_account_json / service_account。 + const credentialsStatus = props.account.credentials_status + const hasExistingServiceAccountJson = credentialsStatus + ? Boolean( + credentialsStatus.has_service_account_json || credentialsStatus.has_service_account + ) + : Boolean(currentCredentials.service_account_json || currentCredentials.service_account) + if (!hasExistingServiceAccountJson) { appStore.showError(t('admin.accounts.vertexSaJsonRequired')) return } @@ -3444,7 +3497,7 @@ const handleSubmit = async () => { newCredentials.tier_id = 'vertex' // Add model mapping if configured - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + const modelMapping = buildModelRestrictionMapping() if (modelMapping) { newCredentials.model_mapping = modelMapping } else { @@ -3494,7 +3547,7 @@ const handleSubmit = async () => { } // Model mapping - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + const modelMapping = buildModelRestrictionMapping() if (modelMapping) { newCredentials.model_mapping = modelMapping } else { @@ -3528,7 +3581,7 @@ const handleSubmit = async () => { const shouldApplyModelMapping = !openaiPassthroughEnabled.value if (shouldApplyModelMapping) { - const modelMapping = buildModelMappingObject(modelRestrictionMode.value, allowedModels.value, modelMappings.value) + const modelMapping = buildModelRestrictionMapping() if (modelMapping) { newCredentials.model_mapping = modelMapping } else { @@ -3721,6 +3774,13 @@ const handleSubmit = async () => { } else { newExtra.openai_compact_mode = openAICompactMode.value } + if (props.account.type === 'apikey') { + if (openAIResponsesMode.value === 'auto') { + delete newExtra.openai_responses_mode + } else { + newExtra.openai_responses_mode = openAIResponsesMode.value + } + } delete newExtra.codex_image_generation_bridge_enabled if (codexImageGenerationBridgeMode.value === 'inherit') { diff --git a/frontend/src/components/account/ModelWhitelistSelector.vue b/frontend/src/components/account/ModelWhitelistSelector.vue index ebce3740103..9a0d6af80f3 100644 --- a/frontend/src/components/account/ModelWhitelistSelector.vue +++ b/frontend/src/components/account/ModelWhitelistSelector.vue @@ -85,6 +85,15 @@ > {{ t('admin.accounts.fillRelatedModels') }} +