## Import

In [2]:
import random
import pandas as pd
import numpy as np
import os
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision.models import resnet18, mobilenet_v2
from torchvision import transforms

from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings(action='ignore') 

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
# device = torch.device("cuda:0")

## Hyperparameter Setting

In [6]:
CFG = {
    'IMG_HEIGHT_SIZE':64,
    'IMG_WIDTH_SIZE':224,
    'EPOCHS':80,#80,
    'LEARNING_RATE':1e-3,
    'BATCH_SIZE':256,
    'NUM_WORKERS':4, # 본인의 GPU, CPU 환경에 맞게 설정
    'SEED':41
}

## Fixed RandomSeed

In [7]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(CFG['SEED']) # Seed 고정

## Data Load & Train/Validation Split

In [8]:
df = pd.read_csv('./train.csv')

In [9]:
# 제공된 학습데이터 중 1글자 샘플들의 단어사전이 학습/테스트 데이터의 모든 글자를 담고 있으므로 학습 데이터로 우선 배치
df['len'] = df['label'].str.len()
train_v1 = df[df['len']==1]

In [10]:
# 제공된 학습데이터 중 2글자 이상의 샘플들에 대해서 단어길이를 고려하여 Train (80%) / Validation (20%) 분할
df = df[df['len']>1]
train_v2, val, _, _ = train_test_split(df, df['len'], test_size=0.2, random_state=CFG['SEED'])

In [12]:
# 학습 데이터로 우선 배치한 1글자 샘플들과 분할된 2글자 이상의 학습 샘플을 concat하여 최종 학습 데이터로 사용
train = pd.concat([train_v1, train_v2])
print(len(train), len(val))

141440 29435


## Get Vocabulary

In [13]:
# 학습 데이터로부터 단어 사전(Vocabulary) 구축
train_gt = [gt for gt in train['label']]
train_gt = "".join(train_gt)
letters = sorted(list(set(list(train_gt))))
print(len(letters))

2349


In [14]:
vocabulary = ["-"] + letters
print(len(vocabulary))
idx2char = {k:v for k,v in enumerate(vocabulary, start=0)}
char2idx = {v:k for k,v in idx2char.items()}

2350


## CustomDataset

