# Common Imports

In [1]:
! if [ ! $pip_done ]; then pip install -q transformers ;fi 
! if [ ! $pip_done ]; then pip install -q datasets jiwer ;fi 
! if [ ! $pip_done ]; then pip install -q sentencepiece ;fi 

pip_done = 1

In [2]:
import torch
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AdamW
from sklearn.model_selection import train_test_split

# Dataset

In [3]:
root_dir = "/kaggle/input/str-arabic-dataset/Arabic_words_train"

In [4]:
column_names = ['image_path', 'text']
df = pd.read_csv("/kaggle/input/str-arabic-dataset/Arabic_words_train/gt.txt",names = column_names)


In [5]:
test_size = 0.2
train_df, test_df = train_test_split(df, test_size=test_size, random_state=42)

In [6]:
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)

In [7]:
train_df.head()

Unnamed: 0,image_path,text
0,word_1902.png,إبداع
1,word_2924.png,مقتل
2,word_3297.png,عروض
3,word_2206.png,نوجا
4,word_2274.png,فعالة


In [8]:
test_df.head()

Unnamed: 0,image_path,text
0,word_1743.png,فى
1,word_3814.png,يوم
2,word_1515.png,على
3,word_96.png,فستيفال
4,word_4073.png,الخاصة


In [9]:
train_df.shape, test_df.shape

((3350, 2), (838, 2))

In [10]:
class ArabicSTRDataset(Dataset):
    def __init__(self, root_dir, df, processor, tokenizer, max_target_length):
        self.root_dir = root_dir
        self.df = df
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get file name and text
        file_name = self.df.iloc[idx]['image_path']
        text = self.df.iloc[idx]['text']

        # Prepare image (resize and normalize)
        image_path = f"{self.root_dir}/{file_name}"
        image = Image.open(image_path).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values

        # Encode the text
        labels = self.tokenizer(text, padding="max_length", max_length=self.max_target_length, return_tensors="pt").input_ids
        labels = labels.squeeze()
        labels[labels == self.tokenizer.pad_token_id] = -100


        encoding = {"pixel_values": pixel_values.squeeze(), "labels": labels}
        return encoding

# Model 

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
device

device(type='cuda')

In [13]:
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-stage1")
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-stage1')
model.to(device)

config.json:   0%|          | 0.00/4.21k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/246M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


generation_config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

2024-05-20 20:09:01.553244: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-20 20:09:01.553370: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-20 20:09:01.693662: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]



tokenizer_config.json:   0%|          | 0.00/1.12k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

VisionEncoderDecoderModel(
  (encoder): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTAttention(
            (attention): DeiTSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): Linear(

In [14]:
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 512
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

In [15]:
train_dataset = ArabicSTRDataset(root_dir=root_dir,
                           df=train_df,
                           processor=processor,
                           tokenizer=processor.tokenizer,
                           max_target_length=100)

eval_dataset = ArabicSTRDataset(root_dir=root_dir,
                           df=test_df,
                           processor=processor,
                           tokenizer=processor.tokenizer,
                           max_target_length=100)

In [16]:
batch_size = 8

In [17]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)

# Training

In [18]:
optimizer = AdamW(model.parameters(), lr=5e-5)



In [19]:
def train(EPOCHS, version, model, train_dataloader, optimizer, device):
    
    hist = []
    for epoch in range(EPOCHS):  
       # train
       model.train()
       train_loss = 0.0
       for i, batch in enumerate(tqdm(train_dataloader)):
          # get the inputs
          for k,v in batch.items():
            batch[k] = v.to(device)

          # forward + backward + optimize
          outputs = model(**batch)
          loss = outputs.loss
          loss.backward()
          optimizer.step()
          optimizer.zero_grad()

          train_loss += loss.item()
          #if i % 100 == 0: print(f"Loss: {loss.item()}") 

       print(f"Loss after epoch {epoch}:", train_loss/len(train_dataloader))
       hist.append(train_loss/len(train_dataloader))
       model.save_pretrained(f"version_{version}/epoch_{epoch}")
    
    model.save_pretrained(f"version_{version}/final")
    return hist

In [20]:
train(20, "model0", model, train_dataloader, optimizer, device)

  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 0: 3.2472673105453818


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 1: 2.115452472520614


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 2: 1.9365794698492156


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 3: 1.6834036689385026


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 4: 1.3960979916599883


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 5: 1.0780897122721114


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 6: 0.8039229431840855


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 7: 0.5946513017492818


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 8: 0.4577342018947396


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 9: 0.35260590479257997


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 10: 0.27447615752801263


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 11: 0.2216197220355201


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 12: 0.19348403612298443


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 13: 0.18736610980148746


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 14: 0.151078425470055


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 15: 0.143357198674294


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 16: 0.11924972611997356


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 17: 0.11558081013281654


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 18: 0.10059404833440119


  0%|          | 0/419 [00:00<?, ?it/s]

Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


Loss after epoch 19: 0.0914248395366661


Non-default generation parameters: {'max_length': 512, 'early_stopping': True, 'num_beams': 4, 'length_penalty': 2.0, 'no_repeat_ngram_size': 3}


# Testing

In [21]:
def generate_text(image_path):
    # Read and process the image
    image = Image.open(image_path).convert("RGB")
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

    # Generate text
    outputs = model.generate(pixel_values, num_beams=4, max_length=512, early_stopping=True)
    predicted_text = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
    return predicted_text

In [22]:
generate_text("/kaggle/input/str-arabic-dataset/Arabic_words_train/word_1.png")



'خاص'

In [23]:
generate_text("/kaggle/input/str-arabic-dataset/Arabic_words_train/word_1008.png")

'هيوصل'

In [24]:
generate_text("/kaggle/input/str-arabic-dataset/Arabic_words_train/word_101.png")

'إلحقي'

In [25]:
generate_text("/kaggle/input/str-arabic-dataset/Arabic_words_train/word_1076.png")

'شركة'