In [None]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import h5py
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
import os
import matplotlib.pyplot as plt
from torch.optim import Adam
from PIL import Image
from tqdm import tqdm_notebook
import seaborn as sns

print(torch.__version__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class Weedread(Dataset):
    def __init__(self, name, transform=None):
        hf = h5py.File(name, 'r')
        self.input_images = np.array(hf.get('data'), np.uint8)
        self.target_labels = np.array(hf.get('labels')).astype(np.long)
        self.transform = transform
        hf.close()

    def __len__(self):
        return self.input_images.shape[0]

    def __getitem__(self, idx):
        images = self.input_images[idx]
        classes = self.target_labels[idx][1]
        family =  self.target_labels[idx][0]
        if self.transform is not None:
            images = self.transform(images)
        images = images
        
        return images, classes, family

In [None]:
INPUT_CHANNEL = 3
BATCH_SIZE = 1
normalize = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.Resize((96,96)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

classes = ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", 
           "14", "15", "16", "17", "18", "19", "20", "21"]

data_path = os.path.dirname(os.getcwd()) + "/data/weed/"
Train_data = Weedread(data_path + "train.h5", transform=normalize)
Test_data = Weedread(data_path + "val.h5", transform=normalize)

# Train_dataloader = DataLoader(dataset=Train_data,
#                               batch_size = BATCH_SIZE,
#                               shuffle=True)
Test_dataloader = DataLoader(dataset=Test_data,
                              batch_size = BATCH_SIZE,
                              shuffle=True)

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

class My_Model(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(My_Model, self).__init__()
        model = models.resnet18(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.children())[:-1]))
        set_parameter_requires_grad(self.model_ft, False)

        self.family_fc = nn.Linear(512, num_family)
        self.class_fc = nn.Linear(512, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)

In [None]:
from torchsummary import summary
train_images, _, _ = next(iter(Test_dataloader))

_model = My_Model(num_class=21, num_family = 5)
_model.to(device)
_model.load_state_dict(torch.load('epochs/ResNet18-base-line.pt'), strict=False)

print(_model)
#summary(_model, input_size= train_images[0].size())

In [None]:
def imshow(img, title):
  
    """Custom function to display the image using matplotlib"""
  
    #define std correction to be made
    std_correction = np.asarray([0.229, 0.224, 0.225]).reshape(3, 1, 1)
  
    #define mean correction to be made
    mean_correction = np.asarray([0.485, 0.456, 0.406]).reshape(3, 1, 1)
  
    #convert the tensor img to numpy img and de normalize 
    npimg = np.multiply(img.numpy(), std_correction) + mean_correction
  
    #plot the numpy image
    plt.figure(figsize = (BATCH_SIZE * 4, 4))
    plt.axis("off")
    npimg = (npimg * 255).astype(np.uint8)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.title(title)
    plt.show()
    

def show_batch_images(dataloader):
    _model.eval()
    images, cls, _ = next(iter(dataloader))
    images = images.to(device)
    outputs, _ = _model(images)
    if images.is_cuda:
        images = images.cpu()
    _, pred = torch.max(outputs.data, 1)
    #make grid
    
    img = torchvision.utils.make_grid(images)
    imshow(img, title=[classes[x.item()] for x in pred])
    
    return images, pred, cls

In [None]:
def plot_filters_single_channel_big(t):
    
    print(t.shape[:])
    nrows = t.shape[0]*t.shape[2]
    ncols = t.shape[1]*t.shape[3]
    
    npimg = np.array(t.numpy(), np.float32)
    npimg = npimg.transpose((0, 2, 1, 3))
    npimg = npimg.ravel().reshape(nrows, ncols)
    
    npimg = npimg.T
    
    fig, ax = plt.subplots(figsize=(ncols/10, nrows/200))
    imgplot = sns.heatmap(npimg, xticklabels=False, yticklabels=False, cmap='Blues', ax=ax, cbar=False)
    
def plot_filters_single_channel(t):
    
    nplots = t.shape[0]*t.shape[1]
    ncols = 10
    nrows = nplots // ncols + 1
    
    npimg = np.array(t.numpy(), np.float32)
    count = 0
    fig = plt.figure(figsize=(ncols, nrows), dpi=80)
    
    for i in range(t.shape[0]):
        for j in range(t.shape[1]):
            count += 1
            ax1 = fig.add_subplot(nrows, ncols, count)
            npimg = np.array(t[i, j].numpy(), np.float32)
            npimg = (npimg - np.mean(npimg)) / np.std(npimg)
            #remove 1 < and 0 >
            npimg = np.minimum(1, np.maximum(0, (npimg + 0.5)))
            ax1.imshow(npimg)
            ax1.set_title(str(i) + ',' + str(j))
            ax1.axis('off')
            ax1.set_xticklabels([])
            ax1.set_yticklabels([])
    
    plt.tight_layout()
    plt.show()
            

def plot_weights(model, layer_num, single_channel = True, collated = False):
  
    layer = model.model_ft[layer_num]
  
    if isinstance(layer, nn.Conv2d):
        #getting the weight tensor data
        weight_tensor = layer.weight.data
        weight_tensor = weight_tensor.cpu()
    
        if single_channel:
            if collated:
                plot_filters_single_channel_big(weight_tensor)
            else:
                plot_filters_single_channel(weight_tensor)   
    else:
        print("Can only visualize layers which are convolutional")
        


In [None]:
plot_weights(_model, 0, single_channel = True, collated = True)

In [None]:
plot_weights(_model, 1, collated = False)

In [None]:
def interpreter(model, image, label, occ_size = 50, occ_stride = 4, occ_pixel = 0):
    width, height = image.shape[-2], image.shape[-1]
    
    output_height = int(np.ceil((height-occ_size)/occ_stride))
    output_width = int(np.ceil((width-occ_size)/occ_stride))
    
    heatmap = torch.zeros((output_height, output_width))
    
    for h in range(0, height):
        for w in range(0, width):
            
            h_start = h*occ_stride
            w_start = w*occ_stride
            h_end = min(height, h_start + occ_size)
            w_end = min(width, w_start + occ_size)
            
            if (w_end) >= width or (h_end) >= height:
                continue
            
            input_image = image.clone().detach()
            
            input_image[:, :, w_start:w_end, h_start:h_end] = occ_pixel
           
            input_image = input_image.to(device)
           
            #run inference on modified image
            output, _ = model(input_image)
            prob = output.tolist()[0][label]
            
            heatmap[h, w] = prob
    return heatmap

In [None]:
images, pred, cls = show_batch_images(Test_dataloader)
print("true label: ", cls + 1)
heatmap = interpreter(_model, images, cls[0].item(), occ_size = 60, occ_stride = 2)
imgplot = sns.heatmap(heatmap, xticklabels=False, yticklabels=False, vmax=1)
figure = imgplot.get_figure()