In [1]:
import os
import glob 
import pandas as pd
import string
import collections
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import DataLoader
import torch.optim as optim
from bnn_layer import BLinear, BConv2d, HLinear

In [2]:
data = glob.glob(os.path.join(r'F:\project\python\OCR\Recognize\data\samples', '*.png'))
path = r'F:\project\python\OCR\Recognize\data\samples'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
DEVICE

device(type='cuda')

In [4]:
all_letters = string.ascii_lowercase + string.digits
mapping = {}
mapping_inv = {}
i = 1
for x in all_letters:
    mapping[x] = i
    mapping_inv[i] = x
    i += 1

In [5]:
# type(all_letters)

In [6]:
num_class = len(mapping)

In [7]:
# mapping

In [8]:
images = []
labels = []
datas = collections.defaultdict(list)
for d in data:
    x = d.split('\\')[-1]

    datas['image'].append(x)

    datas['label'].append([mapping[i] for i in x.split('.')[0]])
df = pd.DataFrame(datas)
df.head()

Unnamed: 0,image,label
0,226md.png,"[29, 29, 33, 13, 4]"
1,22d5n.png,"[29, 29, 4, 32, 14]"
2,2356g.png,"[29, 30, 32, 33, 7]"
3,23mdg.png,"[29, 30, 13, 4, 7]"
4,23n88.png,"[29, 30, 14, 35, 35]"


In [9]:
# datas

In [10]:
df_train, df_test = train_test_split(df, test_size=0.2, shuffle=True)

In [11]:
class CaptchaDataset:
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        data = self.df.iloc[idx]
        image = Image.open(os.path.join(path, data['image'])).convert('L')
        label = torch.tensor(data['label'], dtype=torch.int32)
        
        if self.transform is not None:
            image = self.transform(image)
        
        return image, label
        
        
transform = T.Compose([
    T.ToTensor()
])

train_data = CaptchaDataset(df_train, transform)
test_data = CaptchaDataset(df_test, transform)

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8)
# for images, labels in train_loader:
#     print("Batch 图像张量形状:", images.shape)
#     print("Batch 标签张量形状:", labels.shape)
#     break  

In [12]:
class Bidirectional(nn.Module):
    def __init__(self, inp, hidden, out, lstm=True):
        super(Bidirectional, self).__init__()
        if lstm:
            self.rnn = nn.LSTM(inp, hidden, bidirectional=True)
        else:
            self.rnn = nn.GRU(inp, hidden, bidirectional=True)
        self.embedding = nn.Linear(hidden*2, out)
    def forward(self, X):
        recurrent, _ = self.rnn(X)  # [64, 16, 2048]
        out = self.embedding(recurrent)  # [64, 16, 37]
        return out
    
    
class CRNN(nn.Module):
    def __init__(self, in_channels, output):
        super(CRNN, self).__init__()

        self.cnn = nn.Sequential(  # [16, 1, 50, 200]
                nn.Conv2d(in_channels, 256, 9, stride=1, padding=1),  # [16, 256, 44, 194]
                nn.BatchNorm2d(256),
                nn.MaxPool2d(3, 3), # [16, 256, 14, 64]
                nn.Conv2d(256, 256, (4, 3), stride=1, padding=1),  # [16, 256, 13, 64]
                nn.BatchNorm2d(256))
        
        self.linear = nn.Linear(3328, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.rnn = Bidirectional(256, 1024, output+1)

    def forward(self, X, y=None, criterion = None):
        out = self.cnn(X)              # [16, 256, 13, 64] N C H W
        N, C, h, w = out.size()
        out = out.view(N, -1, w)       # [16, 3328, 64]    N C*H W
        out = out.permute(0, 2, 1)     # [16, 64, 3328]    N W C*H
        out = self.linear(out)         # [16, 64, 256]     N W K
        # out = out.permute(0, 2, 1)   # [16, 256, 64]     N K W
        # out = self.bn1(out)  
        # out = out.permute(0, 2, 1)   # [16, 64, 256]     N W K
        
        out = out.permute(1, 0, 2)     # [64, 16, 256]     W N K
        out = self.rnn(out)            # [64, 16, 37]      W N K
            
        if y is not None:
            T = out.size(0)
            N = out.size(1)
        
            input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int32)
            target_lengths = torch.full(size=(N,), fill_value=5, dtype=torch.int32)
        
            loss = criterion(out, y, input_lengths, target_lengths)
            
            return out, loss
        
        return out, None