In [15]:
class CustomDataset(Dataset):
    def __init__(self, img_path_list, label_list, train_mode=True):
        self.img_path_list = img_path_list
        self.label_list = label_list
        self.train_mode = train_mode
        
    def __len__(self):
        return len(self.img_path_list)
    
    def __getitem__(self, index):
        image = Image.open(self.img_path_list[index]).convert('RGB')
        
        if self.train_mode:
            image = self.train_transform(image)
        else:
            image = self.test_transform(image)
            
        if self.label_list is not None:
            text = self.label_list[index]
            return image, text
        else:
            return image
    
    # Image Augmentation
    def train_transform(self, image):
        transform_ops = transforms.Compose([
            transforms.Resize((CFG['IMG_HEIGHT_SIZE'],CFG['IMG_WIDTH_SIZE'])),
            transforms.ToTensor(),
            # Guassian
            transforms.GaussianBlur(kernel_size=(1,1),sigma=(1.25, 1.95)),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)
    
    def test_transform(self, image):
        transform_ops = transforms.Compose([
            transforms.Resize((CFG['IMG_HEIGHT_SIZE'],CFG['IMG_WIDTH_SIZE'])),
            transforms.ToTensor(),
            # Guassian
            transforms.GaussianBlur(kernel_size=(1,1),sigma=(1.25, 1.95)),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
        ])
        return transform_ops(image)

In [16]:
train_dataset = CustomDataset(train['img_path'].values, train['label'].values, True)
train_loader = DataLoader(train_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

val_dataset = CustomDataset(val['img_path'].values, val['label'].values, False)
val_loader = DataLoader(val_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=True, num_workers=CFG['NUM_WORKERS'])

In [17]:
torch.Size()
# minibatchsize, channel size, img size(height), img size(width)

torch.Size([])

In [18]:
image_batch, text_batch = iter(train_loader).next()
print(image_batch.size(), text_batch)

torch.Size([256, 3, 64, 224]) ('뮤', '신화가라앉다', '싱싱하다여군', '작년그곳', '윅', '롭유리창', '발톱무', '최선쿄', '촨남부', '담힘껏', '어질문하다', '인하', '소유자', '부장', '왼쪽', '포스터등', '모범쩡', '국물있다', '적다그려지다', '걷다', '익', '타고나다후기', '수학', '현대인맵', '치우다', '신세재정', '무용전기밥솥', '반장', '김양말', '음력', '앞두다조정', '발생', '흰레이저', '구르다깁', '두', '설득하다지도', '입술고개', '기', '대전며느리', '밝혀지다', '소', '급증하다금', '오', '독일', '장면', '침기르다', '뇌', '썬', '조미료가장', '수학', '닷새', '종합컸', '활동형제', '의도건설', '고향한', '짐작하다', '전기뒬', '뇽', '빠지다', '파', '추석', '쬠', '앍', '뺏', '폭력', '앞길', '막걸리체계적', '소', '기초적전문적', '끝내다', '몇십', '경찰서', '자', '만점그다음', '잘되다절', '아드님', '살', '는', '측', '강도', '차마', '적성', '경영하다소질', '킬로그램', '중단하다저편', '예정되다', '맡다같다', '있피우다', '전개되다거기', '삠지방', '뻠입사하다', '국회의원폼', '몲', '과연내부', '환영밀접하다', '숩', '끝내다', '척하다일대', '베개막걸리', '칭찬대학교수', '경주아무래도', '복잡하다명', '미치다차츰', '물고기', '하나하나', '여름철백', '방송영역', '정답부딪치다', '불확실하다향', '품목노래하다', '꾸리다빼', '납실내', '시장', '사월신기하다', '분중심', '걱정어떡하다', '듯', '잠들다본질', '마요네즈성적', '능력레이저', '가뭄홀로', '발톱줄', '파괴하다튀다', '편', '도장', '시끄럽다출산', '하', '잠기다한둘', '약', '진짜다르다', '일반충돌하다

In [19]:
len(text_batch)

256

## Model Define

In [20]:
class RecognitionModel(nn.Module):
    def __init__(self, num_chars=len(char2idx), rnn_hidden_size=256):
        super(RecognitionModel, self).__init__()
        self.num_chars = num_chars
        self.rnn_hidden_size = rnn_hidden_size
        
        # CNN Backbone = 사전학습된 resnet18 활용
        # https://arxiv.org/abs/1512.03385
        resnet = resnet18(pretrained=True)
        # CNN Feature Extract
        resnet_modules = list(resnet.children())[:-3]
        self.feature_extract = nn.Sequential(
            *resnet_modules,
            nn.Conv2d(256, 256, kernel_size=(3,6), stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.linear1 = nn.Linear(1024, rnn_hidden_size)
        
        # RNN
        self.rnn = nn.RNN(input_size=rnn_hidden_size, 
                            hidden_size=rnn_hidden_size,
                            bidirectional=True, 
                            batch_first=True)
        self.linear2 = nn.Linear(self.rnn_hidden_size*2, num_chars)
        
        
    def forward(self, x):
        # CNN
        x = self.feature_extract(x) # [batch_size, channels, height, width]
        x = x.permute(0, 3, 1, 2) # [batch_size, width, channels, height]
         
        batch_size = x.size(0)
        T = x.size(1)
        x = x.view(batch_size, T, -1) # [batch_size, T==width, num_features==channels*height]
        x = self.linear1(x)
        
        # RNN
        x, hidden = self.rnn(x)
        
        output = self.linear2(x)
        output = output.permute(1, 0, 2) # [T==10, batch_size, num_classes==num_features]
        
        return output

In [21]:
# LSTM
from transformers import ViTFeatureExtractor
class RecognitionModel(nn.Module):
    def __init__(self, num_chars=len(char2idx), lstm_hidden_size=256):
        super(RecognitionModel, self).__init__()
        self.num_chars = num_chars
        self.lstm_hidden_size = lstm_hidden_size
        
        # CNN Backbone = 사전학습된 mobilenet_v2 활용
        mobilenet = mobilenet_v2(pretrained=True)
        
        # CNN Feature Extract
        resnet_modules = list(mobilenet.children())[:-3]
        self.feature_extract = nn.Sequential(
            *resnet_modules,
            nn.Conv2d(3, 128, kernel_size=(3,6), stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.linear1 = nn.Linear(8192, lstm_hidden_size)

        #----------------------------------------------#
        
        self.lstm = nn.LSTM(input_size=lstm_hidden_size, 
                            hidden_size=lstm_hidden_size,
                            num_layers = 2,
                            dropout = 0.2,
                            bidirectional=True, 
                            batch_first=True)
        self.linear2 = nn.Linear(self.lstm_hidden_size*2, num_chars)
        #----------------------------------------------#
        
        
        
    def forward(self, x):
        # CNN
        x = self.feature_extract(x) # [batch_size, channels, height, width]
        x = x.permute(0, 3, 1, 2) # [batch_size, width, channels, height]
         
        batch_size = x.size(0)
        T = x.size(1)
        x = x.view(batch_size, T, -1) # [batch_size, T==width, num_features==channels*height]
        x = self.linear1(x)
        
        x, hidden = self.lstm(x)
        
        output = self.linear2(x)
        output = output.permute(1, 0, 2) # [T==10, batch_size, num_classes==num_features]
        
        return output

## Define CTC Loss

In [22]:
criterion = nn.CTCLoss(blank=0) # idx 0 : '-'

In [23]:
def encode_text_batch(text_batch):
    text_batch_targets_lens = [len(text) for text in text_batch]
    text_batch_targets_lens = torch.IntTensor(text_batch_targets_lens)
    
    text_batch_concat = "".join(text_batch)
    text_batch_targets = [char2idx[c] for c in text_batch_concat]
    text_batch_targets = torch.IntTensor(text_batch_targets)
    
    return text_batch_targets, text_batch_targets_lens

In [24]:
def compute_loss(text_batch, text_batch_logits):
    """
    text_batch: list of strings of length equal to batch size
    text_batch_logits: Tensor of size([T, batch_size, num_classes])
    """
    text_batch_logps = F.log_softmax(text_batch_logits, 2) # [T, batch_size, num_classes]  
    text_batch_logps_lens = torch.full(size=(text_batch_logps.size(1),), 
                                       fill_value=text_batch_logps.size(0), 
                                       dtype=torch.int32).to(device) # [batch_size] 

    text_batch_targets, text_batch_targets_lens = encode_text_batch(text_batch)
    loss = criterion(text_batch_logps, text_batch_targets, text_batch_logps_lens, text_batch_targets_lens)

    return loss

## Train

In [25]:
def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    
    best_loss = 999999
    best_model = None
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for image_batch, text_batch in tqdm(iter(train_loader)):
            image_batch = image_batch.to(device)
            
            optimizer.zero_grad()
            
            text_batch_logits = model(image_batch)
            loss = compute_loss(text_batch, text_batch_logits)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        
        _train_loss = np.mean(train_loss)
        
        _val_loss = validation(model, val_loader, device)
        print(f'Epoch : [{epoch}] Train CTC Loss : [{_train_loss:.5f}] Val CTC Loss : [{_val_loss:.5f}]')
        
        if scheduler is not None:
            scheduler.step(_val_loss)
        
        if best_loss > _val_loss:
            best_loss = _val_loss
            best_model = model
    
    return best_model

## Validation

In [26]:
def validation(model, val_loader, device):
    model.eval()
    val_loss = []
    with torch.no_grad():
        for image_batch, text_batch in tqdm(iter(val_loader)):
            image_batch = image_batch.to(device)
            
            text_batch_logits = model(image_batch)
            loss = compute_loss(text_batch, text_batch_logits)
            
            val_loss.append(loss.item())
    
    _val_loss = np.mean(val_loss)
    return _val_loss

## Run!!

In [27]:
model = RecognitionModel()
# model = nn.DataParallel(model)
model.eval()
optimizer = torch.optim.Adam(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2,threshold_mode='abs',min_lr=1e-8, verbose=True)

infer_model = train(model, optimizer, train_loader, val_loader, scheduler, device)

100%|██████████| 553/553 [01:51<00:00,  4.95it/s]
100%|██████████| 115/115 [00:11<00:00, 10.29it/s]

Epoch : [1] Train CTC Loss : [11.16422] Val CTC Loss : [6.43830]



100%|██████████| 553/553 [01:54<00:00,  4.81it/s]
100%|██████████| 115/115 [00:18<00:00,  6.35it/s]

Epoch : [2] Train CTC Loss : [6.93139] Val CTC Loss : [6.24165]



100%|██████████| 553/553 [02:04<00:00,  4.43it/s]
100%|██████████| 115/115 [00:18<00:00,  6.25it/s]


Epoch : [3] Train CTC Loss : [6.55223] Val CTC Loss : [6.14587]


100%|██████████| 553/553 [02:07<00:00,  4.35it/s]
100%|██████████| 115/115 [00:16<00:00,  7.06it/s]

Epoch : [4] Train CTC Loss : [6.32721] Val CTC Loss : [5.97400]



100%|██████████| 553/553 [02:07<00:00,  4.33it/s]
100%|██████████| 115/115 [00:19<00:00,  5.88it/s]


Epoch : [5] Train CTC Loss : [6.15672] Val CTC Loss : [5.78495]


100%|██████████| 553/553 [02:05<00:00,  4.39it/s]
100%|██████████| 115/115 [00:14<00:00,  7.75it/s]

Epoch : [6] Train CTC Loss : [5.96676] Val CTC Loss : [5.66696]



100%|██████████| 553/553 [02:07<00:00,  4.34it/s]
100%|██████████| 115/115 [00:17<00:00,  6.55it/s]

Epoch : [7] Train CTC Loss : [5.74561] Val CTC Loss : [5.36270]



100%|██████████| 553/553 [02:04<00:00,  4.44it/s]
100%|██████████| 115/115 [00:18<00:00,  6.19it/s]

Epoch : [8] Train CTC Loss : [5.17404] Val CTC Loss : [4.27236]



100%|██████████| 553/553 [02:07<00:00,  4.34it/s]
100%|██████████| 115/115 [00:19<00:00,  5.78it/s]


Epoch : [9] Train CTC Loss : [3.77500] Val CTC Loss : [2.81660]


100%|██████████| 553/553 [02:05<00:00,  4.42it/s]
100%|██████████| 115/115 [00:18<00:00,  6.23it/s]

Epoch : [10] Train CTC Loss : [2.62900] Val CTC Loss : [1.91379]



100%|██████████| 553/553 [02:06<00:00,  4.38it/s]
100%|██████████| 115/115 [00:17<00:00,  6.43it/s]


Epoch : [11] Train CTC Loss : [1.99118] Val CTC Loss : [1.45005]


100%|██████████| 553/553 [02:04<00:00,  4.44it/s]
100%|██████████| 115/115 [00:18<00:00,  6.18it/s]

Epoch : [12] Train CTC Loss : [1.56758] Val CTC Loss : [1.36103]



100%|██████████| 553/553 [02:06<00:00,  4.37it/s]
100%|██████████| 115/115 [00:20<00:00,  5.70it/s]

Epoch : [13] Train CTC Loss : [1.24707] Val CTC Loss : [1.00570]



100%|██████████| 553/553 [02:05<00:00,  4.42it/s]
100%|██████████| 115/115 [00:18<00:00,  6.07it/s]

Epoch : [14] Train CTC Loss : [1.13738] Val CTC Loss : [0.87227]



100%|██████████| 553/553 [02:05<00:00,  4.40it/s]
100%|██████████| 115/115 [00:18<00:00,  6.31it/s]

Epoch : [15] Train CTC Loss : [0.90324] Val CTC Loss : [0.71658]



100%|██████████| 553/553 [02:05<00:00,  4.40it/s]
100%|██████████| 115/115 [00:18<00:00,  6.14it/s]

Epoch : [16] Train CTC Loss : [0.80017] Val CTC Loss : [0.70801]



100%|██████████| 553/553 [02:06<00:00,  4.38it/s]
100%|██████████| 115/115 [00:20<00:00,  5.58it/s]

Epoch : [17] Train CTC Loss : [0.71200] Val CTC Loss : [0.61363]



100%|██████████| 553/553 [02:05<00:00,  4.40it/s]
100%|██████████| 115/115 [00:17<00:00,  6.60it/s]

Epoch : [18] Train CTC Loss : [0.68404] Val CTC Loss : [0.55550]



100%|██████████| 553/553 [02:05<00:00,  4.40it/s]
100%|██████████| 115/115 [00:15<00:00,  7.37it/s]

Epoch : [19] Train CTC Loss : [0.55907] Val CTC Loss : [0.56247]



100%|██████████| 553/553 [02:01<00:00,  4.54it/s]
100%|██████████| 115/115 [00:13<00:00,  8.41it/s]

Epoch : [20] Train CTC Loss : [0.54742] Val CTC Loss : [0.48654]



100%|██████████| 553/553 [01:57<00:00,  4.69it/s]
100%|██████████| 115/115 [00:14<00:00,  8.06it/s]

Epoch : [21] Train CTC Loss : [0.46912] Val CTC Loss : [0.42842]



100%|██████████| 553/553 [01:57<00:00,  4.69it/s]
100%|██████████| 115/115 [00:12<00:00,  9.06it/s]

Epoch : [22] Train CTC Loss : [0.49591] Val CTC Loss : [0.59527]



100%|██████████| 553/553 [01:59<00:00,  4.64it/s]
100%|██████████| 115/115 [00:14<00:00,  7.94it/s]

Epoch : [23] Train CTC Loss : [0.46975] Val CTC Loss : [0.42146]



100%|██████████| 553/553 [01:58<00:00,  4.66it/s]
100%|██████████| 115/115 [00:14<00:00,  8.09it/s]

Epoch : [24] Train CTC Loss : [0.46021] Val CTC Loss : [0.46092]



100%|██████████| 553/553 [01:59<00:00,  4.62it/s]
100%|██████████| 115/115 [00:11<00:00, 10.06it/s]

Epoch : [25] Train CTC Loss : [0.40742] Val CTC Loss : [0.44345]



100%|██████████| 553/553 [01:55<00:00,  4.78it/s]
100%|██████████| 115/115 [00:15<00:00,  7.64it/s]

Epoch : [26] Train CTC Loss : [0.35029] Val CTC Loss : [0.33350]



100%|██████████| 553/553 [01:58<00:00,  4.66it/s]
100%|██████████| 115/115 [00:14<00:00,  7.78it/s]

Epoch : [27] Train CTC Loss : [0.34237] Val CTC Loss : [0.32233]



100%|██████████| 553/553 [01:58<00:00,  4.67it/s]
100%|██████████| 115/115 [00:12<00:00,  8.89it/s]

Epoch : [28] Train CTC Loss : [0.39486] Val CTC Loss : [0.33578]



100%|██████████| 553/553 [01:57<00:00,  4.69it/s]
100%|██████████| 115/115 [00:12<00:00,  8.96it/s]

Epoch : [29] Train CTC Loss : [0.32477] Val CTC Loss : [0.34435]



100%|██████████| 553/553 [01:58<00:00,  4.68it/s]
100%|██████████| 115/115 [00:12<00:00,  9.22it/s]

Epoch : [30] Train CTC Loss : [0.30287] Val CTC Loss : [0.39607]
Epoch    30: reducing learning rate of group 0 to 5.0000e-04.



100%|██████████| 553/553 [01:57<00:00,  4.69it/s]
100%|██████████| 115/115 [00:12<00:00,  8.87it/s]

Epoch : [31] Train CTC Loss : [0.25003] Val CTC Loss : [0.27338]



100%|██████████| 553/553 [01:53<00:00,  4.86it/s]
100%|██████████| 115/115 [00:12<00:00,  9.22it/s]

Epoch : [32] Train CTC Loss : [0.21974] Val CTC Loss : [0.30821]



100%|██████████| 553/553 [01:53<00:00,  4.86it/s]
100%|██████████| 115/115 [00:12<00:00,  9.08it/s]

Epoch : [33] Train CTC Loss : [0.20633] Val CTC Loss : [0.21204]



100%|██████████| 553/553 [01:53<00:00,  4.88it/s]
100%|██████████| 115/115 [00:12<00:00,  8.87it/s]

Epoch : [34] Train CTC Loss : [0.16581] Val CTC Loss : [0.19956]



100%|██████████| 553/553 [01:55<00:00,  4.78it/s]
100%|██████████| 115/115 [00:12<00:00,  8.92it/s]

Epoch : [35] Train CTC Loss : [0.16307] Val CTC Loss : [0.19199]



100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.21it/s]

Epoch : [36] Train CTC Loss : [0.14123] Val CTC Loss : [0.18798]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.43it/s]

Epoch : [37] Train CTC Loss : [0.15696] Val CTC Loss : [0.17318]



100%|██████████| 553/553 [01:54<00:00,  4.83it/s]
100%|██████████| 115/115 [00:12<00:00,  8.89it/s]

Epoch : [38] Train CTC Loss : [0.12503] Val CTC Loss : [0.17272]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.64it/s]

Epoch : [39] Train CTC Loss : [0.12356] Val CTC Loss : [0.15566]



100%|██████████| 553/553 [01:54<00:00,  4.82it/s]
100%|██████████| 115/115 [00:13<00:00,  8.48it/s]

Epoch : [40] Train CTC Loss : [0.12956] Val CTC Loss : [0.17810]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.82it/s]

Epoch : [41] Train CTC Loss : [0.11794] Val CTC Loss : [0.17031]



100%|██████████| 553/553 [01:54<00:00,  4.81it/s]
100%|██████████| 115/115 [00:13<00:00,  8.74it/s]

Epoch : [42] Train CTC Loss : [0.11671] Val CTC Loss : [0.17667]
Epoch    42: reducing learning rate of group 0 to 2.5000e-04.



100%|██████████| 553/553 [01:54<00:00,  4.82it/s]
100%|██████████| 115/115 [00:12<00:00,  8.95it/s]

Epoch : [43] Train CTC Loss : [0.09001] Val CTC Loss : [0.12771]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.63it/s]

Epoch : [44] Train CTC Loss : [0.07388] Val CTC Loss : [0.12708]



100%|██████████| 553/553 [01:54<00:00,  4.83it/s]
100%|██████████| 115/115 [00:11<00:00,  9.62it/s]


Epoch : [45] Train CTC Loss : [0.07124] Val CTC Loss : [0.12606]


100%|██████████| 553/553 [01:54<00:00,  4.84it/s]
100%|██████████| 115/115 [00:12<00:00,  9.18it/s]

Epoch : [46] Train CTC Loss : [0.08008] Val CTC Loss : [0.13814]



100%|██████████| 553/553 [01:53<00:00,  4.85it/s]
100%|██████████| 115/115 [00:11<00:00,  9.81it/s]

Epoch : [47] Train CTC Loss : [0.07933] Val CTC Loss : [0.11994]



100%|██████████| 553/553 [01:53<00:00,  4.86it/s]
100%|██████████| 115/115 [00:13<00:00,  8.71it/s]

Epoch : [48] Train CTC Loss : [0.06392] Val CTC Loss : [0.11583]



100%|██████████| 553/553 [01:54<00:00,  4.81it/s]
100%|██████████| 115/115 [00:13<00:00,  8.61it/s]

Epoch : [49] Train CTC Loss : [0.05838] Val CTC Loss : [0.10584]



100%|██████████| 553/553 [01:53<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.46it/s]

Epoch : [50] Train CTC Loss : [0.06128] Val CTC Loss : [0.10916]



100%|██████████| 553/553 [01:53<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.14it/s]

Epoch : [51] Train CTC Loss : [0.05424] Val CTC Loss : [0.10821]



100%|██████████| 553/553 [01:53<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  8.96it/s]


Epoch : [52] Train CTC Loss : [0.05079] Val CTC Loss : [0.11129]
Epoch    52: reducing learning rate of group 0 to 1.2500e-04.


100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:11<00:00,  9.66it/s]


Epoch : [53] Train CTC Loss : [0.04804] Val CTC Loss : [0.09730]


100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:12<00:00,  9.07it/s]

