Skip to content
This repository has been archived by the owner on Feb 29, 2024. It is now read-only.

Add a model to the chat panel #1

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions jupyterlab_chat/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@
import json
import time
import uuid
from asyncio import AbstractEventLoop
from dataclasses import asdict
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import Dict, List

from jupyter_server.base.handlers import APIHandler as BaseAPIHandler, JupyterHandler
from jupyter_server.utils import url_path_join
from langchain.pydantic_v1 import ValidationError
from tornado import web, websocket

Expand Down Expand Up @@ -153,11 +150,13 @@ async def on_message(self, message):
return

# message broadcast to chat clients
chat_message_id = str(uuid.uuid4())
if not chat_request.id:
chat_request.id = str(uuid.uuid4())

chat_message = ChatMessage(
id=chat_message_id,
id=chat_request.id,
time=time.time(),
body=chat_request.prompt,
body=chat_request.body,
sender=self.chat_client,
)

Expand Down
3 changes: 2 additions & 1 deletion jupyterlab_chat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

# the type of message used to chat with the agent
class ChatRequest(BaseModel):
prompt: str
body: str
id: str


class ChatUser(BaseModel):
Expand Down
3 changes: 3 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@
"@jupyterlab/rendermime": "^4.0.5",
"@jupyterlab/services": "^7.0.5",
"@jupyterlab/ui-components": "^4.0.5",
"@lumino/coreutils": "2.1.2",
"@lumino/disposable": "2.1.2",
"@lumino/signaling": "2.1.2",
"@mui/icons-material": "5.11.0",
"@mui/material": "^5.11.0",
"react": "^18.2.0",
Expand Down
5 changes: 2 additions & 3 deletions src/components/chat-messages.tsx
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import React, { useState, useEffect } from 'react';

import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { Avatar, Box, Typography } from '@mui/material';
import type { SxProps, Theme } from '@mui/material';
import React, { useState, useEffect } from 'react';

import { IRenderMimeRegistry } from '@jupyterlab/rendermime';
import { RendermimeMarkdown } from './rendermime-markdown';
import { ChatService } from '../services';

Expand Down
10 changes: 5 additions & 5 deletions src/components/chat-settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ import {
} from '@mui/material';
import React, { useEffect, useState } from 'react';

import { ChatService } from '../services';
import { useStackingAlert } from './mui-extras/stacking-alert';
import { ServerInfoState, useServerInfo } from './settings/use-server-info';
import { minifyUpdate } from './settings/minify';
import { useStackingAlert } from './mui-extras/stacking-alert';
import { ChatService } from '../services';

