In [2]:
# Cell 1：切到项目根目录 + 基础检查
import os

## 切换工作目录到 Aristotle/TRACE 项目根
os.chdir("/workspace")
print("Current working directory:", os.getcwd())

print("Files here:")
print(os.listdir("."))
print("\nLogicNLI data dir:")
print(os.listdir("data/LogicNLI"))

## 方案 1：告诉 wandb 这是脚本，而不是 notebook（最推荐）
## 在你的 notebook 开头加入：
os.environ["WANDB_NOTEBOOK_NAME"] = "baseline_ Aristotle_LogicNLI_v2.ipynb"
os.environ["WANDB_DISABLE_NOTEBOOK"] = "true"


Current working directory: /workspace
Files here:
['.git', '.gitignore', 'README.md', 'aristotle.png', 'data', 'evaluate.py', 'negate.py', 'prompts', 'requirements.txt', 'search_resolve.py', 'translate_decompose.py', 'utils.py', 'experiments', 'notebooks', 'apptainer', 'scripts', 'TRACE.md', '__pycache__', 'results', 'logs', 'wandb', 'test_env']

LogicNLI data dir:
['dev.json']


In [3]:
# Cell 2：安全设置 OpenAI & wandb（不要提交到 Git）
import os
import getpass

# 建议用 getpass，避免 key 明文出现在界面 / 历史中
openai_key = getpass.getpass("OpenAI API key: ")
os.environ["OPENAI_API_KEY"] = openai_key

# 如果需要手动登录 wandb（通常一次登录之后 cookie 会保存）
import wandb

wandb_api = getpass.getpass("wandb API key (留空则跳过 login): ")
if wandb_api.strip():
    os.environ["WANDB_API_KEY"] = wandb_api
    wandb.login(key=wandb_api)
else:
    print("跳过 wandb.login（用已有登录状态）")

OpenAI API key:  ········
wandb API key (留空则跳过 login):  ········




In [5]:
# Cell 3：实验配置 & 初始化 wandb run
import time
import wandb
from datetime import datetime
import os

# 基本配置
PROJECT_DIR = "/workspace"
DATASET_NAME = "LogicNLI"
SPLIT = "dev"
MODEL_NAME = "gpt-4o"      # 你也可以换成 gpt-4.1 等
MAX_NEW_TOKENS = 2048
BATCH_NUM = 1

RUN_NAME = f"{DATASET_NAME}-{SPLIT}-{MODEL_NAME}-{int(time.time())}"

config = {
    "dataset": DATASET_NAME,
    "split": SPLIT,
    "model_name": MODEL_NAME,
    "max_new_tokens": MAX_NEW_TOKENS,
    "batch_num": BATCH_NUM,
    "pipeline": "Aristotle_baseline",
}

run = wandb.init(
    project="TRACE_LogicNLI",  # 你可以换成自己的项目名
    name=RUN_NAME,
    config=config,
    reinit=True,
)

print("wandb run url:", run.get_url())

### 在你原来的 Cell 3（实验配置 & wandb init） 末尾，加上几行：
## 全局日志缓存（把每一步 stdout/stderr 收集到这里）
LOG_BUFFERS = []

## 准备 logs 目录 & 日志文件名（数据集 + split + 模型 + 时间戳）
# os.makedirs("logs", exist_ok=True)
# LOG_FILENAME = f"logs/{DATASET_NAME}_{SPLIT}_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"
os.makedirs("experiments/logs", exist_ok=True)
LOG_FILENAME = f"experiments/logs/{DATASET_NAME}_{SPLIT}_{MODEL_NAME}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt"

print("Log will be saved to:", LOG_FILENAME)

wandb run url: https://wandb.ai/chengjiezheng-umb/TRACE_LogicNLI/runs/ig7dj2sh
Log will be saved to: experiments/logs/LogicNLI_dev_gpt-4o_20251203_045641.txt


In [6]:
# Cell 4：封装一个运行命令的小工具（方便打印 & 报错）(Version 4)
import subprocess
import sys
import textwrap

def run_cmd(cmd: str, tag: str = ""):
    global LOG_FILENAME

    header = f"[{tag}] " if tag else ""

    print("\n" + "="*80)
    print(f">>> {header}Running command:")
    print(textwrap.indent(cmd, prefix="    "))
    print("="*80 + "\n")

    # ⭐ 不再 capture_output，而是边读边写
    process = subprocess.Popen(
        cmd,
        shell=True,
        cwd=PROJECT_DIR,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,   # 行缓冲，支持 tqdm 刷新
    )

    # 边跑边打印到 notebook，同时写入日志文件
    with open(LOG_FILENAME, "a") as f:
        for line in process.stdout:
            sys.stdout.write(line)
            sys.stdout.flush()
            f.write(line)

    process.wait()

    if process.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")

    return True

In [7]:
# Cell 4：封装一个运行命令的小工具（只显示进度条）(Version 5)
import subprocess
import sys
import re

