In [None]:
!pip install torch torchvision diffusers transformers lpips accelerate
!pip install git+https://github.com/huggingface/diffusers.git


In [None]:
import torch

# Use GPU if available, else CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


In [None]:
from diffusers import AutoencoderKL

# Pretrained latent autoencoder (used in Stable Diffusion)
vq_model = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
vq_model.eval()

def encode_content(img_tensor):
    """
    Encode clean image into semantic latent using pretrained VQ-Diffusion autoencoder.
    """
    with torch.no_grad():
        latent = vq_model.encode(img_tensor).latent_dist.sample()
    return latent


In [None]:
import torchvision.models as models
import torch.nn as nn

class SPN(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet.children())[:-2])  # keep spatial features

    def forward(self, x):
        return self.features(x)

spn = SPN().to(device)


In [None]:
# Single reference image latent (random initialization)
z_style = torch.randn((1, 4, 64, 64), device=device, requires_grad=True)
optimizer = torch.optim.LBFGS([z_style], lr=0.01)


In [None]:
from diffusers import StableDiffusionImg2ImgPipeline

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
).to(device)

# pipe.safety_checker = lambda images, **kwargs: (images, False)
# Modify the safety checker to return a list of booleans
pipe.safety_checker = lambda images, **kwargs: (images, [False])

In [None]:
from transformers import CLIPProcessor, CLIPModel
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def clip_directional_loss(img_out, img_style, img_clean):
    # Clamp strictly to [0,1]
    img_out = torch.clamp(img_out, 0.0, 1.0)
    img_style = torch.clamp(img_style, 0.0, 1.0)
    img_clean = torch.clamp(img_clean, 0.0, 1.0)

    # Convert tensors to PIL Images (RGB)
    tensor_to_pil = transforms.ToPILImage()
    img_out_pil = tensor_to_pil(img_out.squeeze(0).cpu()).convert("RGB")
    img_style_pil = tensor_to_pil(img_style.squeeze(0).cpu()).convert("RGB")
    img_clean_pil = tensor_to_pil(img_clean.squeeze(0).cpu()).convert("RGB")

    # Now use clip_processor with PIL images
    inputs_out = clip_processor(images=img_out_pil, return_tensors="pt").to(device)
    inputs_style = clip_processor(images=img_style_pil, return_tensors="pt").to(device)
    inputs_clean = clip_processor(images=img_clean_pil, return_tensors="pt").to(device)

    emb_out = clip_model.get_image_features(**inputs_out)
    emb_style = clip_model.get_image_features(**inputs_style)
    emb_clean = clip_model.get_image_features(**inputs_clean)

    dir_out = emb_out - emb_clean
    dir_style = emb_style - emb_clean

    dir_out = dir_out / dir_out.norm(dim=-1, keepdim=True)
    dir_style = dir_style / dir_style.norm(dim=-1, keepdim=True)

    loss = 1 - (dir_out * dir_style).sum()
    return loss


In [None]:
from PIL import Image
from torchvision import transforms
import torch

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# Load image as PIL.Image (RGB) for Stable Diffusion
def load_pil_image(path, size=512):
    img = Image.open(path).convert("RGB")   # ensure RGB mode
    img = img.resize((size, size))
    return img

# Convert tensor [1,3,H,W] to PIL.Image
def tensor_to_pil(tensor):
    tensor = tensor.detach().cpu().squeeze(0).clamp(0,1)
    pil_img = transforms.ToPILImage()(tensor)
    return pil_img.convert("RGB")

# Load image as tensor for SPN / CLIP (normalized [0,1])
def load_tensor_image(path, size=512):
    pil_img = load_pil_image(path, size)
    transform = transforms.ToTensor()
    tensor = transform(pil_img).unsqueeze(0)  # [1,3,H,W]
    return tensor.to(device)


In [None]:
# PIL images for Stable Diffusion
pil_clean = load_pil_image("/content/clean.jpg")
pil_style = load_pil_image("/content/hidemotionblur.jpg")

# Tensors for SPN / CLIP
clean_tensor = load_tensor_image("/content/clean.jpg")  # GPU
style_tensor = load_tensor_image("/content/hidemotionblur.jpg")  # GPU

# Ensure z_style is on same device
z_style = z_style.to(device)


In [None]:
# Models
vq_model = vq_model.to(device)
spn = spn.to(device)
clip_model = clip_model.to(device)
pipe = pipe.to(device)

# Tensors


In [None]:
import torch.nn as nn
import torchvision.transforms as transforms

mse_loss = nn.MSELoss()

# Compute SPN / VQ latent features (GPU tensors)
z_struct = spn(clean_tensor)
z_content = encode_content(clean_tensor)

optimizer = torch.optim.Adam([z_style], lr=1e-4)

for step in range(50):
    optimizer.zero_grad()

    # Stable Diffusion Img2Img expects PIL.Image (RGB, CPU)
    # Convert the clean_tensor to a PIL Image
    clean_pil_image = tensor_to_pil(clean_tensor)

    output_imgs = pipe(
        prompt="",
        image=clean_pil_image,     # PIL.Image, CPU
        strength=0.4,
        guidance_scale=12,
        num_inference_steps=50
    ).images

    # Convert output back to GPU tensor for loss calculation
    output_img = transforms.ToTensor()(output_imgs[0]).unsqueeze(0).to(device)

    # Clamp pixel values to [0, 1] before passing to CLIP directional loss
    output_img_clamped = torch.clamp(output_img, 0.0, 1.0).float()
    style_tensor_clamped = torch.clamp(style_tensor, 0.0, 1.0).float()
    clean_tensor_clamped = torch.clamp(clean_tensor, 0.0, 1.0).float()



    # Compute losses
    content_loss = mse_loss(output_img, clean_tensor)
    style_loss = mse_loss(output_img, style_tensor)
    clip_loss_val = clip_directional_loss(output_img_clamped, style_tensor_clamped, clean_tensor_clamped)


    total_loss = 1*content_loss + 100*style_loss + 100*clip_loss_val
    total_loss.backward()
    optimizer.step()

    print(f"Step {step}, Total Loss: {total_loss.item():.4f}")

In [None]:
# ------------------------
# One-shot finalization and metrics
# ------------------------
import os
import torch
import lpips
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
import torchvision.transforms as T
from IPython.display import display

out_dir = "/content/osasis_outputs"
os.makedirs(out_dir, exist_ok=True)

# Final stylized output (from last iteration of your loop)
final_pil = output_imgs[0]
final_path = os.path.join(out_dir, "stylized_final.png")
final_pil.save(final_path)
print("âœ… Saved stylized image:", final_path)

# ---- Metrics ----
to_tensor = T.ToTensor()
to_norm_lpips = lambda t: (t * 2.0 - 1.0)  # map [0,1] -> [-1,1]

lpips_fn = lpips.LPIPS(net='vgg').to(device)

stylized_tensor = to_tensor(final_pil).unsqueeze(0).to(device)
clean_tensor_for_eval = clean_tensor.clone()

lpips_score = lpips_fn(
    to_norm_lpips(stylized_tensor),
    to_norm_lpips(clean_tensor_for_eval)
).item()

styl_np = np.array(final_pil).astype(np.float32) / 255.0
clean_np = np.transpose(clean_tensor_for_eval.squeeze(0).cpu().numpy(), (1,2,0))
ssim_score = ssim(styl_np, clean_np, channel_axis=2, data_range=1.0) # Use channel_axis instead of multichannel

print(f"ðŸ“Š LPIPS (stylized vs clean): {lpips_score:.4f}")
print(f"ðŸ“Š SSIM  (stylized vs clean): {ssim_score:.4f}")

# ---- Visual sanity check ----
print("Displaying final generated image:")
display(final_pil)

### Gram Matrix Calculation

In [None]:
import torch.nn as nn

def gram_matrix(input):
    a, b, c, d = input.size()  # a=batch size(=1), b=number of feature maps, (c,d)=dimensions of a f. map (N=c*d)

    features = input.view(a * b, c * d)  # resise F_XL into \hat F_XL

    G = torch.mm(features, features.t())  # compute the gram product

    return G.div(a * b * c * d) # normalize the values of the gram matrix

In [None]:
# Calculate Gram matrices for the stylized and style images using the SPN model's features
stylized_features = spn(stylized_tensor)
style_features = spn(style_tensor)

gram_stylized = gram_matrix(stylized_features)
gram_style = gram_matrix(style_features)

print("âœ… Calculated Gram matrices for stylized and style images.")

### Gram Matrix Distance

In [None]:
# Calculate the distance between the Gram matrices (e.g., using Mean Squared Error)
gram_loss = mse_loss(gram_stylized, gram_style)

print(f"ðŸ“Š Gram Matrix Loss (stylized vs style): {gram_loss.item():.4f}")