Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The right way of inference #17

Open
AdelZakirovRZ opened this issue Jun 4, 2022 · 25 comments
Open

The right way of inference #17

AdelZakirovRZ opened this issue Jun 4, 2022 · 25 comments
Labels
question Further information is requested

Comments

@AdelZakirovRZ
Copy link

AdelZakirovRZ commented Jun 4, 2022

Hey guys,
can you please direct me on how to properly inference the trained model?
I wrote a small script for it, but not sure that I am doing everything right.
One of the conceptual questions for me is the robustness of the results. I tried to train the colorization model, and during inference it gives me different results for the same image. Is that because of randomness in a noise scheduling? Or is it something else?
Thanks in advance!

@Janspiry
Copy link
Owner

Janspiry commented Jun 5, 2022

Yes, it will output different results due to Guassian random noise.

@AdelZakirovRZ
Copy link
Author

Thanks!
My model now gives unstable results: colorization is either super-good, or quite bad. What do you think can be the cause of it? Small dataset?

@Janspiry
Copy link
Owner

Janspiry commented Jun 7, 2022

I have encountered this same problem. In my experience, colorization in complex scenes requires a lot of data, unlike faces which can be trained well with a small amount of data. But I haven't had time to verify it yet.

@bronyayang
Copy link

Can you possibly share the inference script? Thank you!

@dov84d
Copy link

dov84d commented Jun 12, 2022

Hi @AdelZakirovRZ
Can you please share your inference script?

@AdelZakirovRZ
Copy link
Author

Hey guys, so about the inference example.
First you need you model config (I call it model_args below). It should be in a config file you used, for example. Then you do the following and it should work. @Janspiry correct me if I am doing something wrong over there, please.

model = Network(**model_args)  
state_dict = torch.load(path_ckpt)  
model.load_state_dict(state_dict, strict=False)  
device = torch.device('cuda:0')  
model.to(device)  
model.set_new_noise_schedule(phase='test')  
model.eval()

tfs = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])

img = Image.open(img_path).convert('RGB')
img_gr = img.convert('L').convert('RGB')

x = tfs(img_gr)
x = x.unsqueeze(0).to(device)
with torch.no_grad():
    p = model.restoration(y_cond=x,)

@richard-schwab
Copy link

I tried the above code with some changes to it. I was able to get it to spit out an image, but the image is black. Any suggestions on what my error is?

import torch
from torchvision import transforms
from models.network import Network
from PIL import Image
import numpy as np

import argparse
import core.praser as Praser


def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, default='config/inpainting_places2.json', help='JSON file for configuration')
    parser.add_argument('-p', '--phase', type=str, choices=['train','test'], help='Run train or test', default='test')
    parser.add_argument('-b', '--batch', type=int, default=16, help='Batch size in every gpu')
    parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
    parser.add_argument('-d', '--debug', action='store_true')
    parser.add_argument('-P', '--port', default='21012', type=str)

    args = parser.parse_args()
    opt = Praser.parse(args)
    print("Loaded config: ")
    print("========================")
    print(opt)
    print("========================\n")
    return opt


def load_palette(ckpt_path="16_Network.pth", model_args=None):

    model = Network(**model_args)
    # model = Network(unet=)
    path_ckpt = "16_Network.pth"

    state_dict = torch.load(path_ckpt)
    model.load_state_dict(state_dict, strict=False)
    device = torch.device('cuda:0')
    model.to(device)
    model.set_new_noise_schedule(phase='test')
    model.eval()
    print("Loaded checkpoint: ", ckpt_path)
    return model, device


