In [None]:
# EfficientNet-B5 + OCR 2-Track 앙상블 모델
#
# 수정 사항:
# 1. (Track 1) Vision: 기존 EfficientNet (Pseudo-Labeling + TTA)
# 2. (Track 2) OCR: EasyOCR + TF-IDF/LogisticRegression
# 3. Ensemble: 최종 예측 시 Track 1과 Track 2의 확률을 가중 평균

import os
import random
import re # NEW: OCR 텍스트 클리닝
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb
from tqdm import tqdm
import warnings

# NEW: OCR 및 텍스트 분류기 라이브러리
import easyocr
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
import joblib # NEW: OCR 모델 저장을 위해

warnings.filterwarnings('ignore')

In [None]:
# ===============================
# 4. 데이터셋 클래스 (원본 유지)
# (Vision 모델 학습용)
# ===============================

# 학습/검증용 데이터셋
class DocumentDataset(Dataset):
    def __init__(self, df, train_img_dir, test_img_dir, transform=None):
        self.df = df
        self.train_img_dir = train_img_dir
        self.test_img_dir = test_img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['ID']
        
        train_path = os.path.join(self.train_img_dir, img_id)
        test_path = os.path.join(self.test_img_dir, img_id)
        
        if os.path.exists(train_path):
            img_path = train_path
        elif os.path.exists(test_path):
            img_path = test_path
        else:
            raise FileNotFoundError(f"Image not found in train or test dir: {img_id}")

        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        label = row['target']
        return image, label

# 테스트(예측)용 데이터셋
class TestDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row['ID']
        img_path = os.path.join(self.img_dir, img_id)
        
        image = Image.open(img_path).convert('RGB')
        image = np.array(image)
        
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        
        return image, img_id

In [None]:
# ===============================
# 5. 데이터 증강 (Train) (원본 유지)
# ===============================
def get_train_transform(img_size):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.OneOf([
            A.GaussNoise(var_limit=(10.0, 800.0), p=0.75),
            A.GaussianBlur(blur_limit=(1, 7), p=0.5)
        ], p=0.75),
        A.RandomRotate90(p=0.5),
        A.HorizontalFlip(p=0.75),
        A.Rotate(limit=30, p=0.75),
        A.Transpose(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
        A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),
        A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),
        A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=30, p=0.25),
        A.ElasticTransform(alpha=1, sigma=30, alpha_affine=30, p=0.5),
        A.OpticalDistortion(p=0.5),
        A.CoarseDropout(max_holes=6, max_height=32, max_width=32, p=0.5),
        A.MotionBlur(blur_limit=5, p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
# ===============================
# 6. 검증/테스트 증강 (원본 유지)
# ===============================
def get_valid_transform(img_size):
    return A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
# ===============================
# 7. TTA (Test Time Augmentation) (원본 유지)
# ===============================
def get_tta_transforms(img_size):
    """회전, 플립, 노이즈에 강한 TTA 변형 버전 반환"""
    transforms_list = []
    
    # 1. 원본
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 2. 수평 플립
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.HorizontalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 3. 수직 플립
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.VerticalFlip(p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 4. 임의 각도 회전 (45도)
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Rotate(limit=(45, 45), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 5. 임의 각도 회전 (135도)
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Rotate(limit=(135, 135), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 6. 노이즈 (GaussNoise)
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.GaussNoise(var_limit=(30.0, 200.0), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 7. 블러 (Gaussian Blur)
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.GaussianBlur(blur_limit=(3, 7), p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 8. 밝기/명암 변화
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    # 9. 확대/축소/쉬프트
    transforms_list.append(A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.ShiftScaleRotate(shift_limit=0.08, scale_limit=0.08, rotate_limit=0, p=1.0),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ]))
    
    return transforms_list