[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

In [None]:

# how to access GDrive https://colab.research.google.com/notebooks/io.ipynb#scrollTo=RWSJpsyKqHjH
from google.colab import files, drive
import os

drive.mount('/content/drive')
gdir = os.path.join(os.getcwd(), "drive", "My Drive")


Mounted at /content/drive


## Dependencies

In [None]:
# Torch Dataset and IMNet Loading
import torch
from xml.etree import ElementTree
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import transforms
from PIL import Image, ImageDraw
from  collections import OrderedDict
import numpy as np

# Model development and training
import torchvision.models as models
import torch.nn as nn

# Filesystem and parallelization
import os
import multiprocessing

# Utility 
import time
import datetime

# Constants for parallelization
_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# from ImageNet
tree_synsets = {
    "judas": "n12513613",
    "palm": "n12582231",
    "pine": "n11608250",
    "china tree": "n12741792",
    "fig": "n12401684",
    "cabbage": "n12478768",
    "cacao": "n12201580",
    "kapok": "n12190410",
    "iron": "n12317296",
    "linden": "n12202936",
    "pepper": "n12765115",
    "rain": "n11759853",
    "dita": "n11770256",
    "alder": "n12284262",
    "silk": "n11759404",
    "coral": "n12527738",
    "huisache": "n11757851",
    "fringe": "n12302071",
    "dogwood": "n12946849",
    "cork": "n12713866",
    "ginkgo": "n11664418",
    "golden shower": "n12492106",
    "balata": "n12774299",
    "baobab": "n12189987",
    "sorrel": "n12242409",
    "Japanese pagoda": "n12570394",
    "Kentucky coffee": "n12496427",
    "Logwood": "n12496949"
}
nontree_synsets = {
   # Nontrees
    "garbage_bin": "n02747177",
    "carion_fungus": "n13040303",
    "basidiomycetous_fungus": "n13049953",
    "jelly_fungus": "n13060190",
    "desktop_computer": "n03180011",
    "laptop_computer": "n03642806",
    "cellphone": "n02992529",
    "desk": "n03179701",
    "station_wagon": "n02814533",
    "pickup_truck": "n03930630",
    "trailer_truck": "n04467665"
}
synsets = {**tree_synsets, **nontree_synsets}
print ("Device:" , _DEVICE)

Device: cuda:0



## Dataset Creation
Define datasets for ImageNet and Greenstand sources. Greenstand species classes are yet unlabeled.

In [None]:
class ImnetDataset(data.Dataset):
    
    # initialise function of class
    def __init__(self, dir, synsets, transforms=None, device=None, one_hot=False, nontrees=False):
        # the data directories
        self.img_dir = os.path.join(dir, "original_images")
        self.bb_dir = os.path.join(dir, "bounding_boxes")
        self.nontrees_present = nontrees
        #synsets library to get the associated class
        if not self.nontrees_present:  # only tree images
          self.synsets = tree_synsets
        else: # mix other things
          self.synsets = synsets
        self.rev_synsets = {y:x for x,y in zip(synsets.keys(), synsets.values())}
        self.classes = list(self.synsets.keys())

        self.one_hot = one_hot
        self.imgs = []

        for i in self.classes:
          temp_imgs = list(sorted(os.listdir(os.path.join(self.img_dir, i))))
          for img_path in temp_imgs:
            #in every directory the "tar" file is still present
            if not "tar" in img_path:
              name = img_path.split('.')[0]
              self.imgs.append(name)

        self.bb_dict = {}
        for f, _, d in os.walk(self.bb_dir):
          for file in d:
            if os.path.splitext(file)[1] == ".xml" and file.split("_")[0] in tree_synsets.values():
              tree = ElementTree.parse(os.path.join(f, file))
              root = tree.getroot()
              obj = root.find("object")
              b = obj.find("bndbox")
              xmin = int(b.find("xmin").text)
              ymin = int(b.find("ymin").text)
              xmax = int(b.find("xmax").text)
              ymax = int(b.find("ymax").text)
              self.bb_dict[os.path.join(f, file)] =  (xmin, ymin, xmax, ymax)

        self.transforms = transforms
        self.device = device
        

    def __getitem__(self, idx):
        name = self.imgs[idx]
        label = self.rev_synsets[name.split("_")[0]]
        # modify filters to determine if trees present
        is_tree = 1.0
        if self.nontrees_present:
          if label in tree_synsets.keys():
            is_tree = 1.0
          else:
            is_tree = 0.0

        img_path = os.path.join(self.img_dir, label, f"{name}.JPEG")
        bb_path = os.path.join(self.bb_dir, label, "Annotation", name.split("_")[0], f"{name}.xml")
        img = Image.open(img_path).convert("RGB")


        if bb_path in self.bb_dict.keys():
          xmin, ymin, xmax, ymax = self.bb_dict[bb_path]
        else:
          # the whole image is the bounding box label, as NoneType was causing collating issue. 
          xmin = 0
          ymin = 0
          xmax = img.size[0]
          ymax = img.size[1]
        boxes = torch.as_tensor([xmin, ymin, xmax, ymax], dtype=torch.float32)
        if not is_tree:
          boxes = torch.as_tensor([0, 0, 0, 0], dtype=torch.float32)  # 0 out nontree bounding boxes, don't want predictions for these
        
        if self.transforms is not None:
          img = self.transforms(img)

        if self.one_hot: 
          image_id = torch.zeros(len(self.classes), dtype=torch.float32)
          image_id [self.classes.index(label)] = 1.0
        else:
          image_id = torch.tensor([self.classes.index(label)])

        targets = {}
        targets["boxes"] = boxes
        targets["image_class"] = image_id
        targets["is_tree"] = is_tree
    
        return img, targets

    def __len__(self):
        return len(self.imgs)


In [None]:
class GreenstandDataset(data.Dataset):
  # We don't have labels for this yet...
  def __init__(self, dir, device, transforms, bb_dir=None, one_hot=False):
    self.img_dir = dir
    self.imgs = []
    self.bb_dir = bb_dir
    self.transforms = transforms
    self.classes = None # Change this when we define Greenstand class labels
    self.one_hot = one_hot
    self.device = device

    for f, _, d in os.walk(test_path):
      for fil in d:
        if "jpg" in fil:
          self.imgs.append(os.path.join(f, fil))

  def __getitem__(self, idx):
    img = Image.open(self.img_dir[idx])
    if self.transforms is not None:
      img = self.transforms (img)


  def __len__(self):
    return len(self.imgs)

## MobileNet-v2
See [the original paper](https://arxiv.org/pdf/1704.04861.pdf) for details. This was chosen first because Torchvision has pretrained weights and the net is quite low-latency, which may be useful for user-interface image selection. First go is to simply change the output layer to predict 4 coordinates for the bounding box.

In [None]:
mobilenet = models.mobilenet_v2(pretrained=True)
# Preprocessing required for MobileNet v2
mobilenet_preprocessing = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])




Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth


HBox(children=(FloatProgress(value=0.0, max=14212972.0), HTML(value='')))




### Instantiate datasets, define loader processes

In [None]:
path = "/content/drive/My Drive/data/imnet"
test_path = "/content/drive/My Drive/data/test_greenstand_samples"
model_path = "/content/drive/My Drive/models/ImageNet/detection/mobilenet/%s"%datetime.datetime.today().date()
if not os.path.exists(model_path):
  os.makedirs(model_path)



In [None]:
start = time.time()
data_set = ImnetDataset(path, synsets, transforms=mobilenet_preprocessing, one_hot=False, device=_DEVICE, nontrees=True)
greenstand_test = GreenstandDataset(test_path, transforms=mobilenet_preprocessing, one_hot=False, device=_DEVICE)

print ("Finished creating datasets in ", time.time() - start, " seconds ") # this can take ~60 minutes

Finished creating datasets in  2035.394990682602  seconds 


In [None]:

# Helper functions 
def rmse(x, y):
  '''
  Root-mean squared error of two vectors of the same batch 
  '''
  return torch.sqrt( (1/x.size()[0]) * torch.sum((x-y) **2))

def iou(box_a, box_b):
  # order is xmin, ymin, xmax, ymax 
  intersect_xmin = max(box_a[0], box_b[0])
  intersect_ymin = max(box_a[1], box_b[1])
  intersect_xmax = min(box_a[2], box_b[2])
  intersect_ymax = min(box_a[3], box_b[3])
  area_intersect = max(0, intersect_xmax - intersect_xmin) * max(0, intersect_ymax - intersect_ymin)

  area_a = (box_a[3] - box_a[1]) * (box_a[2] - box_a[0])
  area_b = (box_b[3] - box_b[1]) * (box_b[2] - box_b[0])
  union = area_a + area_b - area_intersect
  return area_intersect / union