Epoch : [54] Train CTC Loss : [0.03997] Val CTC Loss : [0.09309]



100%|██████████| 553/553 [01:55<00:00,  4.81it/s]
100%|██████████| 115/115 [00:12<00:00,  9.25it/s]

Epoch : [55] Train CTC Loss : [0.03563] Val CTC Loss : [0.09024]



100%|██████████| 553/553 [01:54<00:00,  4.82it/s]
100%|██████████| 115/115 [00:13<00:00,  8.83it/s]

Epoch : [56] Train CTC Loss : [0.03574] Val CTC Loss : [0.08920]



100%|██████████| 553/553 [01:54<00:00,  4.82it/s]
100%|██████████| 115/115 [00:14<00:00,  8.20it/s]

Epoch : [57] Train CTC Loss : [0.03226] Val CTC Loss : [0.08870]



100%|██████████| 553/553 [01:55<00:00,  4.79it/s]
100%|██████████| 115/115 [00:12<00:00,  9.04it/s]

Epoch : [58] Train CTC Loss : [0.03207] Val CTC Loss : [0.08923]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.71it/s]

Epoch : [59] Train CTC Loss : [0.03171] Val CTC Loss : [0.09263]



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.64it/s]

Epoch : [60] Train CTC Loss : [0.03264] Val CTC Loss : [0.08898]
Epoch    60: reducing learning rate of group 0 to 6.2500e-05.



