diff --git a/cdk/resources/Memfault.ts b/cdk/resources/Memfault.ts index 99482eb..8c8a062 100644 --- a/cdk/resources/Memfault.ts +++ b/cdk/resources/Memfault.ts @@ -12,6 +12,7 @@ import type { IPrincipal } from 'aws-cdk-lib/aws-iam/index.js' import { Construct } from 'constructs' import type { PackedLambda } from '../backend.js' import { LambdaLogGroup } from './LambdaLogGroup.js' +import type { WebsocketAPI } from './WebsocketAPI.js' /** * Pull Memfault data for devices @@ -24,12 +25,14 @@ export class Memfault extends Construct { lambdaSources, baseLayer, assetTrackerStackName, + websocketAPI, }: { lambdaSources: { memfault: PackedLambda } baseLayer: Lambda.ILayerVersion assetTrackerStackName: string + websocketAPI: WebsocketAPI }, ) { super(parent, 'Memfault') @@ -61,7 +64,7 @@ export class Memfault extends Construct { timeout: Duration.seconds(60), memorySize: 1792, code: Lambda.Code.fromAsset(lambdaSources.memfault.lambdaZipFile), - description: 'Pull Memfault data for devices and put them in the shadow', + description: 'Pull Memfault data for devices and publish it on S3', layers: [baseLayer], environment: { VERSION: this.node.tryGetContext('version'), @@ -69,6 +72,8 @@ export class Memfault extends Construct { ASSET_TRACKER_STACK_NAME: assetTrackerStackName, NODE_NO_WARNINGS: '1', BUCKET: this.bucket.bucketName, + WEBSOCKET_CONNECTIONS_TABLE_NAME: + websocketAPI.connectionsTable.tableName, }, initialPolicy: [ new IAM.PolicyStatement({ @@ -86,6 +91,7 @@ export class Memfault extends Construct { }) this.bucket.grantWrite(fn) + websocketAPI.connectionsTable.grantReadData(fn) const rule = new Events.Rule(this, 'Rule', { schedule: Events.Schedule.expression('rate(5 minutes)'), diff --git a/cdk/resources/WebsocketAPI.ts b/cdk/resources/WebsocketAPI.ts index 88ca09c..daf02a6 100644 --- a/cdk/resources/WebsocketAPI.ts +++ b/cdk/resources/WebsocketAPI.ts @@ -15,7 +15,7 @@ import { LambdaLogGroup } from './LambdaLogGroup.js' export class WebsocketAPI extends Construct { public readonly websocketURI: string - public readonly connectionsTable: DynamoDB.ITable + public readonly connectionsTable: DynamoDB.Table public readonly websocketAPIArn: string public readonly websocketManagementAPIURL: string public constructor( @@ -43,7 +43,7 @@ export class WebsocketAPI extends Construct { }, timeToLiveAttribute: 'ttl', removalPolicy: RemovalPolicy.DESTROY, - }) as DynamoDB.ITable + }) // API const api = new ApiGateway.CfnApi(this, 'api', { diff --git a/cdk/stacks/BackendStack.ts b/cdk/stacks/BackendStack.ts index d99f1c1..bc95078 100644 --- a/cdk/stacks/BackendStack.ts +++ b/cdk/stacks/BackendStack.ts @@ -117,6 +117,7 @@ export class BackendStack extends Stack { assetTrackerStackName, baseLayer, lambdaSources, + websocketAPI: api, }) // Outputs diff --git a/lambda/memfault.ts b/lambda/memfault.ts index f67e89b..d426ff5 100644 --- a/lambda/memfault.ts +++ b/lambda/memfault.ts @@ -2,15 +2,24 @@ import { IoTClient, ListThingsInThingGroupCommand } from '@aws-sdk/client-iot' import { GetParametersByPathCommand, SSMClient } from '@aws-sdk/client-ssm' import { fromEnv } from '@nordicsemiconductor/from-env' import { S3Client, PutObjectCommand } from '@aws-sdk/client-s3' +import { getActiveConnections } from './notifyClients.js' +import { DynamoDBClient } from '@aws-sdk/client-dynamodb' const ssm = new SSMClient({}) const iot = new IoTClient({}) const s3 = new S3Client({}) +const db = new DynamoDBClient({}) -const { stackName, nrfAssetTrackerStackName, bucket } = fromEnv({ +const { + stackName, + nrfAssetTrackerStackName, + bucket, + websocketConnectionsTableName, +} = fromEnv({ stackName: 'STACK_NAME', nrfAssetTrackerStackName: 'ASSET_TRACKER_STACK_NAME', bucket: 'BUCKET', + websocketConnectionsTableName: 'WEBSOCKET_CONNECTIONS_TABLE_NAME', })(process.env) const Prefix = `/${stackName}/memfault/` @@ -37,6 +46,8 @@ if ( ) throw new Error(`Memfault settings not configured!`) +const getActive = getActiveConnections(db, websocketConnectionsTableName) + type Reboot = { type: 'memfault' mcu_reason_register: null @@ -76,6 +87,10 @@ const api = { * Pull data from Memfault about all devices */ export const handler = async (): Promise => { + if ((await getActive()).length === 0) { + console.debug('No active connections.') + return + } const { things } = await iot.send( new ListThingsInThingGroupCommand({ thingGroupName: nrfAssetTrackerStackName, diff --git a/lambda/notifyClients.ts b/lambda/notifyClients.ts index 9e89292..2b785f5 100644 --- a/lambda/notifyClients.ts +++ b/lambda/notifyClients.ts @@ -77,6 +77,7 @@ export const notifyClients = ( dropMessage = false, ): ((event: Event) => Promise) => { const send = sendEvent(apiGwManagementClient) + const getActive = getActiveConnections(db, connectionsTableName) return async (event: Event): Promise => { console.log( JSON.stringify({ @@ -87,10 +88,7 @@ export const notifyClients = ( console.debug(`Dropped message`) return } - const connectionIds: string[] = await getActiveConnections( - db, - connectionsTableName, - ) + const connectionIds: string[] = await getActive() for (const connectionId of connectionIds) { try { @@ -151,18 +149,35 @@ const getEventContext = (event: Event): URL | null => { return null } -export const getActiveConnections = async ( +export const getActiveConnections = ( db: DynamoDBClient, connectionsTableName: string, -): Promise => { - const res = await db.send( - new ScanCommand({ - TableName: connectionsTableName, - }), - ) +): (() => Promise>) => { + let lastResult: { + connectionIds: string[] + ts: number + } + return async (): Promise => { + // Cache for 60 seconds + if (lastResult !== undefined && lastResult.ts > Date.now() - 60 * 1000) { + return lastResult.connectionIds + } + + const res = await db.send( + new ScanCommand({ + TableName: connectionsTableName, + }), + ) + + const connectionIds: string[] = res?.Items?.map( + ({ connectionId }) => connectionId?.S, + ).filter((connectionId) => connectionId !== undefined) as string[] - const connectionIds: string[] = res?.Items?.map( - ({ connectionId }) => connectionId?.S, - ).filter((connectionId) => connectionId !== undefined) as string[] - return connectionIds + lastResult = { + connectionIds, + ts: Date.now(), + } + + return connectionIds + } } diff --git a/lambda/onNewNetworkSurvey.ts b/lambda/onNewNetworkSurvey.ts index c003faf..560bc46 100644 --- a/lambda/onNewNetworkSurvey.ts +++ b/lambda/onNewNetworkSurvey.ts @@ -34,13 +34,12 @@ const notifier = withDeviceAlias(iot)( }), ) +const getActive = getActiveConnections(db, connectionsTableName) + export const handler = async (event: DynamoDBStreamEvent): Promise => { console.log(JSON.stringify({ event, networkGeolocationApiUrl })) - const connectionIds: string[] = await getActiveConnections( - db, - connectionsTableName, - ) + const connectionIds: string[] = await getActive() if (connectionIds.length === 0) { console.log(`No clients to notify.`) return diff --git a/lambda/publishSummaries.ts b/lambda/publishSummaries.ts index 0c01793..75340a7 100644 --- a/lambda/publishSummaries.ts +++ b/lambda/publishSummaries.ts @@ -35,11 +35,10 @@ const [historicaldataDatabaseName, historicaldataTableName] = const timestream = new TimestreamQueryClient({}) +const getActive = getActiveConnections(db, connectionsTableName) + export const handler = async (): Promise => { - const connectionIds: string[] = await getActiveConnections( - db, - connectionsTableName, - ) + const connectionIds: string[] = await getActive() if (connectionIds.length === 0) { console.log(`No clients to notify.`) return diff --git a/lambda/resolveCellLocation.ts b/lambda/resolveCellLocation.ts index dc0183f..7dbc3bb 100644 --- a/lambda/resolveCellLocation.ts +++ b/lambda/resolveCellLocation.ts @@ -25,6 +25,8 @@ const notifier = withDeviceAlias(iot)( }), ) +const getActive = getActiveConnections(db, connectionsTableName) + export const handler = async (event: { roam: { v: { @@ -42,10 +44,7 @@ export const handler = async (event: { }): Promise => { console.log(JSON.stringify({ event, geolocationApiUrl })) - const connectionIds: string[] = await getActiveConnections( - db, - connectionsTableName, - ) + const connectionIds: string[] = await getActive() if (connectionIds.length === 0) { console.log(`No clients to notify.`) return