# 必要パッケージインストール

In [1]:
!pip install kornia

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting kornia
  Downloading kornia-0.6.9-py2.py3-none-any.whl (569 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m569.1/569.1 KB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: kornia
Successfully installed kornia-0.6.9


# インポート

In [2]:
import kornia
import torch
from torch import nn

from tqdm import tqdm

import urllib.request as request;

# Load some scripts from remote.
exec(request.urlopen('https://github.com/mingcv/Bread_Colab/raw/main/colab_utils.py').read(), globals())
exec(request.urlopen(locate_resource('networks.py')).read(), globals())

# モデル重みダウンロード

In [3]:
# Download trained model weights from remote.
download_url_to_file(locate_resource('checkpoints/IANet_335.pth'))
download_url_to_file(locate_resource('checkpoints/NSNet_422.pth'))
download_url_to_file(locate_resource('checkpoints/FuseNet_CA_MEF_251.pth'))
download_url_to_file(locate_resource('checkpoints/FuseNet_FD_297.pth'))

100%|██████████| 3.25M/3.25M [00:00<00:00, 51.9MB/s]
100%|██████████| 3.25M/3.25M [00:00<00:00, 72.2MB/s]
100%|██████████| 874k/874k [00:00<00:00, 28.0MB/s]
100%|██████████| 872k/872k [00:00<00:00, 21.1MB/s]


# モデル定義、重み読み込み

In [4]:
# Defination of the Bread Framework.
class ModelBreadNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-6
        self.model_ianet = IAN(in_channels=1, out_channels=1)
        self.model_nsnet = ANSN(in_channels=2, out_channels=1)
        self.model_canet = FuseNet(in_channels=4, out_channels=2)

        self.load_weight(self.model_ianet, './IANet_335.pth')
        self.load_weight(self.model_nsnet, './NSNet_422.pth')
        self.load_weight(self.model_canet, './FuseNet_CA_MEF_251.pth')

    def load_weight(self, model, weight_pth):
        if model is not None:
            state_dict = torch.load(weight_pth)
            ret = model.load_state_dict(state_dict, strict=True)
            print(ret)

    def noise_syn_exp(self, illumi, strength):
        return torch.exp(-illumi) * strength

    def forward(self, image, gamma=1., strength=0.1):
        # Color space mapping
        texture_in, cb_in, cr_in = torch.split(kornia.color.rgb_to_ycbcr(image), 1, dim=1)

        # Illumination prediction
        texture_in_down = F.interpolate(texture_in, scale_factor=0.5, mode='bicubic', align_corners=True)
        texture_illumi = self.model_ianet(texture_in_down)
        texture_illumi = F.interpolate(texture_illumi, scale_factor=2, mode='bicubic', align_corners=True)

        # Illumination adjustment
        texture_illumi = torch.clamp(texture_illumi ** gamma, 0., 1.)
        texture_ia = texture_in / torch.clamp_min(texture_illumi, self.eps)
        texture_ia = torch.clamp(texture_ia, 0., 1.)

        # Noise suppression and fusion
        attention = self.noise_syn_exp(texture_illumi, strength)
        texture_res = self.model_nsnet(torch.cat([texture_ia, attention], dim=1))
        texture_ns = texture_ia + texture_res

        # Further preserve the texture under brighter illumination
        texture_ns = texture_illumi * texture_in + (1 - texture_illumi) * texture_ns
        texture_ns = torch.clamp(texture_ns, 0, 1)

        # Color adaption
        colors = self.model_canet(
            torch.cat([texture_in, cb_in, cr_in, texture_ns], dim=1))
        cb_out, cr_out = torch.split(colors, 1, dim=1)
        cb_out = torch.clamp(cb_out, 0, 1)
        cr_out = torch.clamp(cr_out, 0, 1)

        # Color space mapping
        image_out = kornia.color.ycbcr_to_rgb(
            torch.cat([texture_ns, cb_out, cr_out], dim=1))

        # Further preserve the color under brighter illumination
        img_fusion = texture_illumi * image + (1 - texture_illumi) * image_out
        _, cb_fuse, cr_fuse = torch.split(kornia.color.rgb_to_ycbcr(img_fusion), 1, dim=1)
        image_out = kornia.color.ycbcr_to_rgb(
            torch.cat([texture_ns, cb_fuse, cr_fuse], dim=1))
        image_out = torch.clamp(image_out, 0, 1)

        # outputs: texture_ia, texture_ns, image_out, texture_illumi, texture_res
        return image_out

model = ModelBreadNet().eval().cuda()

<All keys matched successfully>
<All keys matched successfully>
<All keys matched successfully>


# ONNX変換

In [5]:
def convert_to_onnx(net, file_name='output.onnx', input_shape=(512, 512), device='cpu'):
    input_image = torch.randn(1, 3, input_shape[1], input_shape[0]).to(device)
    gamma = 1.0
    strength = 0.1

    input_layer_names = ['input_image', 'gamma', 'strength']
    output_layer_names = ['output_image']

    torch.onnx.export(
        net, 
        (input_image, gamma, strength),
        file_name, 
        verbose=True,
        opset_version=13,
    )

In [None]:
convert_to_onnx(model, file_name='Bread_320x240.onnx', input_shape=(320, 240), device='cuda:0')
convert_to_onnx(model, file_name='Bread_640x360.onnx', input_shape=(640, 360), device='cuda:0')
convert_to_onnx(model, file_name='Bread_416x416.onnx', input_shape=(416, 416), device='cuda:0')
convert_to_onnx(model, file_name='Bread_512x512.onnx', input_shape=(512, 512), device='cuda:0')