In [None]:
# Import required packages
import glob
import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import matplotlib.pyplot as plt


import sys
sys.path.append('../utils')

from functions import *
from inference_set import inference_set
from cnn_model import SmallResNet18_Input64_Contrastive,OTCNN_Input64_with_pretrained_ResNet18_contrastive

In [None]:
# Use CUDA if exists
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'CHECK devices {device}')

In [None]:
# Load example inputs
sample_paths = sorted(glob.glob("Sample_data/*"))
input_images = np.empty([len(sample_paths),2,100,100])

for idx, path in enumerate(tqdm(sample_paths)):
    #print(path)
    sample = np.load(path, allow_pickle = True)

    input_images[idx] = sample
    print(path)

input_images.shape


In [None]:
# Setup for ResNet18 and VGG models
img_size = 64
BATCH_SIZE_eval = 1

# Load contrastive-based model

contrastive_model = 'pretrained_models/CT_Resnet18.pt'
cont_model = SmallResNet18_Input64_Contrastive(num_classes=1, input_channel=2)
cont_model = cont_model.to(device=device) 
cont_model.load_state_dict(torch.load(contrastive_model))



# Load classification-based model
cls_head_model = 'pretrained_models/CTN2N_ResNet18_OTCNN.pt'
cont_model = cont_model.to(device=device) 
cont_model.load_state_dict(torch.load(contrastive_model))
cont_model = cont_model.float()
model = OTCNN_Input64_with_pretrained_ResNet18_contrastive(contrastive_extractor=cont_model,num_classes=2, input_channel=2)

# Dataloader Settings
scaler = 'std_scl'

add_features = None


test_eval = inference_set(images=input_images, transform=None, mode=add_features, local_scaler = scaler, default_size=img_size)#.to(device=device) #, transform=tranform_train change here
test_loader_eval = DataLoader(test_eval, batch_size=BATCH_SIZE_eval, shuffle=False)

model = model.to(device=device) 
model = model.double()



In [None]:
# Inference


model.eval()
with torch.no_grad():
    # model_name_from_path = os.path.basename(model_pt).split('_')#.replace('.','_')
    # save_at = model_name_from_path[-1].split('.')[0]
    # model_name_from_path = '_'.join(model_name_from_path[:-1]+[save_at])
    print('Get Test set score.....')
    pred_result = []

    # ======================== Val score  ======================== #
    for idx, imgs in enumerate(tqdm(test_loader_eval)):


        images = imgs.to(device=device)
        outputs = model(images)



        if  torch.is_tensor(pred_result):
            # target_result = torch.cat((target_result,labels.to(torch.int8)),0)
            pred_result = torch.cat((pred_result,outputs),0)

        else:
            # target_result = labels.to(torch.int8)
            pred_result = outputs

    print(pred_result.size())

    softmax = nn.Softmax(dim=1)

    all_pred_result = softmax(torch.Tensor(pred_result))
    all_pred_result = all_pred_result[:,1].cpu()


    threshold = 0.5
    all_pred_result= np.array(all_pred_result).reshape(-1)


    pred_cls = np.where(all_pred_result>threshold,1,0)


print(pred_cls)

In [None]:
# Visualization

fig, ax = plt.subplots(len(input_images),2, figsize=(5,10),sharex=True,sharey=True)

for row in range(len(input_images)):

    sample = input_images[row]

    # sample = (sample-np.min(sample))/np.max(sample)
    # sample /= np.max(sample)


    # ax[row,0].imshow(sample[0], vmin=glob_min, vmax=glob_max, cmap='hot')
    # ax[row,1].imshow(sample[1], vmin=glob_min, vmax=glob_max, cmap='hot')
    

    ax[row,0].imshow(sample[0], vmin=np.min(sample)*0.25, vmax=np.max(sample))
    # ax[row,0].scatter(crop_size/2, crop_size/2, s=80, facecolors='none', edgecolors='y')

    ax[row,1].imshow(sample[1], vmin=np.min(sample)*0.25, vmax=np.max(sample))
    # ax[row,1].scatter(crop_size/2, crop_size/2, s=80, facecolors='none', edgecolors='y')


    ax[row,0].axis('off')
    ax[row,1].axis('off')

    # if row == 0:
    #     ax[row,0].set_title('Template',  fontsize=20)
    #     ax[row,1].set_title('Science',  fontsize=20)
    ax[row,0].set_title(f'Template: class {pred_cls[row]}',  fontsize=10)
    ax[row,1].set_title(f'Science: class {pred_cls[row]}',  fontsize=10)




    # print(f'Index = {index}, min = {np.min(sample)}, max = {np.max(sample)}')

# fig.subplots_adjust(wspace=0,hspace=0)
plt.tight_layout()

plt.show()