In [None]:
!pip install hdf5storage
!pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install --pre torch -f  https://download.pytorch.org/whl/nightly/cu101/torch-1.7.0.dev20200626%2Bcu101-cp36-cp36m-linux_x86_64.whl
!pip install tensorboardX

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting hdf5storage
  Downloading hdf5storage-0.1.18-py2.py3-none-any.whl (53 kB)
[K     |████████████████████████████████| 53 kB 399 kB/s 
Installing collected packages: hdf5storage
Successfully installed hdf5storage-0.1.18
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://download.pytorch.org/whl/torch_stable.html
Collecting torch==1.8.0+cu111
  Downloading https://download.pytorch.org/whl/cu111/torch-1.8.0%2Bcu111-cp37-cp37m-linux_x86_64.whl (1982.2 MB)
[K     |█████████████▌                  | 834.1 MB 1.2 MB/s eta 0:15:22tcmalloc: large alloc 1147494400 bytes == 0x3935a000 @  0x7f269407d615 0x58ead6 0x4f355e 0x4d222f 0x51041f 0x5b4ee6 0x58ff2e 0x510325 0x5b4ee6 0x58ff2e 0x50d482 0x4d00fb 0x50cb8d 0x4d00fb 0x50cb8d 0x4d00fb 0x50cb8d 0x4bac0a 0x538a76 0x590ae5 0x510280 0x5b4ee6 0x58ff2e 0x50d482 0x5b4ee6 0x

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir("/content/drive/My Drive")

Mounted at /content/drive


In [None]:
## 
import os
import random
from scipy import spatial
import networkx as nx

import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import cv2
import glob
import scipy.io as io
from matplotlib import pyplot as plt
plt.switch_backend('agg')

class SHHA(Dataset):
    def __init__(self, data_root, transform=None, train=False, patch=False, flip=False):
        self.root_path = data_root
        self.train_lists = os.path.join(self.root_path, "soybean_seed_counting_a.txt")
        self.eval_list = os.path.join(self.root_path, "soybean_seed_counting_b.txt")
        # 
        if train:
            self.img_list_file = [name.split(',') for name in open(self.train_lists).read().splitlines()]
        else:
            self.img_list_file = [name.split(',') for name in open(self.eval_list).read().splitlines()]

        self.img_list = self.img_list_file
        
        # 
        self.nSamples = len(self.img_list)
        
        self.transform = transform
        self.train = train
        self.patch = patch
        self.flip = flip

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error'

        img_path = self.img_list[index][0]
        gt_path = self.img_list[index][1]
        # 
        img, point = load_data((img_path, gt_path), self.train)
        #
        if self.transform is not None:
            img = self.transform(img)

        if self.train:
            # data augmentation -> random scale
            scale_range = [0.5, 1.4]
            min_size = min(img.shape[1:])
            scale = random.uniform(*scale_range)
            # scale the image and points
            if scale * min_size > 224:
                img = torch.nn.functional.upsample_bilinear(img.unsqueeze(0), scale_factor=scale).squeeze(0)
                point *= scale
        # random crop augumentaiton
        if self.train and self.patch:
            img, point = random_crop(img, point)
            for i, _ in enumerate(point):
                point[i] = torch.Tensor(point[i])
        # random flipping
        if random.random() > 0.1 and self.train and self.flip: # never flip
            # random flip
            img = torch.Tensor(img[:, :, :, ::-1].copy())
            for i, _ in enumerate(point):
                point[i][:, 0] = 224 - point[i][:, 0]
        # random change brightness
        if random.random() > 0.3 and self.train: # never flip
            #
            img = (torch.Tensor(img).clone())*random.uniform(8,12)/10
            for i, _ in enumerate(point):
                point[i][:, 0] = point[i][:, 0]

        if not self.train:
            point = [point]

        img = torch.Tensor(img)
        #  need to adapt your own image names
        target = [{} for i in range(len(point))]
        for i, _ in enumerate(point):
            target[i]['point'] = torch.Tensor(point[i])
            image_id_1 = int(img_path.split('/')[-1].split('.')[0].split("_")[1][4:8])
            image_id_1 = torch.Tensor([image_id_1]).long()
            #
            image_id_2 = int(img_path.split('/')[-1].split('.')[0].split("_")[3])
            image_id_2 = torch.Tensor([image_id_2]).long()
            target[i]['image_id_1'] = image_id_1
            target[i]['image_id_2'] = image_id_2
            target[i]['labels'] = torch.ones([point[i].shape[0]]).long()

        return img, target


def load_data(img_gt_path, train):
    img_path, gt_path = img_gt_path
    # load the images
    img = cv2.imread(img_path)
    img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # load ground truth points
    points = []
    #
    pts = open(gt_path).read().splitlines()
    for pt_0 in pts:
        pt = eval(pt_0)        
        x = float(pt[0])
        y = float(pt[1])
        points.append([x, y])
    return img, np.array(points)

# random crop augumentation
def random_crop(img, den, num_patch=10):
    half_h = 224
    half_w = 224
    result_img = np.zeros([num_patch, img.shape[0], half_h, half_w])
    result_den = []
    # 
    for i in range(num_patch):
        start_h = random.randint(0, img.size(1) - half_h)
        start_w = random.randint(0, img.size(2) - half_w)
        end_h = start_h + half_h
        end_w = start_w + half_w
        # 
        result_img[i] = img[:, start_h:end_h, start_w:end_w]#*random.uniform(5,15)/10
        # copy the cropped points
        idx = (den[:, 0] >= start_w) & (den[:, 0] <= end_w) & (den[:, 1] >= start_h) & (den[:, 1] <= end_h)
        # 
        record_den = den[idx]
        record_den[:, 0] -= start_w
        record_den[:, 1] -= start_h

        result_den.append(record_den)

    return result_img, result_den

In [None]:
# 
import torchvision.transforms as standard_transforms

# 
class DeNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor

def loading_data(data_root):
    # 
    transform = standard_transforms.Compose([
        standard_transforms.ToTensor(), 
        standard_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225]),
    ])
    # 
    train_set = SHHA(data_root, train=True, transform=transform, patch=True, flip=True)
    # 
    val_set = SHHA(data_root, train=False, transform=transform)

    return train_set, val_set

In [None]:
# 
import math
import os
import sys
from typing import Iterable
import torch
import numpy as np
import time
import torchvision.transforms as standard_transforms
import cv2

class DeNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor
#
def vis(samples, targets, pred, vis_dir, epoch, predict_cnt, gt_cnt):
    '''
    samples -> tensor: [batch, 3, H, W]
    targets -> list of dict: [{'points':[], 'image_id': str}]
    pred -> list: [num_preds, 2]
    '''
    gts = [t['point'].tolist() for t in targets]

    pil_to_tensor = standard_transforms.ToTensor()

    restore_transform = standard_transforms.Compose([
        DeNormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        standard_transforms.ToPILImage()
    ])
    # 
    for idx in range(samples.shape[0]):
        sample = restore_transform(samples[idx])
        sample = pil_to_tensor(sample.convert('RGB')).numpy() * 255
        sample_gt = sample.transpose([1, 2, 0])[:, :, :].astype(np.uint8).copy()
        sample_pred = sample.transpose([1, 2, 0])[:, :, :].astype(np.uint8).copy()

        max_len = np.max(sample_gt.shape)

        size = 5
        # draw gt
        for t in gts[idx]:
            sample_gt = cv2.circle(sample_gt, (int(t[0]), int(t[1])), size, (0, 255, 0), -1)
        # draw predictions
        for p in pred[idx]:
            sample_pred = cv2.circle(sample_pred, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)

        name_1 = targets[idx]['image_id_1']
        name_2 = targets[idx]['image_id_2']
        #################
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.imshow(sample_gt)
        ax1.get_xaxis().set_visible(False)
        ax1.get_yaxis().set_visible(False)
        ax2 = fig.add_subplot(1, 2, 2)
        ax2.imshow(sample_pred)
        ax2.get_xaxis().set_visible(False)
        ax2.get_yaxis().set_visible(False)
        fig.suptitle('manual count=%4.2f, inferred count=%4.2f'%(gt_cnt, predict_cnt), fontsize=10)
        plt.tight_layout(rect=[0, 0, 0.95, 0.95]) # maize tassels counting
        plt.savefig(os.path.join(vis_dir, '{}_{}_id_{}_ind_{}.jpg'.format(epoch, idx, int(name_1), int(name_2))), bbox_inches='tight', dpi = 300)
        plt.close()
