In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import cv2
from PIL import Image
import torch.utils.model_zoo as model_zoo
import torch.onnx

In [2]:
import torch.nn.functional as F

class BRAIT_CNN(nn.Module):
  def __init__(self):
    super(BRAIT_CNN, self).__init__()
    self.brait1 = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),
        nn.MaxPool2d(kernel_size=2),
        nn.Dropout(p=0.01))
    self.brait2 = nn.Sequential(
        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),
        nn.MaxPool2d(kernel_size=2),
        nn.Dropout(p=0.01))
    self.brait3 = nn.Sequential(
        nn.Linear(32*7*7, 100),
        nn.Linear(100, 26))

  def forward(self, x):
    y = F.relu(self.brait1(x))
    y = F.relu(self.brait2(y))

    #flatten
    y = y.view(-1, 32*7*7)
    y = F.relu(self.brait3(y))

    return y

In [3]:
def export_model():
    model = CNN()
    model.load_state_dict(torch.load("D:\BRAIT-ML\BRAIT-WEIGHT\BRAIT_PYTORCH.pth"))

    # Input to the model
    x = torch.randn(5, 3, 28, 28)

    # Export the model
    torch_out = torch.onnx._export(model,  # model being run
                                   x,  # model input (or a tuple for multiple inputs)
                                   "model.onnx-2",
                                   # where to save the model (can be a file or file-like object)
                                   export_params=True)  # store the trained parameter weights inside the model file

In [4]:
def inspect_model():
    # Input image into the ONNX model
    onnx_model = onnx.load("D:\BRAIT-ML\BRAIT-WEIGHT\BRAIT.onnx")
    model = onnx_caffe2.backend.prepare(onnx_model)

    image = Image.open("z.jpg")
    # # image = image.convert('RGB')
    image = np.array(image)
    image = cv2.resize(image, (28, 28))
    image = image.astype(np.float32) / 255.0
    image = torch.from_numpy(image[None, :, :, :])
    image = image.permute(0, 3, 1, 2)
    W = {model.graph.input[0].name: image.data.numpy()}
    model_out = model.run(W)[0]
    print(model_out)

In [8]:
def BRAIT_PREDICTION(img_path):
    #load model BRAIT_CNN
    model = BRAIT_CNN()
    model.load_state_dict(torch.load('D:\\BRAIT-ML\\BRAIT-WEIGHT\\BRAIT_PYTORCH.pth', map_location=torch.device('cpu')))

    #prepocess image inputan convert jadi RGB
    image = Image.open(img_path)
    image = image.convert('RGB')

    #segmentasi image inputan
    width, height = image.size #mengambil ukuran size image
    jumlah_segment = round(width/height/0.78) #menentukan jumlah segment huruf braille
    print(jumlah_segment)
    segment = width/jumlah_segment
    print(segment)
    
    tamp=[]
    for i in range (0,jumlah_segment):
        cropped = image.crop((i*segment,0,(i+1)*segment,height))
        cropped = np.array(cropped)
        cropped = cv2.resize(cropped, (28, 28))
        cropped = cropped.astype(np.float32) / 255.0
        cropped = torch.from_numpy(cropped[None, :, :, :])
        cropped = cropped.permute(0, 3, 1, 2)
        predicted_tensor = model(cropped)
        _, predicted_letter = torch.max(predicted_tensor, 1)
        if int(predicted_letter) == 26:
            tamp.append(chr(32))
        else:
            tamp.append(chr(97 + predicted_letter))     
        
    return tamp

print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\family.jpg"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\home.jpg"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\Prairie.jpg"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\threw_the_ball.png"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\would.png"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\with_his_family.png"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\the_little.png"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\BRAIT-SAMPLE\little_girl.png"))
print(BRAIT_PREDICTION(r"D:\BRAIT-ML\dataset-BRAIT\dataset\test\i\P_20180609_101612_2_2_1.jpg"))

6
31.0
['f', 'a', 'm', 'i', 'l', 'y']
4
31.5
['h', 'r', 'm', 'e']
7
76.28571428571429
['p', 'r', 'a', 'i', 'r', 'i', 'e']
14
32.07142857142857
['t', 'h', 'r', 'e', 'w', 's', 't', 'a', 'e', 'c', 'a', 'a', 'l', 'l']
5
30.8
['w', 'o', 'u', 'l', 'd']
15
31.6
['w', 'i', 't', 'h', 'k', 'h', 'i', 's', 'c', 'f', 'a', 'm', 'i', 'l', 'y']
10
31.2
['t', 'h', 'e', 'a', 'l', 'i', 't', 't', 'l', 'd']
11
32.18181818181818
['l', 'i', 't', 't', 'l', 'e', 'z', 'g', 'i', 'a', 'l']
1
386.0
['p']