In [13]:
class Engine:
    def __init__(self, model, optimizer, criterion, epochs=50, early_stop=False, device='cpu'):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.epochs = epochs
        self.early_stop = early_stop
        self.device = device
        
    def fit(self, dataloader):
        hist_loss = []
        for epoch in range(self.epochs):
            self.model.train()
            tk = tqdm(dataloader, total=len(dataloader))
            for data, target in tk:
                data = data.to(device=self.device)
                target = target.to(device=self.device)

                self.optimizer.zero_grad()

                out, loss = self.model(data, target, criterion=self.criterion)

                loss.backward()

                self.optimizer.step()


                tk.set_postfix({'Epoch':epoch+1, 'Loss' : loss.item()})
                
    def evaluate(self, dataloader):
        self.model.eval()
        loss = 0
        hist_loss = []
        outs = collections.defaultdict(list)
        tk = tqdm(dataloader, total=len(dataloader))
        with torch.no_grad():
            for data, target in tk:
                data = data.to(device=self.device)
                target = target.to(device=self.device)

                out, loss = self.model(data, target, criterion=self.criterion)
               
                target=target.cpu().detach().numpy()
                outs['pred'].append(out)
                outs['target'].append(target)


                hist_loss.append(loss)

                tk.set_postfix({'Loss':loss.item()})
                
        return outs, hist_loss
    
    def predict(self, image):
        image = Image.open(image).convert('L')
        image_tensor = T.ToTensor()(image)
        image_tensor = image_tensor.unsqueeze(0)        
        out, _ = self.model(image_tensor.to(device=self.device))
        out = out.permute(1, 0, 2)  # [16, 64, 37]
        out = out.log_softmax(2)
        out = out.argmax(2)
        out = out.cpu().detach().numpy()
        
        return out
        


In [14]:
model = CRNN(in_channels=1, output=num_class).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CTCLoss()

engine = Engine(model, optimizer, criterion, device=DEVICE)


In [15]:
engine.fit(train_loader)
outs, loss = engine.evaluate(test_loader)

  2%|█▏                                                            | 1/52 [00:00<00:21,  2.42it/s, Epoch=1, Loss=-5.34]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


  8%|████▊                                                         | 4/52 [00:00<00:05,  8.60it/s, Epoch=1, Loss=-3.05]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 19%|███████████▎                                               | 10/52 [00:00<00:02, 16.32it/s, Epoch=1, Loss=0.00557]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 31%|██████████████████▍                                         | 16/52 [00:01<00:01, 20.56it/s, Epoch=1, Loss=-0.275]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 42%|██████████████████████████▏                                   | 22/52 [00:01<00:01, 22.96it/s, Epoch=1, Loss=2.45]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 48%|█████████████████████████████▊                                | 25/52 [00:01<00:01, 23.39it/s, Epoch=1, Loss=3.31]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 60%|████████████████████████████████████▉                         | 31/52 [00:01<00:00, 23.78it/s, Epoch=1, Loss=2.77]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 65%|████████████████████████████████████████▌                     | 34/52 [00:01<00:00, 23.96it/s, Epoch=1, Loss=3.77]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 77%|███████████████████████████████████████████████▋              | 40/52 [00:02<00:00, 23.37it/s, Epoch=1, Loss=4.24]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 88%|███████████████████████████████████████████████████████▋       | 46/52 [00:02<00:00, 23.98it/s, Epoch=1, Loss=3.8]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 94%|██████████████████████████████████████████████████████████▍   | 49/52 [00:02<00:00, 24.04it/s, Epoch=1, Loss=3.78]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


100%|██████████████████████████████████████████████████████████████| 52/52 [00:02<00:00, 20.51it/s, Epoch=1, Loss=3.78]
 12%|███████▎                                                       | 6/52 [00:00<00:01, 27.36it/s, Epoch=2, Loss=3.87]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 17%|██████████▉                                                    | 9/52 [00:00<00:01, 27.21it/s, Epoch=2, Loss=3.88]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 29%|█████████████████▉                                            | 15/52 [00:00<00:01, 27.40it/s, Epoch=2, Loss=3.81]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 40%|█████████████████████████                                     | 21/52 [00:00<00:01, 27.39it/s, Epoch=2, Loss=3.71]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 52%|████████████████████████████████▏                             | 27/52 [00:01<00:00, 25.75it/s, Epoch=2, Loss=3.55]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 63%|███████████████████████████████████████▎                      | 33/52 [00:01<00:00, 26.54it/s, Epoch=2, Loss=3.45]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 69%|██████████████████████████████████████████▉                   | 36/52 [00:01<00:00, 26.62it/s, Epoch=2, Loss=3.37]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 81%|██████████████████████████████████████████████████            | 42/52 [00:01<00:00, 26.95it/s, Epoch=2, Loss=3.35]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 92%|█████████████████████████████████████████████████████████▏    | 48/52 [00:01<00:00, 27.17it/s, Epoch=2, Loss=3.32]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