# the training
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, max_norm: float = 0):
    model.train()
    criterion.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
    # iterate all training samples
    for samples, targets in data_loader:
        #
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        # forward
        outputs = model(samples)
        #
        # calc the losses
        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        # reduce all losses (get the mean values)
        loss_dict_reduced = reduce_dict(loss_dict)
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

        loss_value = losses_reduced_scaled.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)
        # backward
        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        optimizer.step()
        # update logger
        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

# evaluate the model performance during training
@torch.no_grad()
def evaluate_crowd_no_overlap(model, data_loader, device, epoch, threshold, vis_dir=None):
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
    # run inference on all images to calc MAE
    maes = []
    mses = []
    for samples, targets in data_loader:

        samples = samples.to(device)

        outputs = model(samples)
        outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

        outputs_points = outputs['pred_points'][0]

        gt_cnt = targets[0]['point'].shape[0]
        # 0.5 is used by default
        threshold = threshold

        points = outputs_points[outputs_scores > threshold].detach().cpu().numpy()#.tolist()
        # choose to merge closely located points
        if points.shape[0]<10000 and points.shape[0] != 0:
            # choose the cut off point
            cutoff = 500/points.shape[0]
            if cutoff<20:
                cutoff = 20
            components = nx.connected_components(
                nx.from_edgelist(
                    (i, j) for i, js in enumerate(
                        spatial.KDTree(points).query_ball_point(points, cutoff)
                    )
                    for j in js
                )
            )

            clusters = {j: i for i, js in enumerate(components) for j in js}

            # reorganize the points to the order of clusters 
            points_reo = np.zeros(points.shape)
            i = 0
            for key in clusters.keys():
                points_reo[i,:] = points[key,:]
                i+=1
            # points_n has the same order as clusters
            res = [clusters[key] for key in clusters.keys()]
            res_n = np.array(res).reshape(-1,1)

            points_n = []
            for i in np.unique(res_n):
                tmp = points_reo[np.where(res_n[:,0] == i)]
                points_n.append( [np.mean(tmp[:,0]), np.mean(tmp[:,1])])
        else:
            points_n = points.tolist()

        predict_cnt = len(points_n)
        #save the visualized images
        if vis_dir is not None: 
            vis(samples, targets, [points_n], vis_dir, epoch, predict_cnt, gt_cnt)
        # accumulate MAE, MSE
        mae = abs(predict_cnt - gt_cnt)
        mse = (predict_cnt - gt_cnt) * (predict_cnt - gt_cnt)
        maes.append(float(mae))
        mses.append(float(mse))
    # calc MAE, MSE
    mae = np.mean(maes)
    mse = np.sqrt(np.mean(mses))

    return mae, mse

In [None]:
import argparse
import datetime
import random
import time
from pathlib import Path
from IPython.display import clear_output 

import torch
from torch.utils.data import DataLoader, DistributedSampler

import os
from tensorboardX import SummaryWriter
import warnings
warnings.filterwarnings('ignore')


In [None]:
# 
import os
import subprocess
import time
from collections import defaultdict, deque
import datetime
import pickle
from typing import Optional, List

import torch
import torch.distributed as dist
from torch import Tensor

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

