In [1]:
import torch
import torchvision
import os
import pickle
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import v2
import matplotlib.pyplot as plt
import torch.nn as nn
from PIL import Image
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image



parent_dir = 'rois2/'
obj_dir = parent_dir + 'objects/'
img_dir = parent_dir + 'images/'
label_dir = parent_dir + 'labels/'
model_dir = parent_dir + 'models/'

obj_train_dir = obj_dir + 'train/'
obj_test_dir = obj_dir + 'test/'
false_positive_label_dir = parent_dir + 'tiles/false_positives/labels/'
if not os.path.exists(obj_dir):
    os.makedirs(obj_dir)

In [2]:
mean_std_dict = pickle.load(open(obj_train_dir + 'mean_std.pkl', 'rb'))

In [3]:
def crop_name_to_roi_name(crop_name):
    return (crop_name.split('__')[0]).split('\\')[-1]

def custom_loader(path):
    return path

class CustomDataset_train(Dataset):

    def __init__(self, inputs_list):
        
        self.inputs_list = inputs_list

    def __len__(self):
        return len(self.inputs_list)

    def __getitem__(self, idx):
        # print(self.inputs_list[idx])
        path = self.inputs_list[idx][0]
        roi_name = crop_name_to_roi_name(path)
        if roi_name in mean_std_dict.keys():
            mean, std = mean_std_dict[roi_name]
        else:
            print('mean std not found')
            print(roi_name)
            
        label = self.inputs_list[idx][1]
        image = Image.open(path)
        
        transf = v2.Compose([
            v2.ToTensor(),
            v2.Resize((260,260)),
            v2.RandomHorizontalFlip(p=0.5),
            v2.RandomVerticalFlip(p=0.5),
            v2.RandomAffine(degrees=(0,90), translate=(0.1,0.3), scale=(0.5,0.75)),
            v2.Normalize(mean=mean, std=std), # aim to make the mean = 0 and std = 1
           
        ])
        image = transf(image)

        return image, label

class CustomDataset_test(Dataset):

    def __init__(self, inputs_list):
        
        self.inputs_list = inputs_list

    def __len__(self):
        return len(self.inputs_list)

    def __getitem__(self, idx):
        # print(self.inputs_list[idx])
        path = self.inputs_list[idx][0]
        roi_name = crop_name_to_roi_name(path)
        if roi_name in mean_std_dict.keys():
            mean, std = mean_std_dict[roi_name]
        else:
            print('mean std not found')
            print(roi_name)
            
        label = self.inputs_list[idx][1]
        image = Image.open(path)
        image.convert('RGB')
        
        transf = v2.Compose([
            v2.ToTensor(),
            v2.Resize((260,260)),
            v2.Normalize(mean=mean, std=std), 
        ])
        image = transf(image)

        return image, label

In [4]:
orig_test_dataset = datasets.ImageFolder(root=obj_test_dir,loader=custom_loader)
test_dataset = CustomDataset_test(orig_test_dataset.samples)
val_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

In [5]:
def torch_cam_test(model, val_loader, begin_step = 0):
    
    model.to('cuda')
    
    # Validation
    model.eval()
    print('Validation...')
    
    # Calculate accuracy
    correct = 0
    total = 0
    # Calculate recall
    correct_pos_obj = 0
    total_pos_obj = 0
    
    for i, (image, label) in enumerate(val_loader):
        if i < begin_step:
            continue
        image = image.to('cuda')
        label = label.to('cuda')
        
        cam_extractor = SmoothGradCAMpp(model, target_layer=model.features[-1])
        
        output = model(image)
        
        activation_map = cam_extractor(output.argmax().item(), output)
        
        _, predicted = torch.max(output.data, 1)
        
        total += 1
        if (predicted.item() == label.item()):
            correct += 1
        
        if(label.item() == 1):
            total_pos_obj += 1
            if(predicted.item() == 1):
                correct_pos_obj += 1
        
        # Apply torchCAM to high-score false positive images
        
        if label == 0 and predicted.item() == 1 and output[0,1].item() > 0.8:
            # Resize the CAM and overlay it
            result = overlay_mask(to_pil_image(image.squeeze(0)), to_pil_image(activation_map[0].squeeze(0), mode='F'), alpha=0.5)
            # Display it
            false_positive_cam_dir = obj_dir + 'false_postives_cam/'
            if not os.path.exists(false_positive_cam_dir):
                os.makedirs(false_positive_cam_dir)
            plt.imshow(result); plt.axis('off'); plt.tight_layout()
            plt.savefig(false_positive_cam_dir + 'epoch' + '_batch' + str(i) + '.png')
            plt.close('all')
            
        cam_extractor.remove_hooks()
        
        
        if (i+1)%2000 == 0 or i == len(val_loader)-1: 
            print('      Image [{}/{}], Accuracy: {:.2f} %, Recall: {:.2f} %'
                    .format( i+1, len(val_loader), 100 * correct / total, 100 * correct_pos_obj / total_pos_obj))

            
    print('Accuracy of the network on the validation images: {} % \n'.format(100 * correct / total))
    print('Recall of the network on the validation images: {} % \n'.format(100 * correct_pos_obj / total_pos_obj))
        

In [6]:
model = torchvision.models.efficientnet_b1().to('cuda')
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 2).to('cuda')
model.load_state_dict(torch.load('rois2/models/round3/2024_08_28_lr0.001 decay0.01 weight 1-5/round3_model_epoch17.pth'))

<All keys matched successfully>

In [7]:
torch_cam_test(model, val_loader, 10000)

Validation...




      Image [12000/16602], Accuracy: 98.70 %, Recall: 98.89 %
      Image [14000/16602], Accuracy: 98.33 %, Recall: 99.28 %
      Image [16000/16602], Accuracy: 98.30 %, Recall: 99.06 %
      Image [16602/16602], Accuracy: 98.35 %, Recall: 98.89 %
Accuracy of the network on the validation images: 98.34898515601333 % 

Recall of the network on the validation images: 98.89267461669506 % 

