## Import

In [None]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision
from torch.optim import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F

from lion_pytorch import Lion

import math
import time
import numpy as np
from PIL import Image
import cv2
import numpy as np
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from matplotlib import font_manager, rc
from IPython import display
import random
import glob
import os
from os import listdir
from os.path import isfile, join
import warnings
import sys
from tqdm import tqdm
import pickle
import gc
import random
import urllib.request

print("Version of Torch : {0}".format(torch.__version__))
print("Version of TorchVision : {0}".format(torchvision.__version__))

In [None]:
gc.collect()
torch.cuda.empty_cache()

## Hyper Parameter

In [None]:
# Device
USE_CUDA = torch.cuda.is_available()

print("Device : {0}".format("GPU" if USE_CUDA else "CPU"))
device = torch.device("cuda" if USE_CUDA else "cpu")
cpu_device = torch.device("cpu")

# Train
EPOCHS = 6
BATCH_SIZE = 20
START_EPOCH = 1

lr = 0.0001

IMAGE_SIZE = 256
MAX_LEN = 10
DATASET_PATH = [
#    "/kaggle/input/large-captcha-dataset/Large_Captcha_Dataset",
#    "/kaggle/input/captcha-dataset",
#    "/kaggle/input/comprasnet-captchas/comprasnet_imagensacerto",
#    "/kaggle/input/captcha-images",
#    "/kaggle/input/captcha-version-2-images/",
#    "/kaggle/input/new-captcha1000",
    "/kaggle/input/new-captcha-30000"
]

BAN_DATA = [
]

RANDOM_SEED = 2004

USE_CHECKPOINT = True
CHECKPOINT_PATH = "/kaggle/input/captcha/pytorch/default/2/Checkpoint.pth"

In [None]:
def random_seed():
    torch.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.cuda.manual_seed_all(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(RANDOM_SEED)
    random.seed(RANDOM_SEED)

    print('Random Seed : {0}'.format(RANDOM_SEED))
    
random_seed()

In [None]:
if USE_CHECKPOINT:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    START_EPOCH = checkpoint["epoch"]+1
    print("Loading Checkpoint [START_EPOCH : {0}]".format(START_EPOCH))
else :
    START_EPOCH = 1
    print("Training New Model")

## Dataset

In [None]:
special_char_list = ["<pad>"] # ["<unk>", "<pad>"]
num_list = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
upper_alphabet_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
lower_alphabet_list = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

string_list = special_char_list + num_list + upper_alphabet_list + lower_alphabet_list
CHAR_NUM = len(string_list)

token_dictionary = {i : string_list[i] for i in range(len(string_list))}
reversed_token_dictionary = {v: k for k, v in token_dictionary.items()}

print(CHAR_NUM)

In [None]:
def torch_tensor_to_plt(img):
    img = img.detach().numpy()[0]
    img = np.transpose(img, (1, 2, 0))
    return img 

In [None]:
transformer = transforms.Compose([transforms.ToTensor(),
                                  torchvision.transforms.Resize((IMAGE_SIZE,IMAGE_SIZE)),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                 ])

In [None]:
class ImageToTextDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.transformer = transform
        self.file = []
        
        file_list = glob.glob(join(self.path, '*'))
        self.file = [file for file in file_list if (file.endswith(".png") or file.endswith(".jpg"))]
        
        for ban_file in BAN_DATA:
            if ban_file in self.file:
                self.file.remove(ban_file)
        
        self.num = len(self.file)
        
    def __len__(self):
        return self.num
    
    def transform(self, image):
        if self.transformer!=None:
            return self.transformer(image)
        else :
            return image

    def __getitem__(self, idx):
        filename = self.file[idx]
        
        Y = []
        for char in list(filename.split("/")[-1].split(".")[0]):
            if(char == "_"):
                break
            Y.append(reversed_token_dictionary[char])
            
        if len(Y) < MAX_LEN:
            Y += [reversed_token_dictionary["<pad>"]]*(MAX_LEN-len(Y))
        
        img = cv2.imread(self.file[idx])
        try:
            sketch_image = cv2.cvtColor(img[:,:256,:], cv2.COLOR_BGR2RGB)
        except:
            print(self.file[idx])
        X = self.transform(sketch_image)
        
        Y_tensor_list = []
        for y_ind in Y:
            y_tensor = torch.zeros(CHAR_NUM)
            y_tensor[y_ind] = 1
            Y_tensor_list.append(y_tensor.unsqueeze(0))

        return X, torch.tensor(Y), torch.tensor(Y) #torch.cat(Y_tensor_list).transpose(-1, -2), torch.tensor(Y)

In [None]:
dataset_list = []
for dataset_path in DATASET_PATH:
    dataset_list.append(ImageToTextDataset(dataset_path, transform=transformer))
dataset = torch.utils.data.ConcatDataset(dataset_list)
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset)-len(dataset)//10, len(dataset)//10])


train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)


