<a href="https://github.com/" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<!-- href="https://colab.research.google.com/github/cnielly/prototypical-networks-omniglot/blob/master/prototypical_networks_pytorch_omniglot.ipynb" -->

# Few-shot learning via multi-layer feature fusion and relative entropy for maize crops insect classification

## Import Packages

In [21]:
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from scipy import ndimage
import multiprocessing as mp

import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

#Check GPU support, please do activate GPU
print(torch.cuda.is_available())

## Import the dataset

In [22]:
!gdown --id 1DkabY6V-2WGkgAiXRmH4DtEE5Y4ZrHiH

In [23]:
!unzip insect_maize_dataset.zip

## Read the dataset

In [4]:
def read_classes(class_path, class_name,nH): 
    datax = []
    datay = []
    images = os.listdir(class_path)
    for img in images:
        image = cv2.resize(cv2.imread(class_path + '/' + img),(nH,nH))

        # Rotate images to create new classes
        rotated_90 = ndimage.rotate(image, 90)
        rotated_180 = ndimage.rotate(image, 180)
        rotated_270 = ndimage.rotate(image, 270)
        datax.extend((image, rotated_90, rotated_180, rotated_270))
        datay.extend((
            class_name + '_0',
            class_name + '_90',
            class_name + '_180',
            class_name + '_270'
        ))
        
    return np.array(datax), np.array(datay)

In [5]:
def read_images(base_directory, nH):
    """
    Reads all classes from the base_directory
    Uses multithreading to decrease the reading time
    """
    datax = None
    datay = None
    
    pool = mp.Pool(mp.cpu_count())
    results = [pool.apply(read_classes, args=(base_directory + '/' + directory + '/', directory, nH,
                          )) for directory in os.listdir(base_directory)]
    pool.close()

    for result in results:
        if datax is None:
            datax = result[0]
            datay = result[1]
        else:
            datax = np.vstack([datax, result[0]])
            datay = np.concatenate([datay, result[1]])
    return datax, datay

**Define image size**

In [6]:
nH = 96 # image.shape = [nH,nH,3]

**Read data**

In [7]:
path_train = '/content/images/source'
path_test = '/content/images/target'

In [25]:
%%time 
trainx, trainy = read_images(path_train,nH)

In [26]:
%%time 
testx, testy = read_images(path_test, nH)

In [27]:
trainx.shape, trainy.shape, testx.shape, testy.shape

## Create samples

In [11]:
def extract_sample(n_way, n_support, n_query, datax, datay):
  """
  Picks random sample of size n_support+n_querry, for n_way classes
  Args:
      n_way (int): number of classes in a classification task
      n_support (int): number of labeled examples per class in the support set
      n_query (int): number of labeled examples per class in the query set
      datax (np.array): dataset of images
      datay (np.array): dataset of labels
  Returns:
      (dict) of:
        (torch.Tensor): sample of images. Size (n_way, n_support+n_query, (dim))
        (int): n_way
        (int): n_support
        (int): n_query
  """
  sample = []
  K = np.random.choice(np.unique(datay), n_way, replace=False)
  for cls in K:
    datax_cls = datax[datay == cls]
    perm = np.random.permutation(datax_cls)
    sample_cls = perm[:(n_support+n_query)]
    sample.append(sample_cls)
  sample = np.array(sample)
  sample = torch.from_numpy(sample).float()
  sample = sample.permute(0,1,4,2,3)
  return({
      'images': sample,
      'n_way': n_way,
      'n_support': n_support,
      'n_query': n_query
      })

In [53]:
def display_sample(sample):
  """
  Displays sample in a grid
  Args:
      sample (torch.Tensor): sample of images to display
  """
  # Need 4D tensor to create grid, currently 5D
  sample_4D = sample.view(sample.shape[0]*sample.shape[1],*sample.shape[2:])
  # Make a grid
  out = torchvision.utils.make_grid(sample_4D, nrow=sample.shape[1])
  plt.figure(figsize = (16,7))
  plt.imshow(out.permute(1, 2, 0))

