# Imports

In [None]:
#Imports and installs
import transformers
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# !pip install craft-text-detector
import transformers
from craft_text_detector import Craft # Need to edit the saving function to prepend 0's
import requests 
import torch
import os, random
from PIL import Image,ImageFilter
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets
from torch.utils.data import Dataset
from tqdm import tqdm
import pandas as pd

import numpy as np
import imghdr
import pickle
from pathlib import Path
import cv2
import torch.nn.functional as F
import multiprocessing
from functools import partial
import json

# from string_grouper import match_strings, match_most_similar
# !pip install pycountry
import pycountry
import matplotlib.pyplot as plt
import warnings
import time

import trocr
import matching
import predictions
import results


# Directories

In [None]:
# Suppressing all the huggingface warnings
SUPPRESS = True
if SUPPRESS:
    from transformers.utils import logging
    logging.set_verbosity(40)
# Turning off this warning, isn't relevant for this application
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)

# Location of images
workdir = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/TROCR_Training/goodfiles/' # update this to the desired directory on scc
# Location of the segmentations
output_dir_craft = '/projectnb/sparkgrp/colejh/goodfilescraft'
# Location to save all output files
save_dir = '/projectnb/sparkgrp/colejh/saved_results2/'
# For ground truth labels 
workdir2 = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/scraped-data/drago_testdata/gt_labels' # update this to the desired directory on scc
# Corpus files
ALL_SPECIES_FILE = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/corpus_taxon/output/possible_species.pkl'
ALL_GENUS_FILE = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/corpus_taxon/output/possible_genus.pkl'
ALL_TAXON_FILE = '/usr3/graduate/colejh/corpus_taxon.txt'


# Running craft and saving the segmented images

In [None]:
# initialize the CRAFT model
craft = Craft(output_dir = output_dir_craft,export_extra = False, text_threshold = .7,link_threshold = .4, crop_type="poly",low_text = .3,cuda = True)

# CRAFT on images to get bounding boxes
images = []
corrupted_images = []
no_segmentations = []
boxes = {}
count= 0
img_name = []
box = []
file_types = (".jpg", ".jpeg",".png")
for filename in tqdm(sorted(os.listdir(workdir))):
    if filename.endswith(file_types):
        image = workdir+filename
        try:
            img = Image.open(image) 
            img.verify() # Check that the image is valid
            bounding_areas = craft.detect_text(image)
            if len(bounding_areas['boxes']): #check that a segmentation was found
                images.append(image)
                boxes[image] = bounding_areas['boxes']
                
            else:
                no_segmentations.append(image)
        except (IOError, SyntaxError) as e:
            corrupted_images.append(image)

    

# Getting all the segemnted images into a dataloader, and loading model and processor for trocr

In [None]:
# Deleting empty folders, which occurs if some of the images get no segementation from CRAFT
root = output_dir_craft
folders = list(os.walk(root))[1:]
deleted = []
for folder in folders:
    if not folder[2]:
        deleted.append(folder)
        os.rmdir(folder[0])
        
# Setting up the Tr-OCR model and processor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") 
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten").to(device)

# Use all available gpu's
model_gpu= nn.DataParallel(model,list(range(torch.cuda.device_count()))).to(device)

# Dataloader for working with gpu's
trainset = datasets.ImageFolder(output_dir_craft, transform = processor)
testloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=False)

# For matching words to image
filenames = [s.replace('_crops', '') for s in list(trainset.class_to_idx)]

# For matching the image name with the label name
word_log_dic = {k: v for k,v in enumerate(filenames)}
# For matching the image name with the transriptions
words_identified = {k: [] for v,k in enumerate(filenames)}


# Saving the filenames, word_log_dic and words_identified

In [None]:
# Save filenames
with open(save_dir+'filenames.txt', 'w') as fp:
    for item in filenames:
        # write each item on a new line
        fp.write("%s\n" % item)
# Save word_log_dic 
with open(save_dir+'word_log_dic.json', 'w') as fp:
    json.dump(word_log_dic, fp)
# Save words_identified
with open(save_dir+'words_identified.json', 'w') as fp:
    json.dump(words_identified, fp)

# Running Tr-OCR on the Segmented Images from Craft

In [None]:
#Storing the outputs
results,confidence,labels = trocr.evaluate_craft_seg(model,processor, words_identified,word_log_dic,testloader,device)
#Saving all the outputs in dataframe
df = pd.DataFrame(list(zip(results,confidence,labels)),columns = ['Results','Confidence','Labels'])
df.to_pickle(save_dir+'full_results.pkl')


In [None]:
# First part of final csv with results, confidence level from tr-ocr, and label
combined_df = trocr.combine_by_label(df)

# Adding the image path and all bounding boxes 

df_dictionary = pd.DataFrame(boxes.items(), columns=['Image_Path', 'Bounding_Boxes'])
combined_df = pd.concat([combined_df, df_dictionary], axis=1, join='inner')
display(combined_df.head())

In [None]:
#Save intermediate file
combined_df.to_pickle(save_dir+'/test.pkl')

# Get the Bigrams for all transcriptions


In [None]:
# Creating a new column which contains all bigrams from the transcription, with an associated index for each bigram
bigram_df = combined_df.copy()

bigram_df['Bigrams'] = bigram_df['Transcription'].str.join(' ').str.split(' ')

bigram_df['Bigrams'] = bigram_df['Bigrams'].apply(lambda lst: [lst[i:i+2] for i in range(len(lst) - 1)]).apply(lambda sublists: [' '.join(sublist) for sublist in sublists])
bigram_df['Bigram_idx'] = bigram_df.apply(matching.bigram_indices, axis=1)

# Associating all biagrams with their respective image
bigram_idx = []
for i in range(len(bigram_df)):
    for j in range(len(bigram_df.loc[i, 'Bigrams'])):
        bigram_idx.append((i))
bigram_idx = pd.Series(bigram_idx)

# Getting the bigrams as individual strings
results = pd.Series(bigram_df['Bigrams'].explode().reset_index(drop=True))

# Loading the full species, taxon, genus, countries and subdivsion information

In [None]:
species = pd.Series(list(pd.read_pickle(ALL_SPECIES_FILE)))
genus = pd.Series(list(pd.read_pickle(ALL_GENUS_FILE)))
taxon = pd.read_csv(ALL_TAXON_FILE,delimiter = "\t", names=["Taxon"]).squeeze()

# All countries and subdivisions for matching 
countries = []
for country in list(pycountry.countries):
    countries.append(country)


subdivisions_dict = {}
subdivisions = []
for subdivision in pycountry.subdivisions:
    subdivisions.append(subdivision.name)
    subdivisions_dict[subdivision.name] = pycountry.countries.get(alpha_2 = subdivision.country_code).name


 # Performing the String Matching against all Corpus Files

In [None]:
#running the matching against all files
minimum_similarity = .01 #arbitrary, set here to get every prediction, likely want to set this quite a bit higher
start = time.time()
all_matches = matching.pooled_match(results,bigram_idx,minimum_similarity =.01,Taxon = taxon,Species = species,Genus = genus,Countries = countries,Subdivisions = subdivisions)
end = time.time()
print('Time to match all strings: ',end-start)
    

In [None]:
# save all_matches pickle
with open(save_dir + 'all_matches.pkl', 'wb') as f:
    pickle.dump(all_matches, f)

# Final Output File

In [None]:
# Getting the final dataframe with all output information
final_df = bigram_df.copy()

for k,v in all_matches.items():
    final_df = pd.merge(final_df,v[['right_index','Predictions','similarity',k+'_Corpus']],how = 'left',
                    left_on = 'Labels', right_on = 'right_index')
    # Rename the predictions, similarity, and corpus columns
    final_df = final_df.rename(columns = {'Predictions':k+'_Prediction_String','similarity':k+'_Similarity',k+'_Corpus':k+'_Prediction'})
    # Drop the right_index column
    final_df = final_df.drop(columns = ['right_index'])
    # Dealing with the case where there is no match
    final_df[k+'_Index_Location'] = [x[0].index(x[1]) if x[1] in x[0] else 'No Match Found' for x in zip(final_df['Bigrams'], final_df[k+'_Prediction_String'])]

