In [2]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import h5py
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import models
import os
import matplotlib.pyplot as plt
from torch.optim import Adam
from tqdm import tqdm_notebook

print(torch.__version__)
device = torch.device('cuda:0')

1.2.0


In [3]:
class Weedread(Dataset):
    def __init__(self, name, transform=None):
        hf = h5py.File(name, 'r')
        self.input_images = np.array(hf.get('data'), np.uint8)
        self.target_labels = np.array(hf.get('labels')).astype(np.long)
        self.transform = transform
        hf.close()

    def __len__(self):
        return self.input_images.shape[0]

    def __getitem__(self, idx):
        images = self.input_images[idx]
        classes = self.target_labels[idx][1]
        family =  self.target_labels[idx][0]
        if self.transform is not None:
            images = self.transform(images)
        images = images
        
        return images, classes, family

In [4]:
INPUT_CHANNEL = 3
BATCH_SIZE = 32
normalize = transforms.Compose([
    #transforms.ToPILImage(),
    #transforms.Resize((96,96)),
    transforms.ToTensor(),
    #transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

imagenet_classes = range(1, 22)

data_path = os.path.dirname(os.getcwd()) + "/data/weed/"
Train_data = Weedread(data_path + "train.h5", transform=normalize)
Test_data = Weedread(data_path + "val.h5", transform=normalize)

Train_dataloader = DataLoader(dataset=Train_data,
                              batch_size = BATCH_SIZE,
                              shuffle=True)
Test_dataloader = DataLoader(dataset=Test_data,
                              batch_size = BATCH_SIZE,
                              shuffle=False)

In [5]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

class My_Model(nn.Module):
    def __init__(self, input_channel=1, num_class=21, num_family = 5):
        super(My_Model, self).__init__()
        model = models.resnet18(pretrained=True)
        self.model_ft = torch.nn.Sequential(*(list(model.children())[:-1]))
        set_parameter_requires_grad(self.model_ft, False)
        print(self.model_ft)
        self.average_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.family_fc = nn.Linear(512, num_family)
        self.class_fc = nn.Linear(512, num_class)
        
    
    def forward(self, x):
        # Perform the usual forward pass
        x = self.model_ft(x)
        #x = self.average_pool(x)
        x = torch.flatten(x, 1)
        x_class = self.class_fc(x)
        x_family = self.family_fc(x)
        return F.softmax(x_class, dim=1), F.softmax(x_family, dim=1)

In [6]:
class _loss(nn.Module):
    def __init__(self, alpha=0.5):
        super(_loss, self).__init__()
        self.alpha = alpha
        self.class_loss = nn.CrossEntropyLoss()
        self.family_loss = nn.CrossEntropyLoss()
    def forward(self, predicted_class, true_class, predicted_family, true_family):
        return self.alpha * self.class_loss(predicted_class, true_class) + \
                (1-self.alpha)*self.family_loss(predicted_family, true_family)
        

In [7]:
from torchsummary import summary
train_images, _, _ = next(iter(Test_dataloader))

_model = My_Model(num_class = 21)
summary(_model, input_size= train_images[0].size(), device="cpu")
_model = _model.to(device)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [7]:
from tqdm import tqdm_notebook

optimizer = torch.optim.SGD(_model.parameters(), lr=0.01)
criterion = _loss(alpha = 0)
EPOCHS = 100

max_correct = 0
for epoch in range(1, EPOCHS + 1):
    #training
    _model.train()
    for image, classes, family in tqdm_notebook(Train_dataloader):
        image, classes, family = image.to(device), classes.to(device), family.to(device)
        image = image.float()
        p_classes, p_family = _model(image)
        loss = criterion(p_classes, classes, p_family, family)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Loss :{:.4f} Epoch[{}/{}]'.format(loss.item(), epoch, EPOCHS))
    #testing
    _model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for image, classes, family in (Test_dataloader):
            image, classes, family = image.to(device), classes.to(device), family.to(device)
            image = image.float()
            p_classes, p_family = _model(image)
            predicted = torch.argmax(p_family,dim=1)
            total += image.size(0)
            correct += (predicted == family).sum().item()
        print('Test Accuracy of the model on the test images: {:.4f} %'.format(100 * correct / total))
    if(correct > max_correct):
        max_correct = correct
        torch.save(_model.state_dict(), 'epochs/ResNet18-family-model.pt')

HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :1.0156 Epoch[1/100]
Test Accuracy of the model on the test images: 94.5621 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9429 Epoch[2/100]
Test Accuracy of the model on the test images: 96.3763 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9436 Epoch[3/100]
Test Accuracy of the model on the test images: 96.4743 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9051 Epoch[4/100]
Test Accuracy of the model on the test images: 96.5102 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9824 Epoch[5/100]
Test Accuracy of the model on the test images: 96.5197 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9433 Epoch[6/100]
Test Accuracy of the model on the test images: 96.5628 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9433 Epoch[7/100]
Test Accuracy of the model on the test images: 96.5938 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9425 Epoch[8/100]
Test Accuracy of the model on the test images: 96.5675 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9433 Epoch[9/100]
Test Accuracy of the model on the test images: 96.5484 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9412 Epoch[10/100]
Test Accuracy of the model on the test images: 96.6010 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[11/100]
Test Accuracy of the model on the test images: 96.6154 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :1.0202 Epoch[12/100]
Test Accuracy of the model on the test images: 96.6273 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9428 Epoch[13/100]
Test Accuracy of the model on the test images: 96.5938 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9050 Epoch[14/100]
Test Accuracy of the model on the test images: 96.5962 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9591 Epoch[15/100]
Test Accuracy of the model on the test images: 97.3898 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9062 Epoch[16/100]
Test Accuracy of the model on the test images: 98.8359 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9585 Epoch[17/100]
Test Accuracy of the model on the test images: 99.0893 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[18/100]
Test Accuracy of the model on the test images: 99.1084 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9240 Epoch[19/100]
Test Accuracy of the model on the test images: 99.3307 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9054 Epoch[20/100]
Test Accuracy of the model on the test images: 99.3188 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[21/100]
Test Accuracy of the model on the test images: 99.3522 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[22/100]
Test Accuracy of the model on the test images: 99.4526 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[23/100]
Test Accuracy of the model on the test images: 99.4741 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[24/100]
Test Accuracy of the model on the test images: 99.4837 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9050 Epoch[25/100]
Test Accuracy of the model on the test images: 99.5100 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9063 Epoch[26/100]
Test Accuracy of the model on the test images: 99.5291 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9048 Epoch[27/100]
Test Accuracy of the model on the test images: 99.5745 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9048 Epoch[28/100]
Test Accuracy of the model on the test images: 99.5482 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9048 Epoch[29/100]
Test Accuracy of the model on the test images: 99.5554 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9048 Epoch[30/100]
Test Accuracy of the model on the test images: 99.5267 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[31/100]
Test Accuracy of the model on the test images: 99.5626 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9049 Epoch[32/100]
Test Accuracy of the model on the test images: 99.5411 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9048 Epoch[33/100]
Test Accuracy of the model on the test images: 99.5435 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9059 Epoch[34/100]
Test Accuracy of the model on the test images: 99.5865 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))


Loss :0.9060 Epoch[35/100]
Test Accuracy of the model on the test images: 99.5554 %


HBox(children=(IntProgress(value=0, max=3922), HTML(value='')))

KeyboardInterrupt: 