# needed due to empty tensor bug in pytorch and torchvision 0.5
import torchvision
if float(torchvision.__version__[:3]) < 0.7:
    from torchvision.ops import _new_empty_tensor
    from torchvision.ops.misc import _output_size


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """
    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{median:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)
#
def all_gather(data):
    """
    Run all_gather on arbitrary picklable data (not necessarily tensors)
    Args:
        data: any picklable object
    Returns:
        list[data]: list of data gathered from each rank
    """
    world_size = get_world_size()
    if world_size == 1:
        return [data]

    # serialized to a Tensor
    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    # obtain Tensor size of each rank
    local_size = torch.tensor([tensor.numel()], device="cuda")
    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    # receiving Tensor from all ranks
    # we pad the tensor because torch all_gather does not support
    # gathering tensors of different shapes
    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
    if local_size != max_size:
        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
        tensor = torch.cat((tensor, padding), dim=0)
    dist.all_gather(tensor_list, tensor)

    data_list = []
    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list
#
def reduce_dict(input_dict, average=True):
    """
    Args:
        input_dict (dict): all the values will be reduced
        average (bool): whether to do average or sum
    Reduce the values in the dictionary from all processes so that all processes
    have the averaged results. Returns a dict with the same fields as
    input_dict, after reduction.
    """
    world_size = get_world_size()
    if world_size < 2:
        return input_dict
    with torch.no_grad():
        names = []
        values = []
        # sort the keys so that they are consistent across processes
        for k in sorted(input_dict.keys()):
            names.append(k)
            values.append(input_dict[k])
        values = torch.stack(values, dim=0)
        dist.all_reduce(values)
        if average:
            values /= world_size
        reduced_dict = {k: v for k, v in zip(names, values)}
    return reduced_dict
#
class MetricLogger(object):
    def __init__(self, delimiter="\t"):
        self.meters = defaultdict(SmoothedValue)
        self.delimiter = delimiter

    def update(self, **kwargs):
        for k, v in kwargs.items():
            if isinstance(v, torch.Tensor):
                v = v.item()
            assert isinstance(v, (float, int))
            self.meters[k].update(v)

    def __getattr__(self, attr):
        if attr in self.meters:
            return self.meters[attr]
        if attr in self.__dict__:
            return self.__dict__[attr]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, attr))

    def __str__(self):
        loss_str = []
        for name, meter in self.meters.items():
            loss_str.append(
                "{}: {}".format(name, str(meter))
            )
        return self.delimiter.join(loss_str)

    def synchronize_between_processes(self):
        for meter in self.meters.values():
            meter.synchronize_between_processes()

    def add_meter(self, name, meter):
        self.meters[name] = meter

    def log_every(self, iterable, print_freq, header=None):
        i = 0
        if not header:
            header = ''
        start_time = time.time()
        end = time.time()
        iter_time = SmoothedValue(fmt='{avg:.4f}')
        data_time = SmoothedValue(fmt='{avg:.4f}')
        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
        if torch.cuda.is_available():
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}',
                'max mem: {memory:.0f}'
            ])
        else:
            log_msg = self.delimiter.join([
                header,
                '[{0' + space_fmt + '}/{1}]',
                'eta: {eta}',
                '{meters}',
                'time: {time}',
                'data: {data}'
            ])
        MB = 1024.0 * 1024.0
        for obj in iterable:
            data_time.update(time.time() - end)
            yield obj
            iter_time.update(time.time() - end)
            if i % print_freq == 0 or i == len(iterable) - 1:
                eta_seconds = iter_time.global_avg * (len(iterable) - i)
                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
                if torch.cuda.is_available():
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time),
                        memory=torch.cuda.max_memory_allocated() / MB))
                else:
                    print(log_msg.format(
                        i, len(iterable), eta=eta_string,
                        meters=str(self),
                        time=str(iter_time), data=str(data_time)))
            i += 1
            end = time.time()
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print('{} Total time: {} ({:.4f} s / it)'.format(
            header, total_time_str, total_time / len(iterable)))
#
def get_sha():
    cwd = os.path.dirname(os.path.abspath(__file__))

    def _run(command):
        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
    sha = 'N/A'
    diff = "clean"
    branch = 'N/A'
    try:
        sha = _run(['git', 'rev-parse', 'HEAD'])
        subprocess.check_output(['git', 'diff'], cwd=cwd)
        diff = _run(['git', 'diff-index', 'HEAD'])
        diff = "has uncommited changes" if diff else "clean"
        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
    except Exception:
        pass
    message = f"sha: {sha}, status: {diff}, branch: {branch}"
    return message


def collate_fn(batch):
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)

def collate_fn_crowd(batch):
    # re-organize the batch
    batch_new = []
    for b in batch:
        imgs, points = b
        if imgs.ndim == 3:
            imgs = imgs.unsqueeze(0)
        for i in range(len(imgs)):
            batch_new.append((imgs[i, :, :, :], points[i]))
    batch = batch_new
    batch = list(zip(*batch))
    batch[0] = nested_tensor_from_tensor_list(batch[0])
    return tuple(batch)


def _max_by_axis(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)
    return maxes

def _max_by_axis_pad(the_list):
    # type: (List[List[int]]) -> List[int]
    maxes = the_list[0]
    for sublist in the_list[1:]:
        for index, item in enumerate(sublist):
            maxes[index] = max(maxes[index], item)

    block = 128

    for i in range(2):
        maxes[i+1] = ((maxes[i+1] - 1) // block + 1) * block
    return maxes
#
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
    # TODO make this more general
    if tensor_list[0].ndim == 3:

        # TODO make it support different-sized images
        max_size = _max_by_axis_pad([list(img.shape) for img in tensor_list])
        # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
        batch_shape = [len(tensor_list)] + max_size
        b, c, h, w = batch_shape
        dtype = tensor_list[0].dtype
        device = tensor_list[0].device
        tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
        for img, pad_img in zip(tensor_list, tensor):
            pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
    else:
        raise ValueError('not supported')
    return tensor
#
class NestedTensor(object):
    def __init__(self, tensors, mask: Optional[Tensor]):
        self.tensors = tensors
        self.mask = mask

    def to(self, device):
        # type: (Device) -> NestedTensor # noqa
        cast_tensor = self.tensors.to(device)
        mask = self.mask
        if mask is not None:
            assert mask is not None
            cast_mask = mask.to(device)
        else:
            cast_mask = None
        return NestedTensor(cast_tensor, cast_mask)

    def decompose(self):
        return self.tensors, self.mask

    def __repr__(self):
        return str(self.tensors)
#
def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print
#
def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True
#
def get_world_size():
    if not is_dist_avail_and_initialized():
        return 1
    return dist.get_world_size()
#
def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()
#
def is_main_process():
    return get_rank() == 0
#
def save_on_master(*args, **kwargs):
    if is_main_process():
        torch.save(*args, **kwargs)
#
def init_distributed_mode(args):
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        args.rank = int(os.environ["RANK"])
        args.world_size = int(os.environ['WORLD_SIZE'])
        args.gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        args.rank = int(os.environ['SLURM_PROCID'])
        args.gpu = args.rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        args.distributed = False
        return

    args.distributed = True

    torch.cuda.set_device(args.gpu)
    args.dist_backend = 'nccl'
    print('| distributed init (rank {}): {}'.format(
        args.rank, args.dist_url), flush=True)
    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
    torch.distributed.barrier()
    setup_for_distributed(args.rank == 0)
#
@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res
#
def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
    # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
    """
    Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
    This will eventually be supported natively by PyTorch, and this
    class can go away.
    """
    if float(torchvision.__version__[:3]) < 0.7:
        if input.numel() > 0:
            return torch.nn.functional.interpolate(
                input, size, scale_factor, mode, align_corners
            )

        output_shape = _output_size(2, input, size, scale_factor)
        output_shape = list(input.shape[:-2]) + list(output_shape)
        return _new_empty_tensor(input, output_shape)
    else:
        return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
#
class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classiﬁed examples (p > .5),
                                   putting more focus on hard, misclassiﬁed examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)

        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p

        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

In [None]:
## 
import torch
from scipy.optimize import linear_sum_assignment
from torch import nn

class HungarianMatcher_Crowd(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(self, cost_class: float = 1, cost_point: float = 1):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the foreground object
            cost_point: This is the relative weight of the L1 error of the points coordinates in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_point = cost_point
        assert cost_class != 0 or cost_point != 0, "all costs cant be 0"

    @torch.no_grad()
    def forward(self, outputs, targets):
        """ Performs the matching

        Params:
            outputs: This is a dict that contains at least these entries:
                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
                 "points": Tensor of dim [batch_size, num_queries, 2] with the predicted point coordinates

            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
                 "labels": Tensor of dim [num_target_points] (where num_target_points is the number of ground-truth
                           objects in the target) containing the class labels
                 "points": Tensor of dim [num_target_points, 2] containing the target point coordinates

        Returns:
            A list of size batch_size, containing tuples of (index_i, index_j) where:
                - index_i is the indices of the selected predictions (in order)
                - index_j is the indices of the corresponding selected targets (in order)
            For each batch element, it holds:
                len(index_i) = len(index_j) = min(num_queries, num_target_points)
        """
        bs, num_queries = outputs["pred_logits"].shape[:2]

        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)  # [batch_size * num_queries, num_classes]
        out_points = outputs["pred_points"].flatten(0, 1)  # [batch_size * num_queries, 2]

        # Also concat the target labels and points
        # tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_ids = torch.cat([v["labels"] for v in targets])
        tgt_points = torch.cat([v["point"] for v in targets])

        # Compute the classification cost. Contrary to the loss, we don't use the NLL,
        # but approximate it in 1 - proba[target class].
        # The 1 is a constant that doesn't change the matching, it can be ommitted.
        cost_class = -out_prob[:, tgt_ids]

        # Compute the L2 cost between point
        cost_point = torch.cdist(out_points, tgt_points, p=2)

        # Compute the giou cost between point

        # Final cost matrix
        C = self.cost_point * cost_point + self.cost_class * cost_class
        C = C.view(bs, num_queries, -1).cpu()

        sizes = [len(v["point"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
#
def build_matcher_crowd(args):
    return HungarianMatcher_Crowd(cost_class=args.set_cost_class, cost_point=args.set_cost_point)

### The feature extraction part is mainly adapted from https://github.com/zhaoyuzhi/PyTorch-Pyramid-Feature-Attention-Network-for-Saliency-Detection

In [None]:
## 
import torch
import torch.nn as nn
import torch.nn.functional as F
#
class SpatialAttention(nn.Module):
    def __init__(self, in_channels, kernel_size=9):
        super(SpatialAttention, self).__init__()

        self.kernel_size = kernel_size
        self.in_channels = in_channels
        pad = (self.kernel_size-1)//2  # Padding on one side for stride 1

        self.grp1_conv1k = nn.Conv2d(self.in_channels, self.in_channels//2, (1, self.kernel_size), padding=(0, pad))
        self.grp1_bn1 = nn.BatchNorm2d(self.in_channels//2)
        self.grp1_convk1 = nn.Conv2d(self.in_channels//2, 1, (self.kernel_size, 1), padding=(pad, 0))
        self.grp1_bn2 = nn.BatchNorm2d(1)

        self.grp2_convk1 = nn.Conv2d(self.in_channels, self.in_channels//2, (self.kernel_size, 1), padding=(pad, 0))
        self.grp2_bn1 = nn.BatchNorm2d(self.in_channels//2)
        self.grp2_conv1k = nn.Conv2d(self.in_channels//2, 1, (1, self.kernel_size), padding=(0, pad))
        self.grp2_bn2 = nn.BatchNorm2d(1)

    def forward(self, input_):
        # Generate Group 1 Features
        grp1_feats = self.grp1_conv1k(input_)
        grp1_feats = F.relu(self.grp1_bn1(grp1_feats))
        grp1_feats = self.grp1_convk1(grp1_feats)
        grp1_feats = F.relu(self.grp1_bn2(grp1_feats))

        # Generate Group 2 features
        grp2_feats = self.grp2_convk1(input_)
        grp2_feats = F.relu(self.grp2_bn1(grp2_feats))
        grp2_feats = self.grp2_conv1k(grp2_feats)
        grp2_feats = F.relu(self.grp2_bn2(grp2_feats))

        added_feats = torch.sigmoid(torch.add(grp1_feats, grp2_feats))
        added_feats = added_feats.expand_as(input_).clone()

        return added_feats
#
class ChannelwiseAttention(nn.Module):
    def __init__(self, in_channels):
        super(ChannelwiseAttention, self).__init__()

        self.in_channels = in_channels

        self.linear_1 = nn.Linear(self.in_channels, self.in_channels//4)
        self.linear_2 = nn.Linear(self.in_channels//4, self.in_channels)

    def forward(self, input_):
        n_b, n_c, h, w = input_.size()

        feats = F.adaptive_avg_pool2d(input_, (1, 1)).view((n_b, n_c))
        feats = F.relu(self.linear_1(feats))
        feats = torch.sigmoid(self.linear_2(feats))
        
        # Activity regularizer
        ca_act_reg = torch.mean(feats)

        feats = feats.view((n_b, n_c, 1, 1))
        feats = feats.expand_as(input_).clone()

        return feats, ca_act_reg

In [None]:
#### 
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

vgg_conv1_2 = vgg_conv2_2 = vgg_conv3_3 = vgg_conv4_3 = vgg_conv5_3 = None

def conv_1_2_hook(module, input, output):
    global vgg_conv1_2
    vgg_conv1_2 = output
    return None

def conv_2_2_hook(module, input, output):
    global vgg_conv2_2
    vgg_conv2_2 = output
    return None

def conv_3_3_hook(module, input, output):
    global vgg_conv3_3
    vgg_conv3_3 = output
    return None

def conv_4_3_hook(module, input, output):
    global vgg_conv4_3
    vgg_conv4_3 = output
    return None

def conv_5_3_hook(module, input, output):
    global vgg_conv5_3
    vgg_conv5_3 = output
    return None

##
class CPFE_hl(nn.Module):
    def __init__(self, feature_layer=None, out_channels=8):
        super(CPFE_hl, self).__init__()

        self.dil_rates = [3, 5, 7]

        # Determine number of in_channels from VGG-16 feature layer
        if feature_layer == 'conv5_3':
            self.in_channels = 512
        elif feature_layer == 'conv4_3':
            self.in_channels = 512
        elif feature_layer == 'conv3_3':
            self.in_channels = 256
        elif feature_layer == 'conv2_3':
            self.in_channels = 128
        elif feature_layer == 'conv1_3':
            self.in_channels = 64

        # Define layers
        self.conv_1_1 = nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, bias=False)
        self.conv_dil_3 = nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels, kernel_size=3,
                                    stride=1, dilation=self.dil_rates[0], padding=self.dil_rates[0], bias=False)
        self.conv_dil_5 = nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels, kernel_size=3,
                                    stride=1, dilation=self.dil_rates[1], padding=self.dil_rates[1], bias=False)
        self.conv_dil_7 = nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels, kernel_size=3,
                                    stride=1, dilation=self.dil_rates[2], padding=self.dil_rates[2], bias=False)

        self.bn = nn.BatchNorm2d(out_channels*4)

    def forward(self, input_):
        # Extract features
        conv_1_1_feats = self.conv_1_1(input_)
        conv_dil_3_feats = self.conv_dil_3(input_)
        conv_dil_5_feats = self.conv_dil_5(input_)
        conv_dil_7_feats = self.conv_dil_7(input_)

        # Aggregate features
        concat_feats = torch.cat((conv_1_1_feats, conv_dil_3_feats, conv_dil_5_feats, conv_dil_7_feats), dim=1)
        bn_feats = F.relu(self.bn(concat_feats))

        return bn_feats
##
class SODModel(nn.Module):
    def __init__(self):
        super(SODModel, self).__init__()

        # 
        self.vgg16 = models.vgg16(pretrained=True).features

        # Extract and register intermediate features of VGG-16_bn
        self.vgg16[3].register_forward_hook(conv_1_2_hook)
        self.vgg16[8].register_forward_hook(conv_2_2_hook)
        self.vgg16[15].register_forward_hook(conv_3_3_hook)
        self.vgg16[22].register_forward_hook(conv_4_3_hook)
        self.vgg16[29].register_forward_hook(conv_5_3_hook)

        # Initialize layers for high level (hl) feature (conv3_3, conv4_3, conv5_3) processing
        self.cpfe_conv3_3 = CPFE_hl(feature_layer='conv3_3')
        self.cpfe_conv4_3 = CPFE_hl(feature_layer='conv4_3')
        self.cpfe_conv5_3 = CPFE_hl(feature_layer='conv5_3')
        #
        self.cpfe_conv1_3 = CPFE_hl(feature_layer='conv1_3')
        self.cpfe_conv2_3 = CPFE_hl(feature_layer='conv2_3')
        # 11,03,2022, remove channel attention
        self.cha_att = ChannelwiseAttention(in_channels=96)  # in_channels = 3 x (8 x 4)

        self.hl_conv1 = nn.Conv2d(96, 8, (3, 3), padding=1)
        self.hl_bn1 = nn.BatchNorm2d(8)

        # 
        self.ll_conv_1 = nn.Conv2d(64, 8, (3, 3), padding=1)
        self.ll_bn_1 = nn.BatchNorm2d(8)
        self.ll_conv_2 = nn.Conv2d(128, 8, (3, 3), padding=1)
        self.ll_bn_2 = nn.BatchNorm2d(8)
        self.ll_conv_3 = nn.Conv2d(64, 8, (3, 3), padding=1) 
        self.ll_bn_3 = nn.BatchNorm2d(8)

        self.spa_att = SpatialAttention(in_channels=8)

        # 
        self.ff_conv_1 = nn.Conv2d(16, 3, (3, 3), padding=1)
        self.ff_bn_1 = nn.BatchNorm2d(3)
    def forward(self, input_):
        global vgg_conv1_2, vgg_conv2_2, vgg_conv3_3, vgg_conv4_3, vgg_conv5_3

        # Pass input_ through vgg16 to generate intermediate features
        self.vgg16(input_)
        # Process high level features
        conv3_cpfe_feats = self.cpfe_conv3_3(vgg_conv3_3)
        conv4_cpfe_feats = self.cpfe_conv4_3(vgg_conv4_3)
        conv5_cpfe_feats = self.cpfe_conv5_3(vgg_conv5_3)

        conv4_cpfe_feats = F.interpolate(conv4_cpfe_feats, scale_factor=2, mode='bilinear', align_corners=True) # reduce spatial dimension by 2
        conv5_cpfe_feats = F.interpolate(conv5_cpfe_feats, scale_factor=4, mode='bilinear', align_corners=True)

        conv_345_feats = torch.cat((conv3_cpfe_feats, conv4_cpfe_feats, conv5_cpfe_feats), dim=1)
        
        # channel attention on high level features
        conv_345_ca, ca_act_reg = self.cha_att(conv_345_feats)
        conv_345_feats = torch.mul(conv_345_feats, conv_345_ca)

        conv_345_feats = self.hl_conv1(conv_345_feats)
        conv_345_feats = F.relu(self.hl_bn1(conv_345_feats))
        ##
        # Process low level features
        conv0_feats = input_ # the original input image
        conv1_cpfe_feats = self.cpfe_conv1_3(vgg_conv1_2)
        conv2_cpfe_feats = self.cpfe_conv2_3(vgg_conv2_2)

        conv0_feats = F.interpolate(conv0_feats, scale_factor=0.25, mode='bilinear', align_corners=True)
        conv1_cpfe_feats = F.interpolate(conv1_cpfe_feats, scale_factor=0.25, mode='bilinear', align_corners=True)
        conv2_cpfe_feats = F.interpolate(conv2_cpfe_feats, scale_factor=0.5, mode='bilinear', align_corners=True)

        #
        conv_12_feats = torch.cat((conv1_cpfe_feats, conv2_cpfe_feats), dim=1)
        conv_12_feats = self.ll_conv_3(conv_12_feats)
        conv_12_feats = F.relu(self.ll_bn_3(conv_12_feats))
        # spatial attention on low level features
        conv_12_sa = self.spa_att(conv_12_feats)
        conv_12_feats = torch.mul(conv_12_feats, conv_12_sa)

        # fuse the low and high level features
        fused_feats = torch.cat((conv_12_feats, conv_345_feats), dim=1)
        #
        fused_final = self.ff_conv_1(fused_feats)
        fused_final = F.relu(self.ff_bn_1(fused_final))
        # add the fused low and high level features to the original image
        fused_final_out = torch.add(fused_final, conv0_feats)
        fused_final_out = torch.sigmoid(fused_final_out)

        return fused_final_out


In [None]:
# build model p2pNet.py
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import time

# the network frmawork of the regression branch
class RegressionModel(nn.Module):
    def __init__(self, num_features_in, num_anchor_points=4, feature_size=32):
        super(RegressionModel, self).__init__()

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchor_points * 2, kernel_size=3, padding=1) # one point has two coordinates 
    # sub-branch forward
    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.output(out)

        out = out.permute(0, 2, 3, 1)

        return out.contiguous().view(out.shape[0], -1, 2)

# the network frmawork of the classification branch
class ClassificationModel(nn.Module):
    def __init__(self, num_features_in, num_anchor_points=4, num_classes=80, prior=0.01, feature_size=32):
        super(ClassificationModel, self).__init__()

        self.num_classes = num_classes
        self.num_anchor_points = num_anchor_points

        self.conv1 = nn.Conv2d(num_features_in, feature_size, kernel_size=3, padding=1)
        self.act1 = nn.ReLU()

        self.conv2 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act2 = nn.ReLU()

        self.conv3 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act3 = nn.ReLU()

        self.conv4 = nn.Conv2d(feature_size, feature_size, kernel_size=3, padding=1)
        self.act4 = nn.ReLU()

        self.output = nn.Conv2d(feature_size, num_anchor_points * num_classes, kernel_size=3, padding=1) # one classes, only positives 
        self.output_act = nn.Sigmoid()
    # sub-branch forward
    def forward(self, x):
        out = self.conv1(x)
        out = self.act1(out)

        out = self.conv2(out)
        out = self.act2(out)

        out = self.output(out)

        out1 = out.permute(0, 2, 3, 1)

        batch_size, width, height, _ = out1.shape

        out2 = out1.view(batch_size, width, height, self.num_anchor_points, self.num_classes)

        return out2.contiguous().view(x.shape[0], -1, self.num_classes)

# generate the reference points in grid layout
def generate_anchor_points(stride=8, row=3, line=3):
    row_step = stride / row
    line_step = stride / line

    shift_x = (np.arange(1, line + 1) - 0.5) * line_step - stride / 2
    shift_y = (np.arange(1, row + 1) - 0.5) * row_step - stride / 2

    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    anchor_points = np.vstack((
        shift_x.ravel(), shift_y.ravel()
    )).transpose()

    return anchor_points
# shift the meta-anchor to get an acnhor points
def shift(shape, stride, anchor_points):
    shift_x = (np.arange(0, shape[1]) + 0.5)* stride
    shift_y = (np.arange(0, shape[0]) + 0.5)* stride

    shift_x, shift_y = np.meshgrid(shift_x, shift_y)

    shifts = np.vstack((
        shift_x.ravel(), shift_y.ravel()
    )).transpose()

    A = anchor_points.shape[0]
    K = shifts.shape[0]
    all_anchor_points = (anchor_points.reshape((1, A, 2)) + shifts.reshape((1, K, 2)).transpose((1, 0, 2)))
    all_anchor_points = all_anchor_points.reshape((K * A, 2))

    return all_anchor_points

# 
class AnchorPoints(nn.Module):
    def __init__(self, pyramid_levels=None, strides=None, row=3, line=3):
        super(AnchorPoints, self).__init__()

        if pyramid_levels is None:
            self.pyramid_levels = [3, 4, 5, 6, 7]
        else:
            self.pyramid_levels = pyramid_levels

        if strides is None:
            self.strides = [2 ** x for x in self.pyramid_levels]

        self.row = row
        self.line = line

    def forward(self, image):
        image_shape = image.shape[2:]
        image_shape = np.array(image_shape)
        image_shapes = [(image_shape + 2 ** x - 1) // (2 ** x) for x in self.pyramid_levels] # calcualtes the output size of the model (image of 128*128 to feature map of 16*16)

        all_anchor_points = np.zeros((0, 2)).astype(np.float32)
        # get reference points for each level
        for idx, p in enumerate(self.pyramid_levels):
            anchor_points = generate_anchor_points(2**p, row=self.row, line=self.line)
            shifted_anchor_points = shift(image_shapes[idx], self.strides[idx], anchor_points)
            all_anchor_points = np.append(all_anchor_points, shifted_anchor_points, axis=0)

        all_anchor_points = np.expand_dims(all_anchor_points, axis=0)
        # send reference points to device
        if torch.cuda.is_available():
            return torch.from_numpy(all_anchor_points.astype(np.float32)).cuda()
        else:
            return torch.from_numpy(all_anchor_points.astype(np.float32))
##
# the defenition of the P2PNet model
class P2PNet(nn.Module):
    def __init__(self, row=2, line=2):
        super().__init__()
        self.num_classes = 2
        # the number of all anchor points
        num_anchor_points = row * line

        self.regression = RegressionModel(num_features_in=3, num_anchor_points=num_anchor_points)
        self.classification = ClassificationModel(num_features_in=3, \
                                            num_classes=self.num_classes, \
                                            num_anchor_points=num_anchor_points)

        self.anchor_points = AnchorPoints(pyramid_levels=[2,], row=row, line=line) # remember to change pyramid level when you change feature input

        self.fpn = SODModel()

    def forward(self, samples: NestedTensor): #
        # 
        features_fpn = self.fpn(samples) # output = bach_size, channel, Height, Weight
        #
        batch_size = features_fpn.size()[0]
        # run the regression and classification branch
        regression = self.regression(features_fpn) * 100 # 8x
        classification = self.classification(features_fpn)
        #
        anchor_points = self.anchor_points(samples).repeat(batch_size, 1, 1)
        output_coord = regression + anchor_points
        output_class = classification
        out = {'pred_logits': output_class, 'pred_points': output_coord}
        #
        return out

class SetCriterion_Crowd(nn.Module):

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        empty_weight = torch.ones(self.num_classes + 1)
        empty_weight[0] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_points):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
        """
        assert 'pred_logits' in outputs
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], 0,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        return losses

    def loss_points(self, outputs, targets, indices, num_points):

        assert 'pred_points' in outputs
        idx = self._get_src_permutation_idx(indices)
        src_points = outputs['pred_points'][idx]
        target_points = torch.cat([t['point'][i] for t, (_, i) in zip(targets, indices)], dim=0)
        #print("target_points {}".format(target_points))
        loss_bbox = F.mse_loss(src_points, target_points, reduction='none')

        losses = {}
        losses['loss_points'] = loss_bbox.sum() / num_points

        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_points, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'points': self.loss_points,
        }
        #print("loss_map {}".format(loss_map))
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_points, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        output1 = {'pred_logits': outputs['pred_logits'], 'pred_points': outputs['pred_points']}

        indices1 = self.matcher(output1, targets)

        num_points = sum(len(t["labels"]) for t in targets)

        num_points = torch.as_tensor([num_points], dtype=torch.float, device=next(iter(output1.values())).device)

        num_boxes = torch.clamp(num_points / get_world_size(), min=1).item()

        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(loss, output1, targets, indices1, num_boxes)) #

        return losses