class Customized_MobileNet(nn.Module):
  def __init__(self, pretrained_model):
    super().__init__()
    self.pretrained = pretrained_model
    self.pretrained.classifier = nn.Identity()
    for param in self.pretrained.parameters():
      param.requires_grad = False
    self._binary_classifier_layer()
    self._regressor_layer()

    
  def forward(self, x):
    """
    The model is performing a regression on bounding boxes and a classifier 
    """
    return self.classifier(self.pretrained(x)), self.regressor(self.pretrained(x))

  def _binary_classifier_layer(self):
    """
    Initializes final classification layer for labeling genus, species, etc.
    """
    self.classifier = nn.Sequential(
                        nn.Dropout(0.2),
                        nn.Linear(1280, 1), # 1280 is num_outputs of the last feature layer
                      ) 
    for param in self.classifier.parameters():
      param.requires_grad = True

  def _regressor_layer(self):
    """
    A bounding box output layer for predicting object location
    This is currently designed to output exactly one bounding box
    """
    self.regressor =  nn.Sequential(
                        nn.Dropout(0.2), 
                        nn.Linear(1280, 4) # 1280 is num_outputs of the last feature layer
                      )
    for param in self.regressor.parameters():
      param.requires_grad=True

class ModelTrainer():
  '''
  An abstraction to help keep track of model parameters and run training. 
  '''
  def __init__ (self, model, dataset, learning_rate, device, batch_size, model_savepath,
                gamma=1e-4, train_split=0.8, pin_memory=False, n_workers=0,
                alpha=0.5, beta=0.5
                ):
    
    self.model = model # like Customized_MobileNet
    self.model_savepath = os.path.join (model_savepath, "checkpoint.pth.tar")
    self.alpha = alpha
    self.beta = beta
    # Initialize device
    self.device = device
    if self.device == torch.device("cuda:0"):
      self.model.cuda()
    

    # Make validation split
    self.trainsize = int(train_split * len(dataset))
    self.valsize = len(dataset) - self.trainsize
    train_dataset, valid_dataset = torch.utils.data.dataset.random_split(dataset, [self.trainsize, self.valsize])

    # Define data loader for training and validation
    self.batch_size = batch_size
    self.data_loader  = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, sampler=None,
              batch_sampler=None, num_workers=n_workers, collate_fn=None,
              pin_memory=pin_memory, drop_last=False, timeout=0,
              worker_init_fn=None)

    self.val_data_loader = DataLoader(valid_dataset, batch_size=self.batch_size, shuffle=True, sampler=None,
              batch_sampler=None, num_workers=n_workers, collate_fn=None,
              pin_memory=pin_memory, drop_last=False, timeout=0,
              worker_init_fn=None)
    
    # Loss specifications and optimizer parameter setting
    # This is defined here so that the underlying model can be changed (i.e. hidden layers) 
    cps = [param for param in self.model.classifier.parameters()]
    rps = [param for param in self.model.regressor.parameters()]
    self.optimizer = torch.optim.Adam(params=cps+rps, lr=learning_rate, weight_decay=gamma)
    self.binary_classification_criterion = nn.BCEWithLogitsLoss()
    self.regression_criterion = nn.MSELoss()

    if os.path.exists(self.model_savepath):
      print ("Found saved model at savepath %s" %(self.model_savepath))
      checkpoint = torch.load(self.model_savepath)
      self.model.load_state_dict(checkpoint['model_state_dict'])
      self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
      self.start_epoch = checkpoint['epoch']
    else:
      self.start_epoch = 0


  def describe_training(self):
    print (self.trainsize, " training examples")
    print (self.valsize, " validation examples")

  def train(self, num_epochs, val_interval=1, batch_report=50, batch_lookback=10):
    '''
    Main function to train. 
    @param num_epochs(int): Number of epochs of training
    @param val_interval(int): Interval epochs between validation metric
    @param batch_report(int): Interval batches between training reports
    @param batch_lookback(int): Number of batches to use for averaging metrics in printing
    '''
    num_tr_batches = np.ceil(self.trainsize/self.data_loader.batch_size)
    num_val_batches = np.ceil(self.valsize/self.valsize)
    epoch_loss = []
    epoch_acc = []
    epoch_rmse = []
    epoch_iou = []
    val_epoch_loss = []
    val_epoch_acc = []
    val_epoch_rmse = []
    val_epoch_iou = []
    print ("Starting at epoch %d"%self.start_epoch)
    for epoch in range(self.start_epoch, num_epochs):
      epoch_start = time.time()
      print ("=" * 50)
      print ("EPOCH ", epoch)
      batch_count = 0
      batch_loss = []
      batch_acc = []
      batch_rmse = []
      batch_iou = []
      for batchx, batchy in self.data_loader:
          batch_count += 1
          # Device designation
          if self.device == torch.device("cuda:0"):
            batchx = batchx.cuda(non_blocking=True)
            batchy["boxes"] = batchy["boxes"].cuda(non_blocking=True)
            batchy["image_class"] = batchy["image_class"].cuda(non_blocking=True)
            batchy["is_tree"] = batchy["is_tree"].cuda(non_blocking=True)
          class_labels = batchy["image_class"]
          box_labels = batchy["boxes"]
          is_tree_labels = batchy["is_tree"]

          # Forward pass
          is_tree_preds, box_preds = self.model.forward(batchx)
          loss = self._loss_specification(is_tree_labels, is_tree_preds, box_labels, box_preds)
          loss.backward()
          self.optimizer.step()

          # Metrics 
          box_rmse = rmse(box_preds, box_labels)
          avg_box_iou = torch.mean(torch.as_tensor([iou(box_labels[i, :], box_preds[i, :]) for i in range(box_labels.size()[0])], dtype=torch.float32))
          binary_correct = (torch.round(is_tree_preds) == is_tree_labels.squeeze()).sum()
          acc = binary_correct/float(batchx.shape[0])
          batch_iou.append(avg_box_iou)
          batch_rmse.append(box_rmse)
          batch_acc.append(acc)
          batch_loss.append(loss.data)

          if batch_count % batch_report == 0 or batch_count == num_tr_batches:
            print ("\nLast %d Batch Avg Metrics, Batch %d/%d" %(batch_lookback, batch_count, num_tr_batches))
            print ("Total Loss: {:.3f}".format(torch.mean(torch.as_tensor(batch_loss[-batch_lookback:], dtype=torch.float32))))
            print ("Classification Acc: {:.3f}".format(torch.mean(torch.as_tensor(batch_acc[-batch_lookback:], dtype=torch.float32))))
            print ("BBox RMSE: {:.3f}".format(torch.mean(torch.as_tensor(batch_rmse[-batch_lookback:], dtype=torch.float32))))
            print ("Avg Bbox IoU: {:.3f} \n".format(torch.mean(torch.as_tensor(batch_iou[-batch_lookback:], dtype=torch.float32))))
            torch.save({
                          'epoch': epoch + 1,
                          'model_state_dict': self.model.state_dict(),
                          'optimizer_state_dict': self.optimizer.state_dict(),
                          }, 
                       self.model_savepath)
            print ("Checkpoint created")
      
      if epoch % val_interval == 0:
          print ("VALIDATION EPOCH ", epoch)
          batch_count = 0
          
          self.model.eval()
          with torch.no_grad():
            rmses = []
            ious = []
            losses = []
            class_accs = []

            for batchx, batchy in self.val_data_loader:
                batch_count += 1
                 # Device designation
                if self.device == torch.device("cuda:0"):
                  batchx = batchx.cuda(non_blocking=True)
                  batchy["boxes"] = batchy["boxes"].cuda(non_blocking=True)
                  batchy["image_class"] = batchy["image_class"].cuda(non_blocking=True)
                  batchy["is_tree"] = batchy["is_tree"].cuda(non_blocking=True)
                class_labels = batchy["image_class"]
                box_labels = batchy["boxes"]
                is_tree_labels = batchy["is_tree"]
                is_tree_preds, box_preds = self.model.forward(batchx)
                losses.append(self._loss_specification(is_tree_labels, is_tree_preds, box_labels, box_preds).data)
                class_accs.append(float((torch.round(is_tree_preds) == is_tree_labels.squeeze()).sum())/self.val_data_loader.batch_size)
                ious.append(torch.mean(torch.as_tensor([iou(box_labels[i, :], box_preds[i, :]) for i in range(box_labels.size()[0])], dtype=torch.float32)))
                rmses.append(rmse(box_preds, box_labels))

              
            losses = torch.mean(torch.as_tensor(losses, dtype=torch.float32))
            class_accs = torch.mean(torch.as_tensor(class_accs, dtype=torch.float32))
            box_rmse = torch.mean(torch.as_tensor(rmses, dtype=torch.float32))
            avg_box_iou = torch.mean(torch.as_tensor(ious, dtype=torch.float32))
            val_epoch_loss.append(losses)
            val_epoch_acc.append(class_accs)
            val_epoch_rmse.append(box_rmse)
            val_epoch_iou.append(avg_box_iou)

            # We can change this to be epoch wise or not averaged over all batches
            print ("Batch Average Val Loss: {:.3f}".format(losses))
            print ("Batch Avg Val Classification Acc: {:.3f}".format(class_accs))
            print ("Batch Avg Val BBox RMSE: {:.3f}".format(box_rmse))
            print ("Batch Avg Avg Bbox IoU: {:.3f} \n".format(avg_box_iou))
          self.model.train()
      epoch_loss.append(torch.mean(torch.as_tensor(batch_loss, dtype=torch.float32)))
      epoch_acc.append(torch.mean(torch.as_tensor(batch_acc, dtype=torch.float32)))
      epoch_iou.append(torch.mean(torch.as_tensor(batch_iou, dtype=torch.float32)))
      epoch_rmse.append(torch.mean(torch.as_tensor(batch_rmse, dtype=torch.float32)))
  
      print ("Epoch ", epoch + 1, " finished in ", time.time() - epoch_start)
    tr_metric_dict = {"Loss": epoch_loss, "Acc": epoch_acc, "IoU": epoch_iou, "RMSE": epoch_rmse}
    val_metric_dict = {"Loss": val_epoch_loss, "Acc": val_epoch_acc, "IoU": val_epoch_iou, "RMSE": val_epoch_rmse}
    torch.save({
              'epoch': epoch + 1,
              'model_state_dict': self.model.state_dict(),
              'optimizer_state_dict': self.optimizer.state_dict(),
              'tr_metric_dict': tr_metric_dict, 
              'val_metric_dict': val_metric_dict
              }, 
            self.model_savepath)
    print ("Final checkpoint created. Model dict and metrics saved. ")
    return self.model, tr_metric_dict, val_metric_dict


  def _loss_specification(self, is_tree_labels, is_tree_preds, box_labels, box_preds):
    binary_detection_error = self.binary_classification_criterion(is_tree_preds, is_tree_labels.unsqueeze(1)) # output, target
    bounding_box_error = self.regression_criterion(box_preds, box_labels)
    return self.alpha * binary_detection_error + self.beta * bounding_box_error