100%|██████████| 553/553 [01:55<00:00,  4.80it/s]
100%|██████████| 115/115 [00:13<00:00,  8.27it/s]

Epoch : [61] Train CTC Loss : [0.02980] Val CTC Loss : [0.08567]



100%|██████████| 553/553 [01:55<00:00,  4.79it/s]
100%|██████████| 115/115 [00:13<00:00,  8.83it/s]

Epoch : [62] Train CTC Loss : [0.02609] Val CTC Loss : [0.08281]



100%|██████████| 553/553 [01:55<00:00,  4.77it/s]
100%|██████████| 115/115 [00:13<00:00,  8.29it/s]

Epoch : [63] Train CTC Loss : [0.02470] Val CTC Loss : [0.08234]



100%|██████████| 553/553 [01:54<00:00,  4.81it/s]
100%|██████████| 115/115 [00:14<00:00,  8.13it/s]

Epoch : [64] Train CTC Loss : [0.02452] Val CTC Loss : [0.08500]



100%|██████████| 553/553 [01:54<00:00,  4.83it/s]
100%|██████████| 115/115 [00:12<00:00,  9.21it/s]


Epoch : [65] Train CTC Loss : [0.02405] Val CTC Loss : [0.08283]


100%|██████████| 553/553 [01:53<00:00,  4.86it/s]
100%|██████████| 115/115 [00:12<00:00,  8.99it/s]


