In [1]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import h5py
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
import os
import matplotlib.pyplot as plt
from torch.optim import Adam
from tqdm import tqdm_notebook

print(torch.__version__)
device = torch.device('cuda:1')

1.2.0


In [2]:
class Weedread(Dataset):
    def __init__(self, name, transform=None):
        hf = h5py.File(name, 'r')
        input_images = np.array(hf.get('data'), np.uint8)
        target_labels = np.array(hf.get('labels')).astype(np.long)
        family_index = np.where(target_labels[:, 0] == 0)
        
        self.input_images = input_images[family_index]
        self.target_labels = target_labels[family_index]
        self.transform = transform
        hf.close()

    def __len__(self):
        return self.input_images.shape[0]

    def __getitem__(self, idx):
        images = self.input_images[idx]
        classes = self.target_labels[idx][1]
        family =  self.target_labels[idx][0]
        if self.transform is not None:
            images = self.transform(images)
        images = images
        
        return images, classes, family

In [3]:
INPUT_CHANNEL = 3
BATCH_SIZE = 32
normalize = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.Resize((96,96)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

imagenet_classes = range(1, 22)

data_path = os.path.dirname(os.getcwd()) + "/data/weed/"
Train_data = Weedread(data_path + "train.h5", transform=normalize)
Test_data = Weedread(data_path + "val.h5", transform=normalize)

Train_dataloader = DataLoader(dataset=Train_data,
                              batch_size = BATCH_SIZE,
                              shuffle=True)
Test_dataloader = DataLoader(dataset=Test_data,
                              batch_size = BATCH_SIZE,
                              shuffle=False)

In [4]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

class My_Model(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(My_Model, self).__init__()
        model = models.resnet18(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.children())[:-3]))
        set_parameter_requires_grad(self.model_ft, False)
        
        self.average_pool = nn.AdaptiveAvgPool2d(output_size=4)
        self.family_fc = nn.Linear(4096, num_family)
        self.class_fc = nn.Linear(4096, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        x = self.average_pool(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)

In [5]:
class _loss(nn.Module):
    def __init__(self, alpha=0.5):
        super(_loss, self).__init__()
        self.alpha = alpha
        self.class_loss = nn.CrossEntropyLoss()
        self.family_loss = nn.CrossEntropyLoss()
    def forward(self, predicted_class, true_class, predicted_family, true_family):
        return self.alpha * self.class_loss(predicted_class, true_class) + \
                (1-self.alpha)*self.family_loss(predicted_family, true_family)

In [6]:
from torchsummary import summary
train_images, _, _ = next(iter(Test_dataloader))


_model = My_Model(num_class = 11, num_family = 1)
summary(_model, input_size= train_images[0].size(), device="cpu")
_model = _model.to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 64, 64]           9,408
       BatchNorm2d-2           [-1, 64, 64, 64]             128
              ReLU-3           [-1, 64, 64, 64]               0
         MaxPool2d-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
              ReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
             ReLU-10           [-1, 64, 32, 32]               0
       BasicBlock-11           [-1, 64, 32, 32]               0
           Conv2d-12           [-1, 64, 32, 32]          36,864
      BatchNorm2d-13           [-1, 64, 32, 32]             128
             ReLU-14           [-1, 64,

In [7]:
from tqdm import tqdm_notebook

optimizer = torch.optim.SGD(_model.parameters(), lr=0.01)
criterion = _loss(alpha = 1)
EPOCHS = 100

max_correct = 0
#_model.load_state_dict(torch.load('epochs/epoch-weed.pt'), strict=False)
for epoch in range(1, EPOCHS + 1):
    #training
    _model.train()
    for image, classes, family in tqdm_notebook(Train_dataloader):
        image, classes, family = image.to(device), classes.to(device), family.to(device)
        image = image.float()
        p_classes, p_family = _model(image)
        loss = criterion(p_classes, classes, p_family, family)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Loss :{:.4f} Epoch[{}/{}]'.format(loss.item(), epoch, EPOCHS))
    #testing
    _model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for image, classes, family in (Test_dataloader):
            image, classes, family = image.to(device), classes.to(device), family.to(device)
            image = image.float()
            p_classes, p_family = _model(image)
            predicted = torch.argmax(p_classes,dim=1)
            total += image.size(0)
            correct += (predicted == classes).sum().item()
        print('Test Accuracy of the model on the test images: {:.4f} %'.format(100 * correct / total))
    if(correct > max_correct):
        max_correct = correct
        torch.save(_model.state_dict(), 'epochs/ResNet-class-model1.pt')

HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :2.0328 Epoch[1/100]
Test Accuracy of the model on the test images: 50.3888 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :2.0656 Epoch[2/100]
Test Accuracy of the model on the test images: 62.1313 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7792 Epoch[3/100]
Test Accuracy of the model on the test images: 71.0340 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.8046 Epoch[4/100]
Test Accuracy of the model on the test images: 72.7304 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.9284 Epoch[5/100]
Test Accuracy of the model on the test images: 73.5916 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.9854 Epoch[6/100]
Test Accuracy of the model on the test images: 74.0985 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.8055 Epoch[7/100]
Test Accuracy of the model on the test images: 77.5490 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7235 Epoch[8/100]
Test Accuracy of the model on the test images: 80.0634 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7510 Epoch[9/100]
Test Accuracy of the model on the test images: 80.8151 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7806 Epoch[10/100]
Test Accuracy of the model on the test images: 81.4171 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6611 Epoch[11/100]
Test Accuracy of the model on the test images: 81.9412 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6638 Epoch[12/100]
Test Accuracy of the model on the test images: 82.4453 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6747 Epoch[13/100]
Test Accuracy of the model on the test images: 82.4251 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7578 Epoch[14/100]
Test Accuracy of the model on the test images: 82.7621 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6299 Epoch[15/100]
Test Accuracy of the model on the test images: 82.6671 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6593 Epoch[16/100]
Test Accuracy of the model on the test images: 82.9378 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6998 Epoch[17/100]
Test Accuracy of the model on the test images: 82.5893 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6554 Epoch[18/100]
Test Accuracy of the model on the test images: 83.1279 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7919 Epoch[19/100]
Test Accuracy of the model on the test images: 82.9724 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6950 Epoch[20/100]
Test Accuracy of the model on the test images: 83.2892 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5816 Epoch[21/100]
Test Accuracy of the model on the test images: 83.2373 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7338 Epoch[22/100]
Test Accuracy of the model on the test images: 83.3583 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7500 Epoch[23/100]
Test Accuracy of the model on the test images: 83.4361 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6591 Epoch[24/100]
Test Accuracy of the model on the test images: 83.2200 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6421 Epoch[25/100]
Test Accuracy of the model on the test images: 83.4706 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6198 Epoch[26/100]
Test Accuracy of the model on the test images: 83.5081 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6586 Epoch[27/100]
Test Accuracy of the model on the test images: 83.3554 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6193 Epoch[28/100]
Test Accuracy of the model on the test images: 83.4447 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.7179 Epoch[29/100]
Test Accuracy of the model on the test images: 83.5397 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6610 Epoch[30/100]
Test Accuracy of the model on the test images: 89.6112 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5826 Epoch[31/100]
Test Accuracy of the model on the test images: 90.1210 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5824 Epoch[32/100]
Test Accuracy of the model on the test images: 90.2189 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5819 Epoch[33/100]
Test Accuracy of the model on the test images: 90.3053 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5816 Epoch[34/100]
Test Accuracy of the model on the test images: 90.5069 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6190 Epoch[35/100]
Test Accuracy of the model on the test images: 90.4320 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[36/100]
Test Accuracy of the model on the test images: 90.5069 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5790 Epoch[37/100]
Test Accuracy of the model on the test images: 90.5012 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5888 Epoch[38/100]
Test Accuracy of the model on the test images: 90.3859 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6575 Epoch[39/100]
Test Accuracy of the model on the test images: 90.6624 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6386 Epoch[40/100]
Test Accuracy of the model on the test images: 90.5818 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5879 Epoch[41/100]
Test Accuracy of the model on the test images: 90.7028 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5816 Epoch[42/100]
Test Accuracy of the model on the test images: 90.6653 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6190 Epoch[43/100]
Test Accuracy of the model on the test images: 90.5933 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5433 Epoch[44/100]
Test Accuracy of the model on the test images: 90.7748 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6181 Epoch[45/100]
Test Accuracy of the model on the test images: 90.6740 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5825 Epoch[46/100]
Test Accuracy of the model on the test images: 90.7143 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6941 Epoch[47/100]
Test Accuracy of the model on the test images: 90.6797 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6177 Epoch[48/100]
Test Accuracy of the model on the test images: 90.7028 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6050 Epoch[49/100]
Test Accuracy of the model on the test images: 90.8180 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5434 Epoch[50/100]
Test Accuracy of the model on the test images: 90.7402 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6551 Epoch[51/100]
Test Accuracy of the model on the test images: 90.7604 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6201 Epoch[52/100]
Test Accuracy of the model on the test images: 90.7546 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5863 Epoch[53/100]
Test Accuracy of the model on the test images: 90.6250 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6171 Epoch[54/100]
Test Accuracy of the model on the test images: 90.7431 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5799 Epoch[55/100]
Test Accuracy of the model on the test images: 90.8237 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6191 Epoch[56/100]
Test Accuracy of the model on the test images: 91.0196 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5818 Epoch[57/100]
Test Accuracy of the model on the test images: 90.7805 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6187 Epoch[58/100]
Test Accuracy of the model on the test images: 90.8842 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5814 Epoch[59/100]
Test Accuracy of the model on the test images: 90.8957 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6252 Epoch[60/100]
Test Accuracy of the model on the test images: 90.7488 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5815 Epoch[61/100]
Test Accuracy of the model on the test images: 90.8641 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6200 Epoch[62/100]
Test Accuracy of the model on the test images: 90.8727 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[63/100]
Test Accuracy of the model on the test images: 90.8093 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5810 Epoch[64/100]
Test Accuracy of the model on the test images: 90.9188 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5498 Epoch[65/100]
Test Accuracy of the model on the test images: 94.6429 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5529 Epoch[66/100]
Test Accuracy of the model on the test images: 94.8934 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5465 Epoch[67/100]
Test Accuracy of the model on the test images: 95.2362 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5820 Epoch[68/100]
Test Accuracy of the model on the test images: 95.0518 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6194 Epoch[69/100]
Test Accuracy of the model on the test images: 95.1642 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5435 Epoch[70/100]
Test Accuracy of the model on the test images: 95.0864 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5433 Epoch[71/100]
Test Accuracy of the model on the test images: 95.2506 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5440 Epoch[72/100]
Test Accuracy of the model on the test images: 95.2909 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5496 Epoch[73/100]
Test Accuracy of the model on the test images: 95.4435 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5432 Epoch[74/100]
Test Accuracy of the model on the test images: 95.4695 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5433 Epoch[75/100]
Test Accuracy of the model on the test images: 95.3917 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5475 Epoch[76/100]
Test Accuracy of the model on the test images: 95.5789 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[77/100]
Test Accuracy of the model on the test images: 95.3975 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5454 Epoch[78/100]
Test Accuracy of the model on the test images: 95.6106 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6178 Epoch[79/100]
Test Accuracy of the model on the test images: 95.5760 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5815 Epoch[80/100]
Test Accuracy of the model on the test images: 95.5789 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6181 Epoch[81/100]
Test Accuracy of the model on the test images: 95.6768 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[82/100]
Test Accuracy of the model on the test images: 95.7172 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[83/100]
Test Accuracy of the model on the test images: 95.6509 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5431 Epoch[84/100]
Test Accuracy of the model on the test images: 95.6336 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.6029 Epoch[85/100]
Test Accuracy of the model on the test images: 95.2362 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5813 Epoch[86/100]
Test Accuracy of the model on the test images: 95.7402 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5818 Epoch[87/100]
Test Accuracy of the model on the test images: 95.5559 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5811 Epoch[88/100]
Test Accuracy of the model on the test images: 95.7172 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5432 Epoch[89/100]
Test Accuracy of the model on the test images: 95.6567 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5833 Epoch[90/100]
Test Accuracy of the model on the test images: 95.6682 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5933 Epoch[91/100]
Test Accuracy of the model on the test images: 95.7661 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5583 Epoch[92/100]
Test Accuracy of the model on the test images: 95.2362 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5445 Epoch[93/100]
Test Accuracy of the model on the test images: 95.6048 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5451 Epoch[94/100]
Test Accuracy of the model on the test images: 95.6567 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5816 Epoch[95/100]
Test Accuracy of the model on the test images: 95.6855 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5438 Epoch[96/100]
Test Accuracy of the model on the test images: 95.7028 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5951 Epoch[97/100]
Test Accuracy of the model on the test images: 95.3831 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5435 Epoch[98/100]
Test Accuracy of the model on the test images: 95.7488 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5445 Epoch[99/100]
Test Accuracy of the model on the test images: 95.8093 %


HBox(children=(IntProgress(value=0, max=3255), HTML(value='')))


Loss :1.5457 Epoch[100/100]
Test Accuracy of the model on the test images: 95.6164 %
