# Imports

In [8]:
#Imports and installs
import transformers
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# !pip install craft-text-detector
import transformers
from craft_text_detector import Craft
import requests 
import torch
import os, random
from PIL import Image,ImageFilter
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset
from tqdm import tqdm
import pandas as pd

import numpy as np
import imghdr
import pickle
from pathlib import Path
import cv2
import torch.nn.functional as F
import multiprocessing
from functools import partial
import json

from string_grouper import match_strings, match_most_similar
# !pip install pycountry
import pycountry

# Directories

In [9]:
# Set COLAB = False if running on SCC
COLAB = False

#suppressing all the huggingface warnings
SUPPRESS = True
if SUPPRESS:
    from transformers.utils import logging
    logging.set_verbosity(40)

if COLAB:
    from google.colab import drive
    drive.mount('/content/gdrive',force_remount = True)
    workdir = '/content/gdrive/MyDrive/testing-trocr'
    output_dir_craft = '/content/gdrive/MyDrive/craft/'
else:
    workdir = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/scraped-data/drago_testdata' # update this to the desired directory on scc
    output_dir_craft = '/projectnb/sparkgrp/colejh/craft'

# Running craft and saving the segmented images

In [None]:
# initialize the CRAFT model
craft = Craft(output_dir = output_dir_craft,export_extra = False, text_threshold = .8,link_threshold = .6, crop_type="poly",low_text = .5,cuda = True)

# CRAFT on images to get bounding boxes
images = []
corrupted_images = []
no_segmentations = []
boxes = {}
count= 0
img_name = []
box = []
file_types = (".jpg", ".jpeg",".png")
for filename in tqdm(os.listdir(workdir)):
    if filename.endswith(file_types):
        image = workdir+'/'+filename
        try:
            img = Image.open(image) 
            img.verify() # Check that the image is valid
            bounding_areas = craft.detect_text(image)
            if not bounding_areas: #check that a segmentation was found
                no_segmentations.append(image)
            else:
                images.append(image)
                boxes[image] = bounding_areas['boxes']
        except (IOError, SyntaxError) as e:
            corrupted_images.append(image)

    

# Getting all the segemnted images into a dataloader, and loading model and processor for trocr

In [None]:
# Deleting empty folders
root = output_dir_craft
folders = list(os.walk(root))[1:]
deleted = []
for folder in folders:
    # folder example: ('FOLDER/3', [], ['file'])
    if not folder[2]:
        deleted.append(folder)
        os.rmdir(folder[0])
        
# Setting up the Tr-OCR model (using base model currently, large takes much longer)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(device)

# Use all available gpu's
model_gpu= nn.DataParallel(model,list(range(torch.cuda.device_count()))).to(device)

# Dataloader for working with gpu's
trainset = datasets.ImageFolder(output_dir_craft, transform = processor)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)# for i, data in enumerate(trainloader):

# For matching words to image
filenames = [s.replace('_crops', '') for s in list(trainset.class_to_idx)]
word_log_dic = {k: v for k,v in enumerate(filenames)}
words_identified = {k: [] for v,k in enumerate(filenames)}


In [None]:
save_dir = '/projectnb/sparkgrp/colejh/saved_results/'
# save filenames
with open(r'/projectnb/sparkgrp/colejh/saved_results/filenames.txt', 'w') as fp:
    for item in filenames:
        # write each item on a new line
        fp.write("%s\n" % item)

In [None]:
# Save word_log_dic and words_identified
with open(save_dir+'word_log_dic.json', 'w') as fp:
    json.dump(word_log_dic, fp)
with open(save_dir+'words_identified.json', 'w') as fp:
    json.dump(words_identified, fp)

# Training Model

In [None]:
# Yoinked https://github.com/rsommerfeld/trocr/blob/main/src/scripts.py
# For later use if we want to use the confidence scores from the model 
def get_confidence_scores(generated_ids):
    # Get raw logits, with shape (examples,tokens,token_vals)
    logits = generated_ids.scores
    logits = torch.stack(list(logits),dim=1)

    # Transform logits to softmax and keep only the highest (chosen) p for each token
    logit_probs = F.softmax(logits, dim=2)
    char_probs = logit_probs.max(dim=2)[0]

    # Only tokens of val>2 should influence the confidence. Thus, set probabilities to 1 for tokens 0-2
    mask = generated_ids.sequences[:,:-1] > 2
    char_probs[mask] = 1

    # Confidence of each example is cumulative product of token probs
    batch_confidence_scores = char_probs.cumprod(dim=1)[:, -1]
    return [v.item() for v in batch_confidence_scores]



In [None]:
def evaluate_craft_seg(model,trainloader):
    results = []
    confidence = []
    label = []

    model_gpu.eval()
    with torch.no_grad():
        for idx,data in enumerate(tqdm(trainloader)):

            images, labels = data
            images, labels = images['pixel_values'][0].to(device), labels.to(device)

            decoded = model_gpu.module.generate(images,return_dict_in_generate = True, output_scores = True) 
            final_values = processor.batch_decode(decoded.sequences, skip_special_tokens=True)


            confidences = get_confidence_scores(decoded)

            for idx,value in enumerate(labels.cpu().numpy()):
#           if confidences[idx]>.8: # to cull some of the really terrible guesses, probably want to do this in the search function instead
                words_identified[word_log_dic[value]].append(final_values[idx])
            
            results.extend(final_values)
            confidence.extend(confidences)
            label.extend(labels.cpu().numpy())

    return results,confidence,label

In [None]:
#Storing the outputs
results,confidence,labels = evaluate_craft_seg(model,trainloader)
#Saving all the outputs 
df = pd.DataFrame(list(zip(results,confidence,labels)),columns = ['Results','Confidence','Labels'])
df.to_pickle(save_dir+'full_results.pkl')