Display a sample

*   n_way = 5
*   n_support = 4
*   n_query = 3

In [57]:
sample_example = extract_sample(5, 4, 3, trainx, trainy)
# display_sample(sample_example['images'])
display_sample(F.normalize(sample_example['images'], p=10, dim=0))
sample_example['images'].shape # [N, k+q, chanels, width, height]

## Build model

### Model

In [14]:
class Flatten(nn.Module):
  def __init__(self):
    super(Flatten, self).__init__()

  def forward(self, x):
    return x.view(x.size(0), -1)

def load_protonet_conv(**kwargs):
  """
  Loads the model
  Arg:
      x_dim (tuple): dimension of input image
      hid_dim (int): dimension of hidden layers in conv blocks
      z_dim (int): dimension of embedded image
  Returns:
      Model (Class ProtoNet)
  """
  x_dim = kwargs['x_dim']
  hid_dim = kwargs['hid_dim']
  z_dim = kwargs['z_dim']

  def conv_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Dropout(0.1),
        nn.MaxPool2d(3)
        )
    
  encoder = nn.Sequential(
        conv_block(x_dim[0], hid_dim[0]),
        conv_block(hid_dim[0], hid_dim[1]),
        conv_block(hid_dim[1], hid_dim[2]),
        conv_block(hid_dim[2], z_dim),
        Flatten()
  )
    
  return ProtoNet(encoder)

### Encoder

In [15]:
class ProtoNet(nn.Module):
  def __init__(self, encoder):
    """
    Args:
        encoder : CNN encoding the images in sample
        n_way (int): number of classes in a classification task
        n_support (int): number of labeled examples per class in the support set
        n_query (int): number of labeled examples per class in the query set
    """
    super(ProtoNet, self).__init__()
    # self.encoder1 = nn.Sequential(*list(encoder.children())[:1],
    #                     nn.MaxPool2d(5),
    #                     nn.MaxPool2d(5),
    #                     Flatten()
    #                     ).cuda()
    # self.encoder2 = nn.Sequential(*list(encoder.children())[:2],
    #                     nn.MaxPool2d(3),
    #                     nn.MaxPool2d(3),
    #                     Flatten()
    #                     ).cuda()
    # self.encoder3 = nn.Sequential(*list(encoder.children())[:3],
    #                     nn.MaxPool2d(3),
    #                     Flatten()
    #                     ).cuda()
    self.encoder = encoder.cuda()

  def set_forward_loss(self, sample):
    """
    Computes loss, accuracy and output for classification task
    Args:
        sample (torch.Tensor): shape (n_way, n_support+n_query, (dim))
    Returns:
        torch.Tensor: shape(2), loss, accuracy and y_hat
    """
    sample_images = sample['images'].cuda()
    n_way = sample['n_way']
    n_support = sample['n_support']
    n_query = sample['n_query']

    x_support = sample_images[:, :n_support]
    x_query = sample_images[:, n_support:]
   
    # Target indices are 0 ... n_way-1
    target_inds = torch.arange(0, n_way).view(n_way, 1, 1).expand(n_way, n_query, 1).long()
    target_inds = Variable(target_inds, requires_grad=False)
    target_inds = target_inds.cuda()
   
    # Encode images of the support and the query set
    x = torch.cat([x_support.contiguous().view(n_way * n_support, *x_support.size()[2:]),
                   x_query.contiguous().view(n_way * n_query, *x_query.size()[2:])], 0)

    ##
    # f1 = self.encoder1.forward(x)
    # f2 = self.encoder2.forward(x)
    # f3 = self.encoder3.forward(x)
    f4 = self.encoder.forward(x)

    z = f4
    # z = (f1 + f4) / 2
    # z = (f3 + f4) / 2
    # z = (f1 + f2 + f4) / 3
    # z = (f2 + f3 + f4) / 3
    # z = (f1 + f2 + f3 + f4) / 4
    ##

    z_dim = z.size(-1)

    # Z_PROTO
    z_proto = z[:n_way*n_support]

    # Z_QUERY
    z_query = z[n_way*n_support:]

    # Compute Squared Euclidean distance
    dists = euclidean_dist(z_query, z_proto.view(n_way, n_support, z_dim).mean(1))

    # Compute probabilities, loss, y_hat, and accuracy
    log_p_y = F.log_softmax(-dists, dim=1).view(n_way, n_query, -1)
    loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
    _, y_hat = log_p_y.max(2)
    acc_val = torch.eq(y_hat, target_inds.squeeze()).float().mean()
   
    return loss_val, {
        'loss': loss_val.item(),
        'acc': acc_val.item(),
        'y_hat': y_hat
        }