def predict_img(img_path, model, device):

    tfs = transforms.Compose([
                    transforms.Resize((256,256)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
            ])

    print("Predicting image: ", img_path)
    img = Image.open(img_path).convert('RGB')
    img_gr = img.convert('L').convert('RGB')

    x = tfs(img_gr)
    x = x.unsqueeze(0).to(device)
    with torch.no_grad():
        p = model.restoration(y_cond=x,)

    # print(p)

    predict_img = Image.fromarray(p[0][0][0].cpu().numpy())
    predict_img = predict_img.convert("RGB")
    predict_img.save("result.jpg")

    return predict_img




if __name__ == "__main__":
    opt = parse_config()
    model, device = load_palette("16_Network.pth", opt["model"]["which_networks"][0]["args"])
    img_path = "./misc/image/Mask_Places365_test_00143399.jpg"
    predict_img(img_path, model, device)

@richard-schwab
Copy link

richard-schwab commented Jul 11, 2022

Oh also, I did normalize the script by changing line 64 to:

predict_img = Image.fromarray(p[0][0][0].cpu().numpy()*255)

But it is only spitting out very strange images..

@AdelZakirovRZ
Copy link
Author

Hey, the output is in [-1; 1] - I think you miss that part

@richard-schwab
Copy link

So then should it be this? I'm still not getting anything that looks right:

Image.fromarray((p[0][0][0].cpu().numpy()/2+1)*255)

@richard-schwab
Copy link

Ok going through your code again, I had to hunt a bit but I found in Util, your tensor2img function. I ran my prediction through that and it generated this:

https://imgur.com/a/Icz4HFI

Is this correct? The original image was this: ./misc/image/Mask_Places365_test_00144085.jpg

@AdelZakirovRZ
Copy link
Author

hey @richard-schwab

p[0][0][0].cpu()

that part does not look right. p is tuple of two things: pure prediction and noisy predictions on multiple steps. Your prediction p[0] has dimensions BxCxHxW. Since you are using only one images B=1, so your full image prediction is p[0][0]. You don't need additional [0].
Seems like your models is simple not trained good enough to output something meaningful.

Also, FYI - I am not the author of the code

@Janspiry
Copy link
Owner

Hey guys, so about the inference example. First you need you model config (I call it model_args below). It should be in a config file you used, for example. Then you do the following and it should work. @Janspiry correct me if I am doing something wrong over there, please.

model = Network(**model_args)  
state_dict = torch.load(path_ckpt)  
model.load_state_dict(state_dict, strict=False)  
device = torch.device('cuda:0')  
model.to(device)  
model.set_new_noise_schedule(phase='test')  
model.eval()

tfs = transforms.Compose([
                transforms.Resize((256,256)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5,0.5, 0.5])
        ])

img = Image.open(img_path).convert('RGB')
img_gr = img.convert('L').convert('RGB')

x = tfs(img_gr)
x = x.unsqueeze(0).to(device)
with torch.no_grad():
    p = model.restoration(y_cond=x,)

Thanks for sharing. I think this test code is correct. Did you get better experimental results later?

@Janspiry
Copy link
Owner

So then should it be this? I'm still not getting anything that looks right:

hey, Code of @AdelZakirovRZ is used for colorization where y_cond is grayscale image. In inpainting task, y_cond is masked image contains noises, like
image

You need to munually replace blank masked regions with random noises.

@AdelZakirovRZ
Copy link
Author

Thanks for sharing. I think this test code is correct. Did you get better experimental results later?

