In [69]:
from torchrl.envs.llm import ChatEnv
from torchrl.envs import Transform
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import httpx
import os
import warnings
import asyncio
from typing import Optional, List, Dict, Tuple

from pprint import pprint

import httpx
import torch
import tensordict
from tensordict import TensorDict
from torchrl.data.tensor_specs import CompositeSpec, Unbounded

from torchrl.modules.llm import TransformersWrapper
from torchrl.collectors.llm import LLMCollector

load_dotenv("../.env")
tensordict.set_list_to_stack(True).set() 

model = "McClain/plasmidgpt-addgene-gpt2"

In [2]:
tokenizer = AutoTokenizer.from_pretrained(model, token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(model,  token=os.getenv("HF_TOKEN"))

tokenizer_config.json:   0%|          | 0.00/1.45k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/4.94M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/880 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/439M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/140 [00:00<?, ?B/s]

In [3]:
env = ChatEnv(
    input_mode="text",
    batch_size=(1,),
    tokenizer=tokenizer,
)

In [4]:
reset = env.reset(
        TensorDict(
            {"query": ["AATG"]},
            batch_size=(1,),
    )
)
print(reset)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        query: NonTensorStack(
            ['AATG'],
            batch_size=torch.Size([1]),
            device=None),
        terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        text: Text(
            prompt=NonTensorStack(
                ['    <|im_start|>user\nAATG<|im_end|>\n<|im_start...,
                batch_size=torch.Size([1]),
                device=None),
            full=None,
            response=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)


In [5]:
rand = env.rand_action()
rand

TensorDict(
    fields={
        text: Text(
            full=NonTensorData(data=a string, batch_size=torch.Size([1]), device=None),
            response=None,
            prompt=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [70]:
class RewardTransform(Transform):
    """Assign rewards by calling external scoring endpoints in parallel."""

    def __init__(
        self,
        rewards_server_url: Optional[str] = "http://localhost:8080",
        timeout_s: float = 60.0,
    ):
        self.rewards_server_url = rewards_server_url.rstrip("/")
        self.client = httpx.AsyncClient(
            base_url=self.rewards_server_url,
            timeout=httpx.Timeout(timeout_s),
            follow_redirects=True,
        )
        self._test_connection()

        # Exact paths & params per your curl commands
        self._endpoints: List[Dict] = [
            {"name": "amrfinder", "path": "/amrfinder/text", "params": {"is_protein": "false", "format": "json"}},
            {"name": "prodigal",  "path": "/prodigal/text",  "params": {"mode": "auto",   "format": "json"}},
            {"name": "plannotate","path": "/plannotate/fast","params": {}},
        ]

        # Plain text in, accept json or text back
        self._headers = {
            "Content-Type": "text/plain; charset=utf-8",
            "Accept": "application/json, text/plain; q=0.9, */*; q=0.1",
        }

    def _test_connection(self):
        import requests
        try:
            r = requests.get(self.rewards_server_url + "/health", timeout=5)
            r.raise_for_status()
        except requests.RequestException as e:
            raise RuntimeError(f"Failed to connect to rewards server: {e}")

    async def _post_text(self, path: str, params: Dict, text: str, name: str) -> Tuple[bool, Dict]:
        try:
            resp = await self.client.post(
                path, params=params, content=text.encode("utf-8"), headers=self._headers
            )
            ok = (resp.status_code == 200)
            if not ok:
                try:
                    preview = (resp.text or "")[:300]
                except Exception:
                    preview = "<unreadable>"
                print(f"[{name}] {path} -> {resp.status_code} | {preview}")
            return {"status": ok, "name": name, "reponse": resp.json()}
        except Exception as e:
            print(f"[{name}] {path} call failed: {e}")
            return False

    async def combine_rewards(info_dicts: List[Dict]) -> float:
        """
        takes all the detailed info from the informatics server and combined it into a reward.
        """
        return 

    async def _call(self, td: TensorDict) -> TensorDict:
        # Accept either td["text"] (string or obj with .response) or td["query"] (list/stack/string)
        llm_text = None
        if "text" in td.keys(True):
            t = td["text"]
            llm_text = t if isinstance(t, str) else getattr(t, "response", None)
        if not llm_text and "query" in td.keys(True):
            q = td.get("query")
            try:
                llm_text = q[-1] if hasattr(q, "__getitem__") else q
            except Exception:
                llm_text = q
        if not llm_text:
            print("No LLM text found; reward=0.0")
            td["reward"] = torch.zeros(td.batch_size + (1,), dtype=torch.float32)
            return td
        if not isinstance(llm_text, str):
            llm_text = str(llm_text)

        results = await asyncio.gather(*[
            self._post_text(cfg["path"], cfg.get("params", {}), llm_text, cfg["name"])
            for cfg in self._endpoints
        ])

        pprint(f"results: {results}")
        successes = sum(bool(r) for r in results)
        reward_val = 2.0 if successes == 3 else 1.0 if successes >= 1 else 0.0
        print(f"Endpoint successes: {successes}/{len(self._endpoints)} → reward {reward_val}")

        td["reward"] = torch.full(td.batch_size + (1,), float(reward_val))
        return td

    def transform_reward_spec(self, reward_spec: CompositeSpec) -> CompositeSpec:
        reward_spec["reward"] = Unbounded(shape=reward_spec.shape + (1,), dtype=torch.float32)
        return reward_spec

    async def aclose(self):
        try:
            await self.client.aclose()
        except Exception:
            pass

    def __del__(self):
        try:
            if not self.client.is_closed:
                loop = asyncio.get_event_loop()
                if loop.is_running():
                    loop.create_task(self.client.aclose())
        except Exception:
            pass

In [73]:
test_plasmid = "GGGCGAATTCGAGCTCGGTACCCGGGGATCCTCTAGAGTCGACCTGCAGGCATGCAAGCTTGAGTATTCTATAGTGTCACCTAAATAGCTTGGCGTAATCATGGTCATAGCTGTTTCCTGTGTGAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCTAATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCCAGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCTCTTCCGCTTCCTCGCTCACTGACTCGCTGCGCTCGGTCGTTCGGCTGCGGCGAGCGGTATCAGCTCACTCAAAGGCGGTAATACGGTTATCCACAGAATCAGGGGATAACGCAGGAAAGAACATGAATTAATTCTCATGTTTGACAGCTTATCATCGATTAGCTTTAATGCGGTAGTTTATCACAGTTAAATTGCTAACGCAGTCAGGCACCGTGTATGAAATCTAACAATGCGCTCATCGTCATCCTCGGCACCGTCACCCTGGATGCTGTAGGCATAGGCTTGGTTATGCCGGTACTGCCGGGCCTCTTGCGGGATATCGTCCATTCCGACAGCATCGCCAGTCACTATGGCGTGCTGCTAGCGCTATATGCGTTGATGCAATTTCTATGCGCACCCGTTCTCGGAGCACTGTCCGACCGCTTTGGCCGCCGCCCAGTCCTGCTCGCTTCGCTACTTGGAGCCACTATCGACTACGCGATCATGGCGACCACACCCGTCCTGTGGATTCTCTACGCCGGACGCATCGTGGCCGGCATCACCGGCGCCACAGGTGCGGTTGCTGGCGCCTATATCGCCGACATCACCGATGGGGAAGATCGGGCTCGCCACTTCGGGCTCATGAGCGCTTGTTTCGGCGTGGGTATGGTGGCAGGCCCCGTGGCCGGGGGACTGTTGGGCGCCATCTCCTTACATGCACCATTCCTTGCGGCGGCGGTGCTCAACGGCCTCAACCTACTACTGGGCTGCTTCCTAATGCAGGAGTCGCATAAGGGAGAGCGCCGACCGATGCCCTTGAGAGCCTTCAACCCAGTCAGCTCCTTCCGGTGGGCGCGGGGCATGACTATCGTCGCCGCACTTATGACTGTCTTCTTTATCATGCAACTCGTAGGACAGGTGCCGGCAGCGCTCTGGGTCATTTTCGGCGAGGACCGCTTTCGCTGGAGCGCGACGATGATCGGCCTGTCGCTTGCGGTATTCGGAATCTTGCACGCCCTCGCTCAAGCCTTCGTCACTGGTCCCGCCACCAAACGTTTCGGCGAGAAGCAGGCCATTATCGCCGGCATGGCGGCCGACGCGCTGGGCTACGTCTTGCTGGCGTTCGCGACGCGAGGCTGGATGGCCTTCCCCATTATGATTCTTCTCGCTTCCGGCGGCATCGGGATGCCCGCGTTGCAGGCCATGCTGTCCAGGCAGGTAGATGACGACCATCAGGGACAGCTTCAAGGATCGCTCGCGGCTCTTACCAGCCTAACTTCGATCACTGGACCGCTGATCGTCACGGCGATTTATGCCGCCTCGGCGAGCACATGGAACGGGTTGGCATGGATTGTAGGCGCCGCCCTATACCTTGTCTGCCTCCCCGCGTTGCGTCGCGGTGCATGGAGCCGGGCCACCTCGACCTGAATGGAAGCCGGCGGCACCTCGCTAACGGATTCACCACTCCAAGAATTGGAGCCAATCAATTCTTGCGGAGAACTGTGAATGCGCAAACCAACCCTTGGCAGAACATATCCATCGCGTCCGCCATCTCCAGCAGCCGCACGCGGCGCATCTCGGGCAGCGTTGGGTCCTGGCCACGGGTGCGCATGATCGTGCTCCTGTCGTTGAGGACCCGGCTAGGCTGGCGGGGTTGCCTTACTGGTTAGCAGAATGAATCACCGATACGCGAGCGAACGTGAAGCGACTGCTGCTGCAAAACGTCTGCGACCTGAGCAACAACATGAATGGTCTTCGGTTTCCGTGTTTCGTAAAGTCTGGAAACGCGGAAGTCAGCGCCCTGCACCATTATGTTCCGGATCTGCATCGCAGGATGCTGCTGGCTACCCTGTGGAACACCTACATCTGTATTAACGAAGCGCTGGCATTGACCCTGAGTGATTTTTCTCTGGTCCCGCCGCATCCATACCGCCAGTTGTTTACCCTCACAACGTTCCAGTAACCGGGCATGTTCATCATCAGTAACCCGTATCGTGAGCATCCTCTCTCGTTTCATCGGTATCATTACCCCCATGAACAGAAATTCCCCCTTACACGGAGGCATCAAGTGACCAAACAGGAAAAAACCGCCCTTAACATGGCCCGCTTTATCAGAAGCCAGACATTAACGCTTCTGGAGAAACTCAACGAGCTGGACGCGGATGAACAGGCAGACATCTGTGAATCGCTTCACGACCACGCTGATGAGCTTTACCGCAGCTGCCTCGCGCGTTTCGGTGATGACGGTGAAAACCTCTGACACATGCAGCTCCCGGAGACGGTCACAGCTTGTCTGTAAGCGGATGCCGGGAGCAGACAAGCCCGTCAGGGCGCGTCAGCGGGTGTTGGCGGGTGTCGGGGCGCAGCCATGACCCAGTCACGTAGCGATAGCGGAGTGTATACTGGCTTAACTATGCGGCATCAGAGCAGATTGTACTGAGAGTGCACCATATGCGGTGTGAAATACCGCACAGATGCGTAAGGAGAAAATACCGCATCAGGCGCTCTTCCGCTTCCTCGCTCACTGACTCGCTGCGCTCGGTCGTTCGGCTGCGGCGAGCGGTATCAGCTCACTCAAAGGCGGTAATACGGTTATCCACAGAATCAGGGGATAACGCAGGAAAGAACATGTGAGCAAAAGGCCAGCAAAAGGCCAGGAACCGTAAAAAGGCCGCGTTGCTGGCGTTTTTCCATAGGCTCCGCCCCCCTGACGAGCATCACAAAAATCGACGCTCAAGTCAGAGGTGGCGAAACCCGACAGGACTATAAAGATACCAGGCGTTTCCCCCTGGAAGCTCCCTCGTGCGCTCTCCTGTTCCGACCCTGCCGCTTACCGGATACCTGTCCGCCTTTCTCCCTTCGGGAAGCGTGGCGCTTTCTCATAGCTCACGCTGTAGGTATCTCAGTTCGGTGTAGGTCGTTCGCTCCAAGCTGGGCTGTGTGCACGAACCCCCCGTTCAGCCCGACCGCTGCGCCTTATCCGGTAACTATCGTCTTGAGTCCAACCCGGTAAGACACGACTTATCGCCACTGGCAGCAGCCACTGGTAACAGGATTAGCAGAGCGAGGTATGTAGGCGGTGCTACAGAGTTCTTGAAGTGGTGGCCTAACTACGGCTACACTAGAAGGACAGTATTTGGTATCTGCGCTCTGCTGAAGCCAGTTACCTTCGGAAAAAGAGTTGGTAGCTCTTGATCCGGCAAACAAACCACCGCTGGTAGCGGTGGTTTTTTTGTTTGCAAGCAGCAGATTACGCGCAGAAAAAAAGGATCTCAAGAAGATCCTTTGATCTTTTCTACGGGGTCTGACGCTCAGTGGAACGAAAACTCACGTTAAGGGATTTTGGTCATGAGATTATCAAAAAGGATCTTCACCTAGATCCTTTTAAATTAAAAATGAAGTTTTAAATCAATCTAAAGTATATATGAGTAAACTTTGGCTGACAGTTACCAATGCTTAATCAGTGAGGCACCTATCTCAGCGATCTGTCTATTTCGTTCATCCATAGTTGCCTGACTCCCCGTCGTGTAGATAACTACGATACGGGAGGGCTTACCATCTGGCCCCAGTGCTGCAATGATACCGCGAGACCCACGCTCACCGGCTCCAGATTTATCAGCAATAAACCAGCCAGCCGGAAGGGCCGAGCGCAGAAGTGGTCCTGCAACTTTATCCGCCTCCATCCAGTCTATTAATTGTTGCCGGGAAGCTAGAGTAAGTAGTTCGCCAGTTAATAGTTTGCGCAACGTTGTTGCCATTGCGGCATCGTGGTGTCACGCTCGTCGTTTGGTATGGCTTCATTCAGCTCCGGTTCCCAACGATCAAGGCGAGTTACATGATCCCCCATGTTGTGCAAAAAAGCGGTTAGCTCCTTCGGTCCTCCGATCGTTGTCAGAAGTAAGTTGGCCGCAGTGTTATCACTCATGGTTATGGCAGCACTGCATAATTCTCTTACTGTCATGCCATCCGTAAGATGCTTTTCTGTGACTGGTGAGTACTCAACCAAGTCATTCTGAGAATAGTGTATGCGGCGACCGAGTTGCTCTTGCCCGGCGTCAACACGGGATAATACCGCGCCACATAGCAGAACTTTAAAAGTGCTCATCATTGGAAAACGTTCTTCGGGGCGAAAACTCTCAAGGATCTTACCGCTGTTGAGATCCAGTTCGATGTAACCCACTCGTGCACCCAACTGATCTTCAGCATCTTTTACTTTCACCAGCGTTTCTGGGTGAGCAAAAACAGGAAGGCAAAATGCCGCAAAAAAGGGAATAAGGGCGACACGGAAATGTTGAATACTCATACTCTTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTATTGTCTCATGAGCGGATACATATTTGAATGTATTTAGAAAAATAAACAAATAGGGGTTCCGCGCACATTTCCCCGAAAAGTGCCACCTGACGTCTAAGAAACCATTATTATCATGACATTAACCTATAAAAATAGGCGTATCACGAGGCCCTTTCGTCTCGCGCGTTTCGGTGATGACGGTGAAAACCTCTGACACATGCAGCTCCCGGAGACGGTCACAGCTTGTCTGTAAGCGGATGCCGGGAGCAGACAAGCCCGTCAGGGCGCGTCAGCGGGTGTTGGCGGGTGTCGGGGCTGGCTTAACTATGCGGCATCAGAGCAGATTGTACTGAGAGTGCACCATATGCGGTGTGAAATACCGCACAGATGCGTAAGGAGAAAATACCGCATCAGGCGAAATTGTAAACGTTAATATTTTGTTAAAATTCGCGTTAAATATTTGTTAAATCAGCTCATTTTTTAACCAATAGGCCGAAATCGGCAAAATCCCTTATAAATCAAAAGAATAGACCGAGATAGGGTTGAGTGTTGTTCCAGTTTGGAACAAGAGTCCACTATTAAAGAACGTGGACTCCAACGTCAAAGGGCGAAAAACCGTCTATCAGGGCGATGGCCCACTACGTGAACCATCACCCAAATCAAGTTTTTTGCGGTCGAGGTGCCGTAAAGCTCTAAATCGGAACCCTAAAGGGAGCCCCCGATTTAGAGCTTGACGGGGAAAGCCGGCGAACGTGGCGAGAAAGGAAGGGAAGAAAGCGAAAGGAGCGGGCGCTAGGGCGCTGGCAAGTGTAGCGGTCACGCTGCGCGTAACCACCACACCCGCCGCGCTTAATGCGCCGCTACAGGGCGCGTCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCGGGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGGGTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATTGTAATACGACTCACTATA"


In [74]:
td = TensorDict({}, batch_size=(1,))
class DummyText:
    def __init__(self, response: str):
        self.response = response

td['text'] = test_plasmid

rt = RewardTransform("http://server:8080")

scored = await rt._call(td)
print("Reward:", scored["reward"])

("results: [{'status': True, 'name': 'amrfinder', 'reponse': {'genes': "
 "[{'protein_id': 'NA', 'contig_id': 'sequence', 'start': 541, 'stop': 1728, "
 "'strand': '+', 'element_symbol': 'tet(C)', 'element_name': 'tetracycline "
 "efflux MFS transporter Tet(C)', 'scope': 'core', 'type': 'AMR', 'subtype': "
 "'AMR', 'class': 'TETRACYCLINE', 'subclass': 'TETRACYCLINE', 'method': "
 "'EXACTX', 'target_length': 396, 'reference_sequence_length': 396, "
 "'percent_coverage_of_reference': 100.0, 'percent_identity_to_reference': "
 "100.0, 'alignment_length': 396, 'closest_reference_accession': "
 "'WP_010891057.1', 'closest_reference_name': 'tetracycline efflux MFS "
 "transporter Tet(C)', 'hmm_accession': 'NA', 'hmm_description': 'NA'}, "
 "{'protein_id': 'NA', 'contig_id': 'sequence', 'start': 4064, 'stop': 4606, "
 "'strand': '-', 'element_symbol': 'blaTEM', 'element_name': 'TEM family class "
 "A beta-lactamase', 'scope': 'core', 'type': 'AMR', 'subtype': 'AMR', "
 "'class': 'BETA-LACTAM'

In [7]:
env = env.append_transform(RewardTransform())

RuntimeError: Failed to connect to rewards server: 404 Client Error: Not Found for url: http://server:8080/healthz

In [None]:
from transformers import 

policy = TransformersWrapper(
    model=model,          # "McClain/plasmidgpt-addgene-gpt2"
    tokenizer=tokenizer,
    max_new_tokens=500,         # for generation
    temperature=1.0,
)

collector = LLMCollector(
    policy=policy,
    env=env,
)

In [None]:
reset = env.reset()
print(reset)