# create the P2PNet model
def build(args, training):
    # treats persons as a single class
    num_classes = 1

    model = P2PNet(args.row, args.line)
    if not training: 
        return model

    weight_dict = {'loss_ce': 1, 'loss_points': args.point_loss_coef}
    losses = ['labels', 'points']     #['labels', 'points']
    matcher = build_matcher_crowd(args)
    criterion = SetCriterion_Crowd(num_classes, \
                                matcher=matcher, weight_dict=weight_dict, \
                                eos_coef=args.eos_coef, losses=losses)
    return model, criterion
##


### List all the parameters here

In [None]:

def get_arguments():
    """Parse all the arguments provided from the CLI.
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="Object Counting Framework")
    # constant
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--lr_fpn', default=1e-5, type=float)
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--weight_decay', default=1e-4, type=float)
    parser.add_argument('--epochs', default=3500, type=int)
    parser.add_argument('--lr_drop', default=100, type=int)
    parser.add_argument('--clip_max_norm', default=0.1, type=float,
                        help='gradient clipping max norm')

    # Model parameters
    parser.add_argument('--frozen_weights', type=str, default=None,
                        help="Path to the pretrained model. If set, only the mask head will be trained")
    #
    parser.add_argument('--set_cost_class', default=1, type=float,
                        help="Class coefficient in the matching cost")

    parser.add_argument('--set_cost_point', default=0.99, type=float,
                        help="L1 point coefficient in the matching cost")

    # * Loss coefficients
    parser.add_argument('--point_loss_coef', default=0.02, type=float) # default = 0.0002 # 0.5
    parser.add_argument('--eos_coef', default=0.02, type=float, # 0.05
                        help="Relative classification weight of the no-object class") # default = 0.5
    
    # a threshold during evaluation for counting and visualization
    parser.add_argument('--threshold', default=0.5, type=float,
                        help="threshold in evalluation: evaluate_crowd_no_overlap")
    parser.add_argument('--row', default=2, type=int,
                        help="row number of anchor points")
    parser.add_argument('--line', default=2, type=int,
                        help="line number of anchor points")

    # dataset parameters
    parser.add_argument('--dataset_file', default='SHHA')
    parser.add_argument('--data_root', default='/content/drive/My Drive/P2PNet-Soy/Soybean_seed_counting/',
                        help='path where the dataset is')
    
    parser.add_argument('--output_dir', default='/content/drive/My Drive/P2PNet-Soy/log_P2PNet_Soy',
                        help='path where to save, empty for no saving')
    parser.add_argument('--checkpoints_dir', default='/content/drive/My Drive/P2PNet-Soy/ckpt_P2PNet_Soy_01',
                        help='path where to save checkpoints, empty for no saving') 
    parser.add_argument('--tensorboard_dir', default='/content/drive/My Drive/P2PNet-Soy/runs_P2PNet_Soy',
                        help='path where to save, empty for no saving')

    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
                        help='start epoch')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--eval_freq', default=3, type=int,
                        help='frequency of evaluation, default setting is evaluating in every 5 epoch')
    parser.add_argument('--gpu_id', default=0, type=int, help='the gpu used for training')
    #
    opt = parser.parse_known_args()[0]
    return opt

In [None]:
args = get_arguments()
#
#put the model path here if you have trained any or comment it out
args.resume = "/content/drive/My Drive/P2PNet-Soy/ckpt_P2PNet_Soy/best_mae.pth" 
# the directory to save the evaluations during training
args.vis_dir = "/content/drive/My Drive/P2PNet-Soy/vis_P2PNet_Soy"
if not os.path.exists(args.vis_dir):
    os.makedirs(args.vis_dir)
##
os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)
# create the logging file
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
run_log_name = os.path.join(args.output_dir, 'run_log.txt')
with open(run_log_name, "w") as log_file:
    log_file.write('Eval Log %s\n' % time.strftime("%c"))
#
with open(run_log_name, "a") as log_file:
    log_file.write("{}".format(args))
device = torch.device('cuda')
# fix the seed for reproducibility
seed = args.seed + get_rank()
seed = args.seed
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# get the P2PNet model
model, criterion = build(args, training=True)
# move to GPU
model.to(device)
criterion.to(device)

model_without_ddp = model

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
# use different optimation params for different parts of the model
param_dicts = [
    {"params": [p for n, p in model_without_ddp.named_parameters() if "fpn" not in n and p.requires_grad]},
    {
        "params": [p for n, p in model_without_ddp.named_parameters() if "fpn" in n and p.requires_grad],
        "lr": args.lr_fpn,
    },
]
# Adam is used by default
optimizer = torch.optim.Adam(param_dicts, lr=args.lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
# create the training and valiation set
train_set, val_set = loading_data(args.data_root)
# create the sampler used during training
sampler_train = torch.utils.data.RandomSampler(train_set)
sampler_val = torch.utils.data.SequentialSampler(val_set)

batch_sampler_train = torch.utils.data.BatchSampler(
    sampler_train, args.batch_size, drop_last=True)
# the dataloader for training
data_loader_train = DataLoader(train_set, batch_sampler=batch_sampler_train,
                                collate_fn=collate_fn_crowd, num_workers=args.num_workers)

data_loader_val = DataLoader(val_set, 1, sampler=sampler_val,
                                drop_last=False, collate_fn=collate_fn_crowd, num_workers=args.num_workers)

if args.frozen_weights is not None:
    checkpoint = torch.load(args.frozen_weights, map_location='cpu')
    model_without_ddp.detr.load_state_dict(checkpoint['model'])
# resume the weights and training state if exists
if args.resume:
    checkpoint = torch.load(args.resume, map_location='cpu')
    model_without_ddp.load_state_dict(checkpoint['model'])
    args.start_epoch = checkpoint['epoch']
    new_start = 1
    if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
else:
    new_start = 0
##
start_time = time.time()
# save the performance during the training
mae = []
mse = []
epoch_save = []
# the logger writer
writer = SummaryWriter(args.tensorboard_dir)
# save latest weights every epoch
if not os.path.exists(args.checkpoints_dir):
    os.makedirs(args.checkpoints_dir)
#
step = 0
# training starts here
for epoch in range(args.start_epoch, args.epochs):
    # always run evaluation first (to check what model has been loaded !!!)
    if (epoch +2) % args.eval_freq == 0 or new_start: # and epoch != 0
        # change the status right after the first iteration
        new_start = 0
        #
        t1 = time.time()
        result = evaluate_crowd_no_overlap(model, data_loader_val, device, epoch, args.threshold, args.vis_dir)
        t2 = time.time()
        print("evaluation time {}".format(t2-t1))
        mae.append(result[0])
        mse.append(result[1])
        epoch_save.append(epoch)
        #
        epoch_save_m = np.array(epoch_save)[mae == np.min(mae)][0]
        # print the evaluation results
        print('=======================================test=======================================')
        print("mae:", result[0], "mse:", result[1], "time:", t2 - t1, "best mae:", np.min(mae), "at epoch: {}".format(epoch_save_m) )
        with open(run_log_name, "a") as log_file:
            log_file.write("mae:{}, mse:{}, time:{}, best mae:{}".format(result[0], 
                            result[1], t2 - t1, np.min(mae)))
        print('=======================================test=======================================')
        # recored the evaluation results
        if writer is not None:
            with open(run_log_name, "a") as log_file:
                log_file.write("metric/mae@{}: {}".format(step, result[0]))
                log_file.write("metric/mse@{}: {}".format(step, result[1]))
            writer.add_scalar('metric/mae', result[0], step)
            writer.add_scalar('metric/mse', result[1], step)
            step += 1

        # save the best model since begining
        if abs(np.min(mae) - result[0]) < 0.01:
            checkpoint_best_path = os.path.join(args.checkpoints_dir, 'best_mae.pth')
            torch.save({
                'model': model_without_ddp.state_dict(),
                'epoch': epoch,
            }, checkpoint_best_path)
    ###
    t1 = time.time()
    stat = train_one_epoch(model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)

    # record the training states after every epoch
    if writer is not None:
        with open(run_log_name, "a") as log_file:
            log_file.write("loss/loss@{}: {}".format(epoch, stat['loss']))
            log_file.write("loss/loss_ce@{}: {}".format(epoch, stat['loss_ce']))
            
        writer.add_scalar('loss/loss', stat['loss'], epoch)
        writer.add_scalar('loss/loss_ce', stat['loss_ce'], epoch)

    t2 = time.time()
    print('[ep %d][lr %.7f][%.2fs]' % \
            (epoch, optimizer.param_groups[0]['lr'], t2 - t1))
    with open(run_log_name, "a") as log_file:
        log_file.write('[ep %d][lr %.7f][%.2fs]' % (epoch, optimizer.param_groups[0]['lr'], t2 - t1))
    # change lr according to the scheduler
    lr_scheduler.step()
    #
    # save latest weights every epoch
    checkpoint_latest_path = os.path.join(args.checkpoints_dir, 'latest.pth')
    torch.save({
        'model': model_without_ddp.state_dict(),
        'epoch': epoch,
    }, checkpoint_latest_path)
    ## clear the cell output regulary
    if epoch % 150 == 0 and epoch != 0:
        clear_output()
# total time for training
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

number of params: 15137927


##Inference

In [None]:
# create the P2PNet model
def build_eval(args):
    model = P2PNet(args.row, args.line)
    return model

In [None]:
def get_arguments():
    """Parse all the arguments provided from the CLI.
    Returns:
      A list of parsed arguments.
    """
    parser = argparse.ArgumentParser(description="Object Counting Framework")
    # constant
    # a threshold during evaluation for counting and visualization
    parser.add_argument('--threshold', default=0.5, type=float,
                        help="threshold in evalluation: evaluate_crowd_no_overlap")
    parser.add_argument('--row', default=2, type=int,
                        help="row number of anchor points")
    parser.add_argument('--line', default=2, type=int,
                        help="line number of anchor points")
    parser.add_argument('--data_root', default='/content/drive/My Drive/soypod_crop_counting/',
                        help='path where the dataset is')
    parser.add_argument('--seed', default=42, type=int)
    parser.add_argument('--resume', default='', help='resume from checkpoint')
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--gpu_id', default=0, type=int, help='the gpu used for training')
    #
    opt = parser.parse_known_args()[0] #if known else parser.parse_args()
    return opt

In [None]:
args = get_arguments()
# specify the directories to model weights and prediction output
args.resume = "/content/drive/My Drive/P2PNet-Soy/ckpt_P2PNet_Soy/best_mae.pth"
args.vis_dir = "/content/drive/My Drive/P2PNet-Soy/vis_P2PNet_Soy_out"

In [None]:
####
if not os.path.exists(args.vis_dir):
    os.makedirs(args.vis_dir)
##
os.environ["CUDA_VISIBLE_DEVICES"] = '{}'.format(args.gpu_id)


device = torch.device('cuda')
# fix the seed for reproducibility
seed = args.seed + get_rank()
random.seed(seed)
# original model
model = build_eval(args, training = False)
# move to GPU
model.to(device)

###On images of cropped individual plant

In [None]:

# threshold for evaluation
threshold = 0.5 #args.threshold
# to apply post processing or not: if Filter = 10000, the filter is applied

#
# load trained model
if args.resume is not None:
    checkpoint = torch.load(args.resume, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
#
model.eval()
#
step = 0
epoch = 0
###
# create the pre-processing transform
transform = standard_transforms.Compose([
    standard_transforms.ToTensor(), 
    standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# set your image path here
for img_type in ["b", "a"]: #
    print("current image folder is {}".format(img_type))
    img_list_path = "/content/drive/MyDrive/soypod_crop_counting/soypod_crop_counting_{}.txt".format(img_type)
    img_list = [name.split(',') for name in open(img_list_path).read().splitlines()]
    #
    # save the number of detected pods
    csv_name = (args.vis_dir).split("/")[-1]
    loss_csv = open(os.path.join(args.vis_dir, '{}_seed_counting_{}.csv'.format(csv_name, img_type)), 'w+')

    for img_path_ij, _ in img_list:
        print("image path = {}".format(img_path_ij))
        #
        img_name = img_path_ij.split("/")[-1]
        #
        img_name_pred_before = img_name.replace(".png", '_pred_bf.png')
        img_name_pred_after = img_name.replace(".png", '_pred_af.png')
        # load the images
        img_0 = cv2.imread(img_path_ij)
        # this is only for drawing points
        img_raw = cv2.cvtColor(img_0, cv2.COLOR_BGR2RGB)
        # pre-proccessing
        img = transform(img_raw)
        ##
        # round the size
        height, width = img_0.shape[:2]
        # get the new input image size suitable for VGG16 net
        new_width = (width // 128 +1) * 128
        new_height = (height // 128 + 1)* 128
        ##
        img_in = torch.zeros((3, new_height, new_width))
        img_in[:,:height,:width] = img
        ##
        ##
        if new_height > 1700:
            print("Large file")
            # for out out image
            img_draw = (np.ones((new_height, new_width, 3))*255).astype(np.uint8)
            img_draw[:height, :width,:] = img_raw
            #
            img_to_draw_before = img_draw.copy() 
            img_to_draw_after = img_draw.copy() 
            #
            new_width_hf = int(new_width/2)
            new_height_hf = int(new_height/2)
            for shi in range(1,3):
                print("the {} half".format(shi))
                # prepare the output image
                img_raw_shi = img_draw[new_height_hf*(shi-1):new_height_hf*shi,:,:]
                #
                samples = img_in[:,new_height_hf*(shi-1):new_height_hf*shi,:].unsqueeze(0).to(device)
                # run inference
                outputs = model(samples)

                outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

                outputs_points = outputs['pred_points'][0]

                # filter the predictions
                # 0.5 is used by default
                points = outputs_points[outputs_scores > threshold].detach().cpu().numpy()#.tolist()
                if points.shape[0]< 10000 and points.shape[0] != 0:
                    #print("doing clustering")
                    cutoff = 500/points.shape[0]
                    if cutoff<20:
                        cutoff = 20
                    components = nx.connected_components(
                        nx.from_edgelist(
                            (i, j) for i, js in enumerate(
                                spatial.KDTree(points).query_ball_point(points, cutoff)
                            )
                            for j in js
                        )
                    )

                    clusters = {j: i for i, js in enumerate(components) for j in js}

                    # reorganize the points to the order of clusters 
                    points_reo = np.zeros(points.shape)
                    i = 0
                    for key in clusters.keys():
                        #print(key)
                        points_reo[i,:] = points[key,:]
                        i+=1
                    # points_n has the same order as clusters
                    res = [clusters[key] for key in clusters.keys()]
                    res_n = np.array(res).reshape(-1,1)

                    points_n = []
                    for i in np.unique(res_n):
                        tmp = points_reo[np.where(res_n[:,0] == i)]
                        points_n.append( [np.mean(tmp[:,0]), np.mean(tmp[:,1])])
                else:
                    points_n = points.tolist()
                #
                if shi ==1:
                    points_bf_sum = np.array(points)
                    points_af_sum = np.array(points_n)
                    print("points_af_sum {}".format(points_af_sum.shape))
                else:
                    points_bf_sum = np.concatenate((points_bf_sum, np.array(points)), 0)
                    points_af_sum = np.concatenate((points_af_sum, np.array(points_n)), 0)
                
                # draw the predictions
                alpha = 0.5
                #. before 
                size = 6
                img_to_draw_before_in = cv2.cvtColor(np.array(img_raw_shi), cv2.COLOR_RGB2BGR)
                img_to_draw_before_in_n = img_to_draw_before_in.copy()
                for p in points:
                    img_to_draw_before_in_n = cv2.circle(img_to_draw_before_in_n, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
                img_to_draw_before_in_nn = cv2.addWeighted(img_to_draw_before_in_n, alpha, img_to_draw_before_in, 1 - alpha, 0)
                # save the visualized image
                img_to_draw_before[new_height_hf*(shi-1):new_height_hf*shi,:,:] = img_to_draw_before_in_nn
                #. after  
                #size = 6
                img_to_draw_after_in = cv2.cvtColor(np.array(img_raw_shi), cv2.COLOR_RGB2BGR)
                img_to_draw_after_in_n = img_to_draw_after_in.copy()
                for p in points_n:
                    img_to_draw_after_in_n = cv2.circle(img_to_draw_after_in_n, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
                #
                img_to_draw_after_in_nn = cv2.addWeighted(img_to_draw_after_in_n, alpha, img_to_draw_after_in, 1 - alpha, 0)
                img_to_draw_after[new_height_hf*(shi-1):new_height_hf*shi,:,:] = img_to_draw_after_in_nn
            #
            cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_before), img_to_draw_before[:height,:width,:])
            # save the visualized image
            cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_after), img_to_draw_after[:height,:width,:])
            #
            #predict_cnt = int((outputs_scores > threshold).sum())
            points_bf_sum = points_bf_sum.tolist()
            points_af_sum = points_af_sum.tolist()
            predict_cnt_before = len(points_bf_sum)
            predict_cnt_after = len(points_af_sum)

        else:
            samples = img_in.unsqueeze(0).to(device)
            # run inference
            outputs = model(samples)

            outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

            outputs_points = outputs['pred_points'][0]

            # filter the predictions
            # 0.5 is used by default
            points = outputs_points[outputs_scores > threshold].detach().cpu().numpy()#.tolist()
            if points.shape[0]< 10000 and points.shape[0] != 0:
                #print("doing clustering")
                cutoff = 500/points.shape[0]
                if cutoff<20:
                    cutoff = 20
                components = nx.connected_components(
                    nx.from_edgelist(
                        (i, j) for i, js in enumerate(
                            spatial.KDTree(points).query_ball_point(points, cutoff)
                        )
                        for j in js
                    )
                )

                clusters = {j: i for i, js in enumerate(components) for j in js}

                # reorganize the points to the order of clusters 
                points_reo = np.zeros(points.shape)
                i = 0
                for key in clusters.keys():
                    #print(key)
                    points_reo[i,:] = points[key,:]
                    i+=1
                # points_n has the same order as clusters
                res = [clusters[key] for key in clusters.keys()]
                res_n = np.array(res).reshape(-1,1)

                points_n = []
                for i in np.unique(res_n):
                    tmp = points_reo[np.where(res_n[:,0] == i)]
                    points_n.append( [np.mean(tmp[:,0]), np.mean(tmp[:,1])])
            else:
                points_n = points.tolist()
            # calculate the distance and find the center of the too close points

            #predict_cnt = int((outputs_scores > threshold).sum())
            predict_cnt_before = len(points)
            predict_cnt_after = len(points_n)
            #
            print("Number of seeds before = {}".format(predict_cnt_before))
            print("Number of seeds after = {}".format(predict_cnt_after))
            # draw the predictions
            alpha = 0.5
            #. before 
            size = 6
            #
            img_to_draw_before = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)
            img_to_draw_before_in_x = img_to_draw_before.copy()
            for p in points:
                img_to_draw_before_in_x = cv2.circle(img_to_draw_before_in_x, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
            img_to_draw_before_in_xx = cv2.addWeighted(img_to_draw_before_in_x, alpha, img_to_draw_before, 1 - alpha, 0)
            # save the visualized image
            cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_before), img_to_draw_before_in_xx[:height,:width,:])
            #. after  
            #size = 6
            img_to_draw_after = cv2.cvtColor(np.array(img_raw), cv2.COLOR_RGB2BGR)
            img_to_draw_after_in_x = img_to_draw_after.copy()
            for p in points_n:
                img_to_draw_after_in_x = cv2.circle(img_to_draw_after_in_x, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
            img_to_draw_after_in_xx = cv2.addWeighted(img_to_draw_after_in_x, alpha, img_to_draw_after, 1 - alpha, 0)
            # save the visualized image
            cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_after), img_to_draw_after_in_xx[:height,:width,:])
            #
        print("Number of seeds before = {}".format(predict_cnt_before))
        print("Number of seeds after = {}".format(predict_cnt_after))
        # save the detected pod number
        loss_csv.write('{},{},{},{}\n'.format(img_name_pred_before.split(".")[0], predict_cnt_before, img_name_pred_after.split(".")[0], predict_cnt_after))
        loss_csv.flush()  
    loss_csv.close


#####On big infield images

In [None]:
# threshold for evaluation
threshold = 0.5 #args.threshold
# to apply post processing or not: if Filter = 10000, the filter is applied

#
# load trained model
if args.resume is not None:
    checkpoint = torch.load(args.resume, map_location='cpu')
    model.load_state_dict(checkpoint['model'])
#
model.eval()
#
step = 0
epoch = 0
###
# create the pre-processing transform
transform = standard_transforms.Compose([
    standard_transforms.ToTensor(), 
    standard_transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# set your image path here
folder_path = "/content/drive/MyDrive/soybean_1118selection_annotated/"
folder_sub = os.listdir(folder_path)
#
# save the number of detected pods
csv_name = (args.vis_dir).split("/")[-1]
loss_csv = open(os.path.join(args.vis_dir, '{}_seed_counting_{}.csv'.format(csv_name, 'final')), 'w+')
#
img_ct = 0
for sub in range(len(folder_sub)):
    #
    folder_sub_s = folder_sub[sub]
    #
    folder_path_s = os.path.join(folder_path, folder_sub_s)
    #
    img_list = glob.glob(os.path.join(folder_path_s,'*.JPG'))#[:3]
    #
    for img_path_ij in img_list:
        print("image path = {}".format(img_path_ij))
        #
        img_name = img_path_ij.split("/")[-1]
        #
        img_name_pred_before = img_name.replace(".png", '_pred_bf.png')
        img_name_pred_after = img_name.replace(".png", '_pred_af.png')
        # load the images
        img_00 = cv2.imread(img_path_ij)
        img_0 = cv2.resize(img_00, (int(img_00.shape[1]/2), int(img_00.shape[0]/2)), interpolation = cv2.INTER_AREA)
        # this is only for drawing points
        #img_raw = Image.fromarray(cv2.cvtColor(img_0, cv2.COLOR_BGR2RGB))
        img_raw = cv2.cvtColor(img_0, cv2.COLOR_BGR2RGB)
        # pre-proccessing
        img = transform(img_raw)
        ##
        # round the size
        height, width = img_0.shape[:2]
        # get the new input image size suitable for VGG16 net
        new_width = (width // 128 +1) * 128
        new_height = (height // 128 + 1)* 128
        ##
        img_in = torch.zeros((3, new_height, new_width))
        img_in[:,:height,:width] = img

        # divide the image into two to save memory
        print("new_height = {}".format(new_height))
        # for out out image
        img_draw = (np.ones((new_height, new_width, 3))*255).astype(np.uint8)
        img_draw[:height, :width,:] = img_raw
        #
        img_to_draw_before = img_draw.copy() 
        img_to_draw_after = img_draw.copy() 
        #
        new_width_hf = int(new_width/4)
        new_height_hf = int(new_height/4)
        for shi in range(1,5):
            print("the {} half".format(shi))
            # prepare the output image
            img_raw_shi = img_draw[:,new_width_hf*(shi-1):new_width_hf*shi,:]
            #
            samples = img_in[:,:,new_width_hf*(shi-1):new_width_hf*shi].unsqueeze(0).to(device)
            # run inference
            outputs = model(samples)

            outputs_scores = torch.nn.functional.softmax(outputs['pred_logits'], -1)[:, :, 1][0]

            outputs_points = outputs['pred_points'][0]

            # filter the predictions
            # 0.5 is used by default
            points = outputs_points[outputs_scores > threshold].detach().cpu().numpy()#.tolist()
            if points.shape[0]< 100000 and points.shape[0] != 0:
                #print("doing clustering")
                cutoff = 500/points.shape[0]
                if cutoff<10:
                    cutoff = 10
                components = nx.connected_components(
                    nx.from_edgelist(
                        (i, j) for i, js in enumerate(
                            spatial.KDTree(points).query_ball_point(points, cutoff)
                        )
                        for j in js
                    )
                )

                clusters = {j: i for i, js in enumerate(components) for j in js}

                # reorganize the points to the order of clusters 
                points_reo = np.zeros(points.shape)
                i = 0
                for key in clusters.keys():
                    #print(key)
                    points_reo[i,:] = points[key,:]
                    i+=1
                # points_n has the same order as clusters
                res = [clusters[key] for key in clusters.keys()]
                res_n = np.array(res).reshape(-1,1)

                points_n = []
                for i in np.unique(res_n):
                    tmp = points_reo[np.where(res_n[:,0] == i)]
                    points_n.append( [np.mean(tmp[:,0]), np.mean(tmp[:,1])])
            else:
                points_n = points.tolist()
            #
            if points_n:
                    if shi ==1:
                        points_bf_sum = np.array(points)
                        points_af_sum = np.array(points_n)
                        print("points_af_sum {}".format(points_af_sum.shape))
                    else:
                        points_bf_sum = np.concatenate((points_bf_sum, np.array(points)), 0)
                        points_af_sum = np.concatenate((points_af_sum, np.array(points_n)), 0)
            # draw the predictions
            alpha = 0.6
            #. before 
            size = 5
            img_to_draw_before_in = cv2.cvtColor(np.array(img_raw_shi), cv2.COLOR_RGB2BGR)
            img_to_draw_before_in_n = img_to_draw_before_in.copy()
            for p in points:
                img_to_draw_before_in_n = cv2.circle(img_to_draw_before_in_n, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
            img_to_draw_before_in_nn = cv2.addWeighted(img_to_draw_before_in_n, alpha, img_to_draw_before_in, 1 - alpha, 0)
            # save the visualized image
            img_to_draw_before[:,new_width_hf*(shi-1):new_width_hf*shi,:] = img_to_draw_before_in_nn
            #. after  
            #size = 4
            img_to_draw_after_in = cv2.cvtColor(np.array(img_raw_shi), cv2.COLOR_RGB2BGR)
            img_to_draw_after_in_n = img_to_draw_after_in.copy()
            for p in points_n:
                img_to_draw_after_in_n = cv2.circle(img_to_draw_after_in_n, (int(p[0]), int(p[1])), size, (0, 0, 255), -1)
            #
            img_to_draw_after_in_nn = cv2.addWeighted(img_to_draw_after_in_n, alpha, img_to_draw_after_in, 1 - alpha, 0)
            img_to_draw_after[:,new_width_hf*(shi-1):new_width_hf*shi,:] = img_to_draw_after_in_nn
        #
        cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_before), img_to_draw_before[:height,:width,:])
        # save the visualized image
        cv2.imwrite(os.path.join(args.vis_dir, img_name_pred_after), img_to_draw_after[:height,:width,:])
        #
        #predict_cnt = int((outputs_scores > threshold).sum())
        points_bf_sum = points_bf_sum.tolist()
        points_af_sum = points_af_sum.tolist()
        predict_cnt_before = len(points_bf_sum)
        predict_cnt_after = len(points_af_sum)
        #
        print("Number of seeds before = {}".format(predict_cnt_before))
        print("Number of seeds after = {}".format(predict_cnt_after))
        # save the detected pod number
        loss_csv.write('{},{},{},{}\n'.format(img_name_pred_before.split(".")[0], predict_cnt_before, img_name_pred_after.split(".")[0], predict_cnt_after))
        loss_csv.flush()  
        #
        print("saving {}th image {}".format(img_ct, img_path_ij))
        img_ct +=1
loss_csv.close
