In [1]:
import torch
from kornia.losses import PSNRLoss, SSIMLoss
from torchmetrics.image import VisualInformationFidelity
import torch.nn.functional as F
from compressai.zoo import models
from torchvision import transforms
import PIL
import numpy as np
from skimage import io , img_as_float

In [3]:
model_names = []
for model_name in models:
    model_names.append(model_name)
print(model_names)

['bmshj2018-factorized', 'bmshj2018-factorized-relu', 'bmshj2018-hyperprior', 'mbt2018-mean', 'mbt2018', 'cheng2020-anchor', 'cheng2020-attn', 'bmshj2018-hyperprior-vbr', 'mbt2018-mean-vbr', 'mbt2018-vbr', 'hrtzxf2022-pcc-rec', 'sfu2023-pcc-rec-pointnet', 'sfu2024-pcc-rec-pointnet2-ssg', 'ssf2020']


In [4]:
max_values_of_quality = [0] * len(models)

for i, model_name in enumerate(models):
    model_class = models[model_name]
    print(f"Testing {model_name}")

    j = 1
    max_quality = None

    while max_quality is None:
        try:
            print(f"Trying quality={j}")
            model = model_class(quality=j, pretrained=True)
            j += 1
        except ValueError as e:
            # We've found the limit when ValueError is raised
            max_quality = j - 1
            print(f"Max quality for {model_name}: {max_quality}")
        except TypeError as e:
            # This model might not use quality parameter
            print(f"Model {model_name} may not use quality parameter: {e}")
            max_quality = "N/A"
        except Exception as e:
            # Handle other unexpected errors
            print(f"Error with {model_name}: {e}")
            max_quality = "Error"
        if j>20:
            break

    max_values_of_quality[i] = max_quality

