In [1]:
import torch.nn as nn
import googlenet_pytorch
import torchvision.transforms as transforms

class GoogleNet(nn.Module):
    def __init__(self):
        super(GoogleNet,self).__init__()
        #self.model = torch.hub.load('pytorch/vision:v0.10.0', 'googlenet', pretrained=False, progress=True)
        self.model = googlenet_pytorch.GoogLeNet.from_pretrained('googlenet')
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.layernorm = nn.LayerNorm(1024,elementwise_affine=True)
        self._fc = nn.Linear(1024,2, bias=False)
    def forward(self, x):
        batch_size ,_,_,_ =x.shape
        x = self.model.extract_features(x)
        x = self.model.avgpool(x)
        x = x.view(-1, 1024)
        x = self.layernorm(x)
        x = self._fc(x)
        x = F.normalize(x, p=2, dim=1)
        return x
    
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

In [2]:
import torch
from PIL import Image
#print(torch.__version__)
import torch.nn.functional as F

model = GoogleNet()#*args, **kwargs)
model.load_state_dict(torch.load('model\saved_model.pt'))
model.eval()

Loaded pretrained weights for googlenet


GoogleNet(
  (model): GoogLeNet(
    (conv1): BasicConv2d(
      (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (conv2): BasicConv2d(
      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv3): BasicConv2d(
      (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (inception3a): Inception(
      (branch1): BasicConv2d(
        (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): B

In [3]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
import matplotlib.pyplot as plt
import numpy as np



explainer = lime_image.LimeImageExplainer()


pil_image_transform = transforms.Compose([
    transforms.Resize((224, 224))
])

lime_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])
def batch_predict(images):
    model.eval()
    
    batch = torch.stack(tuple(lime_transform(i) for i in images), dim=0)

    device = torch.device("cpu")
    model.to(device)
    batch = batch.to(device)
    
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [4]:
path_for_input_image = 'test_images\MI_2619_874.png'
def get_lime_exp(path_for_image):
    input_image = Image.open(path_for_image).convert('RGB')
    explanation = explainer.explain_instance(np.array(pil_image_transform(input_image)),
                                            batch_predict, batch_size=10 , num_samples=20)


    #print(explanation.local_exp)
    temp, mask = explanation.get_image_and_mask(1,positive_only=True)# 0:'Norm',1:'MI'

    img_boundry1 = mark_boundaries(temp/255.0, mask, color=[0.5,0,0])
    return img_boundry1
    #plt.imshow(img_boundry1)

In [5]:
import PySimpleGUIWeb as psg
import numpy as np
import PySimpleGUI as sg
from PIL import Image
import cv2
import io



#creating the table
toprow = ['Patient information']
rows = [['Patient ID', 7591],
        ['Age', 50],
        ['Sex', 'Female'],
        ['Height (cm)', 175],
       ['Weight (kg)', 83],
       ['Date/Time', '26/04 09:10']]
tbl1 = psg.Table(values=rows, headings=toprow,
   auto_size_columns=True,
   display_row_numbers=False,
    text_color = 'Black',
   justification='left', key='-TABLE-',
   select_mode = 'TABLE_SELECT_MODE_NONE',
   enable_events=False)

#create the image element
#image_ele = psg.Image(filename ="LIME_better_resilution_dark_red_15353.png",
 #                    size = (500, 500))
image_np = get_lime_exp(path_for_input_image)
#pil_image = Image.fromarray(np.uint8(image_np)).convert('RGB')
cv2.imwrite('temp.png',image_np*255)
#is_success, buffer_png_image = cv2.imencode(".png", image_np)
#io_buf = io.BytesIO(buffer_png_image)
image_ele = psg.Image('temp.png',
                     size = (500, 500))

#create the predicted score element
text_pred_score = ['Prediction score: 1',
             'On a scale from 0 to 1, where 0 represents the lowest level of certainty and 1 signifies the highest, the prediction score for this specific prediction is 1.']
combo_ele_pred_score = psg.Combo(text_pred_score, 
         readonly=True,
        size=(20,1),
        default_value = 'Prediction score: 1',
          key='-COMBO_PRED-')

#create the XAI method element
text_xai_method = ['XAI method: LIME',
             'Some details on LIME...'];
combo_ele_xai_method = psg.Combo(text_xai_method, 
         readonly=True,
        size=(20,1),
        default_value = 'XAI method: LIME',
          key='-COMBO_XAI_METHOD-')

#create the XAI output element
text_xai_output = ['XAI output',
             'On the left hand side in the image:',
               '1: The point has the highest likelihood of positively impacting the AI decisions',
               '0: The point has no impact on the AI decisions'];
combo_ele_xai_output = psg.Combo(text_xai_output, 
         readonly=True,
        size=(20,1),
        default_value = 'XAI output',
          key='-COMBO_XAI_OP-')

#create the prominent leads element
text_prom_leads = ['Prominent leads',
             '1: V3',
               '2: aVF',
               '3: V1'];
combo_ele_prom_leads = psg.Combo(text_prom_leads, 
         readonly=True,
        size=(20,1),
        default_value = 'XAI output',
          key='-COMBO_PROM_LEADS-')

#down_button_pred_score = psg.RealtimeButton(sg.SYMBOL_DOWN, key='-DOWN-PRED-SCORE-')
#down_button_XAI_op = psg.RealtimeButton(sg.SYMBOL_DOWN, key='-DOWN-XAI-OP-')
#down_button_xai_meth = psg.RealtimeButton(sg.SYMBOL_DOWN, key='-DOWN-XAI-METH-')
yes_button = psg.Yes(" v ", key='-YES-BUTTON-')

multiline_ele_pred = psg.Multiline(default_text = "", disabled=True, key="-MULTILINE-PRED-")

  0%|          | 0/20 [00:00<?, ?it/s]

In [6]:
import PySimpleGUIWeb as sg

# Basic example of PSGWeb

def main():
    #layout = [
     #   [tbl1, image_ele, combo_ele_xai_output],
     #   [combo_ele_pred_score, combo_ele_xai_method, combo_ele_prom_leads],
     #   [sg.Ok(), sg.Cancel()]
    #]
    yes_button_toggle = 0
    
    column1 = [[tbl1],[combo_ele_pred_score, yes_button],[sg.Text(size=(10,1), key='-TEXT_PRED_SCORE-', justification='l', pad=(0,0))],[multiline_ele_pred]]
    column2 = [[image_ele], [combo_ele_xai_method]]
    column3 = [[combo_ele_prom_leads],[combo_ele_xai_output]]
    layout = [
        [sg.Column(column1), sg.Column(column2), sg.Column(column3)],
        [sg.Ok(), sg.Cancel(), sg.Button()]
    ]
    
    window = sg.Window('Demo window..', layout)
    i = 0
    while True:
        event, values = window.read(timeout=1)
        if event != sg.TIMEOUT_KEY:
            print(event, values)
        if event == "-YES-BUTTON-" :
            if yes_button_toggle==0:
                window['-TEXT_PRED_SCORE-'].update(text_pred_score)
                window['-MULTILINE-PRED-'].update(text_pred_score, disabled=False)
                yes_button_toggle=1
            elif yes_button_toggle==1:
                window['-TEXT_PRED_SCORE-'].update("")
                window['-MULTILINE-PRED-'].update(text_pred_score, disabled=True)
                yes_button_toggle=0
        if event is None:
            break
        i += 1
    window.close()

main()
print('Program terminating normally')

-YES-BUTTON- {'-TABLE-': [None], '-COMBO_PRED-': 'Prediction score: 1', '-MULTILINE-PRED-': '', '-COMBO_XAI_METHOD-': 'XAI method: LIME', '-COMBO_PROM_LEADS-': None, '-COMBO_XAI_OP-': 'XAI output'}


AttributeError: 'list' object has no attribute 'replace'