In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

## The followings are checkpoints for the models

Faster R-CNN: https://drive.google.com/file/d/1WAbR50Eows_lDsevBR8aUV9-RTRKFZ2O/view?usp=sharing

Prototypical Network with ResNet Backbone:https://drive.google.com/file/d/1--ZwmUJAvYlCForsYBtLG-Nbi-kelldr/view?usp=sharing

Matching Network with ResNet Backbone: https://drive.google.com/file/d/1-6Lxza8lPvnURCJrdG-8MGEEtSeptJeX/view?usp=sharing


In [None]:
!pip install easyfsl
!pip install gradio
!pip install git+https://github.com/facebookresearch/detectron2.git

In [None]:
import os 
import pandas as pd 
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T
from skimage import io
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import resnet18
from torch.utils.data import DataLoader
from torch.utils.data import Sampler
from torch.nn import Module
from easyfsl.utils import plot_images, sliding_average
import torch.optim 
import random
import numpy as np
import time
from tqdm import tqdm
from time import sleep

import gradio as gr
import glob
import sys
from math import tan, pi
import torchvision
import cv2



import json

import matplotlib.pyplot as plt

from detectron2.structures import BoxMode
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
from detectron2.evaluation import COCOEvaluator, inference_on_dataset


from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import ColorMode, Visualizer

import shutil
from PIL import Image
from tkinter import Tk
from tkinter.filedialog import askdirectory
import csv
from openpyxl import Workbook

#TODO HOOKS
from detectron2.engine.hooks import HookBase
from detectron2.evaluation import inference_context
from detectron2.utils.logger import log_every_n_seconds
from detectron2.data import DatasetMapper, build_detection_test_loader
import detectron2.utils.comm as comm


## Backbone
This is ResNet18 Backbone used for application

In [None]:
class resnet(nn.Module):
    def __init__(self):
        super(resnet, self).__init__()
        self.resnet_18 = resnet18(pretrained=False)
        self.resnet_18.fc = nn.Flatten()
    def forward(self, x):

        if x.shape[1] == 3:
          return self.resnet_18(x)
        elif x.shape[1] == 1:
          x = torch.cat((x, x, x), dim = 1)
          return self.resnet_18(x)
        else:
          raise ValueError('shape[1] is not 1 or 3 or it is not even channel dimension')

Prototypical Networks with ResNet-18 Backbone

In [None]:
class Proto(nn.Module):
    def __init__(self, hidden_channels, input_channels
                 #, n_shot, k_way, q
                 ):
        super(Proto, self).__init__()
        # self.n_shot = n_shot
        # self.k_way = k_way
        # self.q = q
        self.input_channels = input_channels
        self.hidden_channels = hidden_channels
        # self.backbone = ConvNet(self.input_channels, self.hidden_channels)
        self.backbone = resnet()
    def forward(self, x, n_shot, k_way, q):
        #here y_support is already made label:
        x = x.squeeze(0).to(device)
        embedded_x = self.backbone(x)
        # print(embedded_x.shape)
        support_set = embedded_x[:k_way*n_shot] #shape of n*k, embedding dim 
        q_set = embedded_x[k_way*n_shot:] #shape of q*k, embedding dim
        mean_support = torch.cat([torch.mean(support_set[i*n_shot:(i+1)*n_shot], dim = 0).unsqueeze(0) for i in range(k_way)]) #now we have 
        # print(q_set.shape, mean_support.shape)
        l2_distance = torch.cdist(q_set, mean_support)
        # print(l2_distance.shape, 'shape of l2 distance matrix')
        return -l2_distance

## Device
Please set the device as you want

In [None]:
device = torch.device('cuda')

## Weight Loading
Load the weights for protypical networks

In [None]:
model = Proto(32, 1).to(device)
model.load_state_dict(torch.load('/content/gdrive/My Drive/Proto_ResNet.pth', map_location=device))

## Detection

In [None]:
# TODO NOT FREEZE
def run():
    torch.multiprocessing.freeze_support()
    print('loop')

# TODO GETTING DATA
def get_data_dicts(directory, classes):
    dataset_dicts = []
    for filename in [file for file in os.listdir(directory) if file.endswith('.json')]:
        json_file = os.path.join(directory, filename)
        with open(json_file) as f:
            img_anns = json.load(f)

        record = {}

        filename = os.path.join(directory, img_anns["imagePath"])

        record["file_name"] = filename
        record["height"] = 800
        record["width"] = 800

        annos = img_anns["shapes"]
        objs = []
        for anno in annos:
            px = [a[0] for a in anno['points']]  # x coord
            py = [a[1] for a in anno['points']]  # y-coord
            poly = [(x, y) for x, y in zip(px, py)]  # poly for segmentation
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": classes.index(anno['label']),
                "iscrowd": 0
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts

if __name__ == '__main__':
    run()


# # TODO CLASS AND SHIT
classes = ['signature', 'date']

data_path = 'C:/Users/kamra/Documents/My Documents/2022 Fall/Deep Learning/Project/ds_reformatted1/'

for d in ["train", "test"]:
    DatasetCatalog.register(
        "category_" + d,
        lambda d=d: get_data_dicts(data_path+d, classes)
    )
    MetadataCatalog.get("category_" + d).set(thing_classes=classes)

microcontroller_metadata = MetadataCatalog.get("category_train")


# # TODO TRAIN OPTIONS
cfg = get_cfg()
cfg.MODEL.DEVICE = "cpu"

# TODO TEST OPTIONS
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "/content/gdrive/My Drive/model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.6

predictor = DefaultPredictor(cfg)

# TODO TEST ONLY

import json
import matplotlib.pyplot as plt

def out(imageName):
    im = np.array(imageName)
    im2 = imageName
    outputs = predictor(im)
    v = Visualizer(im[:, :, ::-1],
                    metadata=microcontroller_metadata)
    #
    pred_classes = outputs['instances'].pred_classes.cpu().tolist()
    class_names = MetadataCatalog.get("category_train").thing_classes
    pred_class_names = list(map(lambda x: class_names[x], pred_classes))
    #Save detected image
    v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    #v.save(imageName + "_test.png")
    # Cropping
    i = 0
    for pred in range(len(pred_class_names)):
        if pred_class_names[pred] == "signature":
            break
        else:
            i += 1
    boxes = list(outputs["instances"].pred_boxes)
    box = boxes[i]
    box = box.detach().cpu().numpy()
    x_top_left = box[0]
    y_top_left = box[1]
    x_bottom_right = box[2]
    y_bottom_right = box[3]
    crop_img = im2.crop((int(x_top_left), int(y_top_left), int(x_bottom_right), int(y_bottom_right)))
    # print(int(x_top_left),int(x_bottom_right), int(y_bottom_right), int(y_top_left), "Results for this")
    # print(im2.shape)
    # crop_img = im2[int(x_top_left):int(x_bottom_right), int(y_top_left):int(y_bottom_right)]
    #crop_img.save(imageName + "_cropped.png")
    return v, crop_img

## Detection Application
This application is for only detection, given the document to the left window, signature detected document, and cropped part for further processing will be out

In [None]:
def detection(image):
  v, cropped_image = out(image)
  v = Image.fromarray(np.uint8(v.get_image())).convert('RGB')
  return v, cropped_image


demo_detection = gr.Interface(
    fn=detection,
    inputs=[gr.Image(type="pil")],
    outputs =[gr.Image(), gr.Image()]
)
demo_detection.launch(debug= True)

## Application for both Registering and Classfying
Please be careful that you don't register two same signatures since model is based one-shot learning.

Firstly register only images containing signatures only, after loading at least two images, you can test the classification in Detection and Classification window by uploading the document

In [None]:
signatures = [] #signatures to be stores 
names = [] #carrier of the signatures
def label_to_name(names):
  return {name:i for i, name in enumerate(names)}

transform = T.Compose([T.ToTensor(), T.Grayscale(), T.Resize((200, 200))])
def registering_new_signature(name,image):
    image = transform(image)
    signatures.append(image)
    names.append(name)
    return str(image.shape[0]) + "," + str(image.shape[1]) + "," + str(image.shape[2])

def classify_signature(image):

    v, cropped_image = out(image)
    v = Image.fromarray(np.uint8(v.get_image())).convert('RGB')
    image = transform(cropped_image)
    x = torch.cat((torch.stack(signatures), image.unsqueeze(0)))

    probs = torch.softmax(model(x, 1, len(signatures), 1), dim = 1).squeeze(0)


    return {str(names[i]):float(probs[i]) for i in range(len(signatures))}, v
    
demo_register = gr.Interface(
    fn=registering_new_signature,
    inputs=['text', gr.Image()],
    outputs = "text"
)

demo_classifier = gr.Interface(
    fn = classify_signature,
    inputs = gr.Image(type = 'pil'),
    outputs=[gr.Label(), gr.Image()]
)

demo = gr.TabbedInterface([demo_register, demo_classifier], ["Registering", "Detection & Classification"])
print(names)
demo.launch(debug= True)