## Compute Squared Euclidean distance

In [16]:
def euclidean_dist(x, y):
  """
  Computes euclidean distance between x and y
  Args:
      x (torch.Tensor): shape (n, d). n usually n_way*n_query
      y (torch.Tensor): shape (m, d). m usually n_way
  Returns:
      torch.Tensor: shape(n, m). For each query, the euclidean distance to each centroid
  """
  n = x.size(0)
  m = y.size(0)
  d = x.size(1)
  assert d == y.size(1)

  x = x.unsqueeze(1).expand(n, m, d)
  y = y.unsqueeze(0).expand(n, m, d)

  return torch.pow(x - y, 2).sum(2)

## Training function

In [17]:
from tqdm import tqdm_notebook
from tqdm import tnrange

def train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size):
  """
  Trains the protonet
  Args:
      model
      optimizer
      train_x (np.array): images of training set
      train_y(np.array): labels of training set
      n_way (int): number of classes in a classification task
      n_support (int): number of labeled examples per class in the support set
      n_query (int): number of labeled examples per class in the query set
      max_epoch (int): max epochs to train on
      epoch_size (int): episodes per epoch
  """
  # divide the learning rate by 2 at each epoch, as suggested in paper
  scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.5, last_epoch=-1)
  epoch = 0 #epochs done so far
  stop = False #status to know when to stop

  while epoch < max_epoch and not stop:
    running_loss = 0.0
    running_acc = 0.0

    for episode in tnrange(epoch_size, desc="Epoch {:d} train".format(epoch+1)):
      sample = extract_sample(n_way, n_support, n_query, train_x, train_y)
      optimizer.zero_grad()
      loss, output = model.set_forward_loss(sample)
      
      running_loss += output['loss']
      running_acc += output['acc']
      loss.backward()
      optimizer.step()
    epoch_loss = running_loss / epoch_size
    epoch_acc = running_acc / epoch_size
    print('Epoch {:d} -- Loss: {:.4f} Acc: {:.4f}'.format(epoch+1,epoch_loss, epoch_acc))

    epoch += 1
    scheduler.step()

## Train and Test

In [18]:
def test(model, test_x, test_y, n_way, n_support, n_query, test_episode):
  running_loss = 0.0
  running_acc = 0.0
  for episode in tnrange(test_episode):
    sample = extract_sample(n_way, n_support, n_query, test_x, test_y)
    loss, output = model.set_forward_loss(sample)
    running_loss += output['loss']
    running_acc += output['acc']
  avg_loss = running_loss / test_episode
  avg_acc = running_acc / test_episode
  print('Test results -- Loss: {:.4f} Acc: {:.4f}'.format(avg_loss, avg_acc))

In [20]:
%%time
train_x = trainx
train_y = trainy
test_x = testx
test_y = testy

max_epoch = 10
epoch_size = 2000
test_episode = 10000

model = load_protonet_conv(
    x_dim=(3,nH,nH),
    hid_dim=(64,64,64),
    z_dim=64
    )
      
optimizer = optim.Adam(model.parameters(), lr = 0.001)

n_way = 3
N_test = 3
n_support = 5
n_query = 5

print('\nTRAINING...')
train(model, optimizer, train_x, train_y, n_way, n_support, n_query, max_epoch, epoch_size)

print('\nTESTING...')
test(model, test_x, test_y, N_test, n_support, n_query, test_episode)