In [1]:
from torchrl.envs.llm import ChatEnv
from torchrl.envs import Transform
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import httpx
import os
import warnings
import asyncio
from typing import Optional, List

import httpx
import torch
import tensordict
from tensordict import TensorDict
from torchrl.data.tensor_specs import CompositeSpec, Unbounded

from torchrl.modules.llm import TransformersWrapper
from torchrl.collectors.llm import LLMCollector

load_dotenv("../.env")
tensordict.set_list_to_stack(True).set() 

model = "McClain/plasmidgpt-addgene-gpt2"

  from .autonotebook import tqdm as notebook_tqdm
2025-09-14 16:38:40,343	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 09-14 16:38:42 [__init__.py:216] Automatically detected platform cpu.
INFO 09-14 16:38:43 [importing.py:63] Triton not installed or not compatible; certain GPU-related functions will not be available.


In [2]:
tokenizer = AutoTokenizer.from_pretrained(model, token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(model,  token=os.getenv("HF_TOKEN"))

In [3]:
env = ChatEnv(
    input_mode="text",
    batch_size=(1,),
    tokenizer=tokenizer,
)

In [4]:
reset = env.reset(
        TensorDict(
            {"query": ["AATG"]},
            batch_size=(1,),
    )
)
print(reset)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        query: NonTensorStack(
            ['AATG'],
            batch_size=torch.Size([1]),
            device=None),
        terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        text: Text(
            prompt=NonTensorStack(
                ['    <|im_start|>user\nAATG<|im_end|>\n<|im_start...,
                batch_size=torch.Size([1]),
                device=None),
            response=None,
            full=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)


In [5]:
rand = env.rand_action()
rand

TensorDict(
    fields={
        text: Text(
            full=NonTensorData(data=a string, batch_size=torch.Size([1]), device=None),
            response=None,
            prompt=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [6]:
class RewardTransform(Transform):
    """Assign rewards by calling external scoring endpoints in parallel.

    Endpoints:
      - /amrfinder
      - /prodigal
      - /plannotate-fast

    Each endpoint expects a POST with a *plain text* body (the LLM response).
    Reward mapping (by number of successful endpoint responses, HTTP 200):
      - 3 successes -> 2.0
      - 1–2 successes -> 1.0
      - 0 successes -> 0.0
    """

    def __init__(self, rewards_server_url: Optional[str] = "http://server:8080", timeout_s: float = 60.0):
        self.rewards_server_url = rewards_server_url.rstrip("/")
        self.timeout_s = timeout_s
        # Reuse a single async client across calls
        self.client = httpx.AsyncClient(
            base_url=self.rewards_server_url,
            timeout=httpx.Timeout(timeout_s),
            follow_redirects=True,
        )
        self._test_connection()

        # Relative paths to hit (joined to base_url by httpx)
        self._endpoints: List[str] = [
            "/amrfinder",
            "/prodigal",
            "/plannotate-fast",
        ]

    def _test_connection(self):
        """Best-effort connectivity check (sync)."""
        import requests
        try:
            r = requests.get(self.rewards_server_url + "/healthz", timeout=5)
            r.raise_for_status()
        except requests.RequestException as e:
            raise RuntimeError(f"Failed to connect to rewards server: {e}")

    async def _post_text(self, endpoint: str, text: str) -> bool:
        """POST raw text to an endpoint; return True if HTTP 200."""
        try:
            # Plain text body as specified
            resp = await self.client.post(
                endpoint,
                content=text.encode("utf-8"),
                headers={"Content-Type": "text/plain; charset=utf-8"},
            )
            ok = (resp.status_code == 200)
            if not ok:
                # Optional: log body for debugging (could be large—trim as needed)
                try:
                    body_preview = (resp.text or "")[:500]
                except Exception:
                    body_preview = "<unreadable body>"
                torchrl_logger.warning(
                    f"Endpoint {endpoint} returned {resp.status_code}. Body preview: {body_preview}"
                )
            return ok
        except Exception as e:
            torchrl_logger.warning(f"Endpoint {endpoint} call failed: {e}")
            return False

    async def _call(self, tensordict: TensorDict) -> TensorDict:
        """Process the tensordict and assign rewards based on endpoint results."""
        # Grab the most recent LLM response
        history = tensordict[0]["history"]
        last_item = history[-1]
        llm_text: str = getattr(last_item, "content", "")

        # Fire all three calls concurrently
        tasks = [self._post_text(ep, llm_text) for ep in self._endpoints]
        results = await asyncio.gather(*tasks, return_exceptions=False)

        successes = sum(bool(r) for r in results)

        # Reward mapping: 3 -> 2.0, 1–2 -> 1.0, 0 -> 0.0
        if successes == 3:
            reward_val = 2.0
            torchrl_logger.info("All endpoints succeeded (3/3). Reward = 2.0")
        elif successes >= 1:
            reward_val = 1.0
            torchrl_logger.info(f"Partial success ({successes}/3). Reward = 1.0")
        else:
            reward_val = 0.0
            torchrl_logger.info("No endpoints succeeded (0/3). Reward = 0.0")

        # Rewards have a trailing singleton dimension
        tensordict["reward"] = torch.full((1, 1), float(reward_val))
        return tensordict

    def transform_reward_spec(self, reward_spec: CompositeSpec) -> CompositeSpec:
        """Declare reward shape (B, 1)."""
        reward_spec["reward"] = Unbounded(
            shape=reward_spec.shape + (1,),
            dtype=torch.float32
        )
        return reward_spec

    async def aclose(self):
        """Gracefully close the async client when you're done with the env."""
        try:
            await self.client.aclose()
        except Exception:
            pass

    def __del__(self):
        # Best-effort close for non-async teardown contexts.
        try:
            if not self.client.is_closed:
                # Schedule close without awaiting (may be skipped on interpreter shutdown)
                loop = asyncio.get_event_loop()
                if loop.is_running():
                    loop.create_task(self.client.aclose())
        except Exception:
            pass




In [7]:
env = env.append_transform(RewardTransform())

RuntimeError: Failed to connect to rewards server: HTTPConnectionPool(host='server', port=8080): Max retries exceeded with url: /healthz (Caused by NameResolutionError("<urllib3.connection.HTTPConnection object at 0x178ac7c10>: Failed to resolve 'server' ([Errno 8] nodename nor servname provided, or not known)"))

In [None]:
from transformers import 

policy = TransformersWrapper(
    model=model,          # "McClain/plasmidgpt-addgene-gpt2"
    tokenizer=tokenizer,
    max_new_tokens=500,         # for generation
    temperature=1.0,
)

collector = LLMCollector(
    policy=policy,
    env=env,
)

In [None]:
reset = env.reset()
print(reset)