In [1]:
import torch, random
from scripts.sample_diffusion import *
import argparse, os, sys, glob, datetime, yaml
from omegaconf import OmegaConf
import torchvision
import PIL
from PIL import Image
from torchmetrics.functional import psnr, multiscale_structural_similarity_index_measure
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from ldm_testing import *
from jscc_baseline_testing import deep_jscc_testing
from torch.utils.data import DataLoader, random_split, Subset
from ldm_testing import compute_metrics
from einops import rearrange
from tqdm import tqdm

In [2]:
!nvidia-smi

Tue Jul 25 15:46:27 2023       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A40                      On | 00000000:23:00.0 Off |                    0 |
|  0%   26C    P8               27W / 300W|      0MiB / 46068MiB |      0%   E. Process |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                         

In [3]:
test_snr_dB = 10
test_snr_ratio = 10**(test_snr_dB/10)

In [4]:
config = OmegaConf.load("/home/ashri/latent-diffusion/configs/latent-diffusion/ldm_all_models_iterative.yaml")
config.model.params.channel_snr_dB = test_snr_dB

ldm_posterior_model = instantiate_from_config(config["model"])
ldm_posterior_model = ldm_posterior_model.to("cuda")
# ldm_posterior_model, _ = load_model(config, checkpoint_path,True, False)

LatentDiffusionPosteriorJSCC: Running in eps-prediction mode
DiffusionWrapper has 274.06 M params.
Keeping EMAs of 370.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels
Loading model from models/ldm/lsun_beds256/autoencoder.ckpt...
Total Params: 55298206
Total Trainable Params: 55298206
Training LatentDiffusionPosteriorJSCC as an unconditional model.


In [5]:
LOGDIR = "post_exp_metrics"
data_fraction = 0.1

In [6]:
root_dir = "/tmpssd/ashri/LSUN/"
val_txt_path = os.path.join(root_dir, "bedrooms_val.txt")
val_data_dir = os.path.join(root_dir, "bedrooms")
config.data.params.validation.params.txt_file = val_txt_path
config.data.params.validation.params.data_root = val_data_dir

In [7]:
main_dataset = instantiate_from_config(config.data.params.validation)
seed_generator = torch.Generator().manual_seed(42)
random_indices = random.sample(range(main_dataset.__len__()), int(data_fraction*main_dataset.__len__()))
dataset = Subset(main_dataset, random_indices)
dataloader = DataLoader(dataset, batch_size = 2, num_workers = 2)

In [8]:
print(dataset.__len__())

500


In [9]:
def normalize(x:torch.tensor):
    return (x - x.min())/(x.max() - x.min())

In [10]:
scale_function = lambda img: (2*img - 1.0)
ScaleShiftTransform = torchvision.transforms.Lambda(scale_function)
data_transform = ToTensor()

In [11]:
metric_dict = {
        "PSNR": 0.0,
        "MS_SSIM": 0.0,
        "FID": 0.0,
        "LPIPS_VGG": 0.0,
    }

In [None]:
total_examples = dataset.__len__()

for image_dict in tqdm(dataloader):
    
    images = rearrange(image_dict["image"].cuda(), 'b h w c -> b c h w')
    codewords = ldm_posterior_model.first_stage_model.encode(images)
    signal_power = torch.mean(codewords**2)
    noisy_codeword = codewords + torch.randn_like(codewords)*signal_power/test_snr_ratio
    sampled_codeword = ldm_posterior_model.posterior_sampling(test_snr_dB, noisy_codeword)
    
    reconstructed = torch.clamp(ldm_posterior_model.first_stage_model.decode(sampled_codeword), -1.0, 1.0)
    scores = compute_metrics(images, reconstructed)
    for idx, key in enumerate(metric_dict.keys()):
            metric_dict[key] += scores[idx]

for key in metric_dict.keys():
    metric_dict[key] = metric_dict[key]/total_examples

filename = os.path.join(LOGDIR, f"posterior_test_{test_snr_dB}_metrics.pkl")

with open(filename, "wb") as fp:
    pickle.dump(metric_dict, fp)

 83%|████████▎ | 208/250 [2:13:27<25:47, 36.85s/it]  

In [None]:
# codeword = ldm_posterior_model.first_stage_model.encode(tensor_image)
# signal_power = torch.mean(codeword**2)
# noisy_codeword = codeword + torch.randn_like(codeword)*signal_power/test_snr_ratio

In [None]:
# sampled_codeword = ldm_posterior_model.posterior_sampling(test_snr_dB, noisy_codeword, scale_grad = 2.0, verbose=True)

In [None]:
# recon = ldm_posterior_model.first_stage_model.decode(sampled_codeword)

In [None]:
# recon_image = (custom_to_pil(recon[0]))
# recon_image

In [None]:
# recon_image.save("recon_prior2.png")

In [None]:
# psnr(normalize(recon[0]), 0.5*(tensor_image[0]+1.0))