<a href="https://colab.research.google.com/github/Aryan-Dessai-25/AIofGod3.0/blob/main/inference_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

ai_of_god_3_path = kagglehub.competition_download('ai-of-god-3')
aryandessai_badambeyond_path = kagglehub.dataset_download('aryandessai/badambeyond')

print('Data source import complete.')


In [None]:
%pip install evaluate

In [None]:
%pip install jiwer

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import torch
import re
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm.notebook import tqdm
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from transformers import  Trainer, TrainingArguments
from torch.nn.utils.rnn import pad_sequence
from evaluate import load
import torchvision.transforms as transforms
import albumentations as A
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import wandb
wandb.init(mode= 'disabled')

In [None]:
class cfg:
    device='cuda' if torch.cuda.is_available() else 'cpu'
    lr=1e-5
    train_batch_size=2
    eval_batch_size=2
    num_epochs=10
    wt_decay=0.01
    val_ratio=0.1

In [None]:

class CustomDataset(Dataset):
    def __init__(self,images,gt_text,processor,preprocess=True):
        self.images=images
        self.texts=gt_text
        self.processor=processor
        self.preprocess=preprocess

        if self.preprocess:
            self.transform = A.Compose([
                A.OneOf([
                    A.Rotate(limit=2, p=1.0),
                    A.GaussNoise(var_limit=(10.0, 30.0), p=1.0),
                    A.ElasticTransform(alpha=0.3, sigma=50.0, alpha_affine=None, p=1.0),
                    A.OpticalDistortion(distort_limit=0.03, shift_limit=0.03, p=1.0),
                    A.CLAHE(clip_limit=2, tile_grid_size=(4, 4), p=1.0),
                    A.Affine(scale=(0.95, 1.05), translate_percent=(0.01, 0.01), shear=(-2, 2), p=1.0),
                    A.Perspective(scale=(0.01, 0.03), p=1.0),
                    A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=1.0),
                    A.GaussianBlur(blur_limit=(3, 7), p=1.0),
                    A.GridDistortion(num_steps=3, distort_limit=0.02, p=1.0),
                    A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=10, val_shift_limit=10, p=1.0),
                    A.MedianBlur(blur_limit=3, p=1.0)
                ], p=0.7),
            ])
        else:
            self.transform = A.Compose([])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_name = self.images[idx]
        image = Image.open(image_name)
        text = self.texts[idx]

        # Convert to RGB if RGBA
        if image.mode == 'RGBA':
            image = image.convert('RGB')

        # Convert to grayscale (black and white)
        if not self.preprocess:
            image = image.convert('L')  # Convert to grayscale
            image = np.array(image)
            image = np.stack([image] * 3, axis=-1)  # Repeat to create 3 channels
        else:
            image = np.array(image)

        if image.ndim == 2:
            image = np.expand_dims(image, axis=-1)
            image = np.repeat(image, 3, axis=-1)

        image = (image * 255).astype(np.uint8)

        if self.preprocess:
            augmented = self.transform(image=image)
            image = augmented['image']

        image = Image.fromarray(image)
        image = image.resize((256, 64), Image.BILINEAR)
        image = np.array(image) / 255.0

        if image.shape[-1] == 3:
            image = np.transpose(image, (2, 0, 1))
        else:
            print(image.shape)

        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        pixel_values = pixel_values.squeeze()

        labels = self.processor.tokenizer(text, return_tensors="pt").input_ids
        labels = labels[:, :512]
        labels = labels.squeeze()

        return {"pixel_values": pixel_values, "labels": labels}



In [None]:
#deal with pading text tokens with -100 value
def collate_fn(batch):
    pixel_vals = [it['pixel_values'] for it in batch]
    pixel_values=torch.stack(pixel_vals)

    labels=[it['labels'] for it in batch]
    labels=pad_sequence(labels, batch_first=True, padding_value=-100)

    return {'pixel_values':pixel_values, 'labels':labels}

In [None]:
wer = load("wer")
def compute_metrics(eval_pred):
    #processor=TTrOCRProcessor.from_pretrained("qantev/trocr-large-spanish", do_rescale=False)
    logits, labels = eval_pred
    if isinstance(logits, tuple):
        logits = logits[0]
    preds = logits.argmax(-1)

    decoded_preds = processor.tokenizer.batch_decode(preds, skip_special_tokens=True)

    decoded_labels = []
    for label in labels:
        label_filtered = [token for token in label if token != -100]
        decoded_label = processor.tokenizer.decode(label_filtered, skip_special_tokens=True)
        decoded_labels.append(decoded_label)

    wer_score = wer.compute(predictions=decoded_preds, references=decoded_labels)

    return {"wer": wer_score}