// /**
// * Component that returns the settings view in the chat panel.
// */
/**
* Component that returns the settings view in the chat panel.
*/
export function ChatSettings(): JSX.Element {
// state fetched on initial render
const server = useServerInfo();
Expand Down
28 changes: 13 additions & 15 deletions src/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@ import { JlThemeProvider } from './jl-theme-provider';
import { ChatMessages } from './chat-messages';
import { ChatInput } from './chat-input';
import { ChatSettings } from './chat-settings';
import { ChatHandler } from '../chat-handler';
import { ScrollContainer } from './scroll-container';
import { IChatModel } from '../model';
import { ChatService } from '../services';

type ChatBodyProps = {
chatHandler: ChatHandler;
chatModel: IChatModel;
rmRegistry: IRenderMimeRegistry;
};

function ChatBody({
chatHandler,
chatModel,
rmRegistry: renderMimeRegistry
}: ChatBodyProps): JSX.Element {
const [messages, setMessages] = useState<ChatService.IChatMessage[]>([]);
Expand All @@ -34,7 +34,8 @@ function ChatBody({
async function fetchHistory() {
try {
const [history, config] = await Promise.all([
chatHandler.getHistory(),
chatModel.getHistory?.() ??
new Promise<ChatService.ChatHistory>(r => r({ messages: [] })),
ChatService.getConfig()
]);
setSendWithShiftEnter(config.send_with_shift_enter ?? false);
Expand All @@ -45,13 +46,13 @@ function ChatBody({
}

fetchHistory();
}, [chatHandler]);
}, [chatModel]);

/**
* Effect: listen to chat messages
*/
useEffect(() => {
function handleChatEvents(message: ChatService.IMessage) {
function handleChatEvents(_: IChatModel, message: ChatService.IMessage) {
if (message.type === 'connection') {
return;
} else if (message.type === 'clear') {
Expand All @@ -62,19 +63,19 @@ function ChatBody({
setMessages(messageGroups => [...messageGroups, message]);
}

chatHandler.addListener(handleChatEvents);
chatModel.incomingMessage.connect(handleChatEvents);
return function cleanup() {
chatHandler.removeListener(handleChatEvents);
chatModel.incomingMessage.disconnect(handleChatEvents);
};
}, [chatHandler]);
}, [chatModel]);

// no need to append to messageGroups imperatively here. all of that is
// handled by the listeners registered in the effect hooks above.
const onSend = async () => {
setInput('');

// send message to backend
chatHandler.sendMessage({ prompt: input });
chatModel.sendMessage({ body: input });
};

return (
Expand All @@ -100,7 +101,7 @@ function ChatBody({
}

export type ChatProps = {
chatHandler: ChatHandler;
chatModel: IChatModel;
themeManager: IThemeManager | null;
rmRegistry: IRenderMimeRegistry;
chatView?: ChatView;
Expand Down Expand Up @@ -147,10 +148,7 @@ export function Chat(props: ChatProps): JSX.Element {
</Box>
{/* body */}
{view === ChatView.Chat && (
<ChatBody
chatHandler={props.chatHandler}
rmRegistry={props.rmRegistry}
/>
<ChatBody chatModel={props.chatModel} rmRegistry={props.rmRegistry} />
)}
{view === ChatView.Settings && <ChatSettings />}
</Box>
Expand Down
1 change: 0 additions & 1 deletion src/handler.ts → src/handlers/handler.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { URLExt } from '@jupyterlab/coreutils';

import { ServerConnection } from '@jupyterlab/services';

const API_NAMESPACE = 'api/chat';
Expand Down
106 changes: 36 additions & 70 deletions src/chat-handler.ts → src/handlers/websocket-handler.ts
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import { IDisposable } from '@lumino/disposable';
import { ServerConnection } from '@jupyterlab/services';
import { URLExt } from '@jupyterlab/coreutils';
import { ServerConnection } from '@jupyterlab/services';
import { UUID } from '@lumino/coreutils';

import { requestAPI } from './handler';
import { ChatService } from './services';
import { ChatModel, IChatModel } from '../model';
import { ChatService } from '../services';

const CHAT_SERVICE_URL = 'api/chat';

export class ChatHandler implements IDisposable {
/**
* An implementation of the chat model based on websocket handler.
*/
export class WebSocketHandler extends ChatModel {
/**
* The server settings used to make API requests.
*/
readonly serverSettings: ServerConnection.ISettings;

/**
* ID of the connection. Requires `await initialize()`.
*/
id = '';

/**
* Create a new chat handler.
*/
constructor(options: ChatHandler.IOptions = {}) {
constructor(options: WebSocketHandler.IOptions = {}) {
super(options);
this.serverSettings =
options.serverSettings ?? ServerConnection.makeSettings();
}
Expand All @@ -30,35 +31,23 @@ export class ChatHandler implements IDisposable {
* resolved when server acknowledges connection and sends the client ID. This
* must be awaited before calling any other method.
*/
public async initialize(): Promise<void> {
async initialize(): Promise<void> {
await this._initialize();
}

/**
* Sends a message across the WebSocket. Promise resolves to the message ID
* when the server sends the same message back, acknowledging receipt.
*/
public sendMessage(message: ChatService.ChatRequest): Promise<string> {
sendMessage(message: ChatService.ChatRequest): Promise<boolean> {
message.id = UUID.uuid4();
return new Promise(resolve => {
this._socket?.send(JSON.stringify(message));
this._sendResolverQueue.push(resolve);
this._sendResolverQueue.set(message.id!, resolve);
});
}

public addListener(handler: (message: ChatService.IMessage) => void): void {
this._listeners.push(handler);
}

public removeListener(
handler: (message: ChatService.IMessage) => void
): void {
const index = this._listeners.indexOf(handler);
if (index > -1) {
this._listeners.splice(index, 1);
}
}

public async getHistory(): Promise<ChatService.ChatHistory> {
async getHistory(): Promise<ChatService.ChatHistory> {
let data: ChatService.ChatHistory = { messages: [] };
try {
data = await requestAPI('history', {
Expand All @@ -70,22 +59,11 @@ export class ChatHandler implements IDisposable {
return data;
}

/**
* Whether the chat handler is disposed.
*/
get isDisposed(): boolean {
return this._isDisposed;
}

/**
* Dispose the chat handler.
*/
dispose(): void {
if (this.isDisposed) {
return;
}
this._isDisposed = true;
this._listeners = [];
super.dispose();

// Clean up socket.
const socket = this._socket;
Expand All @@ -99,35 +77,15 @@ export class ChatHandler implements IDisposable {
}
}

/**
* A function called before transferring the message to the panel(s).
* Can be useful if some actions are required on the message.
*/
protected formatChatMessage(
message: ChatService.IChatMessage
): ChatService.IChatMessage {
return message;
}

private _onMessage(message: ChatService.IMessage): void {
onMessage(message: ChatService.IMessage): void {
// resolve promise from `sendMessage()`
if (message.type === 'msg' && message.sender.id === this.id) {
this._sendResolverQueue.shift()?.(message.id);
}

if (message.type === 'msg') {
message = this.formatChatMessage(message as ChatService.IChatMessage);
this._sendResolverQueue.get(message.id)?.(true);
}

// call listeners in serial
this._listeners.forEach(listener => listener(message));
super.onMessage(message);
}

/**
* Queue of Promise resolvers pushed onto by `send()`
*/
private _sendResolverQueue: ((value: string) => void)[] = [];

private _onClose(e: CloseEvent, reject: any) {
reject(new Error('Chat UI websocket disconnected'));
console.error('Chat UI websocket disconnected');
Expand Down Expand Up @@ -155,31 +113,39 @@ export class ChatHandler implements IDisposable {
socket.onclose = e => this._onClose(e, reject);
socket.onerror = e => reject(e);
socket.onmessage = msg =>
msg.data && this._onMessage(JSON.parse(msg.data));
msg.data && this.onMessage(JSON.parse(msg.data));

const listenForConnection = (message: ChatService.IMessage) => {
const listenForConnection = (
_: IChatModel,
message: ChatService.IMessage
) => {
if (message.type !== 'connection') {
return;
}
this.id = message.client_id;
resolve();
this.removeListener(listenForConnection);
this.incomingMessage.disconnect(listenForConnection);
};

this.addListener(listenForConnection);
this.incomingMessage.connect(listenForConnection);
});
}

private _isDisposed = false;
private _socket: WebSocket | null = null;
private _listeners: ((msg: any) => void)[] = [];
/**
* Queue of Promise resolvers pushed onto by `send()`
*/
private _sendResolverQueue = new Map<string, (value: boolean) => void>();
}

export namespace ChatHandler {
/**
* The websocket namespace.
*/
export namespace WebSocketHandler {
/**
* The instantiation options for a data registry handler.
*/
export interface IOptions {
export interface IOptions extends ChatModel.IOptions {
serverSettings?: ServerConnection.ISettings;
}
}
5 changes: 3 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from './chat-handler';
export * from './handlers/websocket-handler';
export * from './model';
export * from './services';
export * from './widgets/chat-error';
export * from './widgets/chat-sidebar';
export * from './services';
Loading
Loading