print("Data ratio : {0}:{1}".format(len(dataset)-len(dataset)//10, len(dataset)//10))
print("Train data size : {0}".format(len(train_dataset)))
print("Test data size: {0}".format(len(test_dataset)))

In [None]:
x_val, _, target = dataset[0]
print(x_val.shape)
fig = plt.figure(figsize=(2, 2))
plt.imshow(torch_tensor_to_plt(x_val.unsqueeze(0)), cmap='gray')
plt.axis('off')
plt.title(', '.join(map(str, target.tolist())))
plt.show()

## Model - LACC (LAbel Combination Classifier)

In [None]:
class LACC(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torchvision.models.efficientnet_v2_m().features
        self.converter = nn.parameter.Parameter(torch.ones(64, CHAR_NUM))        

        self.silu = nn.SiLU()
        self.linear1 = nn.Linear(1280, 512)
        self.linear2 = nn.Linear(512, 64)
        self.linear3 = nn.Linear(64, MAX_LEN)
        

    def forward(self, x):
        feature = self.encoder(x)
        #print(feature.shape)
        feature = torch.flatten(feature, start_dim=2)
        #print(feature.shape)
        feature = torch.matmul(feature, self.converter)
        
        y = feature.transpose(-1, -2)
        y = self.linear1(y)
        y = self.silu(y)
        y = self.linear2(y)
        y = self.silu(y)
        y = self.linear3(y)
        
        return y

In [None]:
model = LACC().to(device)

## Train

In [None]:
optimizer = Lion(model.parameters(), lr=lr, weight_decay=1e-2)
criterion = nn.CrossEntropyLoss()

In [None]:
def calculate_loss(predict, y):
    return criterion(predict, y)

In [None]:
def getSimilar(list1, list2):
    correct = 0
    for item1, item2 in zip(list1, list2):
        if item1==item2:
            correct += 1
    return correct    

def getCorrect(list1, list2):
    if ''.join(map(str,list1))==''.join(map(str,list2)):
        return 1
    else :
        return 0 

In [None]:
def evalSample(model, x, target, batch=0):
    def replaceSpeicalToken(text):
        text = text.replace('<pad>','□')
        text = text.replace('<unk>','?')
        return text
    
    x, target = x.to(device), target.to(device)
    model.eval()
    
    predict = model(x[batch].unsqueeze(0))
    predict = F.log_softmax(predict, dim=-2)
    predict = torch.argmax(predict, dim=-2)
    
    predict_text = ""
    for token in predict[0].to(cpu_device).tolist():
    
        predict_text += str(token_dictionary[token])
        
    target_text = ""
    for token in target[0].to(cpu_device).tolist():
        target_text += str(token_dictionary[token])
        
    predict_text = replaceSpeicalToken(predict_text)
    target_text = replaceSpeicalToken(target_text)
        
        
        
    fig = plt.figure(figsize=(2, 2))
    plt.imshow(torch_tensor_to_plt(x.to(cpu_device)[batch].unsqueeze(0)), cmap='gray')
    plt.axis('off')
    plt.title(f"Answer of AI : {predict_text} [Real Answer : {target_text}]")
    plt.show()
    
    return predict_text, target_text

In [None]:
def train_one_epoch(model, optimizer, train_dataloader, test_dataloader, epoch=None):
    train_loss = 0.0
    test_loss = 0.0
    accurate = 0.0
    hard_accurate = 0.0
    i = 1
    start_time = time.time()
    
    # Training
    model.train()
    for x, y, label_target in train_dataloader:
        x, y = x.to(device), y.to(device)
        
        model.zero_grad()
        
        predict = model(x)
        predict = F.log_softmax(predict, dim=-2)
        predict_text = torch.argmax(predict, dim=-2)
        
        loss = calculate_loss(predict, y)

        
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        i += 1
        if(i%1000 == 0):
            print(f"i = {i}, time : {time.time() - start_time}, train loss = {train_loss/i}")
    train_loss /= len(train_dataloader) 
    
    # Testing
    model.eval()
    
    for x, y, target in test_dataloader:
        x, y, target = x.to(device), y.to(device), target.to(device)
        
        predict = model(x)
        predict = F.log_softmax(predict, dim=-2)
        
        loss = calculate_loss(predict, y)
        
        predict = torch.argmax(predict, dim=-2)
        
        for predict_item, y_item in zip(predict, target):
            accurate += getSimilar(predict_item, y_item)/(MAX_LEN*BATCH_SIZE)
            hard_accurate += getCorrect(predict_item, y_item)/(BATCH_SIZE)
        
        test_loss += loss.item()
        
    test_loss /= len(test_dataloader)    
    accurate /= len(test_dataloader)
    hard_accurate /= len(test_dataloader)
    
    if epoch != None:
        print(f"[Epoch {epoch}] Train Loss : {train_loss} & Test Loss : {test_loss} & Accurate : {accurate*100}% & Hard-Accurate : {hard_accurate*100}%")
        
    return train_loss, test_loss

In [None]:
if USE_CHECKPOINT:
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
train_loss_list = []
test_loss_list = []

for epoch in range(START_EPOCH, START_EPOCH+EPOCHS):
    train_loss, test_loss = train_one_epoch(model, optimizer, train_dataloader, test_dataloader, epoch=epoch)
    
    train_loss_list.append(train_loss)
    test_loss_list.append(test_loss)

In [None]:
x = np.array(list(range(START_EPOCH, START_EPOCH+EPOCHS)))
plt.plot(x, np.array(train_loss_list),label='train')
plt.plot(x, np.array(test_loss_list),label='test')
plt.xlim([1, EPOCHS])
plt.title(f"Loss of IRT")
plt.legend(loc='upper right')
plt.show()

In [None]:
for ind, (x, _, y) in enumerate(test_dataloader):
    if ind > 10:
        break
    old_time = time.time()
    predict_text, target_text = evalSample(model, x, y)
    print(f"{time.time()-old_time}s")

## Save

In [None]:
torch.save({
            'epoch': START_EPOCH+EPOCHS-1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, 'Checkpoint.pth')