In [1]:
from torchrl.envs.llm import ChatEnv
from torchrl.envs import Transform
from transformers import AutoTokenizer, AutoModelForCausalLM
from dotenv import load_dotenv
import httpx
import os
import warnings
import asyncio
from typing import Optional, List, Dict, Tuple

from pprint import pprint

import httpx
import torch
import tensordict
from tensordict import TensorDict
from torchrl.data.tensor_specs import CompositeSpec, Unbounded

from torchrl.modules.llm import TransformersWrapper
from torchrl.collectors.llm import LLMCollector

load_dotenv("../.env")
tensordict.set_list_to_stack(True).set() 

model = "McClain/plasmidgpt-addgene-gpt2"

INFO 09-15 13:57:28 [__init__.py:216] Automatically detected platform cuda.


In [None]:
#check if docker server is up
import requests
r = requests.get("http://server:8080/health")
r.json()

In [2]:
tokenizer = AutoTokenizer.from_pretrained(model, token=os.getenv("HF_TOKEN"))
model = AutoModelForCausalLM.from_pretrained(model,  token=os.getenv("HF_TOKEN"))

In [47]:
env = ChatEnv(
    input_mode="text",
    batch_size=(1,),
    tokenizer=tokenizer,
)

In [48]:
reset = env.reset(
        TensorDict(
            {"query": ["AATG"]},
            batch_size=(1,),
    )
)
print(reset)

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        query: NonTensorStack(
            ['AATG'],
            batch_size=torch.Size([1]),
            device=None),
        terminated: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        text: Text(
            prompt=NonTensorStack(
                ['    <|im_start|>user\nAATG<|im_end|>\n<|im_start...,
                batch_size=torch.Size([1]),
                device=None),
            response=None,
            full=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)


In [49]:
rand = env.rand_action()
rand

TensorDict(
    fields={
        text: Text(
            full=NonTensorData(data=a string, batch_size=torch.Size([1]), device=None),
            prompt=None,
            response=None,
            batch_size=torch.Size([1]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

In [50]:
class RewardTransform(Transform):
    """Assign rewards by calling external scoring endpoints in parallel."""

    def __init__(
        self,
        rewards_server_url: Optional[str] = "http://localhost:8080",
        timeout_s: float = 60.0,
    ):
        super().__init__(in_keys=[], out_keys=["reward"])

        self.rewards_server_url = rewards_server_url.rstrip("/")
        self.client = httpx.AsyncClient(
            base_url=self.rewards_server_url,
            timeout=httpx.Timeout(timeout_s),
            follow_redirects=True,
        )
        #defaults for evaluating plasmif 
        self.require = {"ori": None, "amr": None, "mcs": True, "promoter": None}
        self.weights = {"ori": 0.30, "amr": 0.30, "mcs": 0.20, "promoter": 0.20}
        self.gc = {"target": 0.55, "weight": 0.05, "tolerance": 0.10}
        self._test_connection()

        # Exact paths & params per your curl commands
        self._endpoints: List[Dict] = [
            {"name": "amrfinder", "path": "/amrfinder/text", "params": {"is_protein": "false", "format": "json"}},
            {"name": "prodigal",  "path": "/prodigal/text",  "params": {"mode": "auto",   "format": "json"}},
            {"name": "plannotate","path": "/plannotate/fast","params": {}},
        ]

        # Plain text in, accept json or text back
        self._headers = {
            "Content-Type": "text/plain; charset=utf-8",
            "Accept": "application/json, text/plain; q=0.9, */*; q=0.1",
        }

    def _test_connection(self):
        import requests
        try:
            r = requests.get(self.rewards_server_url + "/health", timeout=5)
            r.raise_for_status()
        except requests.RequestException as e:
            raise RuntimeError(f"Failed to connect to rewards server: {e}")

    async def _post_text(self, path: str, params: Dict, text: str, name: str) -> Tuple[bool, Dict]:
        try:
            resp = await self.client.post(
                path, params=params, content=text.encode("utf-8"), headers=self._headers
            )
            ok = (resp.status_code == 200)
            if not ok:
                try:
                    preview = (resp.text or "")[:300]
                except Exception:
                    preview = "<unreadable>"
                print(f"[{name}] {path} -> {resp.status_code} | {preview}")
            return {"status": ok, "name": name, "reponse": resp.json()}
        except Exception as e:
            print(f"[{name}] {path} call failed: {e}")
            return False

    async def combine_rewards(self, info_dicts: List[Dict], overrides: dict | None = None) -> float:
        """
        Combine amrfinder, prodigal, and plannotate responses into a single reward.
        Scoring:
          - Presence of ORI, AMR gene, MCS, promoter (weighted), each multiplied by percent identity if available.
          - Small GC proximity bonus (around 55%) as a tie-breaker.
        Extensible via self.require / self.weights / self.gc.
        """
        require = dict(self.require)
        weights = dict(self.weights)
        gcconf  = dict(self.gc)
    
        # Merge per-example overrides if provided
        if overrides:
            if "require" in overrides: require.update(overrides["require"] or {})
            if "weights" in overrides: weights.update(overrides["weights"] or {})
            if "gc" in overrides:      gcconf.update(overrides["gc"] or {})

        # filter out failed calls (False) and normalize shape
        records = [r for r in info_dicts if isinstance(r, dict) and r.get("status") is not False]
        by_name = {r.get("name"): (r.get("reponse") or {}) for r in records}
    
        amr = by_name.get("amrfinder") or {}
        prodigal = by_name.get("prodigal") or {}
        plannotate = by_name.get("plannotate") or []
    
        # helpers
        def _lc(s): return str(s or "").lower()
        def _contains_any(text: str, needles) -> bool:
            t = _lc(text)
            return any(_lc(n) in t for n in (needles or []))
        def _best_pident(entries) -> float:
            """Return best percent identity (0..1) among provided plannotate entries with 'pident'."""
            best = 0.0
            for e in entries:
                try:
                    p = float(e.get("pident", 100.0)) / 100.0
                except Exception:
                    p = 1.0
                best = max(best, max(0.0, min(1.0, p)))
            return best if best > 0 else 1.0  # default to 1.0 if none provided
    
        # ---- ORI / MCS / Promoter from plannotate ----
        ori_present, mcs_present, prom_present = False, False, False
        ori_pident, mcs_pident, prom_pident = 1.0, 1.0, 1.0
    
        if isinstance(plannotate, list):
            # buckets
            ori_entries, mcs_entries, prom_entries = [], [], []
    
            for feat in plannotate:
                name = str(feat.get("Feature") or "")
                desc = str(feat.get("Description") or "")
                typ  = str(feat.get("Type") or "")
                text = f"{name} {desc} {typ}"
    
                # ORI
                if self.require["ori"] is None:
                    if typ.lower() == "rep_origin" or _contains_any(text, ["ori", "colE1", "pmb1", "pbr322", "puc"]):
                        ori_entries.append(feat)
                else:
                    if _contains_any(text, self.require["ori"]):
                        ori_entries.append(feat)
    
                # MCS
                if self.require["mcs"] is True:
                    if _contains_any(text, ["mcs", "multiple cloning site"]):
                        mcs_entries.append(feat)
                elif isinstance(self.require["mcs"], list):
                    if _contains_any(text, self.require["mcs"]):
                        mcs_entries.append(feat)
    
                # Promoter
                if self.require["promoter"] is None:
                    if typ.lower() == "promoter" or _contains_any(text, ["promoter"]):
                        prom_entries.append(feat)
                else:
                    if _contains_any(text, self.require["promoter"]):
                        prom_entries.append(feat)

            if ori_entries:
                ori_present = True
                ori_pident = _best_pident(ori_entries)
            if mcs_entries:
                mcs_present = True
                mcs_pident = _best_pident(mcs_entries)
            if prom_entries:
                prom_present = True
                prom_pident = _best_pident(prom_entries)
    
        # ---- AMR from amrfinder ----
        amr_present, amr_pident = False, 1.0
        if isinstance(amr, dict):
            hits = amr.get("genes", []) or []
            candidate_hits = []
            for g in hits:
                cls = str(g.get("class") or "")
                sym = str(g.get("element_symbol") or "")
                nm  = str(g.get("element_name") or "")
                hay = f"{cls} {sym} {nm}"
                if self.require["amr"] is None or _contains_any(hay, self.require["amr"]):
                    candidate_hits.append(g)
            if candidate_hits:
                amr_present = True
                # pick best identity among matching hits
                best = 1.0
                for g in candidate_hits:
                    try:
                        pid = float(g.get("percent_identity_to_reference", 100.0)) / 100.0
                    except Exception:
                        pid = 1.0
                    best = max(best, max(0.0, min(1.0, pid)))
                amr_pident = best

        # ---- main score (presence × weight × identity) ----
        w = self.weights
        main = 0.0
        if ori_present:   main += w.get("ori", 0.0) * float(ori_pident)
        if amr_present:   main += w.get("amr", 0.0) * float(amr_pident)
        if mcs_present:   main += w.get("mcs", 0.0) * float(mcs_pident)
        if prom_present:  main += w.get("promoter", 0.0) * float(prom_pident)
        main = min(main, 1.0)  # keep things tidy
    
        # ---- GC tie-breaker from prodigal ----
        gc_bonus = 0.0
        if isinstance(prodigal, dict):
            meta = prodigal.get("metadata", {}) or {}
            gc_raw = meta.get("model_gc_cont") or meta.get("gc_cont")
            if gc_raw is not None:
                try:
                    s = str(gc_raw).strip().replace("%", "")
                    v = float(s)
                    gc = v / 100.0 if "%" in str(gc_raw) or v > 1.0 else v
                    target = float(self.gc["target"])
                    tol = max(1e-6, float(self.gc["tolerance"]))
                    dist = abs(gc - target)
                    norm = max(0.0, 1.0 - (dist / tol))   # 1 at target, 0 at >= tolerance
                    gc_bonus = float(self.gc["weight"]) * norm
                except Exception:
                    gc_bonus = 0.0
    
        total = main + gc_bonus
        return float(total)


    async def _call(self, td: TensorDict) -> TensorDict:
        # 1) extract text(s)
        llm_texts: list[str]
        if "text" in td.keys(True):
            t = td["text"]
            llm_texts = [t if isinstance(t, str) else getattr(t, "response", str(t))]
        elif "query" in td.keys(True):
                q = td.get("query")
                # if query is a list (batch), use it as-is; else wrap
                llm_texts = list(q) if hasattr(q, "__iter__") and not isinstance(q, (str, bytes)) else [q]
        else:
            td["reward"] = torch.zeros(td.batch_size + (1,), dtype=torch.float32)
            return td
    
        # 2) extract per-example overrides (align shape)
        overrides = td.get("reward_params", None)
        if overrides is None or isinstance(overrides, dict):
            overrides_list = [overrides] * len(llm_texts)
        else:
            overrides_list = list(overrides)
            if len(overrides_list) != len(llm_texts):
                # safe fallback
                overrides_list = [None] * len(llm_texts)
    
        # 3) call endpoints per example (sequential or parallel)
        rewards: list[float] = []
        for text, ov in zip(llm_texts, overrides_list):
            # hit the three endpoints for THIS example
            results = await asyncio.gather(*[
                self._post_text(cfg["path"], cfg.get("params", {}), text, cfg["name"])
                for cfg in self._endpoints
            ])
            # combine with per-example params
            r = await self.combine_rewards(results, overrides=ov)
            rewards.append(float(r))
    
        # 4) write back (vector or scalar)
        out = torch.as_tensor(rewards, dtype=torch.float32)
        # reshape to td batch
        if out.ndim == 1 and out.numel() == td.numel():
            out = out.view(td.batch_size + (1,))
        else:
            out = out.mean().view(td.batch_size + (1,))  # conservative fallback
        td["reward"] = out
        return td


    def transform_reward_spec(self, reward_spec: CompositeSpec) -> CompositeSpec:
        reward_spec["reward"] = Unbounded(shape=reward_spec.shape + (1,), dtype=torch.float32)
        return reward_spec

    async def aclose(self):
        try:
            await self.client.aclose()
        except Exception:
            pass

    def __del__(self):
        try:
            if not self.client.is_closed:
                loop = asyncio.get_event_loop()
                if loop.is_running():
                    loop.create_task(self.client.aclose())
        except Exception:
            pass

In [51]:
test_plasmid = "GGGCGAATTCGAGCTCGGTACCCGGGGATCCTCTAGAGTCGACCTGCAGGCATGCAAGCTTGAGTATTCTATAGTGTCACCTAAATAGCTTGGCGTAATCATGGTCATAGCTGTTTCCTGTGTGAAATTGTTATCCGCTCACAATTCCACACAACATACGAGCCGGAAGCATAAAGTGTAAAGCCTGGGGTGCCTAATGAGTGAGCTAACTCACATTAATTGCGTTGCGCTCACTGCCCGCTTTCCAGTCGGGAAACCTGTCGTGCCAGCTGCATTAATGAATCGGCCAACGCGCGGGGAGAGGCGGTTTGCGTATTGGGCGCTCTTCCGCTTCCTCGCTCACTGACTCGCTGCGCTCGGTCGTTCGGCTGCGGCGAGCGGTATCAGCTCACTCAAAGGCGGTAATACGGTTATCCACAGAATCAGGGGATAACGCAGGAAAGAACATGAATTAATTCTCATGTTTGACAGCTTATCATCGATTAGCTTTAATGCGGTAGTTTATCACAGTTAAATTGCTAACGCAGTCAGGCACCGTGTATGAAATCTAACAATGCGCTCATCGTCATCCTCGGCACCGTCACCCTGGATGCTGTAGGCATAGGCTTGGTTATGCCGGTACTGCCGGGCCTCTTGCGGGATATCGTCCATTCCGACAGCATCGCCAGTCACTATGGCGTGCTGCTAGCGCTATATGCGTTGATGCAATTTCTATGCGCACCCGTTCTCGGAGCACTGTCCGACCGCTTTGGCCGCCGCCCAGTCCTGCTCGCTTCGCTACTTGGAGCCACTATCGACTACGCGATCATGGCGACCACACCCGTCCTGTGGATTCTCTACGCCGGACGCATCGTGGCCGGCATCACCGGCGCCACAGGTGCGGTTGCTGGCGCCTATATCGCCGACATCACCGATGGGGAAGATCGGGCTCGCCACTTCGGGCTCATGAGCGCTTGTTTCGGCGTGGGTATGGTGGCAGGCCCCGTGGCCGGGGGACTGTTGGGCGCCATCTCCTTACATGCACCATTCCTTGCGGCGGCGGTGCTCAACGGCCTCAACCTACTACTGGGCTGCTTCCTAATGCAGGAGTCGCATAAGGGAGAGCGCCGACCGATGCCCTTGAGAGCCTTCAACCCAGTCAGCTCCTTCCGGTGGGCGCGGGGCATGACTATCGTCGCCGCACTTATGACTGTCTTCTTTATCATGCAACTCGTAGGACAGGTGCCGGCAGCGCTCTGGGTCATTTTCGGCGAGGACCGCTTTCGCTGGAGCGCGACGATGATCGGCCTGTCGCTTGCGGTATTCGGAATCTTGCACGCCCTCGCTCAAGCCTTCGTCACTGGTCCCGCCACCAAACGTTTCGGCGAGAAGCAGGCCATTATCGCCGGCATGGCGGCCGACGCGCTGGGCTACGTCTTGCTGGCGTTCGCGACGCGAGGCTGGATGGCCTTCCCCATTATGATTCTTCTCGCTTCCGGCGGCATCGGGATGCCCGCGTTGCAGGCCATGCTGTCCAGGCAGGTAGATGACGACCATCAGGGACAGCTTCAAGGATCGCTCGCGGCTCTTACCAGCCTAACTTCGATCACTGGACCGCTGATCGTCACGGCGATTTATGCCGCCTCGGCGAGCACATGGAACGGGTTGGCATGGATTGTAGGCGCCGCCCTATACCTTGTCTGCCTCCCCGCGTTGCGTCGCGGTGCATGGAGCCGGGCCACCTCGACCTGAATGGAAGCCGGCGGCACCTCGCTAACGGATTCACCACTCCAAGAATTGGAGCCAATCAATTCTTGCGGAGAACTGTGAATGCGCAAACCAACCCTTGGCAGAACATATCCATCGCGTCCGCCATCTCCAGCAGCCGCACGCGGCGCATCTCGGGCAGCGTTGGGTCCTGGCCACGGGTGCGCATGATCGTGCTCCTGTCGTTGAGGACCCGGCTAGGCTGGCGGGGTTGCCTTACTGGTTAGCAGAATGAATCACCGATACGCGAGCGAACGTGAAGCGACTGCTGCTGCAAAACGTCTGCGACCTGAGCAACAACATGAATGGTCTTCGGTTTCCGTGTTTCGTAAAGTCTGGAAACGCGGAAGTCAGCGCCCTGCACCATTATGTTCCGGATCTGCATCGCAGGATGCTGCTGGCTACCCTGTGGAACACCTACATCTGTATTAACGAAGCGCTGGCATTGACCCTGAGTGATTTTTCTCTGGTCCCGCCGCATCCATACCGCCAGTTGTTTACCCTCACAACGTTCCAGTAACCGGGCATGTTCATCATCAGTAACCCGTATCGTGAGCATCCTCTCTCGTTTCATCGGTATCATTACCCCCATGAACAGAAATTCCCCCTTACACGGAGGCATCAAGTGACCAAACAGGAAAAAACCGCCCTTAACATGGCCCGCTTTATCAGAAGCCAGACATTAACGCTTCTGGAGAAACTCAACGAGCTGGACGCGGATGAACAGGCAGACATCTGTGAATCGCTTCACGACCACGCTGATGAGCTTTACCGCAGCTGCCTCGCGCGTTTCGGTGATGACGGTGAAAACCTCTGACACATGCAGCTCCCGGAGACGGTCACAGCTTGTCTGTAAGCGGATGCCGGGAGCAGACAAGCCCGTCAGGGCGCGTCAGCGGGTGTTGGCGGGTGTCGGGGCGCAGCCATGACCCAGTCACGTAGCGATAGCGGAGTGTATACTGGCTTAACTATGCGGCATCAGAGCAGATTGTACTGAGAGTGCACCATATGCGGTGTGAAATACCGCACAGATGCGTAAGGAGAAAATACCGCATCAGGCGCTCTTCCGCTTCCTCGCTCACTGACTCGCTGCGCTCGGTCGTTCGGCTGCGGCGAGCGGTATCAGCTCACTCAAAGGCGGTAATACGGTTATCCACAGAATCAGGGGATAACGCAGGAAAGAACATGTGAGCAAAAGGCCAGCAAAAGGCCAGGAACCGTAAAAAGGCCGCGTTGCTGGCGTTTTTCCATAGGCTCCGCCCCCCTGACGAGCATCACAAAAATCGACGCTCAAGTCAGAGGTGGCGAAACCCGACAGGACTATAAAGATACCAGGCGTTTCCCCCTGGAAGCTCCCTCGTGCGCTCTCCTGTTCCGACCCTGCCGCTTACCGGATACCTGTCCGCCTTTCTCCCTTCGGGAAGCGTGGCGCTTTCTCATAGCTCACGCTGTAGGTATCTCAGTTCGGTGTAGGTCGTTCGCTCCAAGCTGGGCTGTGTGCACGAACCCCCCGTTCAGCCCGACCGCTGCGCCTTATCCGGTAACTATCGTCTTGAGTCCAACCCGGTAAGACACGACTTATCGCCACTGGCAGCAGCCACTGGTAACAGGATTAGCAGAGCGAGGTATGTAGGCGGTGCTACAGAGTTCTTGAAGTGGTGGCCTAACTACGGCTACACTAGAAGGACAGTATTTGGTATCTGCGCTCTGCTGAAGCCAGTTACCTTCGGAAAAAGAGTTGGTAGCTCTTGATCCGGCAAACAAACCACCGCTGGTAGCGGTGGTTTTTTTGTTTGCAAGCAGCAGATTACGCGCAGAAAAAAAGGATCTCAAGAAGATCCTTTGATCTTTTCTACGGGGTCTGACGCTCAGTGGAACGAAAACTCACGTTAAGGGATTTTGGTCATGAGATTATCAAAAAGGATCTTCACCTAGATCCTTTTAAATTAAAAATGAAGTTTTAAATCAATCTAAAGTATATATGAGTAAACTTTGGCTGACAGTTACCAATGCTTAATCAGTGAGGCACCTATCTCAGCGATCTGTCTATTTCGTTCATCCATAGTTGCCTGACTCCCCGTCGTGTAGATAACTACGATACGGGAGGGCTTACCATCTGGCCCCAGTGCTGCAATGATACCGCGAGACCCACGCTCACCGGCTCCAGATTTATCAGCAATAAACCAGCCAGCCGGAAGGGCCGAGCGCAGAAGTGGTCCTGCAACTTTATCCGCCTCCATCCAGTCTATTAATTGTTGCCGGGAAGCTAGAGTAAGTAGTTCGCCAGTTAATAGTTTGCGCAACGTTGTTGCCATTGCGGCATCGTGGTGTCACGCTCGTCGTTTGGTATGGCTTCATTCAGCTCCGGTTCCCAACGATCAAGGCGAGTTACATGATCCCCCATGTTGTGCAAAAAAGCGGTTAGCTCCTTCGGTCCTCCGATCGTTGTCAGAAGTAAGTTGGCCGCAGTGTTATCACTCATGGTTATGGCAGCACTGCATAATTCTCTTACTGTCATGCCATCCGTAAGATGCTTTTCTGTGACTGGTGAGTACTCAACCAAGTCATTCTGAGAATAGTGTATGCGGCGACCGAGTTGCTCTTGCCCGGCGTCAACACGGGATAATACCGCGCCACATAGCAGAACTTTAAAAGTGCTCATCATTGGAAAACGTTCTTCGGGGCGAAAACTCTCAAGGATCTTACCGCTGTTGAGATCCAGTTCGATGTAACCCACTCGTGCACCCAACTGATCTTCAGCATCTTTTACTTTCACCAGCGTTTCTGGGTGAGCAAAAACAGGAAGGCAAAATGCCGCAAAAAAGGGAATAAGGGCGACACGGAAATGTTGAATACTCATACTCTTCCTTTTTCAATATTATTGAAGCATTTATCAGGGTTATTGTCTCATGAGCGGATACATATTTGAATGTATTTAGAAAAATAAACAAATAGGGGTTCCGCGCACATTTCCCCGAAAAGTGCCACCTGACGTCTAAGAAACCATTATTATCATGACATTAACCTATAAAAATAGGCGTATCACGAGGCCCTTTCGTCTCGCGCGTTTCGGTGATGACGGTGAAAACCTCTGACACATGCAGCTCCCGGAGACGGTCACAGCTTGTCTGTAAGCGGATGCCGGGAGCAGACAAGCCCGTCAGGGCGCGTCAGCGGGTGTTGGCGGGTGTCGGGGCTGGCTTAACTATGCGGCATCAGAGCAGATTGTACTGAGAGTGCACCATATGCGGTGTGAAATACCGCACAGATGCGTAAGGAGAAAATACCGCATCAGGCGAAATTGTAAACGTTAATATTTTGTTAAAATTCGCGTTAAATATTTGTTAAATCAGCTCATTTTTTAACCAATAGGCCGAAATCGGCAAAATCCCTTATAAATCAAAAGAATAGACCGAGATAGGGTTGAGTGTTGTTCCAGTTTGGAACAAGAGTCCACTATTAAAGAACGTGGACTCCAACGTCAAAGGGCGAAAAACCGTCTATCAGGGCGATGGCCCACTACGTGAACCATCACCCAAATCAAGTTTTTTGCGGTCGAGGTGCCGTAAAGCTCTAAATCGGAACCCTAAAGGGAGCCCCCGATTTAGAGCTTGACGGGGAAAGCCGGCGAACGTGGCGAGAAAGGAAGGGAAGAAAGCGAAAGGAGCGGGCGCTAGGGCGCTGGCAAGTGTAGCGGTCACGCTGCGCGTAACCACCACACCCGCCGCGCTTAATGCGCCGCTACAGGGCGCGTCCATTCGCCATTCAGGCTGCGCAACTGTTGGGAAGGGCGATCGGTGCGGGCCTCTTCGCTATTACGCCAGCTGGCGAAAGGGGGATGTGCTGCAAGGCGATTAAGTTGGGTAACGCCAGGGTTTTCCCAGTCACGACGTTGTAAAACGACGGCCAGTGAATTGTAATACGACTCACTATA"


In [52]:
td = TensorDict({}, batch_size=(1,))
class DummyText:
    def __init__(self, response: str):
        self.response = response

td['text'] = test_plasmid

rt = RewardTransform("http://server:8080")

scored = await rt._call(td)
print("Reward:", scored["reward"])

Reward: tensor([[1.0435]])


In [53]:
env = env.append_transform(RewardTransform("http://server:8080"))

In [54]:
class DefaultQueryOnReset(Transform):
    def __init__(self, default_query: list[str]):
        super().__init__()
        self.default_query = default_query

    # This is the hook TransformedEnv invokes before calling base_env._reset(...)
    def _reset_env_preprocess(self, tensordict: TensorDict | None) -> TensorDict:
        if tensordict is None or ("query" not in tensordict.keys(True)):
            b = (len(self.default_query),)
            tensordict = TensorDict({"query": self.default_query}, batch_size=b)
        return tensordict

In [55]:
# run this with the GRP starter
env = env.append_transform(DefaultQueryOnReset(["AATG"]))


In [71]:
from torchrl.modules.llm import TransformersWrapper
from torchrl.collectors.llm import LLMCollector
from torchrl.objectives.llm import GRPOLoss

policy = TransformersWrapper(
    model=model,          # "McClain/plasmidgpt-addgene-gpt2"
    tokenizer=tokenizer,
    input_mode="text",
    return_log_probs=True,
).eval()

In [72]:
reference = TransformersWrapper(
    model=model,              # or a separate ref checkpoint
    tokenizer=tokenizer,
    input_mode="text",
    return_log_probs=True,
    #freeze_model=True
).eval()
    

In [73]:
collector = LLMCollector(
    policy=policy,
    env=env,
    dialog_turns_per_batch=1
)

In [68]:
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.storages import ListStorage
from torchrl.objectives.llm import MCAdvantage

# Keep K responses per (same) prompt before computing advantage
K = 4
rb = ReplayBuffer(
    storage=ListStorage(1024),
    transform=MCAdvantage(grpo_size=K)  # writes "advantage" into items once K complete trajs exist
)

In [69]:
loss_module = GRPOLoss(
    actor_network=policy,
    # You can add KL-to-reference regularization if desired:
    kl_to_ref_coeff=0.01,
    masking_strategy="sft",   # single-turn: only response tokens; multi-turn: "rlhf"
)

In [74]:
import math

num_iters = 200
for it in range(num_iters):
    # 1) Generate one dialog turn (the env handles reset() via your transform)
    for it, td in enumerate(collector):
    
        # --- make reward shape broadcastable: (*bsz, 1, 1) ---
        if "reward" in td.keys():
            r = td.get("reward")
            if r.ndim == len(td.batch_size):          # e.g., (B,)
                r = r.unsqueeze(-1).unsqueeze(-1)     # -> (B,1,1)
            elif r.ndim == len(td.batch_size) + 1:    # e.g., (B,1)
                r = r.unsqueeze(-1)                   # -> (B,1,1)
            td.set("reward", r)
    
        # 2) Push to RB → MCAdvantage will fill "advantage" once it has K trajs for the same prompt
        rb.add(td)
    
        # Not every iter will have advantage ready; skip until MCAdvantage emits it
        if "advantage" not in td.keys():
            if it % 20 == 0:
                print(f"[it {it}] waiting for groups (K={K}) to fill…")
            continue
    
        # 3) Compute GRPO loss and update
        optim.zero_grad(set_to_none=True)
        loss_td = loss_module(td)          # produces "loss/*" scalars
        loss = sum(v for k, v in loss_td.items() if k.startswith("loss"))
        loss.backward()
        optim.step()
    
        # 4) Basic logging
        avg_r = float(td.get("reward").mean().item()) if "reward" in td.keys() else math.nan
        print(f"[it {it}] loss={loss.item():.4f}  avg_reward={avg_r:.3f}")




AttributeError: 'coroutine' object has no attribute 'batch_size'