In [2]:
import os
import json
import torch
import torch.optim as optim
import numpy as np
import torch.utils.data
from torchvision import transforms
from PIL import Image, ImageDraw
from lincenseplateocr.model import Model
from lincenseplateocr.utils import AttnLabelConverter, CTCLabelConverter
from lincenseplateocr.dataset import AlignCollate, hierarchical_dataset
from lincenseplateocr.test import validation
from ultralytics import YOLO  # YOLO 모델 추가

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Options:
    def __init__(self):
        self.exp_name = 'license_plate_training'
        self.train_data = 'lincenseplateocr/input/train'
        self.valid_data = 'lincenseplateocr/input/vali'
        self.saved_model = 'lincenseplateocr/pretrained/Fine-Tuned.pth'
        self.num_iter = 3000
        self.valInterval = 1
        self.batch_size = 1000
        torch.cuda.empty_cache()
        self.lr = 0.001
        self.beta1 = 0.9
        self.eps = 1e-8
        self.Prediction = 'Attn'
        self.batch_max_length = 25
        self.imgH = 32
        self.imgW = 100
        self.character = '0123456789().JNRW_abcdef가강개걍거겅겨견결경계고과관광굥구금기김깅나남너노논누니다대댜더뎡도동두등디라러로루룰리마머명모무문므미바배뱌버베보부북비사산서성세셔소송수시아악안양어여연영오올용우울원육으을이익인자작저전제조종주중지차처천초추출충층카콜타파평포하허호홀후히ㅣ'
        self.input_channel = 1
        self.output_channel = 512
        self.hidden_size = 256
        self.workers = 0
        self.FT = True
        self.Transformation = 'TPS'
        self.FeatureExtraction = 'ResNet'
        self.SequenceModeling = 'BiLSTM'
        self.num_fiducial = 20
        self.data_filtering_off = False
        self.rgb = False
        self.sensitive = False

opt = Options()

# STR 모델 준비
if opt.Prediction == 'Attn':
    converter = AttnLabelConverter(opt.character)
else:
    converter = CTCLabelConverter(opt.character)
opt.num_class = len(converter.character)
print(f"Using converter: {type(converter)}")

# 모델 로드 부분 수정
str_model = Model(opt).to(device)

if opt.saved_model != '':
    # Fine-tuning 옵션에 따른 requires_grad 설정
    if opt.FT:
        print(f'Fine-tuning mode: Loading pretrained model for fine-tuning')
        str_model.load_state_dict(torch.load(opt.saved_model, map_location=device), strict=False)

        # Prediction 레이어 내 모든 파라미터의 requires_grad를 True로 설정
        for name, param in str_model.named_parameters():
            if 'Prediction' in name:  # Prediction 관련 파라미터만 True로 설정
                param.requires_grad = True
                print(f'{name} - requires_grad: True')
            else:
                param.requires_grad = False  # 나머지 레이어는 고정
    else:
        print(f'Loading pretrained model from {opt.saved_model}')
        str_model.load_state_dict(torch.load(opt.saved_model, map_location=device))

# 손실 함수 및 옵티마이저 설정
criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device)

# 옵티마이저 설정
filtered_parameters = []
for p in filter(lambda p: p.requires_grad, str_model.parameters()):
    filtered_parameters.append(p)

if len(filtered_parameters) == 0:
    raise ValueError("No parameters available for training. Please check requires_grad settings.")

optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))

def load_label_json(label_path):
    """라벨 JSON 파일을 로드"""
    with open(label_path, 'r', encoding='utf-8') as f:
        label_data = json.load(f)
    return label_data['value']  # 이미지의 실제 라벨

def detect_license_plate_yolo(image, model):
    """YOLO 모델을 이용해 번호판 탐지"""
    results = model(image)
    boxes = results[0].boxes  # 탐지된 바운딩 박스
    license_plate_boxes = []
    
    for box in boxes:
        # 번호판 클래스(예: 0)일 때만 처리
        if int(box.cls[0]) == 0:  # 번호판 클래스 ID가 0이라고 가정
            x_min, y_min, x_max, y_max = map(int, box.xyxy[0].tolist())
            license_plate_boxes.append([x_min, y_min, x_max, y_max])

    return license_plate_boxes

def crop_license_plate(image, boxes):
    """탐지된 번호판 영역을 자름"""
    cropped_plates = []
    for box in boxes:
        x_min, y_min, x_max, y_max = box
        cropped_plates.append(image.crop((x_min, y_min, x_max, y_max)))
    return cropped_plates