In [None]:
torch.multiprocessing.set_start_method("spawn", force=True)

_N_WORKERS = 0
_PIN_MEM = True

mob_model = Customized_MobileNet(pretrained_model=mobilenet)
trainer = ModelTrainer(mob_model, 
                       data_set, 
                       train_split=0.8,
                       learning_rate=0.002,
                       batch_size = 64,
                       device=_DEVICE, 
                       pin_memory=_PIN_MEM,
                       n_workers=_N_WORKERS,
                       model_savepath=model_path)
print (trainer.describe_training())


33506  training examples
8377  validation examples
None


### Check model structure and parameters
Regressor and classifier (final layer) should require grad. Others should not. Optimizer should be set only to those regressor and classifier variables. 

In [None]:
print ("Regressor params")
print ([p for p in mob_model.regressor.parameters()])

print ("Classifier params")
print ([p for p in mob_model.classifier.parameters()])

# Should output just two layers (4 variables total)
for param in mob_model.parameters():
  if param.requires_grad:
    print (param.shape)

Regressor params
[Parameter containing:
tensor([[-0.0255,  0.0035, -0.0125,  ..., -0.0158, -0.0085,  0.0181],
        [-0.0224,  0.0162,  0.0235,  ...,  0.0194, -0.0209, -0.0268],
        [ 0.0277, -0.0057, -0.0008,  ...,  0.0081,  0.0156,  0.0271],
        [ 0.0049, -0.0244,  0.0116,  ...,  0.0036,  0.0191,  0.0244]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0066, -0.0188, -0.0213,  0.0067], device='cuda:0',
       requires_grad=True)]
Classifier params
[Parameter containing:
tensor([[-0.0001,  0.0109, -0.0252,  ..., -0.0123,  0.0037,  0.0106]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([0.0167], device='cuda:0', requires_grad=True)]
torch.Size([1, 1280])
torch.Size([1])
torch.Size([4, 1280])
torch.Size([4])


In [None]:
trainer.train(num_epochs=10, batch_report=50) 

Starting at epoch 0
EPOCH  0

Last 10 Batch Avg Metrics, Batch 50/524
Total Loss: 37613.637
Classification Acc: 2.925
BBox RMSE: 542.872
Avg Bbox IoU: 0.017 

Checkpoint created





Last 10 Batch Avg Metrics, Batch 100/524
Total Loss: 30223.746
Classification Acc: 0.105
BBox RMSE: 487.777
Avg Bbox IoU: 0.114 

Checkpoint created