In [None]:
processor = TrOCRProcessor.from_pretrained("/kaggle/input/badambeyond/bsk/bskundu", do_rescale=False)
model = VisionEncoderDecoderModel.from_pretrained("/kaggle/input/badambeyond/bsk/bskundu")
learned_temp = torch.load("/kaggle/input/badambeyond/bsk/learned_temperature.pth")["temperature"]

In [None]:
test_dir='/kaggle/input/ai-of-god-3/Public_data/test_images'
test_id=[]
test_images=[]
for p in range(len(os.listdir(test_dir))):
    for l in range(24):
        test_id.append(f'P_{p+1}_L_{l+1}')
        test_images.append(f'{test_dir}/Page_{p+1}/L_{l+1}.png')

dataset = CustomDataset( test_images, [""]*len(test_images), processor, preprocess=None)
dataloader = DataLoader(dataset, batch_size=cfg.eval_batch_size, shuffle=False, collate_fn=collate_fn)
model = model.to(cfg.device)
model.eval()

generated_texts = []

with torch.no_grad():
    for batch in dataloader:
        pixel_values = batch["pixel_values"].to(cfg.device)

        # Generate predictions
        generated_ids = model.generate(pixel_values, temperature=learned_temp)
        generated_texts_batch = dataset.processor.batch_decode(generated_ids, skip_special_tokens=True)


        generated_texts.extend(generated_texts_batch)

In [None]:
import re

def add_space_after_punctuation(text):
    return re.sub(r'([,.!?;:])(?!\s|$)', r'\1 ', text)


In [None]:
def post_process(s):
    s=s.replace('ç', 'z')
    s=s.replace('à', 'a')
    s=s.replace('ā ', 'a')
    s=s.replace('ō ', 'o')
    s=s.replace('ā', 'a')
    s=s.replace('è', 'e')
    s=s.replace('vlt', 'ult')
    s=s.replace('vn', 'un')
    #s=s.replace('vno', 'uno')
    #s=s.replace('vna', 'una')
    s=s.replace('uu', 'uv')
    s=s.replace('iue', 'ive')
    s=s.replace('iuo', 'ivo')
    s=s.replace('auo', 'avo')
    s=s.replace('aue', 'ave')
    s=s.replace('aui', 'avi')
    s=s.replace('cin', 'tin')
    s=s.replace('oui', 'ovi')
    s=add_space_after_punctuation(s)
    if s.startswith('"') and s.endswith('"'):
        return s[1:-1]
    return s


In [None]:
def replace_words(input_string):
    replacements = {
        #"eft": "est", #correction for words like esto, esta, este
        " ef ": " es ",
        " defde ": " desde ",
        "deft": "dest", # for destos, destas
        " mifmo ": " mismo ",
        " vna ": " una ",
        " fe ": " se ",
        " lof ": " los ",
        " fi ": " si ",
        " mifma ": " misma ",
        " fu ": " su ",
        " vamof ": " vamos ",
        " eftoy ": " estoy ",
        " rengo ": " tengo ",
        " nof ": " nos ",
        " afi ": " asi ",
        " pvedo ": " puedo ",
        " folo ": " solo ",
        " foy ": " soy ",
        " bveno ": " bueno ",
        " nochef ": " noches ",
        " fve ": " fue ",
        " fer ": " ser ",
        " fon ": " son ",
        " defcuy ": " descuy ",
        " seruir ": " servir ",
        " graciaf ": " gracias ",
        " ralera ": " raleza ",
        " hermof": " hermos",
        " faber ": " saber ",
        " fugeto ": " sugeto ",
        " obfcuro ": " obscuro ",
        " vfar ": " usar ",
        " vfan ": " usan ",
        " fegundo ": " segundo ",
        " lvnes " : " lunes ",
        " martef ": " martes ",
        " prouidencia ": " providencia ",
        " cafa ": " casa ",
        " viuir ": " vivir ",
        " adiof ": " adios "

    }

    for old, new in replacements.items():
        input_string = input_string.replace(old, new)

    return input_string



In [None]:
for i in range(len(generated_texts)):
    generated_texts[i]=post_process(generated_texts[i])
    generated_texts[i]=replace_words(generated_texts[i])

In [None]:
litc={'unique Id':test_id, 'prediction':generated_texts}
sub=pd.DataFrame(litc)
sub.to_csv("submission.csv", index=False)