In [None]:
#Imports and installs
import transformers
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
!pip install craft-text-detector
import transformers
from craft_text_detector import Craft
import requests 
import torch
import os, random
from PIL import Image
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from tqdm import tqdm
import pandas as pd
import numpy as np
import imghdr
import pickle
from pathlib import Path
import cv2
import torch.nn.functional as F
import multiprocessing
from functools import partial

In [None]:
# Set COLAB = False if running on SCC
COLAB = False

#suppressing all the huggingface warnings
SUPPRESS = True
if SUPPRESS:
    from transformers.utils import logging
    logging.set_verbosity(40)

In [322]:
if COLAB:
    from google.colab import drive
    drive.mount('/content/gdrive',force_remount = True)
    workdir = '/content/gdrive/MyDrive/testing-trocr'
    output_dir_craft = '/content/gdrive/MyDrive/craft/'
else:
    workdir = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/scraped-data/drago_testdata' # update this to the desired directory on scc
    output_dir_craft = '/usr3/graduate/colejh/craft/'

In [315]:
# initialize the CRAFT model
craft = Craft(output_dir = output_dir_craft,export_extra = False, text_threshold = .8,link_threshold = .6, crop_type="poly",low_text = .5,cuda = True)

# Running craft and saving the segmented images

In [317]:
# CRAFT on images to get bounding boxes
images = []
corrupted_images = []
no_segmentations = []
boxes = {}
count= 0
file_types = (".jpg", ".jpeg",".png")
for filename in tqdm(os.listdir(workdir)):
    if filename.endswith(file_types):
        image = workdir+'/'+filename
        try:
            img = Image.open(image) 
            img.verify() # Check that the image is valid
            bounding_areas = craft.detect_text(image)
            if not bounding_areas: #check that a segmentation was found
                no_segmentations.append(image)
            else:
                images.append(image)
                boxes[image] = bounding_areas['boxes']
        except (IOError, SyntaxError) as e:
            corrupted_images.append(image)

            
#     count +=1
#     # Using count for time being to get things working, remove once setup is complete
#     if count == 5:
#         break
    

  polys = np.array(polys)
  polys_as_ratio = np.array(polys_as_ratio)
100%|██████████| 1018/1018 [16:54<00:00,  1.00it/s]


# Getting all the segemnted images into a dataloader, and loading model and processor for trocr

In [323]:
# Deleting empty folders
root = output_dir_craft
folders = list(os.walk(root))[1:]
deleted = []
for folder in folders:
    # folder example: ('FOLDER/3', [], ['file'])
    if not folder[2]:
        deleted.append(folder)
        os.rmdir(folder[0])
# Setting up the Tr-OCR model (using base model currently, large takes much longer)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Use all available gpu's
model_gpu= nn.DataParallel(model).to(device)

# Dataloader for working with gpu's
trainset = datasets.ImageFolder(output_dir_craft, transform = processor)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)# for i, data in enumerate(trainloader):

# For matching words to image
filenames = [s.replace('_crops', '') for s in list(trainset.class_to_idx)]
word_log_dic = {k: v for k,v in enumerate(filenames)}
words_identified = {k: [] for v,k in enumerate(filenames)}


In [None]:
# Yoinked https://github.com/rsommerfeld/trocr/blob/main/src/scripts.py

def get_confidence_scores(generated_ids):
    # Get raw logits, with shape (examples,tokens,token_vals)
    logits = generated_ids.scores
    logits = torch.stack(list(logits),dim=1)

    # Transform logits to softmax and keep only the highest (chosen) p for each token
    logit_probs = F.softmax(logits, dim=2)
    char_probs = logit_probs.max(dim=2)[0]

    # Only tokens of val>2 should influence the confidence. Thus, set probabilities to 1 for tokens 0-2
    mask = generated_ids.sequences[:,:-1] > 2
    char_probs[mask] = 1

    # Confidence of each example is cumulative product of token probs
    batch_confidence_scores = char_probs.cumprod(dim=1)[:, -1]
    return [v.item() for v in batch_confidence_scores]



