Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add semantic search to spotlight #8932

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ jobs:
id-token: "write"
services:
postgres:
image: postgres:12.15-alpine
# Image is pinned to v15, OK since it's just for testing
image: ankane/pgvector
# This env variables must be the same in the file PARABOL_BUILD_ENV_PATH
env:
POSTGRES_PASSWORD: "temppassword"
Expand Down
3 changes: 3 additions & 0 deletions docker/dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ services:
networks:
- parabol-network
postgres:
depends_on:
- db
- redis
build:
context: "../packages/server/postgres"
restart: unless-stopped
Expand Down
12 changes: 10 additions & 2 deletions packages/server/graphql/mutations/createReflection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import CreateReflectionInput, {CreateReflectionInputType} from '../types/CreateR
import CreateReflectionPayload from '../types/CreateReflectionPayload'
import getReflectionEntities from './helpers/getReflectionEntities'
import getReflectionSentimentScore from './helpers/getReflectionSentimentScore'
import getOpenAIEmbeddings from './helpers/getOpenAIEmbeddings'
import {analytics} from '../../utils/analytics/analytics'

export default {
Expand Down Expand Up @@ -63,9 +64,10 @@ export default {

// RESOLUTION
const plaintextContent = extractTextFromDraftString(normalizedContent)
const [entities, sentimentScore] = await Promise.all([
const [entities, sentimentScore, embeddings] = await Promise.all([
getReflectionEntities(plaintextContent),
tier !== 'starter' ? getReflectionSentimentScore(question, plaintextContent) : undefined
tier !== 'starter' ? getReflectionSentimentScore(question, plaintextContent) : undefined,
getOpenAIEmbeddings(plaintextContent)
])
const reflectionGroupId = generateUID()

Expand Down Expand Up @@ -93,6 +95,12 @@ export default {

await Promise.all([
pg.insertInto('RetroReflectionGroup').values(reflectionGroup).execute(),
embeddings
? pg
.insertInto('ReflectionEmbeddings')
.values({id: reflection.id, vector: embeddings})
.execute()
: null,
r.table('RetroReflectionGroup').insert(reflectionGroup).run(),
r.table('RetroReflection').insert(reflection).run()
])
Expand Down
19 changes: 19 additions & 0 deletions packages/server/graphql/mutations/helpers/getOpenAIEmbeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import {OpenAIEmbeddings} from 'langchain/embeddings/openai'
import {RecursiveCharacterTextSplitter} from 'langchain/text_splitter'

export const getOpenAIEmbeddings = async (plaintextContent: string) => {
if (!plaintextContent) return null
const embeddings = new OpenAIEmbeddings(
{openAIApiKey: 'X'},
{baseURL: 'http://localhost:3002/v1'}
)
const splitter = new RecursiveCharacterTextSplitter({chunkSize: 1000, chunkOverlap: 200})
const splitText = await splitter.splitText(plaintextContent)
const start = performance.now()
const contentEmbedding = await embeddings.embedDocuments(splitText)
const end = performance.now()
console.log('duration', end - start)
return `[${contentEmbedding.join(',')}]`
}

export default getOpenAIEmbeddings
27 changes: 23 additions & 4 deletions packages/server/graphql/types/User.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
GraphQLObjectType,
GraphQLString
} from 'graphql'
import {sql} from 'kysely'
import MeetingMemberId from 'parabol-client/shared/gqlIds/MeetingMemberId'
import toTeamMemberId from 'parabol-client/utils/relay/toTeamMemberId'
import {
Expand All @@ -16,18 +17,22 @@ import {
} from '../../../client/utils/constants'
import groupReflections from '../../../client/utils/smartGroup/groupReflections'
import getRethink from '../../database/rethinkDriver'
import {RDatum} from '../../database/stricterR'
import MeetingMemberType from '../../database/types/MeetingMember'
import OrganizationType from '../../database/types/Organization'
import OrganizationUserType from '../../database/types/OrganizationUser'
import Reflection from '../../database/types/Reflection'
import SuggestedActionType from '../../database/types/SuggestedAction'
import TimelineEvent from '../../database/types/TimelineEvent'
import getKysely from '../../postgres/getKysely'
import {getUserId, isSuperUser, isTeamMember} from '../../utils/authorization'
import getMonthlyStreak from '../../utils/getMonthlyStreak'
import getRedis from '../../utils/getRedis'
import standardError from '../../utils/standardError'
import errorFilter from '../errorFilter'
import {DataLoaderWorker, GQLContext} from '../graphql'
import isValid from '../isValid'
import getOpenAIEmbeddings from '../mutations/helpers/getOpenAIEmbeddings'
import invoices from '../queries/invoices'
import organization from '../queries/organization'
import AuthIdentity from './AuthIdentity'
Expand All @@ -47,8 +52,6 @@ import TeamMember from './TeamMember'
import TierEnum from './TierEnum'
import {TimelineEventConnection} from './TimelineEvent'
import TimelineEventTypeEnum from './TimelineEventTypeEnum'
import TimelineEvent from '../../database/types/TimelineEvent'
import {RDatum} from '../../database/stricterR'

const User: GraphQLObjectType<any, GQLContext> = new GraphQLObjectType<any, GQLContext>({
name: 'User',
Expand Down Expand Up @@ -498,9 +501,25 @@ const User: GraphQLObjectType<any, GQLContext> = new GraphQLObjectType<any, GQLC
}

if (searchQuery !== '') {
const matchedReflections = reflections.filter(({plaintextContent}) =>
plaintextContent.toLowerCase().includes(searchQuery)
const pg = getKysely()
const searchEmbeddings = await getOpenAIEmbeddings(searchQuery)
const vectors = await pg
.selectFrom('ReflectionEmbeddings')
.select([
'id',
sql<number>`1 - (vector <=> ${searchEmbeddings})`.as('cosDistance'),
sql<number>`vector <-> ${searchEmbeddings}`.as('l2Distance'),
sql<number>`vector <#> ${searchEmbeddings}`.as('innerProductDifference')
])
.orderBy('l2Distance asc')
.limit(5)
.execute()
const l2Alpha = 0.6
const matchingReflectionIds = new Set(
vectors.filter((vec) => vec.l2Distance < l2Alpha).map(({id}) => id)
)

const matchedReflections = reflections.filter(({id}) => matchingReflectionIds.has(id))
const relatedReflections = matchedReflections.filter(
({reflectionGroupId: groupId}: Reflection) => groupId !== reflectionGroupId
)
Expand Down
1 change: 1 addition & 0 deletions packages/server/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@
"ioredis": "^5.2.3",
"jsdom": "^20.0.0",
"jsonwebtoken": "^9.0.0",
"langchain": "^0.0.152",
"mailcomposer": "^4.0.1",
"mailgun.js": "^7.0.4",
"mime-types": "^2.1.16",
Expand Down
4 changes: 3 additions & 1 deletion packages/server/postgres/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ FROM postgres:12.15

ADD extensions /extensions

RUN apt-get update && apt-get install -y build-essential
RUN apt-get update && apt-get install -y --no-install-recommends build-essential postgresql-server-dev-12 git-all

RUN cd /extensions/postgres-json-schema && make install && make installcheck

RUN git clone --branch v0.5.0 https://github.com/pgvector/pgvector.git extensions/pgvector && cd extensions/pgvector && make clean && make && make install

COPY extensions/install.sql /docker-entrypoint-initdb.d/
1 change: 1 addition & 0 deletions packages/server/postgres/extensions/install.sql
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
CREATE EXTENSION IF NOT EXISTS "postgres-json-schema";
CREATE EXTENSION IF NOT EXISTS "vector";
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import {Kysely, PostgresDialect, sql} from 'kysely'
import {Client} from 'pg'
import getPg from '../getPg'
import getPgConfig from '../getPgConfig'

export async function up() {
const pg = new Kysely<any>({
dialect: new PostgresDialect({
pool: getPg()
})
})

pg.schema
// I had to normalize domains to its own table to guarantee uniqueness (and make indexing easier)
await sql`
CREATE EXTENSION IF NOT EXISTS "vector";
CREATE TABLE IF NOT EXISTS "ReflectionEmbeddings" (
"id" VARCHAR(100) PRIMARY KEY,
"vector" VECTOR NOT NULL,
"meetingId" VARCHAR(100) NOT NULL,
"teamId" VARCHAR(100) NOT NULL,
"orgId" VARCHAR(100) NOT NULL
);`.execute(pg)
}

export async function down() {
const client = new Client(getPgConfig())
await client.connect()
await client.query(`DROP TABLE IF EXISTS "ReflectionEmbeddings";`)
await client.end()
}
Loading
Loading