In [2]:
!pip install trdg

# Imports
import copy
import torch
import random
import pathlib

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision import transforms
from torchvision.datasets import ImageFolder

from tqdm.auto import tqdm
from IPython.display import HTML, display

from trdg.generators import GeneratorFromStrings
from PIL import Image
import os
import csv
import string

import pandas as pd
from torch.utils.data import Dataset, DataLoader

Collecting trdg
  Downloading trdg-1.8.0-py3-none-any.whl (98.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.6/98.6 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting wikipedia>=1.4.0 (from trdg)
  Downloading wikipedia-1.4.0.tar.gz (27 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting diffimg==0.2.3 (from trdg)
  Downloading diffimg-0.2.3.tar.gz (4.1 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting arabic-reshaper==2.1.3 (from trdg)
  Downloading arabic_reshaper-2.1.3-py3-none-any.whl (20 kB)
Collecting python-bidi==0.4.2 (from trdg)
  Downloading python_bidi-0.4.2-py2.py3-none-any.whl (30 kB)
Reason for being yanked: Doesn't work with Python 2[0m[33m
[0mBuilding wheels for collected packages: diffimg, wikipedia
  Building wheel for diffimg (setup.py) ... [?25l[?25hdone
  Created wheel for diffimg: filename=diffimg-0.2.3-py3-none-any.whl size=4019 sha256=066438ce8eb173bfce8e9d8859b6523f92c47bc6e81171295a2ef399db

In [3]:
# @title Set random seed
# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# For DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.
import random
import torch

def set_seed(seed=None, seed_torch=True):
  """
  Function that controls randomness. NumPy and random modules must be imported.

  Args:
    seed : Integer
      A non-negative integer that defines the random state. Default is `None`.
    seed_torch : Boolean
      If `True` sets the random seed for pytorch tensors, so pytorch module
      must be imported. Default is `True`.

  Returns:
    Nothing.
  """
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  """
  DataLoader will reseed workers following randomness in
  multi-process data loading algorithm.

  Args:
    worker_id: integer
      ID of subprocess to seed. 0 means that
      the data will be loaded in the main process
      Refer: https://pytorch.org/docs/stable/data.html#data-loading-randomness for more details

  Returns:
    Nothing
  """
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)

In [4]:
# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.

# Inform the user if the notebook uses GPU or CPU.

def set_device():
  """
  Set the device. CUDA if available, CPU otherwise

  Args:
    None

  Returns:
    Nothing
  """
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("WARNING: For this notebook to perform best, "
        "if possible, in the menu under `Runtime` -> "
        "`Change runtime type.`  select `GPU` ")
  else:
    print("GPU is enabled in this notebook.")

  return device

In [5]:
SEED = 2021
set_seed(seed=SEED)
DEVICE = set_device()

Random seed 2021 has been set.
GPU is enabled in this notebook.


In [8]:
# @title Define output directory and CSV file path

output_dir = "ocr_dataset"
os.makedirs(output_dir, exist_ok=True)
csv_file = os.path.join(output_dir, "labels.csv")

# Function to generate random words
def generate_random_word(length=10):
    letters = string.ascii_lowercase + ' '
    return ''.join(random.choice(letters) for i in range(length))

# Function to save images with labels
def save_handwritten_text_images(output_dir, csv_file, num_samples=1000):
    # Generate random words
    random_words = [generate_random_word(10) for _ in range(num_samples)]

    # Create generator for handwritten text
    generator = GeneratorFromStrings(
        random_words,
        blur=0,  # No blur
        random_blur=False,
        distorsion_type=0,  # No distortion
        size=32,  # Font size
        language="en"  # Language set to English
    )

    labels = []
    fixed_width = 256
    fixed_height = 56

    for count, (img, lbl) in enumerate(tqdm(generator, total=num_samples, desc="Creating Datas")):
        if count >= num_samples:
          break
        img = img.convert("L")  # Convert image to grayscale
        # Resize the image to fixed dimensions
        img = img.resize((fixed_width, fixed_height), Image.ANTIALIAS)
        # Save image with related filename
        img_filename = os.path.join(output_dir, f"{count+1}.png")
        img.save(img_filename)
        # print(f"Saved {img_filename} with label {lbl}")
        labels.append((f"{count+1}.png", lbl))

    # Save labels to CSV
    with open(csv_file, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow( ["image_name", "label"])
        for img_name, label in labels:
            writer.writerow([img_name, label])

# Generate and save images
num_samples = 10000 ## USER 20 000
save_handwritten_text_images(output_dir, csv_file, num_samples=num_samples)

Creating Datas:   0%|          | 0/10000 [00:00<?, ?it/s]

rrrrrrrrrrr


In [11]:
# Require to be defined to obtain a data loader object
class CustomImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None):
        self.img_labels = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_name = self.img_labels.iloc[idx, 0]
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        label = self.img_labels.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        return image, label

# apply corruptions to the preprocess

# -1 : nothing
# 0 : rotation
# 1 : affine tsf
# 2 Gaussian blur
# ...

## USER
tsf = -1 # type of corruption
p = 0.1 # intensity of corruption
##

float_to_odd_number = lambda float_value: (lambda n: n if n % 2 != 0 else n + 1 if n < float_value else n - 1)(int(round(float_value)))

corr_list = [transforms.RandomRotation(degrees=p*20), transforms.RandomAffine(degrees=p*20), transforms.GaussianBlur(float_to_odd_number(p*20))]

# compose transformations including the specified corruption
if tsf == -1:
  transform = transforms.Compose([
                                 transforms.ToTensor()
                                ])
else:
  transform = transforms.Compose([
                                 corr_list[tsf],
                                 transforms.ToTensor()
                                ])

# Create an instance of the custom dataset
csv_file = pathlib.Path('.')/'ocr_dataset/labels.csv' # Using pathlib to be compatible with all OS's
img_dir = pathlib.Path('.')/'ocr_dataset'

dataset = CustomImageDataset(csv_file=csv_file, img_dir=img_dir, transform=transform)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=1000, shuffle=True)

# Iterate through the DataLoader
for images, labels in dataloader:
    print(images.size(), labels)

torch.Size([1000, 1, 56, 256]) ('nnlfzsnd j', 'cjuvsxfzbk', 'cwpx mkqru', 'emnojaypnx', ' mpagz fzn', 'cvknltpsno', 'txhtotypya', 'qqvrdyoilx', 'fkmtda mkh', 'vetccwkwnu', 'ofgbihpmbr', 'vtgqutgxsf', ' vfovixkgh', 'pqssupytj ', 'iovkxjfjno', 'veutbyfdij', 'zrvnvbiphd', 'urfdpfjmqc', 'xx jvt cer', 'jeoxhevvru', 'aycgcmdiou', 'lkdztqxnhl', 'xkpivrd  q', 'rxxwpkrzsn', 'dazgdkizdi', 'mppjcb jny', 'payhupefb ', 'stesmmydrm', 'uckrfjcpaw', 'ftv mpqaih', 'aqdkjgzufi', 'pdrrgjvwft', 'dl oztojrq', 'lhgbvvhmyz', 'jw vsvlaxr', 'dpqreagvkk', 'nbglyattcm', 'oa ozjixjy', 'kgirhgfkoc', 'behnjlkqun', 'caxqyfwanm', 'ulrzxs cot', 'q uwcqhkrj', 'qdjylwzhuu', 'tgutiucage', 'gelgrhnugl', 'mcqxkieewk', 'qagqeernpu', 'aupxwdtdbf', 'gq jejnwiu', 'kfcfixcnej', 'vflbojoiid', 'sqjnaqrglb', ' fdgocfvts', ' naiembjwe', 'hqlnfkpwgx', 'iutudmylna', 'kwpnhrqwwm', 'gff lkbylb', 'jjkhryiwxg', 'hgwgiynfpf', 'tptp vuoup', 'kalqrynavv', 'jqohcuqxlt', 'ydbsisglco', 'xzgyyafjfw', 'hwjidupwdi', 'uwsaxtj pd', 'dreyfekoyh', 'o

In [15]:
#  @title Extract the dataset with a zip file

from google.colab import drive
drive.mount('/content/drive')

%cd /content
!zip -r ocr_dataset.zip ocr_dataset
from google.colab import files
files.download('ocr_dataset.zip')

[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
  adding: ocr_dataset/7622.png (stored 0%)
  adding: ocr_dataset/6009.png (stored 0%)
  adding: ocr_dataset/6790.png (stored 0%)
  adding: ocr_dataset/1553.png (stored 0%)
  adding: ocr_dataset/5297.png (stored 0%)
  adding: ocr_dataset/8371.png (stored 0%)
  adding: ocr_dataset/6805.png (stored 0%)
  adding: ocr_dataset/5661.png (stored 0%)
  adding: ocr_dataset/1839.png (stored 0%)
  adding: ocr_dataset/8982.png (stored 0%)
  adding: ocr_dataset/4019.png (stored 0%)
  adding: ocr_dataset/6012.png (stored 0%)
  adding: ocr_dataset/6087.png (stored 0%)
  adding: ocr_dataset/6724.png (stored 0%)
  adding: ocr_dataset/6453.png (stored 0%)
  adding: ocr_dataset/1745.png (stored 0%)
  adding: ocr_dataset/1505.png (stored 0%)
  adding: ocr_dataset/7237.png (stored 0%)
  adding: ocr_dataset/4814.png (stored 0%)
  adding: ocr_dataset/4990.png (stored 0%)
  adding: ocr_dataset/2579.png (stored 0%)
  add

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>