In [None]:
def evaluate_craft_seg(model,trainloader):
    results = []
    confidence = []
    label = []

    model_gpu.eval()
    with torch.no_grad():
        for idx,data in enumerate(tqdm(trainloader)):

            images, labels = data
            images, labels = images['pixel_values'][0].to(device), labels.to(device)

            decoded = model_gpu.module.generate(images,return_dict_in_generate = True, output_scores = True) 
            final_values = processor.batch_decode(decoded.sequences, skip_special_tokens=True)


            confidences = get_confidence_scores(decoded)

            for idx,value in enumerate(labels.cpu().numpy()):
#           if confidences[idx]>.8: # to cull some of the really terrible guesses, probably want to do this in the search function instead
                words_identified[word_log_dic[value]].append(final_values[idx])
            
            results.extend(final_values)
            confidence.extend(confidences)
            label.extend(labels.cpu().numpy())

    return results,confidence,label

In [324]:
#Storing the outputs
results,confidence,labels = evaluate_craft_seg(model,trainloader)
df = pd.DataFrame(list(zip(results,confidence,labels)),columns = ['Results','Confidence','Labels'])
df.to_pickle('results7.pkl')


100%|██████████| 1552/1552 [22:33<00:00,  1.15it/s]


In [None]:
def load_pickle(filepath):
    filepkl = open(filepath, "rb")

    # Unpickle the objects

    unpickled = pickle.load(filepkl)
    
    return unpickled



In [None]:
a = load_pickle('/usr3/graduate/colejh/'+'results5.pkl')
results =pd.Series(a.Results)
labels = list(a['Labels'])

In [None]:
# String matching installs
# !pip install --force-reinstall numpy==1.18.5 # need this to work the string grouper
# !pip install numpy
# !pip install string-grouper

from string_grouper import match_strings, match_most_similar

In [None]:
if COLAB:
    taxon_file = '/content/gdrive/MyDrive/corpus_taxon.txt'
else:
    taxon_file = workdir+'/taxon_corpus.txt'
    geography_file = workdir+'/geography_corpus.txt'
    collector_file = workdir+'/collector_corpus.txt'
taxon = pd.read_csv(taxon_file, delimiter = "\t", names=["Taxon"]).squeeze()
geography = pd.read_csv(geography_file, delimiter = "\t", names=["Geography"]).squeeze()
collector = pd.read_csv(collector_file, delimiter = "\t", names=["Collector"]).squeeze()

In [None]:

# String matching
minimum_similarity = .1
results_series = pd.Series(results)
taxon = pd.read_csv(taxon_file, delimiter = "\t", names=["Taxon"]).squeeze()
 
start = time.time()
matches = match_strings(taxon,results_series,n_blocks = 'guess',min_similarity = minimum_similarity,max_n_matches = 4)
end = time.time()
print('time',end-start)


In [None]:
def match(main,comparison_file,minimum_similarity):
     # Function takes a main file containing strings, a comparison file to match against main,
     #  and a minimum similarity confidence level. Returns a list of matches based on similarity.

    if not isinstance(comparison_file, pd.Series):

        comparison_file = pd.Series(comparison_file)

    matches = match_strings(main,comparison_file,n_blocks = 'guess',min_similarity = minimum_similarity,max_n_matches = 1)

    return matches


In [None]:
def highest_score_per_image(df,labels,filenames,minimum_similarity):
    # Getting the highest score for each individual image 
    index_to_labels = df.copy()
    for a in index_to_labels.right_index.unique():
        index_to_labels.loc[index_to_labels['right_index'] == a, 'right_index'] = labels[a]
    unique_labels = index_to_labels.loc[index_to_labels.groupby('right_index')['similarity'].idxmax()]

#     print("Of the",len(filenames),"images evaluated",len(unique_labels), "have a prediction score above",minimum_similarity*100, "percent.")
    return unique_labels
# unique_labels = highest_score_per_image(ascending_df,labels,filenames,minimum_similarity)

In [None]:
def pooled_match(comparison_file,labels,filenames, minimum_similarity = .7,**kwargs):
    # Take in any number of files containing strings to match against and return a dictionary
    # with keys the same name as input and values as the dataframe with matching information
   
    corpus_list = []
    corpus_name = []
    
    for k,v in kwargs.items():
        # Convert to series (string-grouper requires this type), will work if input is list, array, or series
        if not isinstance(v, pd.Series):
            v = pd.Series(v)
        corpus_list.append(v)
        corpus_name.append(k)
    
    func = partial(match, comparison_file = results_series,  minimum_similarity = minimum_similarity)
    pool = multiprocessing.Pool()

    result_dic = {}
    for i,result in enumerate(pool.map(func,corpus_list)):
        result.columns.values[1] = corpus_name[i]+' Corpus'
        result.columns.values[3] = "Predictions"
        result = result.drop('left_index', axis=1)
        result = highest_score_per_image(result,labels,filenames,minimum_similarity)
        result_dic[corpus_name[i]] = result

   
    return result_dic

In [325]:
print(len(labels),len(filenames))
results_series = pd.Series(results)

24823 976


In [326]:
start = time.time()
results = pooled_match(results_series,labels, filenames,minimum_similarity =.01,Taxon = taxon,Geography = g,Collector = collector)
end = time.time()
print(end-start)
# for k,v in results.items():
#     display(v)
    

28.370580911636353


In [None]:
for k,v in results.items():
    display(v)

In [None]:
# Reorganizing the df
ascending_df = matches.sort_values(by=['similarity'],ascending=False)
pd.set_option('display.width', 150)
display(ascending_df)

In [None]:
def matches_above_x(df,x):
    return df.loc[df['similarity']>=.75]


In [None]:
above_75 = matches_above_x(unique_labels,75)
display(above_75)

In [None]:
# Reading in the ground truth values
gt_t = workdir+'/taxon_gt.txt'
Taxon_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_t) }

gt_g = workdir+'/geography_gt.txt'
Geography_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_g) }

gt_c = workdir+'/collector_gt.txt'
Collector_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_c) }

comparison_file = {"Taxon":Taxon_truth,"Geography":Geography_truth,"Collector":Collector_truth}

In [None]:
# need to add in the check for other species names, quite a few would match
prediction_and_imagenumber = list(zip(unique_labels.left_Taxon, unique_labels.right_index))
count = 0
for prediction,image_number in prediction_and_imagenumber:
    try:
        image = word_log_dic[image_number]
        gt = ground_truth[image]
        if gt == prediction:
            count +=1
#             print(gt,"||",prediction,'||',image)
        else:
            print('*',gt,"||",prediction,'||',image)
    except KeyError as e:
        print("Ground Truth Not Found for:",word_log_dic[image_number])
acc = count/len(unique_labels)
print("Accuracy on Predicted:",acc)
print("Total accuracy: ",count/len(filenames))


In [None]:
def prRed(skk): print("\033[91m{}\033[00m" .format(skk))
    
class color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

In [328]:
for k,v in results.items():    
    # need to add in the check for other species names, quite a few would match
    prediction_and_imagenumber = list(zip(v[k+' Corpus'], v.right_index))
    count = 0
    print('Evaluation for',k)
    for prediction,image_number in prediction_and_imagenumber:
        try:
            image = word_log_dic[image_number]
            gt = comparison_file[k][image]
            if gt == prediction:
                count +=1
                print(gt,"||",prediction,'||',image)
            else:
                print(color.RED+gt+"||"+prediction+'||'+image+color.END)
        except KeyError as e:
            print("Ground Truth Not Found for:",word_log_dic[image_number])
    
    acc = count/len(v)
    print(color.BOLD+"Accuracy on Predicted:"+str(acc)+color.END)
    print(color.BOLD+"Total accuracy: "+str(count/len(filenames))+color.END)
    print(color.BOLD+"Total Guessed:"+str(len(prediction_and_imagenumber))+color.END)
    print('\n\n********************************\n\n')


