In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!git clone https://github.com/SerdarHelli/latent-diffusion_teeth_sub_repo.git
!git clone https://github.com/CompVis/taming-transformers
!pip install -e /content/taming-transformers
!pip install  torchmetrics==0.6.0
!pip install omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops
!git clone https://github.com/openai/CLIP
!pip install -e /content/CLIP
!pip install  kornia==0.5.0

In [None]:

import sys
sys.path.append("/content/latent-diffusion")
sys.path.append('/content/taming-transformers')
sys.path.append("/content/CLIP")

In [None]:
!nvidia-smi


In [None]:
import os
DATA_PATH="/content/latent-diffusion/data/Tufts_Raw_Test"
if not os.path.isdir(DATA_PATH):
  os.makedirs(DATA_PATH)

!cp /content/drive/MyDrive/Tufs_Raw_Test/* /content/latent-diffusion/data/Tufts_Raw_Test

In [None]:
%cd /content/latent-diffusion

In [None]:
import torch
import numpy as np

from omegaconf import OmegaConf
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from einops import rearrange

from ldm.data.teethseg import SegmentationBase

class TeethSegTest(SegmentationBase):
    def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
        super().__init__(file_path="/content/latent-diffusion/data/Tufts_Raw_Test",img_dim=size,data_flip=True,with_abnormality=True,apply_flip=False)
 



config_path = '/content/latent-diffusion/models/ldm/teeth/config.yaml'
ckpt_path = '/content/drive/MyDrive/2022-11-01T19-41-14_config/checkpoints/epoch=000179.ckpt'

dataset = TeethSegTest(size=256)



In [None]:
import torch
from omegaconf import OmegaConf

from ldm.util import instantiate_from_config


def load_model_from_config(config, ckpt):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt)#, map_location="cpu")
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    model.cuda()
    model.eval()
    return model


def get_model():
    config = OmegaConf.load(config_path)  
    model = load_model_from_config(config, ckpt_path)
    return model

model = get_model()


In [None]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

x_samples_ddim=list()
x_reals_ddim=list()
x_condition_ddim=list()
for i, data in enumerate(dataloader):
  print(i)
  seg = data['segmentation']
  with torch.no_grad():
      seg = rearrange(seg, 'b h w c -> b c h w')
      condition = model.to_rgb(seg)
      seg = seg.to('cuda').float()

     
      seg = model.get_learned_conditioning(seg)
      samples, _ = model.sample_log(cond=seg, batch_size=4, ddim=True,ddim_steps=200, eta=1.)
      samples = model.decode_first_stage(samples)
      samples = torch.clamp((samples+1.0)/2.0, min=0.0, max=1.0)
     #samples = torch.clamp((samples+torch.abs(torch.min(samples)))/(torch.max(samples)), min=0.0, max=1.0)

      samples = rearrange(samples, ' b c h w -> b h w c')

      x_samples_ddim.append(samples.cpu().detach().numpy())
      x_reals_ddim.append(data["image"].cpu().detach().numpy())
      x_condition_ddim.append(data["segmentation"].cpu().detach().numpy())
  


In [None]:
import matplotlib.pyplot as plt

x_samples=np.asarray(x_samples_ddim)
x_samples=np.reshape(x_samples,(100,256,256,3))
x_reals=np.asarray(x_reals_ddim)
x_reals=np.reshape(x_reals,(100,256,256,3))

x_condition=np.asarray(x_condition_ddim)
x_condition=np.reshape(x_condition,(100,256,256,5))



In [None]:
import matplotlib.pyplot as plt

plt.imshow(np.uint8(x_samples[0]*255))

In [None]:

from math import log10, sqrt
import cv2
import numpy as np
  
def PSNR(original, compressed):
    mse = np.mean((original - compressed) ** 2)
    if(mse == 0):  # MSE is zero means no noise is present in the signal .
                  # Therefore PSNR have no importance.
        return 100
    max_pixel = 255.0
    psnr = 20 * log10(max_pixel / sqrt(mse))
    return psnr


def get_psnr(fakes,reals):
  loss_psnr=[]
  for i in range(len(reals)):
    real=reals[i,:,:,:]
    fake=fakes[i,:,:,:]
    loss_psnr.append(PSNR(real,fake))
  return np.asarray(loss_psnr)

In [None]:
import cv2
import pandas as pd


def convert_uint8(list_imgs):
    res_list=[]
    for i in range(len(list_imgs)):
      res=np.uint8(((list_imgs[i,:,:,:])*255))
      res_list.append(res)
    return np.asarray(res_list)


def save_ssim(save_path,loss_ssim,name):
      ssim=list(loss_ssim.numpy())
      df=pd.DataFrame(data={"ssim":ssim})
      path=os.path.join(save_path,"{}_results_ssim.csv".format(name))
      df.to_csv(path, index=False)
  
def save_psnr(save_path,loss_psnr,name):
    psnr=list(loss_psnr)
    df=pd.DataFrame(data={"psnr":psnr})
    path=os.path.join(save_path,"{}_results_psnr.csv".format(name))
    df.to_csv(path, index=False)

In [None]:
t_samples=convert_uint8(x_samples)
x_reals=(x_reals+1)/2
t_reals=convert_uint8(x_reals)



In [None]:
import tensorflow as tf
import os

#We use tensorflow beacuse it should be same metric with other gan models .. 

loss_psnr=get_psnr(t_samples, t_reals)
loss_ssim = tf.image.ssim(t_samples, t_reals ,max_val=255, filter_size=11,
                          filter_sigma=1.5, k1=0.01, k2=0.03)

save_ssim("/content/drive/MyDrive/LatentDiffusionResults/",loss_ssim,"latent_diffusion")
save_psnr("/content/drive/MyDrive/LatentDiffusionResults/",loss_psnr,"latent_diffusion")

print("Structre Similiratiy : " ,tf.reduce_mean(loss_ssim).numpy())
print("PSNR : " ,np.mean(loss_psnr))


In [None]:


def threshold(categorical_map):
    k=categorical_map[:,:,2]+categorical_map[:,:,3]+categorical_map[:,:,4]
    k=(k>0)*1
    return k
def make_threshold(list_imgs,list_segs):
    res_list=[]
    for i in range(len(list_imgs)):
      a=np.uint8((threshold(list_segs[i,:,:,:])*255))
      ret, mask = cv2.threshold(a, 0, 255, cv2.THRESH_BINARY)
      img=np.uint8((list_imgs[i,:,:,:]*255))
      res = cv2.bitwise_and(img,img,mask = mask)
      res_list.append(res)
    return np.asarray(res_list)


thresholded_samples=make_threshold(x_samples,x_condition)
thresholded_reals=make_threshold(x_reals,x_condition)


real_save_path="/content/drive/MyDrive/LatentDiffusionResults/real"
if not os.path.isdir(real_save_path):
  os.makedirs(real_save_path)

fake_save_path="/content/drive/MyDrive/LatentDiffusionResults/fake"
if not os.path.isdir(fake_save_path):
  os.makedirs(fake_save_path)


for i in range(len(t_reals)):
   real_img=thresholded_reals[i,:,:,:]
   fake_img=thresholded_samples[i,:,:,:]
   fake_img=cv2.resize(fake_img, (512, 256), interpolation= cv2.INTER_LANCZOS4)
   fake_img = cv2.fastNlMeansDenoisingColored(fake_img,None,3,3,7,5)
   real_img=cv2.resize(real_img, (512, 256), interpolation= cv2.INTER_LANCZOS4)

   cv2.imwrite(os.path.join(real_save_path,(str(i)+"_real.png")),real_img)
   cv2.imwrite(os.path.join(fake_save_path,(str(i)+"_{}fake.png".format("LDM"))),fake_img)


In [None]:
import matplotlib.pyplot as plt

plt.imshow(thresholded_samples[0])

In [None]:
plt.imshow(thresholded_reals[0])

In [None]:
from torchvision.utils import make_grid
from PIL import Image
from einops import rearrange

grid = rearrange(samples, ' b h w c -> b c h w')
grid = rearrange(grid, ' b c h w -> (b) c h w')
grid = make_grid(grid, nrow=3)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8))



In [None]:
grid2 = torch.clamp((condition+1.0)/2.0, min=0.0, max=1.0)
grid2 = rearrange(grid2, ' b c h w -> (b) c h w')
grid2 = make_grid(grid2, nrow=3)

# to image
grid2 = 255. * rearrange(grid2, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid2.astype(np.uint8))

In [None]:
from torchvision.utils import make_grid
from PIL import Image
from einops import rearrange

grid = rearrange(data["image"], ' b h w c -> b c h w')
grid = torch.clamp((grid+1.0)/2.0, min=0.0, max=1.0)
grid = rearrange(grid, ' b c h w -> (b) c h w')
grid = make_grid(grid, nrow=3)

# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8))


In [None]:
%load_ext tensorboard


In [None]:
%tensorboard --logdir /content/drive/MyDrive/VQGanTeeth/2022-10-29T10-34-42_vq_config/testtube/version_6/tf
