Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General debug patch #46

Merged
merged 6 commits into from
Dec 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions cheese/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,22 @@ class CHEESE:

:param port: Port to run rabbitmq server on
:type port: int

:param debug: Print debug messages for rabbitmq
:type debug: bool
"""
def __init__(
self,
pipeline_cls = None, client_cls = None, model_cls = None,
pipeline_kwargs : Dict[str, Any] = {}, model_kwargs : Dict[str, Any] = {},
gradio : bool = True, draw_always : bool = False,
host : str = 'localhost', port : int = 5672
host : str = 'localhost', port : int = 5672,
debug : bool = False
):

self.gradio = gradio
self.draw_always = draw_always
self.debug = debug

# Initialize rabbit MQ server
self.connection = BRabbit(host=host, port=port)
Expand Down Expand Up @@ -120,13 +125,10 @@ def launch(self) -> str:
self.url = url
return url

def start_listening(self, verbose : bool = True, listen_every : float = 1.0):
def start_listening(self, listen_every : float = 1.0):
"""
If using as a server, call this before running client.

:param verbose: Whether to print status updates
:type verbose: bool

:param run_every: Listen for messages every x seconds
"""

Expand All @@ -135,8 +137,8 @@ def send(msg : Any):

while True:
if self.receive_buffer:
if verbose:
print("Received a message", self.receive_buffer[0])
if self.debug:
print(f"Responding to message: {self.receive_buffer[0]}")
msg = self.receive_buffer.pop(0).split("|")
if msg[0] == msg_constants.READY:
send(True)
Expand Down Expand Up @@ -167,6 +169,9 @@ def api_ping(self, msg):
# - Get stats
# - draw

if self.debug:
print(f"Received message from API: {pickle.loads(msg)}")

try:
self.receive_buffer.append(pickle.loads(msg))
except Exception as e:
Expand Down
8 changes: 7 additions & 1 deletion cheese/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,16 @@ class CHEESEAPI:

:param timeout: Timeout for waiting for main server to respond
:type timeout: float

:param debug: Print debug messages for rabbitmq
:type debug: bool
"""
def __init__(self, host : str = 'localhost', port : int = 5672, timeout : float = 10):
def __init__(self, host : str = 'localhost', port : int = 5672, timeout : float = 10, debug : bool = False):
self.timeout = timeout

# Initialize rabbit MQ server
self.connection = BRabbit(host=host, port=port)
self.debug = debug

# Channel to get results back from main server
self.subscriber = self.connection.EventSubscriber(
Expand Down Expand Up @@ -65,6 +69,8 @@ def main_listener(self, msg : str):
"""
Callback for main server. Receives messages from main server and places them in buffer.
"""
if self.debug:
print(f"Received message from main server: {pickle.loads(msg)}")
if not self.connected_to_main:
print("Warning: RabbitMQ queue non-empty at startup. Consider restarting RabbitMQ server if unexpected errors arise.")
return
Expand Down
1 change: 0 additions & 1 deletion cheese/pipeline/iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from abc import abstractmethod
from typing import Dict, Any, Iterable

import webdataset as wds
from datasets import load_from_disk, Dataset
import pandas as pd
import joblib
Expand Down
9 changes: 5 additions & 4 deletions examples/image_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,16 @@ def main(self):
)
error_btn = gr.Button("Press This If An Image Is Not Loading")
error_btn.style(full_width = True)
with gr.Row():
with gr.Column():
with gr.Column():
with gr.Row():
im_left = gr.Image(show_label = False, shape = (256, 256))
im_right = gr.Image(show_label = False, shape = (256, 256))
with gr.Row():
btn_left = gr.Button("Select Above")
btn_left.style(full_width = True)
with gr.Column():
im_right = gr.Image(show_label = False, shape = (256, 256))
btn_right = gr.Button("Select Above")
btn_right.style(full_width = True)


