In [1]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import h5py
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
import os
import matplotlib.pyplot as plt
from torch.optim import Adam
from tqdm import tqdm_notebook
from torchsummary import summary

print(torch.__version__)
device = torch.device('cuda:0')

1.2.0


In [2]:
class Weedread(Dataset):
    def __init__(self, name, transform=None, cl=None):
        hf = h5py.File(name, 'r')
        input_images = np.array(hf.get('data'), np.uint8)
        target_labels = np.array(hf.get('labels')).astype(np.long)
        if(cl == None):
            self.input_images = input_images
            self.target_labels = target_labels
        else:
            family_index = np.where(target_labels[:, 0] == cl)
            self.input_images = input_images[family_index]
            self.target_labels = target_labels[family_index]
            
        self.transform = transform
        hf.close()

    def __len__(self):
        return self.input_images.shape[0]

    def __getitem__(self, idx):
        images = self.input_images[idx]
        classes = self.target_labels[idx][1]
        family =  self.target_labels[idx][0]
        if self.transform is not None:
            images = self.transform(images)
        images = images
        
        return images, classes, family

In [3]:
INPUT_CHANNEL = 3
BATCH_SIZE = 32
normalize = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.Resize((96,96)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

data_path = os.path.dirname(os.getcwd()) + "/data/weed/"
Test_data = Weedread(data_path + "test.h5", transform=normalize)

Test_dataloader = DataLoader(dataset=Test_data,
                              batch_size = BATCH_SIZE,
                              shuffle=False)

In [4]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

class MobileNet(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(MobileNet, self).__init__()
        model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.features.children())[:19]))
        set_parameter_requires_grad(self.model_ft, False)

        self.family_fc = nn.Linear(20480, num_family)
        self.class_fc = nn.Linear(20480, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)
    
class MobileNet_16(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(MobileNet_16, self).__init__()
        model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.features.children())[:16]))
        set_parameter_requires_grad(self.model_ft, False)

        self.family_fc = nn.Linear(2560, num_family)
        self.class_fc = nn.Linear(2560, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)
    
class MobileNet_17(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(MobileNet_17, self).__init__()
        model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.features.children())[:17]))
        set_parameter_requires_grad(self.model_ft, False)

        self.family_fc = nn.Linear(2560, num_family)
        self.class_fc = nn.Linear(2560, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)
    
class MobileNet_15(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(MobileNet_15, self).__init__()
        model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.features.children())[:15]))
        set_parameter_requires_grad(self.model_ft, False)

        self.family_fc = nn.Linear(2560, num_family)
        self.class_fc = nn.Linear(2560, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)
    
class MobileNet_joint(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(MobileNet_joint, self).__init__()
        model = torchvision.models.mobilenet_v2(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.features.children())[:18]))
        print(self.model_ft)
        set_parameter_requires_grad(self.model_ft, False)
        self.average_pool = nn.AdaptiveAvgPool2d(output_size=2)

        self.family_fc = nn.Linear(2560, num_family)
        self.class_fc = nn.Linear(5120, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        for name, module in self.model_ft._modules.items():
            x = module(x)
            if(name == "16"):
                x_family = x
                
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        
        x_family = torch.flatten(x_family, 1)
        x_family = self.family_fc(x_family)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)

In [5]:
likelihood_matrix = torch.randn(5, 21).fill_(0).to(device)
likelihood_matrix[0, 0:11] = 1
likelihood_matrix[1, 11] = 1
likelihood_matrix[2, 12:14] = 1
likelihood_matrix[3, 14:19] = 1
likelihood_matrix[4, 19:21] = 1

# Case 1: Baseline Model

# Case 2: Family Prediction

# Case 3: Family 4

# Case 4: Family 5

# Case 4: Family 3

# Case 5: Family 1

# Case 2: Joint model

In [6]:
_model = MobileNet_joint(num_class = 21, num_family = 5)
_model = _model.to(device)
_model.load_state_dict(torch.load('epochs/Mobile-class-conditional-family1.pt'), strict=False)

def test(_model):
    _model.eval()
    with torch.no_grad():
        correct_family = 0
        correct_class = 0
        correct_update = 0
        total = 0
        for image, classes, family in tqdm_notebook(Test_dataloader):
            image, classes, family = image.to(device), classes.to(device), family.to(device)
            image = image.float()
            p_classes, p_family = _model(image)
            
            #Class - Conditional on Family
            max_value, max_index = torch.max(p_family, dim=1)
            F_YX = likelihood_matrix[max_index]
            F_X = max_value.view(max_value.size(0), 1)
            Y_X = p_classes
            Y_FX = (F_YX * Y_X ) / F_X
            
            predicted_family = torch.argmax(p_family, dim=1)
            predicted_class = torch.argmax(p_classes, dim=1)
            predicted_update = torch.argmax(Y_FX, dim=1)
            
            correct_class += (predicted_class == classes).sum().item()
            correct_family += (predicted_family == family).sum().item()
            correct_update += (predicted_update == classes).sum().item()
            
            total += image.size(0)
            
        print('Test Accuracy of the model on the test images (Class): {:.4f} %'.format(100 * correct_class / total))
        print('Test Accuracy of the model on the test images (Family): {:.4f} %'.format(100 * correct_family / total))
        print('Test Accuracy of the model on the test images (Update): {:.4f} %'.format(100 * correct_update / total))

test(_model)

Sequential(
  (0): ConvBNReLU(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU6(inplace=True)
  )
  (1): InvertedResidual(
    (conv): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      )
      (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (2): InvertedResidual(
    (conv): Sequential(
      (0): ConvBNReLU(
        (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU6(inplace=True)
      

HBox(children=(IntProgress(value=0, max=1308), HTML(value='')))


Test Accuracy of the model on the test images (Class): 96.7689 %
Test Accuracy of the model on the test images (Family): 99.8064 %
Test Accuracy of the model on the test images (Update): 96.7569 %


# Case 3: Sperated model