In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from utils import dataloader3D, criterion
import os
import cv2

In [2]:
def readImg3D(img_dir):
    img3D = []
    img_names = os.listdir(img_dir)
    img_names.sort(key=lambda x:int(x[:-4]))
    for img_name in img_names:
        img_path = os.path.join(img_dir, img_name)
        img = plt.imread(img_path)
        img = img if img.shape[-1] != 3 else img.mean(axis=2)
        img3D.append(img)
    return np.array(img3D)

def saveImg3D(save_dir, name, img3D):
    save_dir = os.path.join(save_dir, name)
    os.makedirs(save_dir, exist_ok=True)
    for i in range(img3D.shape[0]):
        img = img3D[i, :, :]
        img_save_path = os.path.join(save_dir, f"{i}.jpg")
        plt.imsave(img_save_path, img, cmap="gray")

In [20]:
img_dir = "./data/Carotid/high"
img = readImg3D(img_dir)
train_loader = dataloader3D.PixelLoader3D(img, N=4, batch_size=512, encoding="Sine", sine_L=6, shuffle=False)
img_high = train_loader.get_high_img()
img_low = train_loader.get_low_img()
saveImg3D("./", f"img_high", img_high)
print(img_high.shape)
print(img_low.shape)

(192, 432, 480)
(48, 108, 120)


# interpolation to Original size

In [21]:
def interpolate_3D(img3D_low, newsize0, newsize1, newsize2, mode):
    img3D_high = []
    for i in range(img3D_low.shape[0]):
        img = img3D_low[i, :, :]
        img = np.round(cv2.resize(img, (newsize2, newsize1), interpolation=mode))
        img3D_high.append(img)
    img3D_low = np.array(img3D_high)
    img3D_high = []
    for i in range(img3D_low.shape[1]):
        img = img3D_low[:, i, :]
        img = np.round(cv2.resize(img, (newsize2, newsize0), interpolation=mode))
        img3D_high.append(img)
    img3D_high = np.array(img3D_high).transpose((1,0,2))
    return img3D_high

In [22]:
img_nearest = interpolate_3D(img_low, img_high.shape[0], img_high.shape[1], img_high.shape[2], cv2.INTER_NEAREST)
saveImg3D("./", f"img3D_nearest", img_nearest)
mae_nearest = criterion.mae(img_high, img_nearest)
psnr_nearest = criterion.psnr(img_nearest, img_high)
ssim_nearest = criterion.ssim(img_nearest, img_high, multichannel=True)
print(f"nearest: mae:{mae_nearest}, psnr:{psnr_nearest}, ssim:{ssim_nearest}")

nearest: mae:9.771450767988041, psnr:22.736692085170006, ssim:0.6308612749342646


In [23]:
img_bilinear = interpolate_3D(img_low.astype(float), img_high.shape[0], img_high.shape[1], img_high.shape[2], cv2.INTER_LINEAR)
saveImg3D("./", f"img3D_bilinear", img_bilinear)
mae_bilinear = criterion.mae(img_bilinear, img_high)
psnr_bilinear = criterion.psnr(img_bilinear, img_high)
ssim_bilinear = criterion.ssim(img_bilinear, img_high, multichannel=True)
print(f"bilinear: mae:{mae_bilinear}, psnr:{psnr_bilinear}, ssim:{ssim_bilinear}")

bilinear: mae:8.600448847013245, psnr:24.446074249250586, ssim:0.7097223783727649


In [24]:
img_bicubic = interpolate_3D(img_low.astype(float), img_high.shape[0], img_high.shape[1], img_high.shape[2], cv2.INTER_CUBIC)
saveImg3D("./", f"img3D_bicubic", img_bicubic)
mae_bicubic = criterion.mae(img_bicubic, img_high)
psnr_bicubic = criterion.psnr(img_bicubic, img_high)
ssim_bicubic = criterion.ssim(img_bicubic, img_high, multichannel=True)
print(f"bicubic: mae:{mae_bicubic}, psnr:{psnr_bicubic}, ssim:{ssim_bicubic}")

bicubic: mae:9.191030193061986, psnr:23.932980457494764, ssim:0.6928764831166428


# fit to (192, 432, 480)

In [26]:
img_high.shape

(192, 432, 480)

In [28]:
epoch = 200
model = torch.load(f"./results/3D results/Carotid/2024-05-11-19/saved_models/model_at_epoch{epoch}.pt")
model.eval()
img_fit = train_loader.fit_img(model)
saveImg3D("./", f"fit3D_epoch{epoch}", img_fit)
# img_fit = readImg3D(f"E:/MyProjects/MedImageSR/results/3D results/CT/fit3D_epoch{epoch}")

fitting image 9/10

In [29]:
    # begin = 0
# end = 128
# MAE = criterion.mae(img_high[begin:end, :, :], img_fit[begin:end, :, :])
# PSNR = criterion.psnr(img_high[begin:end, :, :], img_fit[begin:end, :, :])
# SSIM = criterion.ssim(img_high[begin:end, :, :], img_fit[begin:end, :, :], multichannel=True)
MAE = criterion.mae(img_high, img_fit)
PSNR = criterion.psnr(img_high, img_fit)
SSIM = criterion.ssim(img_high, img_fit, multichannel=True)
print()
print(f"model at epoch {epoch}: MAE={MAE}, PSNR={PSNR}, SSIM={SSIM}")


model at epoch 200: MAE=7.376524120691872, PSNR=26.408324815292016, SSIM=0.7408892222841665
