In [7]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
nst_vangogh_local_progress.py
Neural Style Transfer — 10ステップごとに進行中画像を保存
"""

import os
from PIL import Image, ImageOps
Image.MAX_IMAGE_PIXELS = None  # 巨大画像対応

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms


"""
画風を強くしたい: STYLE_WEIGHT を上げる（例: 1e5 → 3e5）
内容を残したい: CONTENT_WEIGHT を上げる
解像度を上げたい: IMAGE_SIZE を上げる（GPUメモリと実行時間に注意）
収束が荒い場合: NUM_STEPS を増やす、USE_LBFGS=True を推奨（高品質）
"""

# ====== 設定 ======
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 1024 if torch.cuda.is_available() else 256
STYLE_PATH   = "Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"
CONTENT_PATH = "Central_Park_in_Shinjuku_Ward_Tokyo_20250824104812_01.png"
OUTPUT_PATH  = "output.jpg"

CONTENT_WEIGHT = 4.0
STYLE_WEIGHT   = 1e5
NUM_STEPS      = 300
SAVE_INTERVAL  = 1        # ← ここで保存間隔を指定（10ステップごと）
USE_LBFGS      = True
LR_ADAM        = 1e-1

# ====== 基本ユーティリティ ======
def require_file(path, label):
    if not os.path.exists(path):
        raise FileNotFoundError(f"[{label}] {path} が見つかりません。")

def safe_load_and_resize(path, imsize=IMAGE_SIZE):
    with Image.open(path) as im:
        im = ImageOps.exif_transpose(im).convert("RGB")
        im.thumbnail((imsize, imsize), Image.LANCZOS)
        to_tensor = transforms.ToTensor()
        t = to_tensor(im).unsqueeze(0).to(DEVICE, torch.float)
    return t

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(DEVICE)
cnn_normalization_std  = torch.tensor([0.229, 0.224, 0.225]).to(DEVICE)

def save_progress_image(tensor, step):
    """10ステップごとに進行画像を保存"""
    x = tensor.detach().cpu().squeeze(0)
    x = x * cnn_normalization_std.view(3,1,1) + cnn_normalization_mean.view(3,1,1)
    x = torch.clamp(x, 0, 1)
    img = transforms.ToPILImage()(x)
    fname = f"output_step{step}.jpg"
    img.save(fname)
    print(f"[SAVE] {fname}")

def unnormalize_and_save(tensor, path):
    x = tensor.detach().cpu().squeeze(0)
    x = x * cnn_normalization_std.view(3,1,1) + cnn_normalization_mean.view(3,1,1)
    x = torch.clamp(x, 0, 1)
    transforms.ToPILImage()(x).save(path)
    print(f"[SAVE] {path}")

def gram_matrix(x):
    b, c, h, w = x.size()
    F = x.view(b, c, h*w)
    return torch.bmm(F, F.transpose(1,2)) / (c*h*w)

# ====== モデル定義 ======
content_layers_default = ['conv_4']
style_layers_default   = ['conv_1','conv_2','conv_3','conv_4','conv_5']

class Normalization(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = mean.clone().detach().view(-1,1,1)
        self.std  = std.clone().detach().view(-1,1,1)
    def forward(self, img):
        return (img - self.mean) / self.std

class StyleTransferModel(nn.Module):
    def __init__(self, cnn, mean, std,
                 content_layers=content_layers_default,
                 style_layers=style_layers_default):
        super().__init__()
        self.content_layers = content_layers
        self.style_layers   = style_layers
        self.model = nn.Sequential(Normalization(mean, std))
        i = 0
        for layer in cnn.children():
            if isinstance(layer, nn.Conv2d):
                i += 1; name = f'conv_{i}'
            elif isinstance(layer, nn.ReLU):
                name = f'relu_{i}'; layer = nn.ReLU(inplace=False)
            elif isinstance(layer, nn.MaxPool2d):
                name = f'pool_{i}'
            elif isinstance(layer, nn.BatchNorm2d):
                name = f'bn_{i}'
            else:
                name = f'layer_{i}'
            self.model.add_module(name, layer)
            if i >= 5 and isinstance(layer, nn.ReLU):
                break
    def forward(self, x):
        c_feats, s_feats = {}, {}
        for name, layer in self.model._modules.items():
            x = layer(x)
            if name in self.content_layers: c_feats[name] = x
            if name in self.style_layers:   s_feats[name] = x
        return c_feats, s_feats

# ====== メイン最適化 ======
def run_style_transfer(cnn, mean, std, content_img, style_img, input_img,
                       steps=NUM_STEPS, cw=CONTENT_WEIGHT, sw=STYLE_WEIGHT):
    model = StyleTransferModel(cnn, mean, std).to(DEVICE).eval()
    mse = nn.MSELoss()
    with torch.no_grad():
        c_ref, _ = model(content_img)
        _, s_ref = model(style_img)
        s_grams = {l: gram_matrix(s_ref[l]) for l in s_ref}
    x = input_img.clone().requires_grad_(True)

    if USE_LBFGS:
        opt = optim.LBFGS([x])
        run = [0]
        pbar = tqdm(total=steps, desc="optim(L-BFGS)")
        def closure():
            opt.zero_grad()
            c_out, s_out = model(x)
            c_loss = sum(mse(c_out[l], c_ref[l]) for l in c_out)
            s_loss = sum(mse(gram_matrix(s_out[l]), s_grams[l]) for l in s_out)
            loss = cw*c_loss + sw*s_loss
            loss.backward()
            run[0] += 1
            if run[0] % SAVE_INTERVAL == 0:
                save_progress_image(x, run[0])
            if pbar.n < steps:
                pbar.update(min(SAVE_INTERVAL, steps - pbar.n))
            return loss
        opt.step(closure)
        pbar.close()
    else:
        opt = optim.Adam([x], lr=LR_ADAM)
        pbar = tqdm(range(steps), desc="optim(Adam)")
        for i in pbar:
            opt.zero_grad()
            c_out, s_out = model(x)
            c_loss = sum(mse(c_out[l], c_ref[l]) for l in c_out)
            s_loss = sum(mse(gram_matrix(s_out[l]), s_grams[l]) for l in s_out)
            loss = cw*c_loss + sw*s_loss
            loss.backward()
            opt.step()
            if (i+1) % SAVE_INTERVAL == 0:
                save_progress_image(x, i+1)
    return x.detach()

def main():
    require_file(STYLE_PATH,   "style")
    require_file(CONTENT_PATH, "content")

    content_img = safe_load_and_resize(CONTENT_PATH, IMAGE_SIZE)
    style_img   = safe_load_and_resize(STYLE_PATH,   IMAGE_SIZE)

    print("[INFO] VGG19 features をロード")
    vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features.to(DEVICE).eval()

    input_img = content_img.clone()
    output = run_style_transfer(vgg, cnn_normalization_mean, cnn_normalization_std,
                                content_img, style_img, input_img,
                                steps=NUM_STEPS, cw=CONTENT_WEIGHT, sw=STYLE_WEIGHT)

    unnormalize_and_save(output, OUTPUT_PATH)
    print("[DONE] 出力:", OUTPUT_PATH)

if __name__ == "__main__":
    main()


[INFO] VGG19 features をロード




[SAVE] output_step1.jpg




[SAVE] output_step2.jpg




[SAVE] output_step3.jpg




[SAVE] output_step4.jpg




[SAVE] output_step5.jpg




[SAVE] output_step6.jpg




[SAVE] output_step7.jpg




[SAVE] output_step8.jpg




[SAVE] output_step9.jpg




[SAVE] output_step10.jpg




[SAVE] output_step11.jpg




[SAVE] output_step12.jpg




[SAVE] output_step13.jpg




[SAVE] output_step14.jpg




[SAVE] output_step15.jpg




[SAVE] output_step16.jpg




[SAVE] output_step17.jpg




[SAVE] output_step18.jpg




[SAVE] output_step19.jpg


optim(L-BFGS):   7%|▋         | 20/300 [00:31<07:26,  1.60s/it]

[SAVE] output_step20.jpg
[SAVE] output.jpg
[DONE] 出力: output.jpg