Evaluation for Taxon
[91mIsopterygium nivescens||Chaetogastra mollis||1019531437[0m
[91mSiler montanum||Lolium perenne||1019752371[0m
[91mEriophorum tenellum||Eriophorum gracile||1038924603[0m
[91mSolidago patula||Acaciella angustissima||1038926232[0m
[91mNuttallanthus canadensis||Clarkia xantiana||1038933447[0m
Suaeda linearis || Suaeda linearis || 1038967579
[91mEriogonum brachypodum||Eriogonum parvifolium||1038991156[0m
Chenopodium hybridum || Chenopodium hybridum || 1039025105
Bidens cornuta || Bidens cornuta || 1055366369
[91mMangifera duperreana||Poa interior||1056014429[0m
[91mCalophyllum vitiense||Mercurialis annua||1056069802[0m
[91mAngraecum rostratum||Tapirira obtusa||1056306307[0m
[91mCalymmodon clavifer||Adiantum pedatum||1057234135[0m
[91mFumaria capreolata||Pilea pumila||1057260090[0m
Psydrax schimperiana || Psydrax schimperiana || 1057464806
[91mCladophora columbiana||Racomitrium microcarpum||1057532849[0m
[91mCuspidaria simplicifolia||Astragalu

In [None]:
for k,v in words_identified.items():
    print(k,v)

In [None]:
# CRAFT on images to get bounding boxes
images = []
corrupted_images = []
boxes = {}
count= 0
file_types = (".jpg", ".jpeg",".png")
for filename in tqdm(os.listdir(workdir)):
    if filename.endswith(file_types):
        image = workdir+'/'+filename
        try:
            img = Image.open(image) # open the image file
            img.verify() # verify that it is, in fact an image
            images.append(image)
            bounding_areas = craft.detect_text(image)
            boxes[image] = bounding_areas['boxes']
        except (IOError, SyntaxError) as e:
            corrupted_images.append(image)

            
#     count +=1
#     # Using count for time being to get things working, remove once setup is complete
#     if count == 5:
#         break
    

In [None]:
# trainset2 = datasets.ImageFolder(workdir, transform = processor)
# trainloader2 = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)# for i, data in enumerate(trainloader):

# Checking size of all the segmentation images (about 1/6 input)
root_directory = Path(workdir)
sum(f.stat().st_size for f in root_directory.glob('**/*') if f.is_file())

def save_bounding(filename,bounding_boxes):
# function to save the boudning box outputs from CRAFT segmentation

    f = open(filneame+".pkl","wb")

    # write the python object (dict) to pickle file
    pickle.dump(bounding_boxes,f)

    # close file
    f.close()
    
def display_segmentations(boxes_dict): # Check this later if we dont want to save all the image segementations
    # Function takes in the dictionary of image_path and bounding boxes and displays all of the segmentations
    
    count = 0
    for image_path,bounding_boxes in boxes_dict.items():
        original_image = cv2.imread(image_path)
        for box in v:
            segmentation = original_image[int(box[0][1]): int(box[2][1]), 
                  int(box[0][0]): int(box[2][0])]

        segmented_image = Image.fromarray(segmentation).convert("RGB")
        display(segmented_image)
        count+=1
        if count>5:
            break

            
# Downloading the tr-ocr model and processor
# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Training set creation
# trainset = datasets.ImageFolder(output_dir_craft, transform = processor)
# trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Put model on the gpu 
# model.to(device)

# img_dir = '/usr3/graduate/colejh/craft3/'
# for filename in os.listdir(img_dir):
#     for filename2 in os.listdir(img_dir+'/'+filename):
#         image = Image.open(img_dir+'/'+filename+'/'+filename2)
#         pixel_values = processor(image, return_tensors="pt").pixel_values 
#         generated_ids = model.generate(pixel_values)
#         generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
#         display(image)
#         print("Text: "+generated_text)

In [None]:
# Checking size of all the segmentation images (about 1/6 input)
root_directory = Path(workdir)
sum(f.stat().st_size for f in root_directory.glob('**/*') if f.is_file())


# saving the bounding boxes dictionary 
f = open("boundingboxes.pkl","wb")

# write the python object (dict) to pickle file
pickle.dump(dict,f)

# close file
f.close()

#Check this later if we dont want to save all the image segementations
count = 0
for k,v in boxes.items():
    full_image = cv2.imread(k)
    for box in v:
        segmentation = full_image[int(box[0][1]): int(box[2][1]), 
              int(box[0][0]): int(box[2][0])]
    
    #Drop this later, just for checking stuff
    image = Image.fromarray(segmentation).convert("RGB")
    display(image)
   
    count+=1
    if count>5:
        break