# ThermalGAN single-sequence variant generation (Colab)

This notebook wraps `generate_variants_single.py` so you can input a raw sequence instead of a FASTA file.
It downloads weights, runs ESM + ThermalGAN, and writes a FASTA of variants.


## 0) Runtime
Set Runtime -> Change runtime type -> GPU before running.


In [None]:
#@title Clone repository (optional)
from pathlib import Path

GIT_REPO_URL = ""  # TODO: set to your repo URL
REPO_DIR = Path("/content/thermalgan-repo")

if GIT_REPO_URL:
    if not REPO_DIR.exists():
        !git clone -q {GIT_REPO_URL} {REPO_DIR}
    else:
        print(f"Repo already exists at {REPO_DIR}")
else:
    print("Skipping clone. Set GIT_REPO_URL if you want to clone.")


In [None]:
#@title Install dependencies
%pip -q install "transformers>=4.36" "biopython>=1.79" "pyyaml" "tqdm" "typeguard" "gdown"


In [None]:
#@title Configure paths
import sys
from pathlib import Path

candidates = []
if "REPO_DIR" in globals() and REPO_DIR.exists():
    candidates.extend([REPO_DIR / "sandra" / "ThermalGAN", REPO_DIR / "ThermalGAN"])
candidates.extend([Path("/content/ThermalGAN"), Path("/content/sandra/ThermalGAN")])

THERMALGAN_ROOT = None
for c in candidates:
    if c.exists():
        THERMALGAN_ROOT = c
        break

if THERMALGAN_ROOT is None:
    raise FileNotFoundError("Could not locate ThermalGAN. Set THERMALGAN_ROOT manually.")

sys.path.append(str(THERMALGAN_ROOT / "src"))
sys.path.append(str(THERMALGAN_ROOT / "src/scripts"))

print("THERMALGAN_ROOT:", THERMALGAN_ROOT)


## 1) Weights
You need a run directory that contains `config.yaml` and `weights/epoch_<n>/generator_G.h5`.
Optionally, download OGT weights if you want predicted OGT values.


In [None]:
#@title Download weights (optional)
from pathlib import Path

DOWNLOAD_DIR = Path("/content/thermalgan_weights")
DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)

RUN_DIR_ZIP_URL = ""  # TODO: URL to zip/tar.gz that contains a run_dir folder
OGT_WEIGHTS_ZIP_URL = ""  # TODO: URL to zip/tar.gz that contains OGT/Model1,2,3

def download_and_extract(url, dest_dir):
    if not url:
        return
    archive = dest_dir / Path(url).name
    !wget -q -O {archive} {url}
    if str(archive).endswith(".zip"):
        !unzip -q {archive} -d {dest_dir}
    elif str(archive).endswith((".tar.gz", ".tgz")):
        !tar -xzf {archive} -C {dest_dir}
    else:
        print("Unknown archive extension; extract manually if needed.")

download_and_extract(RUN_DIR_ZIP_URL, DOWNLOAD_DIR)
download_and_extract(OGT_WEIGHTS_ZIP_URL, DOWNLOAD_DIR)


In [None]:
#@title Configure run_dir and output
from pathlib import Path

RUN_DIR = Path("/content/thermalgan_weights/your_run_dir")  # TODO
EPOCH = 39

OUTPUT_DIR = Path("/content/thermalgan_outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

OGT_WEIGHTS_DIR = Path("/content/thermalgan_weights/OGT")  # TODO if using OGT
OGT_CONFIG = THERMALGAN_ROOT / "config/Classifier/config_classifier1.yaml"

if not (RUN_DIR / "config.yaml").exists():
    raise FileNotFoundError(f"Missing config.yaml in {RUN_DIR}")
if not (RUN_DIR / "weights" / f"epoch_{EPOCH}" / "generator_G.h5").exists():
    raise FileNotFoundError("Missing generator_G.h5 for the selected epoch")


In [None]:
#@title Sequence input and generation settings
sequence = "YOUR_SEQUENCE_HERE"
sequence_id = "query1"
sequence_temp = None  # optional numeric temperature
output_name = sequence_id

replicates = 10
temperature = 1.0
store_softmax = False

esm_init_mode = "single"  # single or random
esm_random_mask_prob = 0.15
esm_random_chunk = 8
esm_random_seed = 42

esm_filter_threshold = 3.0  # set None to disable
filter_opt_cycles = 1
sampling_cycles = 0
skip_optimize = False

hf_model = "facebook/esm1v_t33_650M_UR90S_1"
esm_device = "cuda"  # cuda or cpu
batch_tokens = 20000
max_length = 1022
esm_bf16 = True
gpu_id = 0

skip_ogt = True  # set False to enable OGT predictions


In [None]:
#@title Run ThermalGAN
import shlex
import subprocess

sequence = "".join(sequence.split())
if not sequence:
    raise ValueError("Sequence is empty.")

cmd = [
    "python",
    str(THERMALGAN_ROOT / "src/scripts/generate_variants_single.py"),
    "--run_dir", str(RUN_DIR),
    "--sequence", sequence,
    "--seq_id", sequence_id,
    "--epoch", str(EPOCH),
    "--output_dir", str(OUTPUT_DIR),
    "--name", output_name,
    "--replicates", str(replicates),
    "--temperature", str(temperature),
    "--hf_model", hf_model,
    "--device", esm_device,
    "--batch_tokens", str(batch_tokens),
    "--max_length", str(max_length),
    "--gpu", str(gpu_id),
    "--esm_init_mode", esm_init_mode,
    "--esm_random_mask_prob", str(esm_random_mask_prob),
    "--esm_random_chunk", str(esm_random_chunk),
    "--esm_random_seed", str(esm_random_seed),
    "--filter_opt_cycles", str(filter_opt_cycles),
    "--sampling_cycles", str(sampling_cycles),
]
if sequence_temp is not None:
    cmd += ["--seq_temp", str(sequence_temp)]
if store_softmax:
    cmd += ["--store_softmax"]
if esm_filter_threshold is not None:
    cmd += ["--esm_filter_threshold", str(esm_filter_threshold)]
if skip_optimize:
    cmd += ["--skip_optimize"]
if esm_bf16:
    cmd += ["--esm_bf16"]
if skip_ogt:
    cmd += ["--skip_ogt"]
else:
    cmd += ["--ogt_config", str(OGT_CONFIG)]
    cmd += [
        "--ogt_weights",
        str(OGT_WEIGHTS_DIR / "Model1" / "variables" / "variables"),
        str(OGT_WEIGHTS_DIR / "Model2" / "variables" / "variables"),
        str(OGT_WEIGHTS_DIR / "Model3" / "variables" / "variables"),
    ]

print("Running:", " ".join(shlex.quote(c) for c in cmd))
subprocess.run(cmd, check=True)


In [None]:
#@title Inspect outputs and download
from pathlib import Path

fasta_path = OUTPUT_DIR / f"{output_name}_variants_epoch_{EPOCH}.fasta"
print("Output FASTA:", fasta_path)

if fasta_path.exists():
    !head -n 20 {fasta_path}
else:
    print("FASTA not found; check logs above.")

softmax_path = OUTPUT_DIR / f"{output_name}_variants_epoch_{EPOCH}_softmax.jsonl"
if softmax_path.exists():
    print("Softmax JSONL:", softmax_path)

try:
    from google.colab import files
    if fasta_path.exists():
        files.download(str(fasta_path))
    if softmax_path.exists():
        files.download(str(softmax_path))
except Exception as exc:
    print(f"Download skipped: {exc}")
