Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some questions about model testing #45

Open
Philharmy-Wang opened this issue May 21, 2024 · 2 comments
Open

Some questions about model testing #45

Philharmy-Wang opened this issue May 21, 2024 · 2 comments

Comments

@Philharmy-Wang
Copy link

Dear Author,

Firstly, I would like to express my profound gratitude for your contributions and research in this field! I am currently working on multi-modal remote sensing for forest fire detection. Your recent introduction of the img2img method has inspired new directions in my research.

In remote sensing applications, acquiring registered multi-modal data (visible and thermal imaging) is exceedingly difficult. I was very intrigued by your project's handling of scenarios like Day to Night and Clear to Rainy transitions. I am interested in converting visible light images into registered thermal imaging to enrich my dataset with synthetic data.

To this end, I have trained a model locally using the training code you provided, utilizing some multi-modal forest fire data I have, which is already registered. I formatted the RGB and thermal images according to the structure used in your dataset.

Clip_2024-05-21_15-30-14

The specific training command was as follows:

accelerate launch src/train_pix2pix_turbo.py \
    --pretrained_model_name_or_path="stabilityai/sd-turbo" \
    --output_dir="output/pix2pix_turbo/fs_rgb_ir_03" \
    --dataset_folder="data/fs_rgb_ir" \
    --resolution=512 \
    --train_batch_size=1 \
    --enable_xformers_memory_efficient_attention --viz_freq 25 \
    --track_val_fid \
    --learning_rate=8e-5 \
    --num_training_epochs=150 \
    --max_train_steps=200000 \
    --lr_scheduler="cosine_with_restarts" \
    --lr_warmup_steps=1000 \
    --lr_num_cycles=1 \
    --report_to "wandb" --tracker_project_name "pix2pix_turbo_fs_rgb_ir"

The model performed exceptionally well during validation, where the generated thermal images were almost identical to the target images.

Clip_2024-05-21_15-30-34

10000 epoch:
Clip_2024-05-10_10-33-23

20000 epoch:
Clip_2024-05-10_10-35-23

30026 epoch:
Clip_2024-05-10_10-36-59

40001 epoch:
Clip_2024-05-10_10-43-51

However, when I applied the model to test on other forest fire datasets, the results were disappointing and vastly different from what was expected.

Clip_2024-05-15_11-15-40
Clip_2024-05-15_11-30-14

The specific testing command was:

python src/inference_paired.py --model_path "output/pix2pix_turbo/fs_rgb_ir_03/checkpoints/model_66001.pkl" \
    --input_image "VOCdevkit-rsy-all/images/train/1_1.jpg" \
    --prompt "RGB to IR" \
    --output_dir "outputs/fs-rgb_ir-512"

Given the above, I have several questions:

  1. Are there any oversights in my training setup? Due to hardware limitations, I had to set the train_batch_size to 1, which might have compromised effective learning.
  2. I used your project's src/inference_paired.py for inference with all parameters set to default. Do you think the distortion in inference could be related to the script's configuration?
  3. The model performed well during training and validation, but the inference on a new dataset was poor. Does this suggest that the model might be overfitting?

I look forward to your reply and thank you for taking the time to assist with these questions!

@GaParmar
Copy link
Owner

GaParmar commented May 26, 2024

Hi,

Thank you for you interest in this project!
A couple things I can observe from your results:

  • What is the image pre-processing used during inference and training time?
  • What is your L2 reconstruction error for the validation set vs test set?
  • It is indeed strange if the model performs well on validation set but poorly on the test set. It could be a sign of overfitting. Is your test set similar to the validation set?

-Gaurav

@Philharmy-Wang
Copy link
Author

Hi,

Thank you for you interest in this project! A couple things I can observe from your results:

  • What is the image pre-processing used during inference and training time?
  • What is your L2 reconstruction error for the validation set vs test set?
  • It is indeed strange if the model performs well on validation set but poorly on the test set. It could be a sign of overfitting. Is your test set similar to the validation set?

-Gaurav

Dear Gaurav,

Thank you very much for your prompt reply and suggestions! Here are the specific answers to the questions you raised:

  1. During inference and training, I adopted the same image preprocessing methods used in your project for training the Fill50k dataset on pix2pix-turbo. These preprocessing steps are implemented by default in the src/train_pix2pix_turbo.py script.

  2. Regarding the L2 reconstruction error, it was 0.147 for the training set (fs_rgb_ir) and 0.202 for the validation set.

  3. Concerning dataset similarity, my test set is indeed quite different from the validation set. Though both are captured from a drone's perspective of forest fires, the test set includes synthesized images and our own drone-captured images, which differ significantly in terms of drone flight altitude and overall image style from the training and validation sets.
    image
    image

Here are the statistical details of the images in my training and test datasets:

Training set RGB images (train_A):

Mean: tensor([0.5097, 0.5172, 0.5023])
Std: tensor([0.1511, 0.1540, 0.1636])
Training set IR images (train_B):

Mean: tensor([0.5555, 0.0691, 0.3500])
Std: tensor([0.1448, 0.1327, 0.2469])
Test set RGB images (test_A):

Mean: tensor([0.5082, 0.5153, 0.5008])
Std: tensor([0.1526, 0.1554, 0.1663])
Test set IR images (test_B):

Mean: tensor([0.5495, 0.0680, 0.3495])
Std: tensor([0.1478, 0.1346, 0.2481])
Sample image from test set (test_img):

Mean: tensor([0.4136, 0.4193, 0.4006])
Std: tensor([0.2235, 0.2367, 0.2664])

I hope this information is helpful in resolving the issues. I look forward to your further guidance and advice!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants