<a href="https://colab.research.google.com/github/L00NE/loone_tmp_scripts/blob/main/notebooks/ImageTextCoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

 # Description:
 This is demo inference code of the paper **LMM-driven Semantic Text-Image Coding for Ultra Low-bitrate Learned Image Compression**

 Main codes are from MISC repository (https://github.com/lcysyzxdxc/MISC ) but I modified some code so that it works on newer version of DiffBIR.

 # Usage:
 To run DiffBIR, colab PRO is required. Set L4 GPU & 64GB RAM in runtime setting to run this code.

# Install Requirements

In [1]:
# %cd /content
# !git clone https://github.com/user475289/ImageTextCoding
# %cd /content/ImageTextCoding
# !mkdir weights
# !pip install -q einops pytorch_lightning==1.9.5 torch==2.1.0 gradio omegaconf transformers lpips segment_anything #opencv-python
# !pip install -q https://download.pytorch.org/whl/cu121/xformers-0.0.22.post7-cp310-cp310-manylinux2014_x86_64.whl
# !pip install -q git+https://github.com/mlfoundations/open_clip@v2.20.0

# !pip install compressai
# !apt -y install -qq aria2
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/general_full_v1.ckpt -d /content/ImageTextCoding/weights -o general_full_v1.ckpt
# !aria2c --console-log-level=error -c -x 16 -s 16 -k 1M https://huggingface.co/lxq007/DiffBIR/resolve/main/general_swinir_v1.ckpt -d /content/ImageTextCoding/weights -o general_swinir_v1.ckpt

In [None]:

import math
import io
import torch
from torchvision import transforms
import numpy as np
import os

from PIL import Image, ImageChops

import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:

import sys
import os
from clip import clip
import torch
import cv2
import numpy as np
from PIL import Image
from  matplotlib import pyplot as plt
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
from segment_anything import sam_model_registry, SamPredictor

preprocess =  Compose([Resize((224, 224), interpolation=BICUBIC), ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
seg_model, preprocess = clip.load("CS-ViT-B/16", device=device)
#NOTE: This model is in CLIP-surgery. Not included in clip-openai.
seg_model.eval()

# Utility Functions

In [None]:
from skimage.transform import resize
from utils.image import wavelet_reconstruction

import shutil
def clear_directory(directory_path):
# 再帰的にディレクトリを消去する
# Recurrently delete specified directory
    for file in os.listdir(directory_path):
        file_path = os.path.join(directory_path, file)
        if os.path.isdir(file_path):

            for root, dirs, files in os.walk(file_path):
                for file in files:
                    os.remove(os.path.join(root, file))



def to_block(img,grid=32,level=8):
# 画像を決定的にダウンサンプルする
    g_w=int(img.size[0]/grid)
    g_h=int(img.size[1]/grid)# 縦横とも 1/grid 倍する
    img_resize=img.resize((g_w, g_h))
    img_np=np.floor(np.array(img_resize)/level)*level
    img_np=img_np.astype (np.uint8)
    img_reference = Image.fromarray(img_np).resize(img.size)
    return img_reference, g_w, g_h


def divide_integer(num, n):
    quotient = num // n  # 整数除法，计算商
    remainder = num % n  # 取余数
    result = [quotient] * n  # 创建一个包含n个quotient的列表

    # 将余数平均分配给前几个数
    for i in range(remainder):
        result[i] += 1

    return result
def mask_block(mask,num=8,level=0.35):
    tmp=resize(mask, (num, num), mode='reflect')
    tmp[tmp>level]=255
    tmp[tmp<=level]=0
    rp_mat_0=np.array(divide_integer(mask.shape[0], num),dtype='int')
    rp_mat_1=np.array(divide_integer(mask.shape[1], num),dtype='int')
    return tmp.repeat(rp_mat_1,axis=1).repeat(rp_mat_0,axis=0)

def image_paddle_in(image, num=32):
    # 计算扩充后的宽度和高度
    new_width = ((image.width-1) // num + 1) * num
    new_height = ((image.height-1) // num + 1) * num

    # 创建一个新的扩充后的图像，用空值填充
    new_image = Image.new("RGB", (new_width, new_height), (0, 0, 0))

    # 将原始图像粘贴到扩充后的图像左上角
    new_image.paste(image, (0, 0))
    return new_image,image.width,image.height

def image_paddle_out(image, old_width, old_height):

    return image.crop((0,0,old_width,old_height))

In [None]:


def compute_bpp(out_net):
    size = out_net['x_hat'].size()
    num_pixels = size[0] * size[2] * size[3]
    return sum(torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)
              for likelihoods in out_net['likelihoods'].values()).item()

def clip_map(img,texts,mask_num=8):
    image = preprocess(img).unsqueeze(0).to(device)

    with torch.no_grad():
        # CLIP architecture surgery acts on the image encoder
        cv2_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
        image_features = seg_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)

        # Prompt ensemble for text features with normalization
        text_features = clip.encode_text_with_prompt_ensemble(seg_model, texts, device)

        # Extract redundant features from an empty string
        redundant_features = clip.encode_text_with_prompt_ensemble(seg_model, [""], device)

        # Apply feature surgery for single text
        similarity = clip.clip_feature_surgery(image_features, text_features, redundant_features)
        similarity_map = clip.get_similarity_map(similarity[:, 1:, :], cv2_img.shape[:2])

        mask_0=(similarity_map[0,:,:,0].cpu().numpy() * 255).astype('uint8')
        mask_1=(similarity_map[0,:,:,1].cpu().numpy() * 255).astype('uint8')
        mask_2=(similarity_map[0,:,:,2].cpu().numpy() * 255).astype('uint8')
        mask_0=Image.fromarray(mask_block(mask_0,num=mask_num))
        mask_1=Image.fromarray(mask_block(mask_1,num=mask_num))
        mask_2=Image.fromarray(mask_block(mask_2,num=mask_num))
        return mask_0,mask_1,mask_2

def sr_pipe(img_reference,positive_prompt="",cfg=1.0,steps=40,res=512, cond_scale = 1.0, old_size = None):
    control_img = img_reference
    sr_scale = 1
    num_samples = 1
    #image_size = old_size
    disable_preprocess_model= False
    strength = 1.0
    cond_scale = 1.0
    use_color_fix = True
    keep_original_size = False
    negative_prompt="Blurry, Low Quality"
    sampler = SpacedSampler(model, var_type="fixed_small")

    if sr_scale != 1:
        control_img = control_img.resize(
            tuple(math.ceil(x * sr_scale) for x in control_img.size),
            Image.BICUBIC
        )
    input_size = control_img.size
    #control_img = auto_resize(control_img, image_size)
    h, w = control_img.height, control_img.width
    control_img = pad(np.array(control_img), scale=64) # HWC, RGB, [0, 255]
    control_imgs = [control_img] * num_samples
    control = torch.tensor(np.stack(control_imgs) / 255.0, dtype=torch.float32, device=model.device).clamp_(0, 1)
    control = einops.rearrange(control, "n h w c -> n c h w").contiguous()
    if not disable_preprocess_model:
        control = model.preprocess_model(control)
    height, width = control.size(-2), control.size(-1)
    cond = {
        "c_latent": [model.apply_condition_encoder(control)],
        "c_crossattn": [model.get_learned_conditioning([positive_prompt] * num_samples)]
    }
    uncond = {
        "c_latent": [model.apply_condition_encoder(control)],
        "c_crossattn": [model.get_learned_conditioning([negative_prompt] * num_samples)]
    }
    model.control_scales = [strength] * 13

    shape = (num_samples, 4, height // 8, width // 8)
    print(f"latent shape = {shape}")
    x_T = torch.randn(shape, device=model.device, dtype=torch.float32)
    samples = sampler.sample(
        steps, shape, cond,
        unconditional_guidance_scale=cond_scale,
        unconditional_conditioning=uncond,
        cond_fn=None, x_T=x_T
    )
    x_samples = model.decode_first_stage(samples)
    x_samples = ((x_samples + 1) / 2).clamp(0, 1)

    # apply color correction
    if use_color_fix:
        x_samples = wavelet_reconstruction(x_samples, control)

    x_samples = (einops.rearrange(x_samples, "b c h w -> b h w c") * 255).cpu().numpy().clip(0, 255).astype(np.uint8)
    preds = []
    for img in x_samples:
        if keep_original_size:
            # remove padding and resize to input size
            img = Image.fromarray(img[:h, :w, :]).resize(input_size, Image.LANCZOS)
            preds.append(np.array(img))
        else:
            # remove padding
            preds.append(img[:h, :w, :])
    return preds

In [None]:
!wget https://github.com/user475289/ImageTextCoding/releases/download/v0.1-alpha/10ep_cheng_3_mse0.5_vgg0.2_i2t0.2_iqa0.1.tar

In [None]:
mode='tuned_net'
using_map=False

if mode=='net':
    from compressai.zoo import cheng2020_attn
    comp_net = cheng2020_attn(pretrained=True, quality = 1).to(device)
elif mode=='tuned_net':
    checkpoint_path= '/content/ImageTextCoding/10ep_cheng_3_mse0.5_vgg0.2_i2t0.2_iqa0.1.tar'
    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    state_dict = checkpoint
    for key in ["network", "state_dict", "model_state_dict"]:
        if key in checkpoint:
            state_dict = checkpoint[key]
    arch='cheng2020-attn'
    from compressai.zoo.image import model_architectures as architectures
    model_cls = architectures[arch]
    comp_net = model_cls.from_state_dict(state_dict).eval().to(device)
elif mode=='ref':
    ref_path='./ref/example-reference.png'
    ref_bpp=0.0421
elif mode=='pixel':
    block_level=3
    block_num_min=32


mask_num=8
res=1024

import pandas as pd
df = pd.read_csv("/content/ImageTextCoding/kodak_llava_1.5.csv")

image_path='/content/ImageTextCoding/kodim15.png'


img = Image.open(image_path).convert('RGB')

In [None]:
########################## Encoder ##########################

#GPT prompt processing
'''
prompt=df['prompt'][0]
prompt_list=prompt.split('\n')
prompt_list = [element for element in prompt_list if element != '']
name_0,name_1,name_2=prompt_list[0].split('.')[0].split(',')
detail_0,detail_1,detail_2=prompt_list[1],prompt_list[2],prompt_list[3]
detail_all=prompt_list[4]
'''
name_0 = df['item1'][0]
name_1 = df['item2'][0]
name_2 = df['item3'][0]
detail_0 = df['item1_description'][0]
detail_1 = df['item2_description'][0]
detail_2 = df['item3_description'][0]
detail_all = df['overall_description'][0]

mask_0,mask_1,mask_2=clip_map(img,[name_0,name_1,name_2],mask_num)


#reference
if mode=='pixel':
    old_width, old_height=img.size
    block_num=max(int(max(old_width, old_height)/16),block_num_min)
    img_reference=to_block(img,block_num,2**block_level)

    b_image=block_level*block_num**2
elif mode=='net' or mode =='tuned_net':
    img, old_width, old_height = image_paddle_in(img, 64)
    x = transforms.ToTensor()(img).unsqueeze(0).to(device)
    with torch.no_grad():
        out_net = comp_net.forward(x)
    out_net['x_hat'].clamp_(0, 1)
    img_reference = transforms.ToPILImage()(out_net['x_hat'].squeeze().cpu())
    img_reference = image_paddle_out(img_reference, old_width, old_height)
    b_image=compute_bpp(out_net)*img.size[0]*img.size[1]
elif mode=='ref':
    old_width, old_height=img.size
    img_reference = Image.open(ref_path).convert('RGB')
    b_image=ref_bpp*img.size[0]*img.size[1]


#############################################################

In [None]:
from utils.image import auto_resize, pad
import einops
from utils.common import instantiate_from_config, load_state_dict
from omegaconf import OmegaConf
from model.cldm import ControlLDM
from model.spaced_sampler import SpacedSampler
model: ControlLDM = instantiate_from_config(OmegaConf.load('./configs/model/cldm.yaml'))
ckpt_swinir='./weights/general_full_v1.ckpt'
#ckpt_net='./weights/cheng_small.pth.tar'


In [None]:
load_state_dict(model, torch.load(ckpt_swinir, map_location="cpu"), strict=True)
model.freeze()
model.to(device)

In [None]:
########################## Decoder ##########################
num_inference_steps=40
exag=1024/max(img_reference.size)
height=int(img_reference.size[1]*exag/8)*8
width=int(img_reference.size[0]*exag/8)*8

#    img_reference=img_reference.resize([width,height])
mask_0=mask_0.resize([width,height])
mask_1=mask_1.resize([width,height])
mask_2=mask_2.resize([width,height])
mask_all=Image.new("RGB", img_reference.size, (255, 255, 255))

image = img_reference

if using_map:
    b_mask=mask_num*mask_num*3
    b_word=(len(detail_0)+len(detail_1)+len(detail_2)+len(detail_all))*8
    bpp=(b_image+b_mask+b_word)/(img.size[0]*img.size[1])
    print('bpp='+str(bpp))

    image_tmp = sr_pipe(image,positive_prompt=detail_0,cfg=3.5,steps=3,res=res)
    image = ImageChops.add(ImageChops.multiply(image_tmp,mask_0.convert("RGB")),
                          ImageChops.multiply(image,Image.fromarray(255-np.array(mask_0)).convert("RGB"))
                          ).resize((old_width, old_height))
#    image.resize((old_width, old_height)).save(output_folder+'Mask0/'+image_name)

    image_tmp = sr_pipe(image,positive_prompt=detail_1,cfg=3.5,steps=3,res=res)
    image = ImageChops.add(ImageChops.multiply(image_tmp,mask_1.convert("RGB")),
                          ImageChops.multiply(image,Image.fromarray(255-np.array(mask_1)).convert("RGB"))
                          ).resize((old_width, old_height))
#    image.resize((old_width, old_height)).save(output_folder+'Mask1/'+image_name)

    image_tmp = sr_pipe(image,positive_prompt=detail_2,cfg=3.5,steps=3,res=res)
    image = ImageChops.add(ImageChops.multiply(image_tmp,mask_2.convert("RGB")),
                          ImageChops.multiply(image,Image.fromarray(255-np.array(mask_2)).convert("RGB"))
                          ).resize((old_width, old_height))
#    image.resize((old_width, old_height)).save(output_folder+'Mask2/'+image_name)

    image = sr_pipe(image,positive_prompt=detail_all,cfg=7,steps=40,res=res).resize((old_width, old_height))
#    image.resize((old_width, old_height)).save(output_folder+'SR/'+image_name)

else:
    b_word=(len(detail_all))*8
    bpp=(b_image+b_word)/(img.size[0]*img.size[1])
    print('image bit = ', b_image, 'text bit = ', b_word, 'bpp=', str(bpp))

    image = sr_pipe(image,positive_prompt=detail_all,cfg=7,steps=40,res=res,
                    old_size = (old_width, old_height))
    #output_image = image.resize((old_width, old_height))

# Output Image (Compressed & Decompressed & Deblurred)

In [None]:
image[0]

# Original Image

In [None]:
img

In [None]:
import torchvision.transforms.functional as TF
from lpips import LPIPS


def psnr(img0, img1):
    mse = np.mean((img0 - img1) ** 2)
    return 10 * np.log10(255 ** 2 / mse)

def lpips(img0, img1):
    # Variables im0, im1 is a PyTorch Tensor/Variable with shape Nx3xHxW
    # (N patches of size HxW, RGB images scaled in [-1,+1])
    loss_fn_vgg = LPIPS(net='vgg')
    img0 = (TF.to_tensor(img0) - 0.5) * 2
    img0.unsqueeze(0)

    img1 = (TF.to_tensor(img1) - 0.5) * 2
    img1.unsqueeze(0)
    # Higher means further/more different. Lower means more similar.
    return loss_fn_vgg(img0, img1).item()



In [None]:
img = np.array(img, dtype = np.float32)
image[0] = np.array(image[0], dtype = np.float32)

In [None]:
print("lpips: ", lpips(img, image[0]))
print("psnr: ", psnr(img, image[0]))