
Given a prompt (in grey) requiring scientific knowledge, FLUX generates imaginary images (lower row) that are far from reality (upper row). Moreover, LMMs like GPT-4o fail to identify the realistic image, whereas our end-to-end reward model succeeds. Notice that the prompts here are summarization of the real prompts that we used for illustration purposes.
- [2025/4/18] Release paper.
- [2025/4/05] Release Science-T2I dataset, as well as the training and evaluation code.
We recommend installing Science-T2I in a virtual environment from Conda (Python>=3.10).
conda create -n science-t2i python=3.10
conda activate science-t2i
Clone the repository and the submodule.
git clone git@github.com:Jialuo-Li/Science-T2I.git
cd Science-T2I
git submodule update --init
Install PyTorch following instruction.
pip install torch torchvision
Install additional dependencies.
pip install -r requirements.txt
In addition to the Science-T2I training dataset, we have also curated two novel benchmarks specifically designed for evaluating vision-based scientific understanding tasks: Science-T2I-S and Science-T2I-C . These benchmarks contain 671 and 227 tuples, respectively. Each tuple consists of:
- An implicit prompt and its corresponding explicit prompt, superficial prompt.
- Two images: one that aligns with the explicit prompt and another that corresponds to the superficial prompt.
We encourage you to evaluate your models on our benchmarks and submit a pull request with your results to refresh the Leaderboard!
To evaluate VLMs using our benchmarks, we provide an example script for assessing SciScore on the Science-T2I-S benchmark. You can adapt this script by modifying the input arguments suit your specific VLM.
python eval/eval_vlm.py \
--dataset_name Jialuo21/Science-T2I-S \
--processor_name Jialuo21/SciScore \
--model_name Jialuo21/SciScore
For evaluating LMMs, we offer an example script to assess LLaVA-OV on the Science-T2I-S benchmark. To adapt this script for your own LMM, simply modify the dataset name and adjust the code accordingly.
First, install the required LLaVA-OV package:
pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git einops flash_attn
Then, run the evaluation script:
python eval/eval_lmm.py \
--dataset_name Jialuo21/Science-T2I-S
We display here an example for running inference with SciScore:
from transformers import AutoProcessor, AutoModel
from PIL import Image
import torch
device = "cuda"
processor_name_or_path = "Jialuo21/SciScore"
model_pretrained_name_or_path = "Jialuo21/SciScore"
processor = AutoProcessor.from_pretrained(processor_name_or_path)
model = AutoModel.from_pretrained(model_pretrained_name_or_path).eval().to(device)
def calc_probs(prompt, images):
image_inputs = processor(
images=images,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(device)
text_inputs = processor(
text=prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(device)
with torch.no_grad():
image_embs = model.get_image_features(**image_inputs)
image_embs = image_embs / torch.norm(image_embs, dim=-1, keepdim=True)
text_embs = model.get_text_features(**text_inputs)
text_embs = text_embs / torch.norm(text_embs, dim=-1, keepdim=True)
scores = model.logit_scale.exp() * (text_embs @ image_embs.T)[0]
probs = torch.softmax(scores, dim=-1)
return probs.cpu().tolist()
pil_images = [Image.open("./examples/camera_1.png"), Image.open("./examples/camera_2.png")]
prompt = "A camera screen without electricity sits beside the window, realistic."
print(calc_probs(prompt, pil_images))
Using SciScore, you can assess how well T2I models align with real-world scenarios in our predefined tasks. Below is an example evaluation script for testing FLUX.1[schnell] on SciScore, utilizing the prompts from the Science-T2I-S dataset:
accelerate launch eval/eval_t2i_with_SciScore.py \
--dataset_name Jialuo21/Science-T2I-S
To train SciScore from scratch, execute the following commands. This process takes approximately one hour on a system with 8 A6000 GPUs.
pip install deepspeed==0.14.5 # First install deepspeed for training
cd SciScore_trainer
bash train_sciscore.sh
Install GroundingDINO dependencies and download pretrained weights.
cd ft_flux/GroundingDINO
pip install -e .
mkdir -p weights && cd weights
wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
cd ..
pip uninstall deepspeed # Uninstall deepspeed if it's currently installed (not needed for this section)
We begin by performing supervised fine-tuning (SFT) on FLUX-1.[dev] for domain adaptation, using the training set from Science-T2I. The example command to run this stage is:
accelerate launch sft_flux.py --config config/custom.py:sft
In this stage, we further fine-tune FLUX.1[dev] using online fine-tuning (OFT) with the DPO training objective, SciScore is used as reward model to guide the optimization process. The example command is:
accelerate launch oft_flux.py --config config/custom.py:oft
We also provide examples for fine-tuning FLUX with different reward models.
- Whiteness Reward (higher reward for whiter images):
accelerate launch oft_flux.py --config config/custom.py:oft_white
pip install image-reward
accelerate launch oft_flux.py --config config/custom.py:oft_image_reward
accelerate launch oft_flux.py --config config/custom.py:oft_aes_score
We are deeply grateful for the following GitHub repositories, as their valuable code and efforts have been incredibly helpful:
- PickScore (https://github.com/yuvalkirstain/PickScore)
- DDPO (https://github.com/kvablack/ddpo-pytorch)
- Diffusers (https://github.com/huggingface/diffusers)
- GroundingDINO (https://github.com/IDEA-Research/GroundingDINO)
- ImageReward (https://github.com/THUDM/ImageReward)
If you find Science-T2I useful for your your research and applications, please cite using this BibTeX:
@misc{li2025sciencet2iaddressingscientificillusions,
title={Science-T2I: Addressing Scientific Illusions in Image Synthesis},
author={Jialuo Li and Wenhao Chai and Xingyu Fu and Haiyang Xu and Saining Xie},
year={2025},
eprint={2504.13129},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2504.13129},
}