100%|██████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 26.96it/s, Epoch=2, Loss=3.38]


torch.Size([64, 16, 2048])


  6%|███▋                                                           | 3/52 [00:00<00:01, 27.03it/s, Epoch=3, Loss=3.35]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


  6%|███▋                                                           | 3/52 [00:00<00:01, 27.03it/s, Epoch=3, Loss=3.26]

torch.Size([64, 16, 2048])


 17%|██████████▉                                                    | 9/52 [00:00<00:01, 25.66it/s, Epoch=3, Loss=3.33]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 17%|██████████▉                                                    | 9/52 [00:00<00:01, 25.66it/s, Epoch=3, Loss=3.28]

torch.Size([64, 16, 2048])


 23%|██████████████▎                                               | 12/52 [00:00<00:01, 26.09it/s, Epoch=3, Loss=3.31]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 29%|█████████████████▉                                            | 15/52 [00:00<00:01, 26.67it/s, Epoch=3, Loss=3.31]

torch.Size([64, 16, 2048])


 35%|█████████████████████▊                                         | 18/52 [00:00<00:01, 27.04it/s, Epoch=3, Loss=3.3]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 40%|█████████████████████████                                     | 21/52 [00:00<00:01, 27.28it/s, Epoch=3, Loss=3.31]

torch.Size([64, 16, 2048])


 46%|████████████████████████████▌                                 | 24/52 [00:00<00:01, 27.51it/s, Epoch=3, Loss=3.25]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 52%|████████████████████████████████▏                             | 27/52 [00:01<00:00, 27.36it/s, Epoch=3, Loss=3.34]

torch.Size([64, 16, 2048])


 58%|███████████████████████████████████▊                          | 30/52 [00:01<00:00, 27.26it/s, Epoch=3, Loss=3.29]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 75%|██████████████████████████████████████████████▌               | 39/52 [00:01<00:00, 27.23it/s, Epoch=3, Loss=3.29]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 81%|██████████████████████████████████████████████████            | 42/52 [00:01<00:00, 27.17it/s, Epoch=3, Loss=3.27]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 92%|█████████████████████████████████████████████████████████▏    | 48/52 [00:01<00:00, 27.17it/s, Epoch=3, Loss=3.25]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


100%|██████████████████████████████████████████████████████████████| 52/52 [00:01<00:00, 27.03it/s, Epoch=3, Loss=3.24]


torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


  6%|███▋                                                           | 3/52 [00:00<00:01, 27.52it/s, Epoch=4, Loss=3.25]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


  6%|███▋                                                           | 3/52 [00:00<00:01, 27.52it/s, Epoch=4, Loss=3.26]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 12%|███████▎                                                       | 6/52 [00:00<00:01, 27.08it/s, Epoch=4, Loss=3.25]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 17%|███████████                                                     | 9/52 [00:00<00:01, 27.28it/s, Epoch=4, Loss=3.2]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])


 25%|███████████████▌                                              | 13/52 [00:00<00:01, 26.42it/s, Epoch=4, Loss=3.26]

torch.Size([64, 16, 2048])
torch.Size([64, 16, 2048])





KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

ids = np.random.randint(len(data))

image = data[ids]
out = engine.predict(image)[0]

def show_prediction(out, image):
    gt = image.split('/')[-1].split('.')[0]
    imagePIL = Image.open(image).convert('L')

    pred = ''
    then = 0
    for x in out:
        if then != x:
            if x > 0 :
                pred += mapping_inv[x]
        then = x

    plt.figure(figsize=(15, 12))
    img_array = np.asarray(imagePIL)
    plt.title(f'Ground Truth - {gt} || Prediction - {pred}')
    plt.axis('off')
    plt.imshow(img_array)
    
show_prediction(out, image)

In [None]:
out

In [None]:
saving = {'state_dict':engine.model.state_dict(),
          'optimizer':engine.optimizer.state_dict(),
         'mapping':mapping,
         'mapping_inv':mapping_inv}
torch.save(saving, './model/BNN_english_model.pth')