def train(opt):
    """STR 모델 학습 루프"""
    transform = transforms.Compose([
        transforms.Resize((opt.imgH, opt.imgW)),
        transforms.ToTensor()
    ])

    total_loss = 0
    best_loss = float('inf')

    save_dir = "./saved_models"
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    # YOLO 번호판 탐지 모델 로드
    yolo_model = YOLO('weights/plate_detect.pt')  # 번호판 탐지 모델 경로 지정

    for iteration in range(opt.num_iter):
        total_loss = 0
        for img_file in os.listdir(opt.train_data):
            torch.cuda.empty_cache()
            if img_file.endswith('.jpg'):
                image_path = os.path.join(opt.train_data, img_file)
                label_path = os.path.splitext(image_path)[0] + ".json"

                try:
                    # 이미지 및 라벨 로드
                    image = Image.open(image_path).convert('RGB')  # 원본 이미지를 로드
                    label = load_label_json(label_path)

                    # 번호판 영역 탐지
                    license_plate_boxes = detect_license_plate_yolo(image, yolo_model)
                    if len(license_plate_boxes) == 0:
                        print(f"번호판이 탐지되지 않음: {image_path}")
                        continue

                    # 번호판 영역 자르기
                    cropped_license_plates = crop_license_plate(image, license_plate_boxes)

                    # 자른 번호판 중 첫 번째 번호판에 대해 진행
                    if len(cropped_license_plates) > 0:
                        plate_image = cropped_license_plates[0].convert('L')  # 흑백 변환
                        plate_image = transform(plate_image).unsqueeze(0).to(device)  # 텐서로 변환

                except Exception as e:
                    print(f"이미지 불러오기 오류: {e} - 파일 경로: {image_path} 라벨 경로 : {label_path}")
                    continue

                # 라벨 인코딩
                text, length = converter.encode([label], batch_max_length=opt.batch_max_length)
                print(f"Encoded label: {text}")

                # 모델 예측 (Attn 모델이므로 text[:, :-1] 사용)
                preds = str_model(plate_image, text[:, :-1])

                # 예측된 시퀀스 디코딩
                _, preds_index = preds.max(2)
                preds_str = converter.decode(preds_index, torch.IntTensor([preds.size(1)] * preds.size(0)))
                print(f"Prediction: {preds_str}, Ground Truth: {label}")

                # 타겟 설정 (라벨에서 [GO] 토큰을 제외한 텍스트)
                target = text[:, 1:]  # [GO] 토큰 제외

                # 예측된 시퀀스 길이를 타겟 길이에 맞춤
                preds = preds[:, :target.size(1), :]

                # 손실 계산 (CrossEntropyLoss로 계산)
                cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))

                # 손실 역전파 및 옵티마이저 업데이트
                optimizer.zero_grad()
                cost.backward()
                optimizer.step()

                total_loss += cost.item()

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch [{iteration + 1}/{opt.num_iter}], Total Loss: {total_loss:.4f}, Learning Rate: {current_lr:.6f}")

        if total_loss < best_loss:
            best_loss = total_loss
            torch.save(str_model.state_dict(), f"./saved_models/{opt.exp_name}_best.pth")
            print(f"Best model saved with loss: {best_loss:.4f}")

        print(f"Model saved for epoch {iteration + 1}")

        if (iteration + 1) % opt.valInterval == 0:
            validate(opt)

def validate(opt):
    """검증 루틴"""
    AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=True)
    valid_dataset, _ = hierarchical_dataset(root=opt.valid_data, opt=opt)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.workers, collate_fn=AlignCollate_valid)

    str_model.eval()
    with torch.no_grad():
        valid_loss, valid_accuracy = validation(str_model, criterion, valid_loader, converter, opt)[:2]
        print(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_accuracy:.4f}")
    str_model.train()

# 트레이닝 실행
train(opt)


Using converter: <class 'lincenseplateocr.utils.AttnLabelConverter'>
Initializing TPS Transformation
TPS 초기화: F=20, I_size=(32, 100), I_r_size=(32, 100), I_channel_num=1


  self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float())  # F+3 x F+3
  self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float())  # n x F+3


TPS Transformation initialized
Initializing Feature Extraction: ResNet
Feature Extraction initialized with output size: 512
Initializing Sequence Modeling with BiLSTM
Sequence Modeling initialized
Initializing Prediction: Attn
Prediction initialized
Fine-tuning mode: Loading pretrained model for fine-tuning
Prediction.attention_cell.i2h.weight - requires_grad: True
Prediction.attention_cell.h2h.weight - requires_grad: True
Prediction.attention_cell.h2h.bias - requires_grad: True
Prediction.attention_cell.score.weight - requires_grad: True
Prediction.attention_cell.rnn.weight_ih - requires_grad: True
Prediction.attention_cell.rnn.weight_hh - requires_grad: True
Prediction.attention_cell.rnn.bias_ih - requires_grad: True
Prediction.attention_cell.rnn.bias_hh - requires_grad: True
Prediction.generator.weight - requires_grad: True
Prediction.generator.bias - requires_grad: True

0: 320x640 1 number_plate, 19.3ms
Speed: 0.0ms preprocess, 19.3ms inference, 5.0ms postprocess per image at shap

KeyboardInterrupt: 