# Save the final dataframe
final_df.to_pickle(save_dir + 'final_df.pkl')
# Display final dataframe
display(final_df.head(10))

# Reading in the ground truth files for tested images

In [None]:
# Reading in the ground truth values

gt_t = workdir2+'/taxon_gt.txt'
Taxon_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_t) }

gt_g = workdir2+'/geography_gt.txt'
Geography_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_g) }

gt_c = workdir2+'/collector_gt.txt'
Collector_truth = { line.split(":")[0] : line.split(": ")[1].strip() for line in open(gt_c) }

comparison_file = {"Taxon":Taxon_truth,"Countries":Geography_truth,"Collector":Collector_truth}

In [None]:
# Reading in synonym dictionary 
syn_location = '/projectnb/sparkgrp/ml-herbarium-grp/ml-herbarium-data/synonym-matching/output/'
syn_pure = pickle.load(open(syn_location+'syn_pure.pkl','rb'))

#adding United States to dictionary 
syn_pure['united states'] = 'United States of America'

# Checking Accuracy Against all Ground Truth Files

In [None]:
predictions.print_prediction(all_matches,comparison_file,word_log_dic,syn_pure)

In [None]:
# Checking the accuracy of the model at various similarity thresholds
sim = []
acc_pred = []
total_acc = []
types = []
s_match = []
co = []
for i in range(1,10):
    min_simil = i/10
    print(predictions.color.BOLD+"Minimum Similarity: "+str(min_simil)+predictions.color.END)
    a,b,c,d,e= predictions.check_accuracy(all_matches,syn_pure,word_log_dic,comparison_file,min_simil)
    sim.append(min_simil)
    acc_pred.append(a)
    total_acc.append(b)
    s_match.append(c)
    types.append(d)
    co.append(e)

In [None]:
# Print out the accuracy of the model at various similarity thresholds for Taxon and Geography
tlist = list(zip(*acc_pred))
tlist2 = list(zip(*total_acc))
tlist3 = list(zip(*s_match))
tlist4 = list(zip(*types))
tlist5 = list(zip(*co))
df = pd.DataFrame({'sim':sim,types[0][0]+' Number Predicted':tlist5[0],types[0][1]+' Number Predicted':tlist5[1], types[0][0]+' Accuracy Predicted':tlist[0],types[0][1]+' Accuracy Predicted':tlist[1],types[0][0]+' Total Accuracy Predicted':tlist2[0],types[0][1]+' Total Accuracy Predicted':tlist2[1],types[0][0]+' Synonym Matches':tlist3[0],types[0][1]+' Synonym Matches':tlist3[1]})
df = df[['sim',types[0][0]+' Number Predicted',types[0][0]+' Accuracy Predicted',types[0][0]+' Total Accuracy Predicted',types[0][0]+' Synonym Matches',types[0][1]+' Number Predicted',types[0][1]+' Accuracy Predicted',types[0][1]+' Total Accuracy Predicted',types[0][1]+' Synonym Matches']]
display(df)

# Displaying the Bounding Box and associated prediciton, colored by similarity confidence

In [None]:
# get random image from the test set
random_label = random.randint(0, final_df.shape[0])

In [None]:
results.display_image(final_df,all_matches,Taxon_truth,Geography_truth,word_log_dic,random_label)

# Displaying each bounding box with its associated transcription and transcription confidence, colored by transcription confidence

In [None]:
interesting = [22,43,89,200]
for label in interesting:
    print('GT: '+Taxon_truth[word_log_dic[label]])
    results.bounding_confidence_text(final_df,Taxon_truth,word_log_dic,label)

# All bounding boxes for an image

In [None]:
results.all_boxes(final_df,label)