In [None]:
import torch
from torchvision import models, transforms, datasets
import numpy as np
import argparse
import cv2
from tqdm import tqdm
import math
from torchvision import models
import torch
import torchnet.meter.confusionmeter as cm
import dask.array as da
import matplotlib.pyplot as plt
import os
import numbers

In [None]:
ckpt_path=r"/home/woody/iwi5/iwi5095h/CNN-multiclass-classification-results/checkpoint/resnet/fold_1/model_ft.pt"
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def load_model():
    model_ft = models.resnet18()
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = torch.nn.Linear(num_ftrs, 7)
    # model_ft=torch.load(ckpt_path)['model_state_dict']
    state = torch.load(ckpt_path, map_location=torch.device(device))
    model_ft.load_state_dict(state["model_state_dict"])
    model_ft.eval()
    return model_ft

model_ft = load_model()
model_ft = model_ft.to(device)

In [None]:
def view_as_windows_new(arr_in, window_shape):
    ndim = arr_in.ndim

    if isinstance(window_shape, numbers.Number):
        window_shape = (window_shape,) * ndim
    if not (len(window_shape) == ndim):
        raise ValueError("`window_shape` is incompatible with `arr_in.shape`")

    arr_shape = np.array(arr_in.shape)
    plt.imshow(arr_in,cmap='gray')
    window_shape = np.array(window_shape, dtype=arr_shape.dtype)

    if (window_shape > arr_shape).any():
        # padd with zeros in that direction
        if window_shape[0] > arr_shape[0]:
            top = int((window_shape[0] - arr_shape[0]) / 2)
            bottom = window_shape[0] - arr_shape[0] - top
            arr_in = cv2.copyMakeBorder(
                arr_in, top, bottom, 0, 0, cv2.BORDER_REPLICATE, value=0
            )

        elif window_shape[1] > arr_shape[1]:
            left = int((window_shape[1] - arr_shape[1]) / 2)
            right = window_shape[1] - arr_shape[1] - left
            arr_in = cv2.copyMakeBorder(
                arr_in, 0, 0, left, right, cv2.BORDER_REPLICATE, value=0
            )

    arr_shape = np.array(arr_in.shape)

    patch_shape = (955, 783)
    num_patches_x, num_patches_y = [
        int(np.ceil(arr_in.shape[i] / patch_shape[i])) for i in range(ndim)
    ]
    print("num_patches_x", num_patches_x)
    print("num_patches_y", num_patches_y)
    arr_out = da.from_array(arr_in, chunks=patch_shape)

    patches = []

    for i in range(num_patches_x):
        for j in range(num_patches_y):
            patch = arr_out[
                i * patch_shape[0] : min((i + 1) * patch_shape[0], arr_out.shape[0]),
                j * patch_shape[1] : min((j + 1) * patch_shape[1], arr_out.shape[1]),
            ]
            if patch.shape[0] < patch_shape[0] and patch.shape[1] == patch_shape[1]:
                # padd with zeros in that direction
                buffer = patch_shape[0] - patch.shape[0]
                patch = da.pad(
                    patch, ((0, buffer), (0, 0)), mode="constant", constant_values=0
                )
            elif patch.shape[1] < patch_shape[1] and patch.shape[0] == patch_shape[0]:
                buffer = patch_shape[1] - patch.shape[1]
                patch = da.pad(
                    patch, ((0, 0), (0, buffer)), mode="constant", constant_values=0
                )
            elif patch.shape[0] < patch_shape[0] and patch.shape[1] < patch_shape[1]:
                buffer_x = patch_shape[0] - patch.shape[0]
                buffer_y = patch_shape[1] - patch.shape[1]
                patch = da.pad(
                    patch,
                    ((0, buffer_x), (0, buffer_y)),
                    mode="constant",
                    constant_values=0,
                )
            patch_numpy = patch.compute()
            patches.append(patch_numpy)

    return patches

In [None]:
class CustomImageFolder(datasets.ImageFolder):
    def __init__(self, root, transform=None, target_transform=None):
        super().__init__(root, transform, target_transform)

    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = cv2.imread(path, -1)
        print("sample shape", sample.shape)
        patch = view_as_windows_new(sample, (955, 783))
        p = []
        for i in range(len(patch)):
            patch[i] = cv2.cvtColor(patch[i], cv2.COLOR_GRAY2RGB)
            patch[i] = torch.from_numpy(patch[i].astype(np.float32))

            if self.transform is not None:
                patch[i] = self.transform(patch[i])
            if self.target_transform is not None:
                target = self.target_transform(target)
            p.append(patch[i])

In [None]:
class CustomTransform(object):
    def __call__(self, image_tensor):
        image_tensor = (image_tensor - image_tensor.min()) / (
            image_tensor.max() - image_tensor.min()
        )
        image_tensor = image_tensor.permute(-1, 0, 1)
        return image_tensor