Hey @Janspiry
I've managed to get better results, but I simplified the task. I used CelebaA dataset to train model for faces colorization. Also, I've changed the code design to pytorch-lightning, which is a bit more convenient for me and gave me some space in terms of additional training features.
These are inference results for images from FFHQ dataset (I am not sure if it overlaps with celeba though, but I don't think so). First image is the original, others sampled from the model.
On the last two you can see how hands got wrong colorization, which is understandable - there are not much hands in the training set.

test

@Janspiry
Copy link
Owner

Hey, thanks for your contributions for this good question and I will keep it for others to add more details

@Janspiry Janspiry added the question Further information is requested label Aug 25, 2022
@ksunho9508
Copy link

@AdelZakirovRZ Hello, Could you share your code for colorization using pytorch-lightning? I am also using pl, but it is difficult to translate this repo to pl. And I am also having problem in reproduction of colorization results, too.
If you can share your repo, you will save my TWO problems simultaneously.

@Nyoko74
Copy link

Nyoko74 commented Sep 21, 2022

Thanks for sharing. I think this test code is correct. Did you get better experimental results later?

Hey @Janspiry I've managed to get better results, but I simplified the task. I used CelebaA dataset to train model for faces colorization. Also, I've changed the code design to pytorch-lightning, which is a bit more convenient for me and gave me some space in terms of additional training features. These are inference results for images from FFHQ dataset (I am not sure if it overlaps with celeba though, but I don't think so). First image is the original, others sampled from the model. On the last two you can see how hands got wrong colorization, which is understandable - there are not much hands in the training set.

test

Hi, may I ask if this has happened to you? I used CelebaA to train the model, but just get some strange result strange. It seems to color a grayscale image with the same color. The same situation occurs in the sketch coloring task.
20220921095400

@Janspiry
Copy link
Owner

@Nyoko74 , I did not, maybe you need to adjust learning rate and train more iterations.

@Nyoko74
Copy link

Nyoko74 commented Sep 23, 2022

@Nyoko74 , I did not, maybe you need to adjust learning rate and train more iterations.

Thank you, I will try it.

@ruairiseosamh
Copy link

Thanks for sharing. I think this test code is correct. Did you get better experimental results later?

Hey @Janspiry I've managed to get better results, but I simplified the task. I used CelebaA dataset to train model for faces colorization. Also, I've changed the code design to pytorch-lightning, which is a bit more convenient for me and gave me some space in terms of additional training features. These are inference results for images from FFHQ dataset (I am not sure if it overlaps with celeba though, but I don't think so). First image is the original, others sampled from the model. On the last two you can see how hands got wrong colorization, which is understandable - there are not much hands in the training set.
test

Hi, may I ask if this has happened to you? I used CelebaA to train the model, but just get some strange result strange. It seems to color a grayscale image with the same color. The same situation occurs in the sketch coloring task. 20220921095400

Hi @Nyoko74, I am having this same issue. Did you manage to fix it?

@tg-bomze
Copy link

tg-bomze commented Oct 4, 2022

Thanks for sharing. I think this test code is correct. Did you get better experimental results later?

Hey @Janspiry I've managed to get better results, but I simplified the task. I used CelebaA dataset to train model for faces colorization. Also, I've changed the code design to pytorch-lightning, which is a bit more convenient for me and gave me some space in terms of additional training features. These are inference results for images from FFHQ dataset (I am not sure if it overlaps with celeba though, but I don't think so). First image is the original, others sampled from the model. On the last two you can see how hands got wrong colorization, which is understandable - there are not much hands in the training set.

test

Hey @AdelZakirovRZ
The results are impressive. Could you share the colorization code using pytorch-lightning? How long did you train the model?

@nappingman
Copy link

nappingman commented Oct 17, 2022

may I ask if this has happened to you? I used CelebaA to train the model, but just get some strange result strange. It seems to color a grayscale image with the same color. The same situation occurs in the sketch coloring task.

Hey @Nyoko74
I've tried several task (image denoising, deblurring, colorization) on some small dataset(about 30k~60k images), and after enough training, ALL the results I get are exactly the same with yours, the whole output images are covered by some strange color.

Did you solve this problem?

And actually in the results shared by @AdelZakirovRZ , I find that the image in the bottom right, the man's hand is blue. I'm wonder that this problem could be solved by increasing training time? If not, how to solve this problem?

@vinodrajendran001
Copy link

In my inpainting case, during the inference only the y_cond and mask images are given. In that case, may I know how to do a inference?

In the network.py script, for the inpainting task the below line will be executed as part of the restoration function. As y_0 is None for me, I am not sure how to deal with this line. If I skip the below line then the results are very bad (just only some whitish kind of image is generated). Also, in the Process.png image I can notice that for each step the noise level is increasing rather than decreasing.

if mask is not None:
    y_t = y_0*(1.-mask) + mask*y_t

@zacsun
Copy link

zacsun commented Dec 9, 2022

Hi everyone, for inpainting task I took a deeper look into the run.py from @Janspiry (Thanks for the amazing work!) and extracted a simpler version with only the necessary code to get the inference working. Hope this helps:

"""
1. Download model and save the model to git_root/model/celebahq/200_Network.pth
2. Modify inpainting_celebahq.json
    ["path"]["resume_state"]: "model/celebahq/200"
    ["datasets"]["test"]["args"]["data_root"]: "<Folder Constains Inference Images>"

    (optinally) change ["model"]["which_networks"]["args"]["beta_schedule"]["test"]["n_timestep"] value to reduce # steps inference should take
                more steps yields better results
3. Modify in your particular case in this code:
    model_pth = "<PATH-TO-MODEL>/200_Network.pth"
    input_image_pth = "<PATH-TO-DATASET_PARENT_DIT>/02323.jpg"
5. Run inpainting code (assume save this code to git_root/inference/inpainting.py)
    cd inference
    python inpainting.py -c ../config/inpainting_celebahq.json -p test
"""

import argparse

import core.praser as Praser
import torch
from core.util import set_device, tensor2img
from data.util.mask import get_irregular_mask
from models.network import Network
from PIL import Image
from torchvision import transforms

model_pth = "<PATH-TO-MODEL>/200_Network.pth"
input_image_pth = "<PATH-TO-DATASET_PARENT_DIT>/02323.jpg"


def parse_config():
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str,
                        default='../config/inpainting_places2.json', help='JSON file for configuration')
    parser.add_argument('-p', '--phase', type=str,
                        choices=['train', 'test'], help='Run train or test', default='test')
    parser.add_argument('-b', '--batch', type=int,
                        default=16, help='Batch size in every gpu')
    parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
    parser.add_argument('-d', '--debug', action='store_true')
    parser.add_argument('-P', '--port', default='21012', type=str)

    args = parser.parse_args()
    opt = Praser.parse(args)
    return opt


# config arg
opt = parse_config()
model_args = opt["model"]["which_networks"][0]["args"]

# initializa model
model = Network(**model_args)
state_dict = torch.load(model_pth)
model.load_state_dict(state_dict, strict=False)
device = torch.device('cuda:0')
model.to(device)
model.set_new_noise_schedule(phase='test')
model.eval()

tfs = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# read input and create random mask
img_pillow = Image.open(input_image_pth).convert('RGB')
img = tfs(img_pillow)
mask = get_irregular_mask([256, 256])
mask = torch.from_numpy(mask).permute(2, 0, 1)
cond_image = img*(1. - mask) + mask*torch.randn_like(img)
mask_img = img*(1. - mask) + mask

# save conditional image used a inference input
cond_image_np = tensor2img(cond_image)
Image.fromarray(cond_image_np).save("./result/cond_image.jpg")

# set device
cond_image = set_device(cond_image)
gt_image = set_device(img)
mask = set_device(mask)

# unsqueeze
cond_image = cond_image.unsqueeze(0).to(device)
gt_image = gt_image.unsqueeze(0).to(device)
mask = mask.unsqueeze(0).to(device)

# inference
with torch.no_grad():
    output, visuals = model.restoration(cond_image, y_t=cond_image,
                                        y_0=gt_image, mask=mask, sample_num=8)

# save intermediate processes
output_img = output.detach().float().cpu()
for i in range(visuals.shape[0]):
    img = tensor2img(visuals[i].detach().float().cpu())
    Image.fromarray(img).save(f"./result/process_{i}.jpg")

# save output (output should be the same as last process_{i}.jpg)
img = tensor2img(output_img)
Image.fromarray(img).save("./result/output.jpg")

Input:
02323
Masked inference input:
cond_image

Output (last 2 results in the process):
process_7

process_8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests