In [1]:
import os
import time
import random
import string

import torch
import torch.backends.cudnn as cudnn
import torch.nn.init as init
import torch.optim as optim
import torch.utils.data
import numpy as np

from lincenseplateocr.utils import CTCLabelConverter, CTCLabelConverterForBaiduWarpctc, AttnLabelConverter, Averager
from lincenseplateocr.dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset
from lincenseplateocr.model import Model
from lincenseplateocr.test import validation
from tensorboardX import SummaryWriter

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

writer = SummaryWriter()

# 옵션 부분을 여기서 설정할 수 있게 변경
class Options:
    def __init__(self):
        self.exp_name = 'experiment'
        self.train_data = './input/lmdb'
        self.valid_data = './modules/Dataset/Valid'
        self.manualSeed = 1111
        self.workers = 0
        self.batch_size = 1400
        self.num_iter = 3000
        self.valInterval = 20
        self.saved_model = './lincenseplateocr/pretrained/Fine-Tuned.pth'
        self.FT = True
        self.adam = False
        self.lr = 1.0
        self.beta1 = 0.9
        self.rho = 0.95
        self.eps = 1e-8
        self.grad_clip = 5
        self.baiduCTC = False
        self.select_data = '/'
        self.batch_ratio = '1'
        self.total_data_usage_ratio = '1.0'
        self.batch_max_length = 25
        self.imgH = 32
        self.imgW = 100
        self.rgb = False
        self.character = '0123456789().JNRW_abcdef가강개걍거겅겨견결경계고과관광굥구금기김깅나남너노논누니다대댜더뎡도동두등디라러로루룰리마머명모무문므미바배뱌버베보부북비사산서성세셔소송수시아악안양어여연영오올용우울원육으을이익인자작저전제조종주중지차처천초추출충층카콜타파평포하허호홀후히ㅣ'
        self.sensitive = False
        self.PAD = False
        self.Transformation = 'TPS'
        self.FeatureExtraction = 'ResNet'
        self.SequenceModeling = 'BiLSTM'
        self.Prediction = 'Attn'
        self.num_fiducial = 20
        self.input_channel = 1
        self.output_channel = 512
        self.hidden_size = 256
        self.num_gpu = torch.cuda.device_count()
        self.data_filtering_off = False 

opt = Options()

# 모델을 학습하는 함수
def train(opt):
    if not opt.data_filtering_off:
        print('Filtering the images containing characters which are not in opt.character')
        print('Filtering the images whose label is longer than opt.batch_max_length')

    train_dataset = Batch_Balanced_Dataset(opt)
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD)
    valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=opt.batch_size, shuffle=True,
        num_workers=int(opt.workers), collate_fn=AlignCollate_valid, pin_memory=True)
    
    # 모델 설정
    if 'CTC' in opt.Prediction:
        if opt.baiduCTC:
            converter = CTCLabelConverterForBaiduWarpctc(opt.character)
        else:
            converter = CTCLabelConverter(opt.character)
    else:
        converter = AttnLabelConverter(opt.character)
    opt.num_class = len(converter.character)

    if opt.rgb:
        opt.input_channel = 3
    model = Model(opt)
    
    print('Model:', model)
    model = torch.nn.DataParallel(model).to(device)
    model.train()

    if opt.saved_model != '':
        print(f'Loading pretrained model from {opt.saved_model}')
        model.load_state_dict(torch.load(opt.saved_model))

    # 손실 함수 및 옵티마이저 설정
    if 'CTC' in opt.Prediction:
        criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)
    else:
        criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

    if opt.adam:
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
    else:
        optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, rho=opt.rho, eps=opt.eps)

    # 학습 루프
    start_time = time.time()
    for iteration in range(opt.num_iter):
        try:
            image_tensors, labels = train_dataset.get_batch()
            if image_tensors is None or labels is None:
                print(f"빈 배치 발생 at iteration {iteration}. 배치를 건너뜁니다.")
                continue  # 빈 배치를 건너뛰기
        except StopIteration:
            print(f"StopIteration 발생 at loader {iteration}")
            continue  # 배치가 끝났을 경우 다음 루프로 넘어감
            
        image = image_tensors.to(device)
        text, length = converter.encode(labels, batch_max_length=opt.batch_max_length)
        
        if 'CTC' in opt.Prediction:
            preds = model(image, text)
            preds_size = torch.IntTensor([preds.size(1)] * image.size(0))
            preds = preds.log_softmax(2).permute(1, 0, 2)
            cost = criterion(preds, text, preds_size, length)
        else:
            preds = model(image, text[:, :-1])
            target = text[:, 1:]
            cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
        
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # Validation
        if (iteration + 1) % opt.valInterval == 0:
            print(f"검증 루프 시작 at iteration {iteration + 1}")
            model.eval()
            with torch.no_grad():
                try:
                    valid_loss, valid_accuracy, *_ = validation(model, criterion, valid_loader, converter, opt)
                    print(f"검증 결과 - Loss: {valid_loss:.4f}, Accuracy: {valid_accuracy:.2f}%")
                except StopIteration:
                    print(f"StopIteration 발생 at validation {iteration}")
                    continue

            model.train()
            
            print(f"[{iteration+1}/{opt.num_iter}] Train Loss: {cost.item():.4f}, Valid Loss: {valid_loss:.4f}, Valid Accuracy: {valid_accuracy:.4f}")
            
        if (iteration + 1) % 10000 == 0:
            torch.save(model.state_dict(), f"./saved_models/{opt.exp_name}/iter_{iteration+1}.pth")



opt.exp_name = "Number_Plate_Search"
opt.train_data = "./lincenseplateocr/input/train"
opt.valid_data = "./lincenseplateocr/input/Vali"
opt.num_iter = 500000  # 예시로 줄인 값
opt.valInterval = 500  # 검증 주기 설정

# 학습 실행
train(opt)


Reading images...
Filtering the images containing characters which are not in opt.character
Filtering the images whose label is longer than opt.batch_max_length
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
num total samples of /: 239997 x 1.0 (total_data_usage_ratio) = 239997
num samples of / per batch: 1400 x 1.0 (batch_ratio) = 1400
--------------------------------------------------------------------------------
Total_batch_size: 1400 = 1400
--------------------------------------------------------------------------------
Initializing TPS Transformation
TPS 초기화: F=20, I_size=(32, 100), I_r_size=(32, 100), I_channel_num=1


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float())  # F+3 x F+3
  self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float())  # n x F+3


TPS Transformation initialized
Initializing Feature Extraction: ResNet
Feature Extraction initialized with output size: 512
Initializing Sequence Modeling with BiLSTM
Sequence Modeling initialized
Initializing Prediction: Attn
Prediction initialized
Model: Model(
  (Transformation): TPS_SpatialTransformerNetwork(
    (LocalizationNetwork): LocalizationNetwork(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ReLU(inplace=True)
        (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  

  model.load_state_dict(torch.load(opt.saved_model))


StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 172. 배치를 건너뜁니다.
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 345. 배치를 건너뜁니다.
검증 루프 시작 at iteration 500
검증 결과 - Loss: 0.1851, Accuracy: 73.92%
[500/500000] Train Loss: 1.4863, Valid Loss: 0.1851, Valid Accuracy: 73.9165
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 518. 배치를 건너뜁니다.
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 691. 배치를 건너뜁니다.
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 864. 배치를 건너뜁니다.
검증 루프 시작 at iteration 1000
검증 결과 - Loss: 0.0791, Accuracy: 87.84%
[1000/500000] Train Loss: 0.0340, Valid Loss: 0.0791, Valid Accuracy: 87.8391
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 1037. 배치를 건너뜁니다.
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니다.
빈 배치 발생 at iteration 1210. 배치를 건너뜁니다.
StopIteration 발생 at loader 0
빈 배치가 감지되었습니다. 빈 배치를 건너뜁니

KeyboardInterrupt: 

In [None]:
import gc

# GPU 메모리 강제 수집
gc.collect()  # Python 객체 수집 (CPU 메모리 해제)
torch.cuda.empty_cache()  # GPU 메모리 캐시 해제