Epoch : [66] Train CTC Loss : [0.02227] Val CTC Loss : [0.08107]


100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.35it/s]

Epoch : [67] Train CTC Loss : [0.02169] Val CTC Loss : [0.08048]



100%|██████████| 553/553 [01:54<00:00,  4.83it/s]
100%|██████████| 115/115 [00:12<00:00,  8.88it/s]

Epoch : [68] Train CTC Loss : [0.02120] Val CTC Loss : [0.07998]



100%|██████████| 553/553 [01:54<00:00,  4.84it/s]
100%|██████████| 115/115 [00:12<00:00,  9.26it/s]

Epoch : [69] Train CTC Loss : [0.02037] Val CTC Loss : [0.07913]



100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.09it/s]


Epoch : [70] Train CTC Loss : [0.01984] Val CTC Loss : [0.07855]


100%|██████████| 553/553 [01:54<00:00,  4.83it/s]
100%|██████████| 115/115 [00:12<00:00,  9.45it/s]

Epoch : [71] Train CTC Loss : [0.01932] Val CTC Loss : [0.07883]



100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:12<00:00,  9.40it/s]

Epoch : [72] Train CTC Loss : [0.01868] Val CTC Loss : [0.07750]



100%|██████████| 553/553 [01:54<00:00,  4.84it/s]
100%|██████████| 115/115 [00:12<00:00,  9.14it/s]

Epoch : [73] Train CTC Loss : [0.01835] Val CTC Loss : [0.07710]



100%|██████████| 553/553 [01:54<00:00,  4.82it/s]
100%|██████████| 115/115 [00:13<00:00,  8.52it/s]

Epoch : [74] Train CTC Loss : [0.01861] Val CTC Loss : [0.07750]



100%|██████████| 553/553 [01:55<00:00,  4.77it/s]
100%|██████████| 115/115 [00:13<00:00,  8.47it/s]

Epoch : [75] Train CTC Loss : [0.01790] Val CTC Loss : [0.07708]



100%|██████████| 553/553 [01:55<00:00,  4.79it/s]
100%|██████████| 115/115 [00:13<00:00,  8.78it/s]

Epoch : [76] Train CTC Loss : [0.01705] Val CTC Loss : [0.07679]



100%|██████████| 553/553 [01:55<00:00,  4.77it/s]
100%|██████████| 115/115 [00:13<00:00,  8.61it/s]

Epoch : [77] Train CTC Loss : [0.01864] Val CTC Loss : [0.07752]



100%|██████████| 553/553 [01:54<00:00,  4.85it/s]
100%|██████████| 115/115 [00:13<00:00,  8.63it/s]

Epoch : [78] Train CTC Loss : [0.01767] Val CTC Loss : [0.07672]



100%|██████████| 553/553 [01:54<00:00,  4.81it/s]
100%|██████████| 115/115 [00:13<00:00,  8.63it/s]

Epoch : [79] Train CTC Loss : [0.01674] Val CTC Loss : [0.07678]
Epoch    79: reducing learning rate of group 0 to 3.1250e-05.



100%|██████████| 553/553 [01:55<00:00,  4.81it/s]
100%|██████████| 115/115 [00:13<00:00,  8.76it/s]

Epoch : [80] Train CTC Loss : [0.01551] Val CTC Loss : [0.07470]





## Inference

In [28]:
test = pd.read_csv('./test.csv')

In [29]:
test_dataset = CustomDataset(test['img_path'].values, None, False)
test_loader = DataLoader(test_dataset, batch_size = CFG['BATCH_SIZE'], shuffle=False, num_workers=CFG['NUM_WORKERS'])

In [30]:
def decode_predictions(text_batch_logits):
    text_batch_tokens = F.softmax(text_batch_logits, 2).argmax(2) # [T, batch_size]
    text_batch_tokens = text_batch_tokens.numpy().T # [batch_size, T]

    text_batch_tokens_new = []
    for text_tokens in text_batch_tokens:
        text = [idx2char[idx] for idx in text_tokens]
        text = "".join(text)
        text_batch_tokens_new.append(text)

    return text_batch_tokens_new

