Skip to content

Commit

Permalink
Test readme commands (#1311)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 3, 2024
1 parent e441c65 commit f334378
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -26,6 +26,7 @@ test = [
"pytest>=8.1.1",
"pytest-rerunfailures>=14.0",
"pytest-timeout>=2.3.1",
"pytest-dependency>=0.6.0",
"transformers>=4.38.0", # numerical comparisons
"einops>=0.7.0",
"protobuf>=4.23.4",
Expand Down
170 changes: 170 additions & 0 deletions tests/test_readme.py
@@ -0,0 +1,170 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from pathlib import Path
import os
import pytest
import requests
import subprocess
import sys
import threading
import time


REPO_ID = Path("EleutherAI/pythia-14m")
CUSTOM_TEXTS_DIR = Path("custom_texts")


def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
return result.stdout
except subprocess.CalledProcessError as e:
error_message = (
f"Command '{' '.join(command)}' failed with exit status {e.returncode}\n"
f"Output:\n{e.stdout}\n"
f"Error:\n{e.stderr}"
)
# You can either print the message, log it, or raise an exception with it
print(error_message)
raise RuntimeError(error_message) from None


@pytest.mark.skipif(
sys.platform.startswith("win") or
sys.platform == "darwin" or
'AGENT_NAME' in os.environ,
reason="Does not run on Windows, macOS, or Azure Pipelines"
)
@pytest.mark.dependency()
def test_download_model():
repo_id = str(REPO_ID).replace("\\", "/") # fix for Windows CI
command = ["litgpt", "download", "--repo_id", str(repo_id)]
output = run_command(command)

s = Path("checkpoints") / repo_id
assert f"Saving converted checkpoint to {str(s)}" in output
assert ("checkpoints" / REPO_ID).exists()


@pytest.mark.dependency()
def test_download_books():
CUSTOM_TEXTS_DIR.mkdir(parents=True, exist_ok=True)

books = [
("https://www.gutenberg.org/cache/epub/24440/pg24440.txt", "book1.txt"),
("https://www.gutenberg.org/cache/epub/26393/pg26393.txt", "book2.txt")
]
for url, filename in books:
subprocess.run(["curl", url, "--output", str(CUSTOM_TEXTS_DIR / filename)], check=True)
# Verify each book is downloaded
assert (CUSTOM_TEXTS_DIR / filename).exists(), f"{filename} not downloaded"


@pytest.mark.dependency(depends=["test_download_model"])
def test_chat_with_model():
command = ["litgpt", "generate", "base", "--checkpoint_dir", f"checkpoints"/REPO_ID]
prompt = "What do Llamas eat?"
result = subprocess.run(command, input=prompt, text=True, capture_output=True, check=True)
assert "What food do llamas eat?" in result.stdout


@pytest.mark.dependency(depends=["test_download_model"])
@pytest.mark.timeout(300)
def test_finetune_model():

OUT_DIR = Path("out") / "lora"
DATASET_PATH = Path("custom_finetuning_dataset.json")
CHECKPOINT_DIR = "checkpoints" / REPO_ID

download_command = ["curl", "-L", "https://huggingface.co/datasets/medalpaca/medical_meadow_health_advice/raw/main/medical_meadow_health_advice.json", "-o", str(DATASET_PATH)]
subprocess.run(download_command, check=True)

assert DATASET_PATH.exists(), "Dataset file not downloaded"

finetune_command = [
"litgpt", "finetune", "lora",
"--checkpoint_dir", str(CHECKPOINT_DIR),
"--lora_r", "1",
"--data", "JSON",
"--data.json_path", str(DATASET_PATH),
"--data.val_split_fraction", "0.00001", # Keep small because new final validation is expensive
"--train.max_steps", "1",
"--out_dir", str(OUT_DIR)
]
run_command(finetune_command)

assert (OUT_DIR/"final").exists(), "Finetuning output directory was not created"
assert (OUT_DIR/"final"/"lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_pretrain_model():
OUT_DIR = Path("out") / "custom_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"--model_name", "pythia-14m",
"--tokenizer_dir", str("checkpoints" / REPO_ID),
"--data", "TextFiles",
"--data.train_data_path", str(CUSTOM_TEXTS_DIR),
"--train.max_tokens", "100", # to accelerate things for CI
"--eval.max_iters", "1", # to accelerate things for CI
"--out_dir", str(OUT_DIR)
]
run_command(pretrain_command)

assert (OUT_DIR / "final").exists(), "Pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model", "test_download_books"])
def test_continue_pretrain_model():
OUT_DIR = Path("out") / "custom_continue_pretrained"
pretrain_command = [
"litgpt", "pretrain",
"--model_name", "pythia-14m",
"--initial_checkpoint", str("checkpoints" / REPO_ID),
"--tokenizer_dir", str("checkpoints" / REPO_ID),
"--data", "TextFiles",
"--data.train_data_path", str(CUSTOM_TEXTS_DIR),
"--train.max_tokens", "100", # to accelerate things for CI
"--eval.max_iters", "1", # to accelerate things for CI
"--out_dir", str(OUT_DIR)
]
run_command(pretrain_command)

assert (OUT_DIR / "final").exists(), "Continued pretraining output directory was not created"
assert (OUT_DIR / "final" / "lit_model.pth").exists(), "Model file was not created"


@pytest.mark.dependency(depends=["test_download_model"])
def test_serve():
CHECKPOINT_DIR = str("checkpoints" / REPO_ID)
run_command = [
"litgpt", "serve",
"--checkpoint_dir", str(CHECKPOINT_DIR)
]

process = None

def run_server():
nonlocal process
try:
process = subprocess.Popen(run_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
stdout, stderr = process.communicate(timeout=60)
except subprocess.TimeoutExpired:
print('Server start-up timeout expired')

server_thread = threading.Thread(target=run_server)
server_thread.start()

# Allow time to initialize and start serving
time.sleep(30)

try:
response = requests.get("http://127.0.0.1:8000")
print(response.status_code)
assert response.status_code == 200, "Server did not respond as expected."
finally:
if process:
process.kill()
server_thread.join()

0 comments on commit f334378

Please sign in to comment.