# Note how all button clicks call response, but with different arguments
# The arguments to response will later be passed to self.receive(...)
Expand Down
39 changes: 31 additions & 8 deletions examples/instruct_hf_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,25 @@ class LMGenerationElement(BatchElement):
rankings : List[int] = None # Ordering for the completions w.r.t indices

class LMPipeline(GenerativePipeline):
def __init__(self, n_samples = 5, **kwargs):
"""
Pipeline for doing language model generation.

:param n_samples: Number of samples to generate for each query. Due to issues with how gradio registers button events,
this had to be hardcoded in the demo, so be aware of this if you want to change the value (i.e. add/remove buttons)
:type n_samples: int

:param device: Device to use for inference (any n > -1 uses cuda device n, -1 uses cpu). Defaults to 0.
:type device: int

:param kwargs: Keyword arguments to pass to GenerativePipeline
:type kwargs: dict
"""
def __init__(self, n_samples = 5, device : int = 0, **kwargs):
super().__init__(**kwargs)

self.n_samples = n_samples
self.pipe = pipeline(task="text-generation", model = 'gpt2', device=0)
self.pipe = pipeline(task="text-generation", model = 'gpt2')
self.pipe.tokenizer.pad_token_id = self.pipe.model.config.eos_token_id
# prevents annoying messages


self.init_buffer()
Expand All @@ -42,7 +54,7 @@ def generate(self, model_input : Iterable[str]) -> List[LMGenerationElement]:
for i in range(self.batch_size):
query = model_input[i]
completions = self.pipe(query, max_length=100, num_return_sequences=self.n_samples)
completions = [completion["generated_text"] for completion in completions]
completions = [completion["generated_text"][len(query):] for completion in completions]
elements.append(LMGenerationElement(query=query, completions=completions))
return elements

Expand All @@ -56,9 +68,21 @@ def extract_data(self, batch_element : LMGenerationElement) -> dict:
"rankings" : batch_element.rankings
}

def make_iter(length : int = 20):
def make_iter(length : int = 20, chunk_size : int = 16, device : int = 0):
"""
Creates an iterator that generates prompts for the completions that will be presented to labeller.

:param length: Number of prompts to generate
:type length: int

:param chunk_size: Number of prompts to generate in one forward pass
:type chunk_size: int

:param device: Device to run model on (any n > -1 uses cuda device n, -1 uses cpu)
:type device: int
"""
print("Creating prompt iterator...")
pipe = pipeline(task="text-generation", model = 'gpt2', device=0)
pipe = pipeline(task="text-generation", model = 'gpt2', device = device)
pipe.tokenizer.pad_token_id = pipe.model.config.eos_token_id
chunk_size = 16
meta_prompt = f"As an example, below is a list of {chunk_size + 3} prompts you could feed to a language model:\n"+\
Expand Down Expand Up @@ -103,7 +127,6 @@ def main(self):
# When a button is pressed, append index to state, and make button not visible

def press_button(i, pressed_val):
print("Pressed button", i)
pressed_val.append(i)

updates = [gr.update(visible = False if j in pressed_val else True) for j in range(5)]
Expand Down Expand Up @@ -194,4 +217,4 @@ def present(self, task):

print(cheese.launch())

print(cheese.create_client(1))
print(cheese.create_client(1))
Empty file added tests/reconnect/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions tests/reconnect/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cheese.api import CHEESEAPI

if __name__ == "__main__":
print("Attempt connect")
cheese = CHEESEAPI(debug = True)

print(cheese.get_stats()["url"])
13 changes: 13 additions & 0 deletions tests/reconnect/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from cheese import CHEESE
from examples.image_selection import *

if __name__ == "__main__":
cheese = CHEESE(
ImageSelectionPipeline, ImageSelectionFront,
pipeline_kwargs = {
"iter" : make_iter(), "write_path" : "./img_dataset_res", "force_new" : True, "max_length" : 5
},
debug = True
)
cheese.launch()
cheese.start_listening()