def run_cmd(cmd: str, tag: str = ""):
    global LOG_FILENAME

    print(f"\n>>> Running [{tag}] ...\n")

    # 启动子进程
    process = subprocess.Popen(
        cmd,
        shell=True,
        cwd=PROJECT_DIR,
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True,
        bufsize=1,   # 可让 tqdm 刷新
    )

    tqdm_pattern = re.compile(r"\d+%|\|\s*[\d/]+|\sit/s|\sETA")

    with open(LOG_FILENAME, "a") as f:
        for line in process.stdout:

            # 1) 写入日志文件（完整内容）
            f.write(line)

            # 2) Jupyter 只显示 tqdm 相关输出
            if tqdm_pattern.search(line):
                sys.stdout.write(line)
                sys.stdout.flush()

    process.wait()

    if process.returncode != 0:
        raise RuntimeError(f"Command failed: {cmd}")

    print(f"\n>>> [{tag}] Done.\n")
    return True


In [None]:
# Cell 5：Step A – translate_decompose（LogicNLI-dev）
cmd_translate = f"""
python translate_decompose.py \
  --api_key "$OPENAI_API_KEY" \
  --model_name "{MODEL_NAME}" \
  --data_path "./data" \
  --dataset_name "{DATASET_NAME}" \
  --split {SPLIT} \
  --max_new_tokens {MAX_NEW_TOKENS} \
  --batch_num {BATCH_NUM}
"""

run_cmd(cmd_translate, tag="TRANSLATE_DECOMPOSE")



>>> Running [TRANSLATE_DECOMPOSE] ...

  0%|          | 0/300 [00:00<?, ?it/s]Translation response:  To translate the given context and conjecture into logical form, we will follow the steps outlined in the task description. Let's begin:
  0%|          | 1/300 [00:31<2:37:53, 31.68s/it]Translation response:  To translate the given context and conjecture into logical form, we will follow the steps outlined in the task description.
  1%|          | 2/300 [01:01<2:32:39, 30.74s/it]Translation response:  To translate the given context and conjecture into logical form, we will follow the steps outlined in the task description.
  1%|          | 3/300 [01:31<2:30:49, 30.47s/it]Translation response:  To translate the given context and conjecture into logical form, we will follow the steps outlined in the task description. Let's begin:


In [7]:
# Cell 6：Step B – negate（生成 negated_data）
cmd_negate = f"""
python negate.py \
  --dataset_name "{DATASET_NAME}" \
  --model "{MODEL_NAME}"
"""

start_time = time.time()
### out_negate = run_cmd(cmd_negate)
out_negate = run_cmd(cmd_negate, tag="NEGATE")
wandb.log({"time_negate_sec": time.time() - start_time})


>>> [NEGATE] Running command:

    python negate.py   --dataset_name "LogicNLI"   --model "gpt-4o"




In [9]:
# Cell 7：Step C1 – search_resolve（no negation）
cmd_search_no_neg = f"""
python search_resolve.py \
  --api_key "$OPENAI_API_KEY" \
  --model_name "{MODEL_NAME}" \
  --data_path "./data" \
  --dataset_name "{DATASET_NAME}" \
  --split {SPLIT} \
  --negation False \
  --max_new_tokens {MAX_NEW_TOKENS} \
  --batch_num {BATCH_NUM}
"""

start_time = time.time()
### out_search_no_neg = run_cmd(cmd_search_no_neg)
out_search_no_neg = run_cmd(cmd_search_no_neg, tag="SEARCH_RESOLVE_NO_NEG")
wandb.log({"time_search_no_neg_sec": time.time() - start_time})


In [None]:
# Cell 8：Step C2 – search_resolve（with negation）
cmd_search_neg = f"""
python search_resolve.py \
  --api_key "$OPENAI_API_KEY" \
  --model_name "{MODEL_NAME}" \
  --data_path "./data" \
  --dataset_name "{DATASET_NAME}" \
  --split {SPLIT} \
  --negation True \
  --max_new_tokens {MAX_NEW_TOKENS} \
  --batch_num {BATCH_NUM}
"""

start_time = time.time()
### out_search_neg = run_cmd(cmd_search_neg)
out_search_neg = run_cmd(cmd_search_neg, tag="SEARCH_RESOLVE_NEG")
wandb.log({"time_search_neg_sec": time.time() - start_time})


In [None]:
# Cell 9：Step D – evaluate + 解析 accuracy + 记录到 wandb
import re

cmd_eval = f"""
python evaluate.py \
  --dataset_name "{DATASET_NAME}" \
  --model "{MODEL_NAME}"
"""

start_time = time.time()
### out_eval = run_cmd(cmd_eval)
out_eval = run_cmd(cmd_eval, tag="EVALUATE")
elapsed_eval = time.time() - start_time
wandb.log({"time_evaluate_sec": elapsed_eval})

# 从 stdout 中解析 Accuracy
match = re.search(r"Accuracy:\s*([0-9.]+)", out_eval)
if match:
    acc = float(match.group(1))
    print(f"\nParsed Accuracy = {acc:.4f}")
    wandb.log({"dev_accuracy": acc})
else:
    print("\n⚠ 没能从 evaluate 输出中解析出 Accuracy，请检查 evaluate.py 的打印格式。")


In [None]:
# Cell 10：收尾（wandb.finish）(version 3)
# 可选：把 log 文件同步到 wandb
try:
    wandb.save(LOG_FILENAME)
    print("Log file also saved to wandb.")
except Exception as e:
    print("wandb.save failed:", e)

wandb.finish()
print("Done. Run URL:", run.get_url())
print("Final log file:", LOG_FILENAME)