Testing bmshj2018-factorized
Trying quality=1
Trying quality=2


Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-2-87279a02.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-2-87279a02.pth.tar
100%|██████████| 11.5M/11.5M [00:04<00:00, 2.87MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-3-5c6f152b.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-3-5c6f152b.pth.tar


Trying quality=3


100%|██████████| 11.6M/11.6M [00:05<00:00, 2.12MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-4-1ed4405a.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-4-1ed4405a.pth.tar


Trying quality=4


100%|██████████| 11.6M/11.6M [00:10<00:00, 1.18MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-5-866ba797.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-5-866ba797.pth.tar


Trying quality=5


100%|██████████| 11.7M/11.7M [00:05<00:00, 2.26MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-6-9b02ea3a.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-6-9b02ea3a.pth.tar


Trying quality=6


100%|██████████| 27.3M/27.3M [00:03<00:00, 7.85MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-7-6dfd6734.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-7-6dfd6734.pth.tar


Trying quality=7


100%|██████████| 27.5M/27.5M [00:12<00:00, 2.39MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-factorized-prior-8-5232faa3.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-factorized-prior-8-5232faa3.pth.tar


Trying quality=8


100%|██████████| 27.9M/27.9M [00:05<00:00, 5.31MB/s]


Trying quality=9
Max quality for bmshj2018-factorized: 8
Testing bmshj2018-factorized-relu
Trying quality=1
Trying quality=2
Trying quality=3
Trying quality=4
Trying quality=5
Trying quality=6
Trying quality=7
Trying quality=8
Trying quality=9
Max quality for bmshj2018-factorized-relu: 8
Testing bmshj2018-hyperprior
Trying quality=1
Trying quality=2


Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-2-93677231.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-2-93677231.pth.tar
100%|██████████| 20.2M/20.2M [00:05<00:00, 3.70MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-3-6d87be32.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-3-6d87be32.pth.tar


Trying quality=3


100%|██████████| 20.2M/20.2M [00:02<00:00, 7.97MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-4-de1b779c.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-4-de1b779c.pth.tar


Trying quality=4


100%|██████████| 20.2M/20.2M [00:08<00:00, 2.56MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-5-f8b614e1.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-5-f8b614e1.pth.tar


Trying quality=5


100%|██████████| 20.2M/20.2M [00:09<00:00, 2.16MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-6-1ab9c41e.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-6-1ab9c41e.pth.tar


Trying quality=6


100%|██████████| 46.0M/46.0M [00:08<00:00, 5.43MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-7-3804dcbd.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-7-3804dcbd.pth.tar


Trying quality=7


100%|██████████| 46.0M/46.0M [00:07<00:00, 6.48MB/s]


Trying quality=8


Downloading: "https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-8-a583f0cf.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\bmshj2018-hyperprior-8-a583f0cf.pth.tar
100%|██████████| 46.0M/46.0M [00:09<00:00, 4.94MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-2-e54a039d.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-2-e54a039d.pth.tar


Trying quality=9
Max quality for bmshj2018-hyperprior: 8
Testing mbt2018-mean
Trying quality=1
Trying quality=2


100%|██████████| 27.6M/27.6M [00:08<00:00, 3.41MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-3-723404a8.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-3-723404a8.pth.tar


Trying quality=3


100%|██████████| 27.6M/27.6M [00:10<00:00, 2.66MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-4-6dba02a3.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-4-6dba02a3.pth.tar


Trying quality=4


100%|██████████| 27.6M/27.6M [00:10<00:00, 2.70MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-5-d504e8eb.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-5-d504e8eb.pth.tar


Trying quality=5


100%|██████████| 67.8M/67.8M [00:15<00:00, 4.48MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-6-a19628ab.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-6-a19628ab.pth.tar


Trying quality=6


100%|██████████| 67.9M/67.9M [00:10<00:00, 6.49MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-7-d5d441d1.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-7-d5d441d1.pth.tar


Trying quality=7


100%|██████████| 67.9M/67.9M [00:17<00:00, 4.07MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-mean-8-8089ae3e.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-mean-8-8089ae3e.pth.tar


Trying quality=8


100%|██████████| 67.9M/67.9M [00:11<00:00, 6.29MB/s]


Trying quality=9
Max quality for mbt2018-mean: 8
Testing mbt2018
Trying quality=1
Trying quality=2


Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-2-43b70cdd.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-2-43b70cdd.pth.tar
100%|██████████| 61.8M/61.8M [00:14<00:00, 4.40MB/s]


Trying quality=3


Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-3-22901978.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-3-22901978.pth.tar
100%|██████████| 61.8M/61.8M [00:15<00:00, 4.27MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-4-456e2af9.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-4-456e2af9.pth.tar


Trying quality=4


100%|██████████| 61.8M/61.8M [00:10<00:00, 5.93MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-5-b4a046dd.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-5-b4a046dd.pth.tar


Trying quality=5


100%|██████████| 118M/118M [00:21<00:00, 5.82MB/s] 


Trying quality=6


Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-6-7052e5ea.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-6-7052e5ea.pth.tar
100%|██████████| 118M/118M [00:23<00:00, 5.22MB/s] 


Trying quality=7


Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-7-8ba2bf82.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-7-8ba2bf82.pth.tar
100%|██████████| 118M/118M [00:20<00:00, 6.03MB/s] 


Trying quality=8


Downloading: "https://compressai.s3.amazonaws.com/models/v1/mbt2018-8-dd0097aa.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\mbt2018-8-dd0097aa.pth.tar
100%|██████████| 118M/118M [00:14<00:00, 8.73MB/s] 


Trying quality=9
Max quality for mbt2018: 8
Testing cheng2020-anchor
Trying quality=1
Trying quality=2


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-2-a29008eb.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020-anchor-2-a29008eb.pth.tar
100%|██████████| 49.1M/49.1M [00:11<00:00, 4.52MB/s]


Trying quality=3


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-3-e49be189.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020-anchor-3-e49be189.pth.tar
100%|██████████| 49.1M/49.1M [00:12<00:00, 4.07MB/s]
Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-4-98b0b468.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020-anchor-4-98b0b468.pth.tar


Trying quality=4


100%|██████████| 109M/109M [00:22<00:00, 5.04MB/s] 


Trying quality=5


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-5-23852949.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020-anchor-5-23852949.pth.tar
100%|██████████| 109M/109M [00:20<00:00, 5.61MB/s] 


Trying quality=6


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020-anchor-6-4c052b1a.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020-anchor-6-4c052b1a.pth.tar
100%|██████████| 109M/109M [00:17<00:00, 6.60MB/s] 


Trying quality=7
Max quality for cheng2020-anchor: 6
Testing cheng2020-attn
Trying quality=1
Trying quality=2


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-2-e0805385.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020_attn-mse-2-e0805385.pth.tar
100%|██████████| 54.3M/54.3M [00:14<00:00, 3.81MB/s]


Trying quality=3


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-3-2d07bbdf.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020_attn-mse-3-2d07bbdf.pth.tar
100%|██████████| 54.3M/54.3M [00:10<00:00, 5.21MB/s]


Trying quality=4


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-4-f7b0ccf2.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020_attn-mse-4-f7b0ccf2.pth.tar
100%|██████████| 121M/121M [00:19<00:00, 6.65MB/s] 


Trying quality=5


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-5-26c8920e.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020_attn-mse-5-26c8920e.pth.tar
100%|██████████| 121M/121M [00:17<00:00, 7.41MB/s] 


Trying quality=6


Downloading: "https://compressai.s3.amazonaws.com/models/v1/cheng2020_attn-mse-6-730501f2.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\cheng2020_attn-mse-6-730501f2.pth.tar
100%|██████████| 121M/121M [00:17<00:00, 7.25MB/s] 


Trying quality=7
Max quality for cheng2020-attn: 6
Testing bmshj2018-hyperprior-vbr
Trying quality=1
Trying quality=2
Trying quality=3
Trying quality=4
Trying quality=5
Trying quality=6
Trying quality=7
Trying quality=8
Trying quality=9
Trying quality=10
Trying quality=11
Trying quality=12
Trying quality=13
Trying quality=14
Trying quality=15
Trying quality=16
Trying quality=17
Trying quality=18
Trying quality=19
Trying quality=20
Testing mbt2018-mean-vbr
Trying quality=1
Trying quality=2
Trying quality=3
Trying quality=4
Trying quality=5
Trying quality=6
Trying quality=7
Trying quality=8
Trying quality=9
Trying quality=10
Trying quality=11
Trying quality=12
Trying quality=13
Trying quality=14
Trying quality=15
Trying quality=16
Trying quality=17
Trying quality=18
Trying quality=19
Trying quality=20
Testing mbt2018-vbr
Trying quality=1
Trying quality=2
Trying quality=3
Trying quality=4
Trying quality=5
Trying quality=6
Trying quality=7
Trying quality=8
Trying quality=9
Trying quality=1

Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-2-79ed4e19.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-2-79ed4e19.pth.tar
100%|██████████| 133M/133M [00:20<00:00, 6.78MB/s] 


Trying quality=3


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-3-9c8b998d.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-3-9c8b998d.pth.tar
100%|██████████| 133M/133M [00:25<00:00, 5.58MB/s] 


Trying quality=4


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-4-577c1eda.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-4-577c1eda.pth.tar
100%|██████████| 133M/133M [00:16<00:00, 8.72MB/s] 


Trying quality=5


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-5-1dd7d574.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-5-1dd7d574.pth.tar
100%|██████████| 133M/133M [00:24<00:00, 5.79MB/s] 


Trying quality=6


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-6-59dfb6f9.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-6-59dfb6f9.pth.tar
100%|██████████| 133M/133M [00:14<00:00, 9.42MB/s] 


Trying quality=7


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-7-4d867411.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-7-4d867411.pth.tar
100%|██████████| 134M/134M [00:19<00:00, 7.28MB/s] 


Trying quality=8


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-8-26439e20.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-8-26439e20.pth.tar
100%|██████████| 134M/134M [00:17<00:00, 7.83MB/s] 


Trying quality=9


Downloading: "https://compressai.s3.amazonaws.com/models/v1/ssf2020-mse-9-e89345c4.pth.tar" to C:\Users\Administrator/.cache\torch\hub\checkpoints\ssf2020-mse-9-e89345c4.pth.tar
100%|██████████| 134M/134M [00:18<00:00, 7.54MB/s] 


Trying quality=10
Max quality for ssf2020: 9


In [5]:
print(max_values_of_quality)
model_names = [x for x in models]
accepted_model =[]
for i in range(14):
    # print(max_values_of_quality[i],model_names[i])
    if type(max_values_of_quality[i]) == int:
        accepted_model.append([model_names[i],max_values_of_quality[i]])
print(accepted_model)

[8, 8, 8, 8, 8, 6, 6, None, None, None, 'N/A', 'N/A', 'N/A', 9]
[['bmshj2018-factorized', 8], ['bmshj2018-factorized-relu', 8], ['bmshj2018-hyperprior', 8], ['mbt2018-mean', 8], ['mbt2018', 8], ['cheng2020-anchor', 6], ['cheng2020-attn', 6], ['ssf2020', 9]]


In [9]:
vif = VisualInformationFidelity()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # setting the computation on GPU or CPU
vif.to(device)

# x and y are torch images, maximum_value is 1 or 255
# bring the image into GPU or CPU
# x = x.to(device)
# y = y.to(device)

def PSNR(x,y,max_val):
    psnr = PSNRLoss(max_val=max_val)
    return -psnr(x,y).item()
def SSIM(x,y):
    ssim_loss = SSIMLoss(window_size=11, reduction='mean')
    ssim_value = 1-ssim_loss(x,y).item()
    return ssim_value

def BPP(image_tensor,model, num_pixels):
    output = model(image_tensor)
    # Calculate BPP
    bpp = sum(
        torch.log(likelihoods).sum() / (-torch.log(2) * num_pixels)
        for likelihoods in output["likelihoods"].values()
    )
    return bpp.item()

def VIF(x,y):
    return vif(x,y).item()

def mse(x,y):
    return F.mse_loss(x,y).item()
def mae(x,y):
    return F.l1_loss(x,y).item()
def calculate_metrics(x,y,max_val,model,num_pixels):
    num_pixels = x.shape[2]*x.shape[3]
    output = model(x)
    bpp = sum(
        torch.log(likelihoods).sum() / (-torch.log(torch.tensor(2)) * num_pixels)
        for likelihoods in output["likelihoods"].values()
    )
    bpp = bpp.item()
    return {"psnr":PSNR(x,y,max_val),
            "ssim":SSIM(x,y),
            "vif":VIF(x,y),
            "mse":mse(x,y),
            "mae":mae(x,y),
            "bpp":bpp}
def compression_image_tensor(x,model,device,quality):
    x = x.to(device)
    with torch.no_grad():
        compressed = model(x)
        decompressed = compressed["x_hat"]
    return decompressed.detach()

# Dynamically retrieve the model class

model_names,quality_sets = [accepted_model[i][0] for i in range(len(accepted_model))],[range(1,accepted_model[i][1]+1) for i in range(len(accepted_model))]
## Load image
kodak_path = [f"./experiments/image_compression/kodak/kodim0{i}.png" for i in range(1,10)]+[f"./experiments/image_compression/kodak/kodim{i}.png" for i in range(10,25)]
kodak_PIL = [PIL.Image.open(i).convert('RGB') for i in kodak_path]
preprocess = transforms.Compose([transforms.ToTensor()])
kodak_Tensor = [preprocess(img).unsqueeze(0) for img in kodak_PIL]
kodak_normalized = [img_as_float(io.imread(img_path)) for img_path in kodak_path]

In [10]:
print(device)

cpu


In [None]:
import torch
import numpy as np
import PIL
from urllib.error import HTTPError

# Define the maximum possible pixel value of the image
MAX_I = 1.0

# Clear GPU memory
torch.cuda.empty_cache()

# Set compression-decompression quality for AI image compression.
result = []
max_val = 1

for model_name,quality in [("cheng2020-anchor",1),("cheng2020-attn",1)]:
    print(f"Quality: {quality}, model = {model_name}")
    # Clear GPU memory
    torch.cuda.empty_cache()

    # Instantiate the model and set it to evaluation mode
    model_class = models[model_name]
    model = model_class(quality=quality, pretrained=True).to(device)
    for param in model.parameters():
        param.requires_grad = False
    for i in range(24):
        x = kodak_Tensor[i].clone().to(device)
        print(f"kodak image {i} \n")
        # Compress and decompress the image
        decompressed = model(x)["x_hat"]
        decompressed = decompressed.clamp(0, 1)
        # Save the decompressed image
        decompressed = decompressed.detach()
        y = decompressed.to(device)
        num_pixels = x.numel()
        # Calculate metrics
        metrics = calculate_metrics(x, y, max_val, model, num_pixels)

        # Convert the tensor to a NumPy array and scale to [0, 255]
        # original_np = x.detach().squeeze(0).cpu().permute(1, 2, 0).numpy()
        # original_np = (original_np * 255).astype(np.uint8)
        # original = PIL.Image.fromarray(original_np)
        # original.save(f"results/kodak_{i}_original_{model_class.__name__}_{quality}.png")

        # Convert the decompressed tensor to a NumPy array and scale to [0, 255]
        decompressed_np = decompressed.detach().squeeze(0).cpu().permute(1, 2, 0).numpy()
        decompressed_np = (decompressed_np * 255).astype(np.uint8)
        decompressed_numpy = PIL.Image.fromarray(decompressed_np)
        decompressed_numpy.save(f"./experiments/image_compression/kodak_{i}_decompressed_{model_class.__name__}_{quality}.png")

        print(metrics)
        result.append(metrics)

        # Clean up
        del x, y, decompressed, metrics
        torch.cuda.empty_cache()
    del model
    torch.cuda.empty_cache()
    torch.cuda.empty_cache()

Quality: 1, model = cheng2020-anchor
kodak image 0 

{'psnr': 26.30901336669922, 'ssim': 0.8591998368501663, 'vif': 0.9961791634559631, 'mse': 0.0023393684532493353, 'mae': 0.03507440164685249, 'bpp': 0.2561964690685272}
kodak image 1 

{'psnr': 30.362871170043945, 'ssim': 0.8805487528443336, 'vif': 1.0018914937973022, 'mse': 0.0009198412299156189, 'mae': 0.02001667581498623, 'bpp': 0.13375042378902435}
kodak image 2 

{'psnr': 31.884979248046875, 'ssim': 0.9284780696034431, 'vif': 0.9986288547515869, 'mse': 0.0006478911964222789, 'mae': 0.01611577346920967, 'bpp': 0.1321416050195694}
kodak image 3 

{'psnr': 30.18390655517578, 'ssim': 0.8880719095468521, 'vif': 1.0166760683059692, 'mse': 0.0009585379157215357, 'mae': 0.021163173019886017, 'bpp': 0.14832915365695953}
kodak image 4 

{'psnr': 26.63875961303711, 'ssim': 0.8925634250044823, 'vif': 1.007175326347351, 'mse': 0.0021683229133486748, 'mae': 0.03394967317581177, 'bpp': 0.3173295259475708}
kodak image 5 

{'psnr': 27.73731422424

In [13]:
import pandas as pd
import re
import math

def parse_metrics(text):
    # Create a list to store the data
    data = []

    # Extract model and quality from the first two lines
    lines = text.strip().split('\n')
    model_match = re.search(r'Model: ([\w-]+)', lines[0])
    quality_match = re.search(r'Quality: (\d+)', lines[1])

    if not model_match or not quality_match:
        raise ValueError("Model or Quality information not found in the text")

    model = model_match.group(1)
    quality = int(quality_match.group(1))

    # Parse each kodak image and its metrics
    pattern = r'kodak image (\d+)\s+\n\s*{\'psnr\': ([\d.]+), \'ssim\': ([\d.]+), \'vif\': ([\d.]+), \'mse\': ([\d.]+), \'mae\': ([\d.]+), \'bpp\': ([\d.]+)}'

    matches = re.finditer(pattern, text)

    for match in matches:
        kodak_id = int(match.group(1))
        psnr = round(float(match.group(2)), 2)
        ssim = round(float(match.group(3)), 4)
        vif = round(float(match.group(4)), 4)
        mse = float(match.group(5))
        mae = float(match.group(6))
        bpp = round(float(match.group(7)), 4)

        # Format MSE and MAE as scientific notation
        def format_scientific(value):
            power = math.floor(math.log10(value)) if value > 0 else 0
            mantissa = value / (10 ** power)
            mantissa_rounded = round(mantissa, 2)
            return f"{mantissa_rounded} × 10^{power}"

        mse_formatted = format_scientific(mse)
        mae_formatted = format_scientific(mae)

        # Append the data
        data.append({
            'model': model,
            'quality': quality,
            'kodak': kodak_id,
            'psnr': psnr,
            'ssim': ssim,
            'vif': vif,
            'bpp': bpp,
            'mse': mse_formatted,
            'mae': mae_formatted
        })

    # Create DataFrame
    df = pd.DataFrame(data)

    return df

In [25]:
text = """
Model: cheng2020-anchor
Quality: 1
kodak image 0

{'psnr': 26.30901336669922, 'ssim': 0.8591998368501663, 'vif': 0.9961791634559631, 'mse': 0.0023393684532493353, 'mae': 0.03507440164685249, 'bpp': 0.2561964690685272}
kodak image 1

{'psnr': 30.362871170043945, 'ssim': 0.8805487528443336, 'vif': 1.0018914937973022, 'mse': 0.0009198412299156189, 'mae': 0.02001667581498623, 'bpp': 0.13375042378902435}
kodak image 2

{'psnr': 31.884979248046875, 'ssim': 0.9284780696034431, 'vif': 0.9986288547515869, 'mse': 0.0006478911964222789, 'mae': 0.01611577346920967, 'bpp': 0.1321416050195694}
kodak image 3

{'psnr': 30.18390655517578, 'ssim': 0.8880719095468521, 'vif': 1.0166760683059692, 'mse': 0.0009585379157215357, 'mae': 0.021163173019886017, 'bpp': 0.14832915365695953}
kodak image 4

{'psnr': 26.63875961303711, 'ssim': 0.8925634250044823, 'vif': 1.007175326347351, 'mse': 0.0021683229133486748, 'mae': 0.03394967317581177, 'bpp': 0.3173295259475708}
kodak image 5

{'psnr': 27.737314224243164, 'ssim': 0.879454955458641, 'vif': 1.0023471117019653, 'mse': 0.0016837151488289237, 'mae': 0.028307795524597168, 'bpp': 0.2174561768770218}
kodak image 6

{'psnr': 31.026004791259766, 'ssim': 0.9475313574075699, 'vif': 1.0152932405471802, 'mse': 0.0007895861635915935, 'mae': 0.017938323318958282, 'bpp': 0.17956428229808807}
kodak image 7

{'psnr': 26.44485855102539, 'ssim': 0.9063182696700096, 'vif': 1.0055928230285645, 'mse': 0.0022673262283205986, 'mae': 0.03422631695866585, 'bpp': 0.34532618522644043}
kodak image 8

{'psnr': 31.581838607788086, 'ssim': 0.933275505900383, 'vif': 1.0107966661453247, 'mse': 0.0006947302026674151, 'mae': 0.016897819936275482, 'bpp': 0.14796313643455505}
kodak image 9

{'psnr': 31.319839477539062, 'ssim': 0.9224376007914543, 'vif': 1.01041841506958, 'mse': 0.0007379313465207815, 'mae': 0.018149826675653458, 'bpp': 0.15688613057136536}
kodak image 10

{'psnr': 28.610923767089844, 'ssim': 0.8770408108830452, 'vif': 1.007388710975647, 'mse': 0.0013769167708232999, 'mae': 0.02506311796605587, 'bpp': 0.19066569209098816}
kodak image 11

{'psnr': 31.550113677978516, 'ssim': 0.9049861058592796, 'vif': 1.0070335865020752, 'mse': 0.0006998236058279872, 'mae': 0.017694152891635895, 'bpp': 0.13169920444488525}
kodak image 12

{'psnr': 24.360429763793945, 'ssim': 0.821920320391655, 'vif': 1.0108064413070679, 'mse': 0.0036640134640038013, 'mae': 0.04489139840006828, 'bpp': 0.35477039217948914}
kodak image 13

{'psnr': 27.386024475097656, 'ssim': 0.8582043051719666, 'vif': 1.003862738609314, 'mse': 0.00182556570507586, 'mae': 0.030680980533361435, 'bpp': 0.23075437545776367}
kodak image 14

{'psnr': 30.45911407470703, 'ssim': 0.9064039438962936, 'vif': 0.9984448552131653, 'mse': 0.0008996811229735613, 'mae': 0.01888301782310009, 'bpp': 0.14671149849891663}
kodak image 15

{'psnr': 29.776084899902344, 'ssim': 0.8861103355884552, 'vif': 1.0073717832565308, 'mse': 0.0010529104620218277, 'mae': 0.022234149277210236, 'bpp': 0.15013062953948975}
kodak image 16

{'psnr': 30.4141845703125, 'ssim': 0.9140203297138214, 'vif': 1.014340877532959, 'mse': 0.0009090368403121829, 'mae': 0.020392248407006264, 'bpp': 0.16392335295677185}
kodak image 17

{'psnr': 26.849254608154297, 'ssim': 0.8612052649259567, 'vif': 1.0206023454666138, 'mse': 0.0020657337736338377, 'mae': 0.031550582498311996, 'bpp': 0.24136197566986084}
kodak image 18

{'psnr': 29.282119750976562, 'ssim': 0.8884568884968758, 'vif': 0.9966604113578796, 'mse': 0.0011797449551522732, 'mae': 0.023714151233434677, 'bpp': 0.17560961842536926}
kodak image 19

{'psnr': 31.153356552124023, 'ssim': 0.9222647100687027, 'vif': 1.003928303718567, 'mse': 0.0007667684112675488, 'mae': 0.015547649934887886, 'bpp': 0.14483647048473358}
kodak image 20

{'psnr': 28.26136016845703, 'ssim': 0.9095698595046997, 'vif': 0.9989607930183411, 'mse': 0.0014923264971002936, 'mae': 0.024758940562605858, 'bpp': 0.20910349488258362}
kodak image 21

{'psnr': 28.45888328552246, 'ssim': 0.8580940067768097, 'vif': 1.0223065614700317, 'mse': 0.0014259741874411702, 'mae': 0.026282411068677902, 'bpp': 0.1726289540529251}
kodak image 22

{'psnr': 32.677894592285156, 'ssim': 0.9441062957048416, 'vif': 1.0001300573349, 'mse': 0.0005397722707130015, 'mae': 0.015248828567564487, 'bpp': 0.13730290532112122}
kodak image 23

{'psnr': 26.907899856567383, 'ssim': 0.8817514702677727, 'vif': 0.9971997141838074, 'mse': 0.002038027159869671, 'mae': 0.029491880908608437, 'bpp': 0.25704845786094666}
"""

In [26]:
df_cheng_anchor_1 = parse_metrics(text)

In [27]:
df_cheng_anchor_1

Unnamed: 0,model,quality,kodak,psnr,ssim,vif,bpp,mse,mae
0,cheng2020-anchor,1,0,26.31,0.8592,0.9962,0.2562,2.34 × 10^-3,3.51 × 10^-2
1,cheng2020-anchor,1,1,30.36,0.8805,1.0019,0.1338,9.2 × 10^-4,2.0 × 10^-2
2,cheng2020-anchor,1,2,31.88,0.9285,0.9986,0.1321,6.48 × 10^-4,1.61 × 10^-2
3,cheng2020-anchor,1,3,30.18,0.8881,1.0167,0.1483,9.59 × 10^-4,2.12 × 10^-2
4,cheng2020-anchor,1,4,26.64,0.8926,1.0072,0.3173,2.17 × 10^-3,3.39 × 10^-2
5,cheng2020-anchor,1,5,27.74,0.8795,1.0023,0.2175,1.68 × 10^-3,2.83 × 10^-2
6,cheng2020-anchor,1,6,31.03,0.9475,1.0153,0.1796,7.9 × 10^-4,1.79 × 10^-2
7,cheng2020-anchor,1,7,26.44,0.9063,1.0056,0.3453,2.27 × 10^-3,3.42 × 10^-2
8,cheng2020-anchor,1,8,31.58,0.9333,1.0108,0.148,6.95 × 10^-4,1.69 × 10^-2
9,cheng2020-anchor,1,9,31.32,0.9224,1.0104,0.1569,7.38 × 10^-4,1.81 × 10^-2


In [28]:
text = """
Model: cheng2020-attn
Quality: 1
kodak image 0

{'psnr': 26.218666076660156, 'ssim': 0.858748123049736, 'vif': 0.9778406620025635, 'mse': 0.002388544613495469, 'mae': 0.035488061606884, 'bpp': 0.24758204817771912}
kodak image 1

{'psnr': 30.190128326416016, 'ssim': 0.8785943686962128, 'vif': 0.9948579668998718, 'mse': 0.0009571655537001789, 'mae': 0.02022584341466427, 'bpp': 0.12960465252399445}
kodak image 2

{'psnr': 31.828983306884766, 'ssim': 0.9286216646432877, 'vif': 0.9971141815185547, 'mse': 0.0006562990020029247, 'mae': 0.01617887243628502, 'bpp': 0.13059023022651672}
kodak image 3

{'psnr': 30.137081146240234, 'ssim': 0.888008750975132, 'vif': 1.0022459030151367, 'mse': 0.0009689287398941815, 'mae': 0.021233996376395226, 'bpp': 0.1440519094467163}
kodak image 4

{'psnr': 26.52194595336914, 'ssim': 0.8912339210510254, 'vif': 0.994297981262207, 'mse': 0.002227437449619174, 'mae': 0.03440535441040993, 'bpp': 0.3184228837490082}
kodak image 5

{'psnr': 27.60310173034668, 'ssim': 0.877206839621067, 'vif': 0.992169201374054, 'mse': 0.001736560370773077, 'mae': 0.028715282678604126, 'bpp': 0.2112560123205185}
kodak image 6

{'psnr': 30.96571922302246, 'ssim': 0.9474280513823032, 'vif': 0.988914966583252, 'mse': 0.0008006229181773961, 'mae': 0.017955221235752106, 'bpp': 0.17768310010433197}
kodak image 7

{'psnr': 26.23619270324707, 'ssim': 0.9042947366833687, 'vif': 1.0014811754226685, 'mse': 0.002378924284130335, 'mae': 0.03497788682579994, 'bpp': 0.33221006393432617}
kodak image 8

{'psnr': 31.50636863708496, 'ssim': 0.9333807602524757, 'vif': 0.9967880845069885, 'mse': 0.0007069083512760699, 'mae': 0.017035720869898796, 'bpp': 0.1478043794631958}
kodak image 9

{'psnr': 31.24069595336914, 'ssim': 0.9221312403678894, 'vif': 1.0110564231872559, 'mse': 0.0007515022880397737, 'mae': 0.01825704053044319, 'bpp': 0.1548675149679184}
kodak image 10

{'psnr': 28.466625213623047, 'ssim': 0.8748100996017456, 'vif': 0.9912469983100891, 'mse': 0.0014234344707801938, 'mae': 0.025419948622584343, 'bpp': 0.1872386932373047}
kodak image 11

{'psnr': 31.416900634765625, 'ssim': 0.903707429766655, 'vif': 0.9968979358673096, 'mse': 0.0007216225494630635, 'mae': 0.01799546740949154, 'bpp': 0.12803281843662262}
kodak image 12

{'psnr': 24.239980697631836, 'ssim': 0.820887878537178, 'vif': 1.004707932472229, 'mse': 0.003767054993659258, 'mae': 0.04529482498764992, 'bpp': 0.35027793049812317}
kodak image 13

{'psnr': 27.260072708129883, 'ssim': 0.855050340294838, 'vif': 0.9989555478096008, 'mse': 0.0018792860209941864, 'mae': 0.030994202941656113, 'bpp': 0.22865602374076843}
kodak image 14

{'psnr': 30.32689094543457, 'ssim': 0.9051331207156181, 'vif': 0.9930492043495178, 'mse': 0.0009274935000576079, 'mae': 0.019221246242523193, 'bpp': 0.139611154794693}
kodak image 15

{'psnr': 29.72078514099121, 'ssim': 0.8843558356165886, 'vif': 0.9956292510032654, 'mse': 0.0010664034634828568, 'mae': 0.022319370880723, 'bpp': 0.14951805770397186}
kodak image 16

{'psnr': 30.293073654174805, 'ssim': 0.91244987398386, 'vif': 1.003495454788208, 'mse': 0.0009347437298856676, 'mae': 0.020696334540843964, 'bpp': 0.16402016580104828}
kodak image 17

{'psnr': 26.737628936767578, 'ssim': 0.860674574971199, 'vif': 1.0121077299118042, 'mse': 0.0021195183508098125, 'mae': 0.031804170459508896, 'bpp': 0.23870344460010529}
kodak image 18

{'psnr': 29.172679901123047, 'ssim': 0.8852472603321075, 'vif': 1.0086029767990112, 'mse': 0.001209851005114615, 'mae': 0.0240012276917696, 'bpp': 0.16950908303260803}
kodak image 19

{'psnr': 31.004535675048828, 'ssim': 0.9207182452082634, 'vif': 0.9865676760673523, 'mse': 0.0007934988825581968, 'mae': 0.01609666645526886, 'bpp': 0.14015258848667145}
kodak image 20

{'psnr': 28.210941314697266, 'ssim': 0.9091996476054192, 'vif': 0.9969952702522278, 'mse': 0.0015097534051164985, 'mae': 0.024869389832019806, 'bpp': 0.20932923257350922}
kodak image 21

{'psnr': 28.357973098754883, 'ssim': 0.8558917045593262, 'vif': 1.005478858947754, 'mse': 0.0014594956301152706, 'mae': 0.026614878326654434, 'bpp': 0.17000263929367065}
kodak image 22

{'psnr': 32.436866760253906, 'ssim': 0.9423518255352974, 'vif': 0.99806147813797, 'mse': 0.0005705758230760694, 'mae': 0.01562279462814331, 'bpp': 0.1364564597606659}
kodak image 23

{'psnr': 26.795074462890625, 'ssim': 0.8814313411712646, 'vif': 0.989131510257721, 'mse': 0.0020916671492159367, 'mae': 0.02968626283109188, 'bpp': 0.25088948011398315}
"""

In [29]:
df_cheng_attn_1 = parse_metrics(text)

In [30]:
df_cheng_attn_1

Unnamed: 0,model,quality,kodak,psnr,ssim,vif,bpp,mse,mae
0,cheng2020-attn,1,0,26.22,0.8587,0.9778,0.2476,2.39 × 10^-3,3.55 × 10^-2
1,cheng2020-attn,1,1,30.19,0.8786,0.9949,0.1296,9.57 × 10^-4,2.02 × 10^-2
2,cheng2020-attn,1,2,31.83,0.9286,0.9971,0.1306,6.56 × 10^-4,1.62 × 10^-2
3,cheng2020-attn,1,3,30.14,0.888,1.0022,0.1441,9.69 × 10^-4,2.12 × 10^-2
4,cheng2020-attn,1,4,26.52,0.8912,0.9943,0.3184,2.23 × 10^-3,3.44 × 10^-2
5,cheng2020-attn,1,5,27.6,0.8772,0.9922,0.2113,1.74 × 10^-3,2.87 × 10^-2
6,cheng2020-attn,1,6,30.97,0.9474,0.9889,0.1777,8.01 × 10^-4,1.8 × 10^-2
7,cheng2020-attn,1,7,26.24,0.9043,1.0015,0.3322,2.38 × 10^-3,3.5 × 10^-2
8,cheng2020-attn,1,8,31.51,0.9334,0.9968,0.1478,7.07 × 10^-4,1.7 × 10^-2
9,cheng2020-attn,1,9,31.24,0.9221,1.0111,0.1549,7.52 × 10^-4,1.83 × 10^-2


In [31]:
df_cheng_anchor_1

Unnamed: 0,model,quality,kodak,psnr,ssim,vif,bpp,mse,mae
0,cheng2020-anchor,1,0,26.31,0.8592,0.9962,0.2562,2.34 × 10^-3,3.51 × 10^-2
1,cheng2020-anchor,1,1,30.36,0.8805,1.0019,0.1338,9.2 × 10^-4,2.0 × 10^-2
2,cheng2020-anchor,1,2,31.88,0.9285,0.9986,0.1321,6.48 × 10^-4,1.61 × 10^-2
3,cheng2020-anchor,1,3,30.18,0.8881,1.0167,0.1483,9.59 × 10^-4,2.12 × 10^-2
4,cheng2020-anchor,1,4,26.64,0.8926,1.0072,0.3173,2.17 × 10^-3,3.39 × 10^-2
5,cheng2020-anchor,1,5,27.74,0.8795,1.0023,0.2175,1.68 × 10^-3,2.83 × 10^-2
6,cheng2020-anchor,1,6,31.03,0.9475,1.0153,0.1796,7.9 × 10^-4,1.79 × 10^-2
7,cheng2020-anchor,1,7,26.44,0.9063,1.0056,0.3453,2.27 × 10^-3,3.42 × 10^-2
8,cheng2020-anchor,1,8,31.58,0.9333,1.0108,0.148,6.95 × 10^-4,1.69 × 10^-2
9,cheng2020-anchor,1,9,31.32,0.9224,1.0104,0.1569,7.38 × 10^-4,1.81 × 10^-2


# Experiment 2. Classification of Decompressed Images on ISIC dataset

## 1. Testing Pre-trained Model

In [4]:
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
import numpy as np
import os

In [1]:
label_map = {
    0: 'pigmented benign keratosis',
    1: 'melanoma',
    2: 'vascular lesion',
    3: 'actinic keratosis',
    4: 'squamous cell carcinoma',
    5: 'basal cell carcinoma',
    6: 'seborrheic keratosis',
    7: 'dermatofibroma',
    8: 'nevus'
    }

In [3]:
clf_model_path = "skin_disease_model.h5"
clf_model = load_model(clf_model_path)

2025-03-07 12:58:55.094112: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [8]:
def get_pred_label(image_path, model, label_map):
    img = image.load_img(image_path, target_size=(75, 100))  # Correct order: height=75px, width=100px
    img_array = image.img_to_array(img)  # Convert image to array
    img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension

    # Normalize the image as per the training data preprocessing
    img_array = (img_array - np.mean(img_array)) / np.std(img_array)

    # Predict the class
    predictions = model.predict(img_array)
    predicted_class = np.argmax(predictions, axis=1)

    # Map the predicted class index to the class label
    predicted_label = label_map[predicted_class[0]]

    return predicted_label

In [10]:
test_len = 0

for label in label_map.values():
    image_dir = os.path.join('./image_compression/ISIC-skin-cancer/Test', label)
    for image_file in os.listdir(image_dir):
        if (image_file.endswith(".jpg")):
            test_len += 1

print("Test dataset len: ", test_len)

Test dataset len:  118


In [None]:
acc = 0.

for label in label_map.values():
    image_dir = os.path.join('./image_compression/ISIC-skin-cancer/Test', label)
    for image_file in os.listdir(image_dir):
        if (image_file.endswith(".jpg")):
            image_path = os.path.join(image_dir, image_file)
            pred_label = get_pred_label(image_path, clf_model, label_map)
            print("Ground truth: ", label)
            print("Predicted: ", pred_label)
            acc += pred_label == label

acc /= test_len
print("Accuracy: ", acc)


Ground truth:  pigmented benign keratosis
Predicted:  squamous cell carcinoma
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  basal cell carcinoma
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  actinic keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Pred

## 1.2. Compressing and decompressing testing images

In [16]:
import torch
from compressai.zoo import models
from torchvision import transforms
import torchvision.transforms.functional as F
import PIL
import numpy as np
from skimage import io , img_as_float
import pathlib

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # setting the computation on GPU or CPU
print("Device: ", device)

Device:  cpu


In [14]:
def pad_to_multiple(img, mul=64):
    w, h = img.size
    pad_width = (mul - w % mul) % mul
    pad_height = (mul - h % mul) % mul
    
    return F.pad(img, (0, 0, pad_width, pad_height), fill=0)

In [17]:
for label in label_map.values():
    isic_path = []
    image_dir = os.path.join('./image_compression/ISIC-skin-cancer/Test', label)
    for image_file in os.listdir(image_dir):
        if (image_file.endswith(".jpg")):
            image_path = os.path.join(image_dir, image_file)
            isic_path.append(image_path)

    isic_PIL = [pad_to_multiple(PIL.Image.open(i).convert('RGB')) for i in isic_path]
    # isic_PIL = [PIL.Image.open(i).convert('RGB') for i in isic_path]
    preprocess = transforms.Compose([transforms.ToTensor()])
    isic_Tensor = [preprocess(img).unsqueeze(0) for img in isic_PIL]
    isic_normalized = [img_as_float(io.imread(img_path)) for img_path in isic_path]

    save_dir = os.path.join(image_dir, "decompressed")
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    for model_name, quality in [("cheng2020-anchor", 1), ("cheng2020-attn", 1)]:
        torch.cuda.empty_cache()
        model_class = models[model_name]
        model = model_class(quality=quality, pretrained=True).to(device)
        for param in model.parameters():
            param.requires_grad = False
        for i in range(len(isic_Tensor)):
            x = isic_Tensor[i].clone().to(device)
            # Compress and decompress the image
            decompressed = model(x)["x_hat"].to(device)
            decompressed = decompressed.clamp(0, 1)
            # Save the decompressed image
            decompressed = decompressed.detach()
            y = decompressed.to(device)
            decompressed_np = decompressed.detach().squeeze(0).cpu().permute(1, 2, 0).numpy()
            decompressed_np = (decompressed_np * 255).astype(np.uint8)
            decompressed_numpy = PIL.Image.fromarray(decompressed_np)

            image_name = pathlib.PurePath(isic_path[i]).stem
            model_save_dir = os.path.join(save_dir, f"{model_class.__name__}_{quality}")
            if not os.path.exists(model_save_dir):
                os.makedirs(model_save_dir)

            decompressed_numpy.save(f"{model_save_dir}/{image_name}_decompressed.png")


## 1.3. Classification on decompressed images

### 1.3.1. Cheng-anchor Quality 1

In [18]:
acc = 0.
model_name = "cheng2020_anchor"
quality = "1"

for label in label_map.values():
    image_dir = os.path.join('./image_compression/ISIC-skin-cancer/Test', label, "decompressed", f"{model_name}_{quality}")
    for image_file in os.listdir(image_dir):
        if (image_file.endswith(".png")):
            image_path = os.path.join(image_dir, image_file)
            pred_label = get_pred_label(image_path, clf_model, label_map)
            print("Ground truth: ", label)
            print("Predicted: ", pred_label)
            acc += pred_label == label

acc /= test_len
print("Accuracy: ", acc)


Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  vascular lesion
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  nevus
Ground truth:  pigmented benign keratosis
Predicted:  nevus
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis


### 1.3.2. Cheng-attn Quality 1

In [19]:
acc = 0.
model_name = "cheng2020_attn"
quality = "1"

for label in label_map.values():
    image_dir = os.path.join('./image_compression/ISIC-skin-cancer/Test', label, "decompressed", f"{model_name}_{quality}")
    for image_file in os.listdir(image_dir):
        if (image_file.endswith(".png")):
            image_path = os.path.join(image_dir, image_file)
            pred_label = get_pred_label(image_path, clf_model, label_map)
            print("Ground truth: ", label)
            print("Predicted: ", pred_label)
            acc += pred_label == label

acc /= test_len
print("Accuracy: ", acc)


Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign keratosis
Ground truth:  pigmented benign keratosis
Predicted:  nevus
Ground truth:  pigmented benign keratosis
Predicted:  nevus
Ground truth:  pigmented benign keratosis
Predicted:  pigmented benign