# Import Libraries

In [None]:
import os, cv2, torch
import numpy as np
from tqdm import tqdm
from torchvision.transforms.functional import to_tensor
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Student Model Structure

In [2]:
class StudentCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(64, 32, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 3, 3, padding=1), torch.nn.Tanh()  
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return (x + 1) / 2

# Loading Weights for Student Model

In [3]:

student_A = StudentCNN().to(device)
student_B = StudentCNN().to(device)
student_A.load_state_dict(torch.load(r"C:\Users\atole\OneDrive\Desktop\Python\Train_A.pth"))
student_B.load_state_dict(torch.load(r"C:\Users\atole\OneDrive\Desktop\Python\Train_B.pth"))
student_A.eval()
student_B.eval()

StudentCNN(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): Upsample(scale_factor=2.0, mode='bilinear')
  )
  (decoder): Sequential(
    (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): Tanh()
  )
)

# Loop to test the model and evaluating performance

In [None]:
test_path = r"C:\Users\atole\OneDrive\Desktop\Python\Working Dataset\test"
test_files = os.listdir(test_path)
ssim_total = 0.0

print("Running Ensemble on Test Set")

for fname in tqdm(test_files, desc="Test"):
    img = cv2.imread(os.path.join(test_path, fname))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    h, w = h - h % 8, w - w % 8
    gt = cv2.resize(img, (w, h))
    lr = cv2.resize(gt, (w//2, h//2), interpolation=cv2.INTER_LINEAR)
    input_img = cv2.resize(lr, (w, h), interpolation=cv2.INTER_CUBIC)
    input_tensor = to_tensor(input_img).unsqueeze(0).to(device)

    with torch.no_grad():
        out_A = student_A(input_tensor)
        out_B = student_B(input_tensor)
        output = (out_A + out_B) / 2.0
        output_img = output.squeeze(0).clamp(0, 1).cpu().numpy()
        gt_img = to_tensor(gt).numpy()
        output_img = np.transpose(output_img, (1, 2, 0))
        gt_img = np.transpose(gt_img, (1, 2, 0))
        score = ssim(output_img, gt_img, channel_axis=2, data_range=1.0, win_size=11)
        ssim_total += score

avg_ssim = ssim_total / len(test_files)
print(f"\n Final Ensemble SSIM on Test Set: {avg_ssim:.4f}")

Running Ensemble on Test Set


Test: 100%|██████████| 198/198 [00:43<00:00,  4.60it/s]


 Final Ensemble SSIM on Test Set: 0.9177





# Loop to showcase images 

In [None]:
showcase_path = r"C:\Users\atole\OneDrive\Desktop\Python\Working Dataset\showcase"
image_files = os.listdir(showcase_path)
ssim_total = 0.0

save_path = r"C:\Users\atole\OneDrive\Desktop\Python\Output"
os.makedirs(save_path, exist_ok=True)

print("Running Ensemble on Showcase Set:")

for i, fname in enumerate(tqdm(image_files, desc=" Showcase")):
    img = cv2.imread(os.path.join(showcase_path, fname))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    h, w = img.shape[:2]
    h, w = h - h % 8, w - w % 8
    gt = cv2.resize(img, (w, h))
    lr = cv2.resize(gt, (w//2, h//2), interpolation=cv2.INTER_LINEAR)
    input_img = cv2.resize(lr, (w, h), interpolation=cv2.INTER_CUBIC)
    input_tensor = to_tensor(input_img).unsqueeze(0).to(device)

    with torch.no_grad():
        out_A = student_A(input_tensor)
        out_B = student_B(input_tensor)
        output = (out_A + out_B) / 2.0
        output_img = output.squeeze(0).clamp(0, 1).cpu().numpy()
        gt_img = to_tensor(gt).numpy()
        output_img = np.transpose(output_img, (1, 2, 0))
        gt_img = np.transpose(gt_img, (1, 2, 0))
        score = ssim(output_img, gt_img, channel_axis=2, data_range=1.0, win_size=11)
        ssim_total += score

    output_bgr = cv2.cvtColor((output_img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    cv2.imwrite(f"{save_path}/{fname}", output_bgr)
    

avg_ssim = ssim_total / len(image_files)
print(f"\n Average SSIM on Showcase Set (Ensemble): {avg_ssim:.4f}")
print(f" Saved ensemble outputs to: {save_path}")

Running Ensemble on Showcase Set:


 Showcase: 100%|██████████| 50/50 [00:16<00:00,  2.97it/s]


 Average SSIM on Showcase Set (Ensemble): 0.9146
 Saved ensemble outputs to: C:\Users\atole\OneDrive\Desktop\Python\Output