In [None]:
def label_to_class(label):
    # create a dictionary of labels to class
    label_to_class = {
        0: "FX00_Sporadic_line_artefacts",
        1: "FX02_Group_of_line_artefact",
        2: "FX03_partly_brighter",
        3: "FX04_group_of_defect_line",
        4: "FX05_Defect_line",
        5: "FX07_Stripes",
        6: "Good_image",
    }
    return label_to_class[label]

In [None]:
def classify_batch(predictions):
    unique_values = set(predictions)    

    if len(unique_values) == 1:
        return predictions[0]

    elif predictions.count(4) == 3:
        return torch.Tensor([4])
    elif predictions.count(3) == 3:
        return torch.Tensor([3])
    elif predictions.count(2) == 3:
        return torch.Tensor([2])
    elif predictions.count(1) == 3:
        return torch.Tensor([1])
    elif 6 in predictions:
        predictions = list(filter(lambda item: item != 6, predictions))
        if 5 in predictions:
            predictions = list(filter(lambda item: item != 5, predictions))
            if predictions == []:
                return torch.Tensor([5])
            else:
                return max(set(predictions), key=predictions.count)
        return max(set(predictions), key=predictions.count)
    else:
        return max(set(predictions), key=predictions.count)

In [None]:

image=cv2.imread('/home/woody/iwi5/iwi5095h/experiments/postprocessed/1.tif',-1)
patches = view_as_windows_new(image, (955, 783))


In [None]:
prediction = []
probibity = []  
confidence = []
for i in range(len(patches)):
    #print('patches[i]',patches[i].shape)
    patch=patches[i]
    #min max normalization
    patch = (patch - patch.min()) / (patch.max() - patch.min())
    expanded_array = patch[np.newaxis, :, :]

# Duplicate the channel to create a (3, 955, 783) array
    patch = np.repeat(expanded_array, 3, axis=0)
    patch = patch[np.newaxis, :, :, :]

    patch = torch.from_numpy(patch.astype(np.float32))
    patch = patch.to(device)
    output = model_ft(patch)
    probs = torch.nn.functional.softmax(output, dim=1)
    conf, preds = torch.max(probs, 1)
    # _, preds = torch.max(output, 1)
    prediction.append(preds.item())
    confidence.append(conf.item())
    probibity.append(probs)
probability_list = [
                [round(value, 1) for value in tensor[0].tolist()]
                for tensor in probibity
            ]
final_prediction = classify_batch(prediction)

print(
    "prediction and its confidence", classify_batch(prediction), np.mean(confidence)
)

In [None]:
def visualize_patch(
    data, prediction, confidence,  final_prediction,target
):
    num_patches = len(data)
    x = int(np.ceil(np.sqrt(num_patches)))
    y = int(np.ceil(num_patches / x))
    fig, axes = plt.subplots(x, y, figsize=(12, 6))  # Adjust figsize as needed

    plt.subplots_adjust(hspace=0.5, wspace=0.3)

    for i in range(num_patches):
        patch = data[i]

        if len(axes.shape) == 1:
            ax = axes
        else:
            ax = axes[i // y, i % y]
        ax.imshow(patch, cmap="gray")
        ax.set_title(f"Patch {i+1}")
        ax.axis("off")  # Turn off axis labels and ticks

        # Calculate text position
        text_x = 0.5
        text_y_prob = 5  # Place probability above the image
        text_y_confidence = -0.15  # Place confidence below the image

        # ax.text(text_x, text_y_prob, f"Probability: {probability[i]}", ha="center")
        ax.text(text_x, text_y_prob, f"Conf: {confidence[i]}", ha="center")

    fig.suptitle(
        f"Final Prediction: {label_to_class(final_prediction.item())} \n Target: {label_to_class(target.item())}"
    )
    plt.imshow(fig,cmap='gray')
    #plt.savefig(os.path.join(r"experiments\postprocessed\results", 'A.png'))

In [None]:
num_patches = len(patches)
print("num_patches", num_patches)

y = int(np.ceil(np.sqrt(num_patches)))
x = int(np.ceil(num_patches / y))
fig, axes = plt.subplots(x, y, figsize=(12, 12),gridspec_kw={"top": 0.95})  # Adjust figsize as needed
fig.suptitle(
    " Target: 'FX00_Sporadic_line_artefacts' "
)
indiviual_prediction=[label_to_class(p) for p in prediction]
for i in range(num_patches):
    ax = axes[i // y, i % y]
    ax.imshow(patches[i], cmap="gray")
    ax.set_title(f"Patch {i+1}")
    ax.text(0.5, -0.06, f'Prob: {probability_list[i]}', ha='center',va='bottom', transform=ax.transAxes)
    ax.text(0.5, -0.07, f'Prediction: {indiviual_prediction[i]}', ha='center',va='top', transform=ax.transAxes)
    ax.axis("off")  # Turn off axis labels and ticks

plt.show()