def inference(model, test_loader, device):
    model.eval()
    preds = []
    with torch.no_grad():
        for image_batch in tqdm(iter(test_loader)):
            image_batch = image_batch.to(device)
            
            text_batch_logits = model(image_batch)
            
            text_batch_pred = decode_predictions(text_batch_logits.cpu())
            
            preds.extend(text_batch_pred)
    return preds

In [31]:
predictions = inference(infer_model, test_loader, device)

100%|██████████| 290/290 [02:49<00:00,  1.71it/s]


## Submission

In [32]:
# 샘플 별 추론결과를 독립적으로 후처리
def remove_duplicates(text):
    if len(text) > 1:
        letters = [text[0]] + [letter for idx, letter in enumerate(text[1:], start=1) if text[idx] != text[idx-1]]
    elif len(text) == 1:
        letters = [text[0]]
    else:
        return ""
    return "".join(letters)

def correct_prediction(word):
    parts = word.split("-")
    parts = [remove_duplicates(part) for part in parts]
    corrected_word = "".join(parts)
    return corrected_word

In [33]:
submit = pd.read_csv('./sample_submission.csv')
submit['label'] = predictions
submit['label'] = submit['label'].apply(correct_prediction)

In [35]:
submit.to_csv('./mobile_lstm_epoch80_gaussian.csv', index=False)