In [1]:
import glob
import numpy as np
import os
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms
import cv2
import torch.nn.functional as F
from lime import lime_image
from skimage.segmentation import mark_boundaries
import torch_directml
import matplotlib.image as imgs
from transformers import BeitImageProcessor, BeitModel

In [2]:
dml = torch_directml.device()

In [3]:
class BEiTLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(BEiTLSTM, self).__init__()
        self.beit_model = BeitModel.from_pretrained("microsoft/beit-base-patch16-224")
        self.input_size = input_size
        self.hidden_size= hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        
        self.lstm = nn.LSTM(self.input_size, self.hidden_size, self.num_layers, batch_first=True, bidirectional=False)
        self.fc = nn.Sequential(nn.Linear(self.hidden_size, 150),
                                nn.Linear(150, self.num_classes))
        
    def forward(self, x):
        
        outputs = self.beit_model(x).last_hidden_state
        
        h0 = torch.zeros(self.num_layers, outputs.size(0), self.hidden_size).to(dml)
        c0 = torch.zeros(self.num_layers, outputs.size(0), self.hidden_size).to(dml)
        
        out, _  = self.lstm(outputs, (h0,c0)) #out: tensor of shape (batch size, seq_length, hidden_size)
#         print(out.shape)
        # Decode the hidden state of the last time step
        out = self.fc(out[:, -1, :])
        return out
    

input_size = 768
hidden_size = 300
num_layers = 2
num_classes = 2

base_model = BEiTLSTM(input_size, hidden_size, num_layers, num_classes).to(dml)

In [4]:
model_path = "../trained-models/beitlstmallnew.pt"
base_model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')), strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['beit_model.encoder.layer.0.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.1.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.2.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.3.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.4.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.5.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.6.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.7.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.8.attention.attention.relative_position_bias.relative_position_index', 'beit_model.encoder.layer.9.attention.attention.relative_position_bias.rela

In [5]:
transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(size = (224, 224)),
        transforms.ToTensor()])

image_processor = BeitImageProcessor.from_pretrained("microsoft/beit-base-patch16-224") 

def my_transform(image):
    pil_image = Image.fromarray(image)
    pixel_values = image_processor(pil_image, return_tensors="pt").pixel_values
    return pixel_values.squeeze()

In [6]:
def batch_predict(images):
    base_model.eval()
    batch = torch.stack(tuple(my_transform(i) for i in images), dim=0)
    batch = batch.to(dml)
    
    logits = base_model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

In [7]:
def explain_image(image):
    explainer = lime_image.LimeImageExplainer()
    explanation = explainer.explain_instance(np.array(image), 
                                            batch_predict, # classification function
                                            hide_color=0, 
                                            batch_size=20,
                                            random_seed=240,
                                            num_samples=1000)
    
    # temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=False)
    # img_boundry1 = mark_boundaries(temp/255.0, mask)
  
    temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=5, hide_rest=False)
    img_boundry2 = mark_boundaries(temp/255.0, mask)
  
    return img_boundry2

In [8]:
image_path = glob.glob("../images/misclassified-new/beitlstm/*")
saving_path = "../images/lime/beit-lstm"

for path in image_path:
    img = cv2.imread(path)
    img_name = path.split("\\")[-1]
    masked_image = explain_image(img)
    imgs.imsave(os.path.join(saving_path, img_name), masked_image)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [44]:
batch_size = 10
total_img = len(image_path)
rem_img = total_img % batch_size
predictions = []

r = batch_size
l = 0
for i in range(total_img // batch_size):
    l = i*batch_size
    r = i*batch_size + batch_size
    batch = [cv2.imread(img) for img in image_path[l:r]]
    predictions.extend(np.argmax(batch_predict(batch), axis=1))

batch = [cv2.imread(img) for img in image_path[r:]]
predictions.extend(np.argmax(batch_predict(batch), axis=1))