# Loading the full species, taxon, genus, countries and subdivsion information

In [None]:
# Downloading the string matching files
ALL_SPECIES_FILE = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/corpus_taxon/output/possible_species.pkl'
ALL_GENUS_FILE = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/corpus_taxon/output/possible_genus.pkl'
ALL_TAXON_FILE = '/usr3/graduate/colejh/corpus_taxon.txt'
species = pd.Series(list(pd.read_pickle(ALL_SPECIES_FILE)))
genus = pd.Series(list(pd.read_pickle(ALL_GENUS_FILE)))
taxon = pd.read_csv(ALL_TAXON_FILE,delimiter = "\t", names=["Taxon"]).squeeze()

countries = []
for country in list(pycountry.countries):
    countries.append(country)


subdivisions_dict = {}
subdivisions = []
for subdivision in pycountry.subdivisions:
    subdivisions.append(subdivision.name)
    subdivisions_dict[subdivision.name] = pycountry.countries.get(alpha_2 = subdivision.country_code).name


# Functions to pick the best match from each set of species, genus, taxon, country, and subdivision

In [None]:
# Select the minimum similarity percentage you want (can also be set in pooled_match)
def matches_above_x(df,x):
    return df.loc[df['similarity']>=.75]


In [None]:
def match(main,comparison_file,minimum_similarity):
     # Function takes a main file containing strings, a comparison file to match against main,
     #  and a minimum similarity confidence level. Returns a list of matches based on similarity.

    if not isinstance(comparison_file, pd.Series):

        comparison_file = pd.Series(comparison_file)

    matches = match_strings(main,comparison_file,n_blocks = 'guess',min_similarity = minimum_similarity,max_n_matches = 1)

    return matches


In [None]:
def highest_score_per_image(df,labels,minimum_similarity):
    # Getting the highest score for each individual image 
    index_to_labels = df.copy()
    for a in index_to_labels.right_index.unique():
        index_to_labels.loc[index_to_labels['right_index'] == a, 'right_index'] = labels[a]
    unique_labels = index_to_labels.loc[index_to_labels.groupby('right_index')['similarity'].idxmax()]


    return unique_labels



In [None]:
def pooled_match(comparison_file,labels, minimum_similarity = .7,**kwargs):
    # Take in any number of files containing strings to match against and return a dictionary
    # with keys the same name as input and values as the dataframe with matching information
   
    corpus_list = []
    corpus_name = []
    
    for k,v in kwargs.items():
    # Convert to series (string-grouper requires this type), will work if input is list, array, or series
        if not isinstance(v, pd.Series):
            v = pd.Series(v)
        corpus_list.append(v)
        corpus_name.append(k)
    
    if not isinstance(comparison_file, pd.Series):
        comparison_file = pd.Series(comparison_file)

    func = partial(match, comparison_file = comparison_file,  minimum_similarity = minimum_similarity)
    pool = multiprocessing.Pool()
   
    result_dic = {}
    for i,result in enumerate(pool.map(func,corpus_list)):
        result.columns.values[1] = corpus_name[i]+' Corpus'
        result.columns.values[3] = "Predictions"
        result = result.drop('left_index', axis=1)
        result = highest_score_per_image(result,labels)
        result_dic[corpus_name[i]] = result
   
    return result_dic

 # Performing the matching

In [None]:
# Input to string-grouper all need to be series format
results_series = pd.Series(results)

#running the matching against all files
minimum_similarity = .01 #arbitrary, set here to get every prediction, likely want to set this quite higher
start = time.time()
results = pooled_match(results_series,labels,minimum_similarity =.01,Taxon = taxon,Species = species,Genus = genus,Countries = countries,Subdivisions = subdivisions)
end = time.time()
print('Time to match all strings: ',end-start)
    

# Reading in the ground truth files for tested images

In [None]:
# Reading in the ground truth values
gt_t = workdir+'/taxon_gt.txt'
Taxon_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_t) }

gt_g = workdir+'/geography_gt.txt'
Geography_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_g) }

gt_c = workdir+'/collector_gt.txt'
Collector_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_c) }

comparison_file = {"Taxon":Taxon_truth,"Geography":Geography_truth,"Collector":Collector_truth}

# Checking each result against the ground truth file

In [None]:
# Just for fancy print
 class color:
    PURPLE = '\033[95m'
    CYAN = '\033[96m'
    DARKCYAN = '\033[36m'
    BLUE = '\033[94m'
    GREEN = '\033[92m'
    YELLOW = '\033[93m'
    RED = '\033[91m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    END = '\033[0m'

In [None]:
count = 0
for k,v in all_matches.items():    
    # need to add in the check for other species names, quite a few would match
    prediction_and_imagenumber = list(zip(v[k+' Corpus'], v.right_index))
    print('Evaluation for',k)
    for idx,(prediction,image_number) in enumerate(prediction_and_imagenumber):
        try:
            image = word_log_dic[image_number]
            gt = comparison_file[k][image]
            if gt == prediction:
                count +=1
                print(gt,"||",prediction,'||',image,'||',v['similarity'][idx])
            else:
                print(color.RED+gt+"||"+prediction+'||'+str(image)+'||'+str(v['similarity'][idx])+color.END)
        except KeyError as e:
            print("Ground Truth Not Found for:",word_log_dic[image_number])
    
    acc = count/len(v)
    print(color.BOLD+"Accuracy on Predicted:"+str(acc)+color.END)
    print(color.BOLD+"Total accuracy: "+str(count/len(filenames))+color.END)
    print(color.BOLD+"Total Guessed:"+str(len(prediction_and_imagenumber))+color.END)
    print('\n\n********************************\n\n')
