# Final Pipeline

### Detection + Classification

In [None]:
import torch
import torchvision
from torchvision.models.detection import retinanet_resnet50_fpn
from torchvision.models.detection.anchor_utils import AnchorGenerator
from torchvision.models import efficientnet_b0

import torch.nn.functional as F
from torchvision.ops import nms
from torchvision import transforms
from PIL import Image

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialising the previously trained detection model

In [None]:
detection_model = retinanet_resnet50_fpn(pretrained=True)
num_classes = 2  # background=0 , cell=1
in_features = detection_model.backbone.out_channels

# Anchor boxes
anchor_generator = AnchorGenerator(
    sizes=((32,), (64,), (128,), (256,), (512,),),
    aspect_ratios=((0.5, 1.0, 2.0),) * 5,
)

num_anchors = anchor_generator.num_anchors_per_location()[0]
detection_model.head = torchvision.models.detection.retinanet.RetinaNetHead(
    in_channels=in_features, num_classes=num_classes, num_anchors=num_anchors
)

detection_model.anchor_generator = anchor_generator
detection_model.to(device)
detection_model.load_save_dict(torch.load("replace w/ path to weights"))
detection_model.eval()

# Initialising the previously trained classification model

In [None]:
classify_model = efficientnet_b0(weights='IMAGENET1K_V1')


# Freezing layers
for param in classify_model.parameters():
    param.requires_grad = False

classes = 12
classify_model.classifier[-1] = torch.nn.Linear(in_features=1280, out_features=classes)

classify_model.to(device)
classify_model.load_state_dict(torch.load('replace w/ path to weights', map_location=device))

# Runs images through the detection network and get the detections

In [None]:
def get_crops(image, crop_size, overlap):
    img_w, img_h = image.size

    step_size = int(crop_size - crop_size * overlap / 100)

    crops = []

    tr = transforms.ToTensor()

    for y in range(0,img_h - crop_size + 1, step_size):
        for x in range(0, img_w - crop_size +1, step_size):

            if (x + crop_size <= img_w) and (y + crop_size <= img_h):
                cropped = image.crop(((x, y,x+crop_size, y+crop_size)))

                crops.append(tr(cropped))


    return crops


def process_detections(detections, detection_threshold):
   results = []

   for detection in detections:
      boxes = detection['boxes']
      scores = detection['scores']
      labels = detection['labels']

      boxes = boxes.cpu()
      scores = scores.cpu()
      labels = labels.cpu()

      keep = nms(boxes, scores, detection_threshold)

      results.append({
         'boxes': boxes[keep],
         'scores' : scores[keep],
         'labels' : labels[keep]
      })

      return results


def process_image(image, crop_size, overlap, model, detection_threshold):
    crops = get_crops(image, crop_size, overlap)

    outputs = []

    for crop in crops:
      crop_tensor = [crop.float().to(device)]
      with torch.no_grad():
        out = model(crop_tensor)
        outputs.extend(out)

    final_output = process_detections(outputs, detection_threshold)

    return final_output

# Runs detections through the classification network

In [None]:
def process_cells(image, cords, clas_model, crop_size=50):
    clas_model.eval()
    outputs = []
    tr = transforms.Compose([
        transforms.Resize((crop_size, crop_size)),
        transforms.ToTensor()
    ])
    
    for cord in cords:
        cord = tuple(cord)
        cropped = image.crop((cord))
        
        cropped = tr(cropped).unsqueeze(0).to(device)
        
        with torch.no_grad():
            out = clas_model(cropped)
            probability = F.softmax(out, dim=1)
            outputs.append(probability)

    results = {} # {label:[cells]}
    agNOR = 0
    predicted_classes = [torch.argmax(pred).item() for pred in outputs]

    for i in range(len(predicted_classes)):
        if predicted_classes[i] in results:
            results[predicted_classes[i]].append(cords[i])
        else:
            results[predicted_classes[i]] = [cords[i]]
        agNOR += predicted_classes[i]

    agNOR /= len(predicted_classes)
    return results, agNOR

In [None]:
def compute_AgNOR_score(image, detect_model, clas_model, crop_size, overlap, detection_threshold):
    cell_detections = process_image(image, crop_size, overlap, detect_model, detection_threshold)
    cell_cordinates = cell_detections[0]['boxes'].tolist()

    cell_label_dict, agnor_score = process_cells(image, cell_cordinates, clas_model)

    return cell_label_dict, agnor_score