In [None]:
import os
import cv2
import random
import numpy as np
import matplotlib.pyplot as plt

import matplotlib.patches as mpatches

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.models import segmentation

from deeplab_utils.utils import get_iou_score
from deeplab_utils.dataset_utils import load_data, train_test_split, ImageDataset

In [None]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
cudnn.deterministic = True
cudnn.benchmark = True


In [None]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize([128, 128]),
    ])

root = '../dataset/image'
data = load_data(root)
_, test_data = train_test_split(data)

test_dataset = ImageDataset(test_data, transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
model_path = 'output/2024-06-17/12_41_33/best_model.pt'

model = segmentation.deeplabv3_resnet101(weights='DEFAULT')
model.classifier[4] = nn.Conv2d(256, 2, kernel_size=(1, 1), stride=(1, 1))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model, device_ids=[0, 1, 2])
model.load_state_dict(torch.load(model_path))
model.to(device)

In [None]:
model.eval()
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(test_loader):
        data, target = data.to(device), target.to(device)

        output = model(data)['out']
        miou = get_iou_score(output, target)
        
        if miou < 0.8:
            print(miou)

            preds = output[:, 1, :, :]
            labels = target[:, 1, :, :]

            preds = torch.sigmoid(preds) > 0.5  # Apply sigmoid and threshold
            preds = preds.long()
            labels = labels.long()

            for i in range(data.size()[0]):
                ori_img = data[i].cpu().numpy().transpose(1, 2, 0) * 255.0
                ori_img = ori_img.astype(np.uint8)

                label_img = labels[i].cpu().numpy().squeeze() * 255.0
                label_img = label_img.astype(np.uint8)
                
                out_img = preds[i].cpu().numpy().squeeze() * 255.0
                out_img = out_img.astype(np.uint8)

                # Create overlay images with transparent colors
                label_overlay = np.zeros_like(ori_img, dtype=np.uint8)
                label_overlay[label_img > 128] = [0, 0, 255]  # Blue color for label

                out_overlay = np.zeros_like(ori_img, dtype=np.uint8)
                out_overlay[out_img > 128] = [255, 0, 0]  # Red color for output

                # Convert original image to RGB
                ori_img_rgb = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)

                # Apply transparency to the overlays
                alpha = 0.5
                combined_img = cv2.addWeighted(ori_img_rgb, 1, label_overlay, alpha, 0)
                combined_img = cv2.addWeighted(combined_img, 1, out_overlay, alpha, 0)

                # Plot original and overlayed images
                fig, ax = plt.subplots(1, 4, figsize=(20, 5))

                ax[0].imshow(ori_img_rgb)
                ax[0].set_title('Original Image')
                ax[0].axis('off')

                ax[1].imshow(label_img, cmap='gray')
                ax[1].set_title('Label Image')
                ax[1].axis('off')

                ax[2].imshow(out_img, cmap='gray')
                ax[2].set_title('Output Image')
                ax[2].axis('off')

                ax[3].imshow(combined_img)
                ax[3].set_title('Overlay Image')
                ax[3].axis('off')

                plt.show()

