In [None]:
import os
import cv2
import torch
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import albumentations
import matplotlib.pyplot as plt

from tqdm import tqdm
from train import train
from source.network import ConvRNN
from utils import load_obj, save_obj
from source.dataset import TRSynthDataset
from sklearn.model_selection import train_test_split

%matplotlib inline

In [None]:
with open('D:\Megha\PPro\textdetectioninimage\modular_code (1)\modular_code\src\config.py') as f:
  config = yaml.safe_load(f)

extn = config["extn"]
epochs = config["epochs"]
log_path = config["log_path"]
mask_path = config["mask_path"]
image_path = config["image_path"]
model_path = config["model_path"]

In [None]:
epochs = config.epochs
batch_size = config.batch_size
model_path = config.model_path
char2int_path = config.char2int_path
int2char_path = config.int2char_path
data_file_path = config.data_file_path
image_path = config.image_path
label_path = config.label_path

### Inspect some images

In [None]:
img = plt.imread(os.path.join(image_path,"00000017.jpg"))
plt.imshow(img)
plt.show()

### Check corresponding label

In [None]:
labels = pd.read_table(label_path, header=None)

In [None]:
labels.head()

In [None]:
labels.iloc[17,0]

In [None]:
labels.shape

### Total number of images

In [None]:
len(os.listdir(image_path))

### Find null values in labels

In [None]:
labels.isna().sum()

In [None]:
labels[labels[0].isna()]

### Let's check those images

In [None]:
img = plt.imread(os.path.join(image_path, "00019198.jpg"))
plt.imshow(img)
plt.show()

In [None]:
img = plt.imread(os.path.join(image_path, "00074347.jpg"))
plt.imshow(img)
plt.show()

### Replace those missing values with the string "null"

In [None]:
labels.fillna("null", inplace=True)

### Create a dataframe with image paths and corresponding labels

In [None]:
image_files = os.listdir(image_path)
image_files.sort()
image_files = [os.path.join(image_path, i) for i in image_files]

In [None]:
data_file = pd.DataFrame({"images": image_files, "labels": labels[0]})
data_file.to_csv(data_file_path, index=False)

In [None]:
data_file.head()

In [None]:
img = plt.imread(os.path.join(image_path, "00000004.jpg"))
plt.imshow(img)
plt.show()

### Find the unique characters in the labels

In [None]:
unique_chars = list({l for word in labels[0] for l in word})
unique_chars.sort()

In [None]:
unique_chars

In [None]:
len(unique_chars)

### Create mapping from characters to integer and integer to character and save them to disc

In [None]:
char2int = {a: i+1 for i, a in enumerate(unique_chars)}
int2char = {i+1: a for i, a in enumerate(unique_chars)}

In [None]:
save_obj(char2int, char2int_path)
save_obj(int2char, int2char_path)

# Training the model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() 
                      else "cpu")

### Split the data into train and validation

In [None]:
train_file, valid_file = train_test_split(data_file, test_size=0.2)

### Create train and validation datasets

In [None]:
train_dataset = TRSynthDataset(train_file, char2int)
valid_dataset = TRSynthDataset(valid_file, char2int)

### Define the loss function

In [None]:
criterion = nn.CTCLoss(reduction="sum")
criterion.to(device)

### Number of classes

In [None]:
n_classes = len(char2int)

### Create the model object

In [None]:
model = ConvRNN(n_classes)
model.to(device)

### Define Optimizer

In [None]:
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.0005)

### Define train and validation data loaders

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           drop_last=True)

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           drop_last=False)

### Training loop

In [None]:
for i in range(epochs):
    print(f"Epoch {i+1} of {epochs}...")
    # Run train function
    train_loss = train(model, train_loader, criterion, device, optimizer, test=False)
    # Run validation function
    valid_loss = train(model, valid_loader, criterion, device, test=True)
    print(f"Train Loss: {round(train_loss,4)}, Valid Loss: {round(valid_loss,4)}")
    if valid_loss < best_loss:
        print("Validation Loss improved, saving Model File...")
        # Save model object
        torch.save(model.state_dict(), model_path)
        best_loss = valid_loss

### Load the trained model

In [None]:
model = ConvRNN(n_classes)
model.load_state_dict(torch.load(model_path, 
                                map_location=torch.device('cpu')))

# Set model mode to evaluation
model.eval()

# Prediction

### Pick a test image

In [None]:
test_img = os.path.join(image_path, "00000017.jpg")
img = cv2.imread(test_img)

In [None]:
plt.imshow(plt.imread(test_img))
plt.show()

### Apply augmentations

In [None]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
img_aug = albumentations.Compose(
        [albumentations.Normalize(mean, std,
                                  max_pixel_value=255.0,
                                  always_apply=True)]
    )
augmented = img_aug(image=img)
img = augmented["image"]
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img)
# Create batch dimension (batch of single image)
img = torch.unsqueeze(img, 0)
# Move the image array to CUDA if available
img = img.to(device)

### Take model output

In [None]:
out = model(img)

### Apply softmax and take label predictions

In [None]:
out = torch.squeeze(out, 0)
out = out.softmax(1)
pred = torch.argmax(out, 1)

In [None]:
pred.shape

In [None]:
pred = pred.tolist()

In [None]:
pred

In [None]:
int2char[75]

### Use 'ph' for special character

In [None]:
int2char[0] = "ph"

### Convert integer predictions to string

In [None]:
out = [int2char[i] for i in pred]

In [None]:
out

### Collapse the output

In [None]:
res = list()
res.append(out[0])
for i in range(1, len(out)):
    if out[i] != out[i - 1]:
        res.append(out[i])
res = [i for i in res if i != "ph"]
res = "".join(res)

In [None]:
print(res)