In [1]:
pip install torch torchvision transformers diffusers numpy easyocr scipy networkx pillow

Note: you may need to restart the kernel to use updated packages.


In [None]:
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import SmoothingFunction
import torch
import torch.nn.functional as F
from torchvision import transforms
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import editdistance
import easyocr

# Image transformation for InceptionV3
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Required input size for InceptionV3
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Initialize EasyOCR reader
reader = easyocr.Reader(['en'])  # Initialize for English

def extract_text(image):
    result = reader.readtext(np.array(image))
    return ' '.join([text for _, text, _ in result])

def calculate_cer(reference, hypothesis):
    return editdistance.eval(reference, hypothesis) / max(len(reference), 1)

def calculate_wer(reference, hypothesis):
    return editdistance.eval(reference.split(), hypothesis.split()) / max(len(reference.split()), 1)

def calculate_bleu(reference, hypothesis):
    smoothing_function = SmoothingFunction().method1
    return sentence_bleu([reference.split()], hypothesis.split(), smoothing_function=smoothing_function)

def calculate_ssim(img1, img2):
    return ssim(np.array(img1.convert("L")), np.array(img2.convert("L")))

def calculate_psnr(img1, img2):
    return psnr(np.array(img1.convert("L")), np.array(img2.convert("L")))

def calculate_fid(real_image, generated_image):
    inception_model = inception_v3(pretrained=True, transform_input=False).eval()
    
    def get_activations(image):
        image = transform(image).unsqueeze(0)  # Convert to tensor and add batch dimension
        with torch.no_grad():
            activations = inception_model(image)
        return activations.cpu().numpy().squeeze()

    real_activations = get_activations(real_image)
    generated_activations = get_activations(generated_image)

    # Ensure activations are 2D
    real_activations = real_activations.reshape(1, -1)
    generated_activations = generated_activations.reshape(1, -1)

    # Replace NaN and Inf values
    real_activations = np.nan_to_num(real_activations, nan=0.0, posinf=1e10, neginf=-1e10)
    generated_activations = np.nan_to_num(generated_activations, nan=0.0, posinf=1e10, neginf=-1e10)

    if real_activations.shape[0] > 1 and generated_activations.shape[0] > 1:
        mu1, sigma1 = real_activations.mean(axis=0), np.cov(real_activations, rowvar=False)
        mu2, sigma2 = generated_activations.mean(axis=0), np.cov(generated_activations, rowvar=False)
    else:
        # Handle the case where we don't have enough samples
        print("Warning: Not enough samples to compute covariance. Using diagonal covariance.")
        mu1, sigma1 = real_activations.mean(axis=0), np.diag(real_activations.var(axis=0))
        mu2, sigma2 = generated_activations.mean(axis=0), np.diag(generated_activations.var(axis=0))

    diff = mu1 - mu2

    # Handle potential numerical instabilities
    eps = 1e-6
    sigma1 = sigma1 + np.eye(sigma1.shape[0]) * eps
    sigma2 = sigma2 + np.eye(sigma2.shape[0]) * eps

    # Ensure no NaN or Inf values in covariance matrices
    sigma1 = np.nan_to_num(sigma1, nan=eps, posinf=1e10, neginf=-1e10)
    sigma2 = np.nan_to_num(sigma2, nan=eps, posinf=1e10, neginf=-1e10)

    try:
        covmean = sqrtm(sigma1.dot(sigma2))
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
        return np.real(fid)
    except ValueError as e:
        print(f"Error in FID calculation: {e}")
        return None

def evaluate_model(original_image, corrected_image, target_text):
    results = {}
    try:
        print("Extracting text from original image...")
        original_text = extract_text(original_image)
        print("Extracting text from corrected image...")
        corrected_text = extract_text(corrected_image)
        
        print("Calculating CER...")
        results['CER'] = calculate_cer(target_text, corrected_text)
        print("Calculating WER...")
        results['WER'] = calculate_wer(target_text, corrected_text)
        print("Calculating BLEU...")
        results['BLEU'] = calculate_bleu(target_text, corrected_text)
        print("Calculating SSIM...")
        results['SSIM'] = calculate_ssim(original_image, corrected_image)
        print("Calculating PSNR...")
        results['PSNR'] = calculate_psnr(original_image, corrected_image)
        print("Calculating FID...")
        results['FID'] = calculate_fid(original_image, corrected_image)
        
        results['Original Text'] = original_text
        results['Corrected Text'] = corrected_text
        results['Target Text'] = target_text
    except Exception as e:
        print(f"An error occurred: {e}")
    return results

# Example usage
try:
    print("Loading images...")
    original_image = Image.open("Incorrect_Images/Incorrect_SOTP_sign.jpg").convert("RGB")
    corrected_image = Image.open("Corrected_Images/Corrected_STOP_mps1.png").convert("RGB")
    print("Images loaded successfully")
    
    target_text = "STOP"

    print("Starting evaluation...")
    evaluation_results = evaluate_model(original_image, corrected_image, target_text)
    print("Evaluation complete. Results:")
    print(evaluation_results)
except Exception as e:
    print(f"An error occurred in the main execution: {e}")


Loading images...
Images loaded successfully
Starting evaluation...
Extracting text from original image...
Extracting text from corrected image...
Calculating CER...
Calculating WER...
Calculating BLEU...
Calculating SSIM...
Calculating PSNR...
Calculating FID...
Evaluation complete. Results:
{'CER': 0.0, 'WER': 0.0, 'BLEU': 0.1778279410038923, 'SSIM': np.float64(0.955556026227496), 'PSNR': np.float64(18.04159816073095), 'FID': np.float64(58.298362731933594), 'Original Text': 'SOTP', 'Corrected Text': 'STOP', 'Target Text': 'STOP'}
