In [None]:
# ===============================
# 8. Learning Rate Scheduler (원본 유지)
# ===============================
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_epochs, total_epochs, max_lr, min_lr):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.max_lr = max_lr
        self.min_lr = min_lr
        
    def step(self, epoch):
        if epoch < self.warmup_epochs:
            lr = self.max_lr * (epoch + 1) / self.warmup_epochs
        else:
            progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = self.min_lr + (self.max_lr - self.min_lr) * 0.5 * (1 + np.cos(np.pi * progress))
        
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        
        return lr

In [None]:
# ===============================
# 9. 모델 생성 (원본 유지)
# ===============================
def create_model(model_name, num_classes, pretrained=True):
    model = timm.create_model(
        model_name,
        pretrained=pretrained,
        num_classes=num_classes
    )
    return model

In [None]:
# ===============================
# 10. 학습 함수 (원본 유지)
# ===============================
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dataloader, desc=f'Epoch {epoch+1} - Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [None]:
# ===============================
# 11. 검증 함수 (원본 유지)
# ===============================
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc='Validating'):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    
    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total
    
    return epoch_loss, epoch_acc

In [None]:
# ===============================
# 13. [NEW] OCR 모델 학습 함수 (Track 2)
# ===============================
def clean_text(text):
    """간단한 텍스트 클리닝 (특수문자 제거)"""
    # 한국어, 영어, 숫자, 공백만 남김
    return re.sub(r'[^A-Za-z0-9가-힣\s]', '', text)

def get_ocr_text_from_path(reader, img_path):
    """EasyOCR을 사용해 이미지 경로에서 텍스트 추출"""
    try:
        # detail=0은 텍스트만 리스트로 반환
        text_list = reader.readtext(img_path, detail=0)
        full_text = ' '.join(text_list)
        return clean_text(full_text)
    except Exception as e:
        # print(f"Warning: OCR 실패 {img_path} - {e}")
        return "" # OCR 실패 시 빈 문자열 반환

def train_ocr_model():
    """
    Train.csv의 모든 이미지를 읽어 OCR 텍스트 분류기를 학습시킵니다.
    학습된 모델은 'ocr_model.joblib'로 저장하고, 캐시된 파일이 있으면 로드합니다.
    """
    ocr_model_path = 'ocr_model.joblib'
    
    # EasyOCR 리더 초기화 (GPU 사용)
    ocr_reader = easyocr.Reader(config.ocr_lang, gpu=True)
    
    if os.path.exists(ocr_model_path):
        print(f"Loading cached OCR model from {ocr_model_path}...")
        ocr_pipeline = joblib.load(ocr_model_path)
        return ocr_reader, ocr_pipeline

    print(f"No cached model found. Training new OCR model...")
    
    # 1. 학습 데이터 로드
    train_df = pd.read_csv(config.train_csv)
    
    # 2. 모든 학습 이미지에서 텍스트 추출
    X_texts = []
    y_labels = []
    
    pbar = tqdm(train_df.itertuples(), total=len(train_df), desc="Running OCR on train data")
    for row in pbar:
        img_id = row.ID
        label = row.target
        img_path = os.path.join(config.train_img_dir, img_id)
        
        ocr_text = get_ocr_text_from_path(ocr_reader, img_path)
        
        X_texts.append(ocr_text)
        y_labels.append(label)
    
    print(f"OCR extraction complete. Found text in {sum([1 for t in X_texts if t])}/{len(X_texts)} images.")
    
    # 3. Scikit-learn 파이프라인 생성 (TF-IDF + 로지스틱 회귀)
    # n-gram (1, 2)를 사용하여 단어 및 연속된 두 단어(bigram)를 피처로 사용
    # max_features로 피처 수 제한 (메모리 관리)
    ocr_pipeline = Pipeline([
        ('tfidf', TfidfVectorizer(ngram_range=(1, 2), max_features=10000, token_pattern=r'\b\w+\b')),
        ('clf', LogisticRegression(solver='liblinear', C=1.0, random_state=42, multi_class='auto'))
    ])
    
    # 4. 모델 학습
    print("Training TF-IDF + Logistic Regression model...")
    ocr_pipeline.fit(X_texts, y_labels)
    
    # 5. 모델 저장
    joblib.dump(ocr_pipeline, ocr_model_path)
    print(f"OCR model saved to {ocr_model_path}")
    
    return ocr_reader, ocr_pipeline

In [None]:
# ===============================
# 14. 학습 파이프라인 (원본 유지)
# (Vision 모델만 학습)
# ===============================
def run_training_pipeline():
    
    # --- 1단계: 원본 데이터로 초기 학습 ---
    print(f'\n{"="*50}')
    print(f'STAGE 1: Initial Training on Labeled Data (Vision Track)')
    print(f'{"="*50}')

    train_df = pd.read_csv(config.train_csv)
    
    train_fold_df, valid_fold_df = train_test_split(
        train_df,
        test_size=config.val_split_ratio,
        random_state=42,
        stratify=train_df['target']
    )
    train_fold_df = train_fold_df.reset_index(drop=True)
    valid_fold_df = valid_fold_df.reset_index(drop=True)

    print(f"Stage 1: Train data: {len(train_fold_df)}, Valid data: {len(valid_fold_df)}")

    train_dataset = DocumentDataset(
        train_fold_df, 
        config.train_img_dir,
        config.test_img_dir,
        transform=get_train_transform(config.img_size)
    )
    valid_dataset = DocumentDataset(
        valid_fold_df, 
        config.train_img_dir,
        config.test_img_dir,
        transform=get_valid_transform(config.img_size)
    )
    
    train_loader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True, 
        num_workers=config.num_workers, pin_memory=True
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=config.batch_size, shuffle=False, 
        num_workers=config.num_workers, pin_memory=True
    )
    
    model = create_model(config.model_name, config.num_classes, pretrained=True).to(config.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config.max_lr, weight_decay=1e-4)
    scheduler = WarmupCosineScheduler(
        optimizer, warmup_epochs=config.warmup_epochs, total_epochs=config.epochs,
        max_lr=config.max_lr, min_lr=config.min_lr
    )
    
    best_val_acc = 0.0
    for epoch in range(config.epochs):
        current_lr = scheduler.step(epoch)
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, config.device, epoch)
        val_loss, val_acc = validate(model, valid_loader, criterion, config.device)
        
        print(f'Stage 1 - Epoch {epoch+1}/{config.epochs}')
        print(f'LR: {current_lr:.6f}')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Valid Loss: {val_loss:.4f}, Valid Acc: {val_acc:.2f}%')
        
        if config.use_wandb:
            wandb.log({
                'stage1/train_loss': train_loss, 'stage1/train_acc': train_acc,
                'stage1/val_loss': val_loss, 'stage1/val_acc': val_acc,
                'stage1/learning_rate': current_lr, 'epoch_stage1': epoch + 1
            })
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model_stage1.pth')
            print(f'✓ Best Stage 1 model saved! (Val Acc: {val_acc:.2f}%)')

    print(f'\nStage 1 Best Validation Accuracy: {best_val_acc:.2f}%')

    # --- 2단계: 의사 라벨링 및 재학습 ---
    if config.use_pseudo_labeling:
        
        # 2-A: 의사 라벨 생성 (Vision 모델만 사용)
        print(f'\n{"="*50}')
        print(f'STAGE 2: Generating Pseudo-Labels (Vision Track)')
        print(f'{"="*50}')
        
        model.load_state_dict(torch.load('best_model_stage1.pth'))
        
        submission_df = pd.read_csv(config.submission_csv)
        pseudo_gen_dataset = TestDataset(
            submission_df, 
            config.test_img_dir, 
            transform=get_valid_transform(config.img_size)
        )
        pseudo_gen_loader = DataLoader(
            pseudo_gen_dataset, batch_size=config.batch_size, shuffle=False, 
            num_workers=config.num_workers, pin_memory=True
        )
        
        pseudo_label_df = generate_pseudo_labels(
            model, pseudo_gen_loader, config.device, config.pseudo_label_threshold
        )
        print(f'Generated {len(pseudo_label_df)} pseudo-labels with threshold >= {config.pseudo_label_threshold}')
        if config.use_wandb:
            wandb.log({'pseudo_label_count': len(pseudo_label_df)})

        if len(pseudo_label_df) == 0:
            print("No pseudo-labels generated. Skipping re-training.")
            return

        # 2-B: 의사 라벨과 함께 재학습 (Vision 모델)
        print(f'\n{"="*50}')
        print(f'STAGE 3: Re-training with Pseudo-Labels (Vision Track)')
        print(f'{"="*50}')

        combined_train_df = pd.concat([train_fold_df, pseudo_label_df], ignore_index=True)
        print(f"Stage 3: Combined Train data: {len(combined_train_df)}, Valid data: {len(valid_fold_df)}")

        train_dataset_pseudo = DocumentDataset(
            combined_train_df, 
            config.train_img_dir,
            config.test_img_dir, 
            transform=get_train_transform(config.img_size)
        )
        
        train_loader_pseudo = DataLoader(
            train_dataset_pseudo, batch_size=config.batch_size, shuffle=True, 
            num_workers=config.num_workers, pin_memory=True
        )
        
        model_pseudo = create_model(config.model_name, config.num_classes, pretrained=True).to(config.device)
        criterion_pseudo = nn.CrossEntropyLoss()
        optimizer_pseudo = optim.AdamW(model_pseudo.parameters(), lr=config.max_lr, weight_decay=1e-4)
        scheduler_pseudo = WarmupCosineScheduler(
            optimizer_pseudo, warmup_epochs=config.warmup_epochs, total_epochs=config.epochs,
            max_lr=config.max_lr, min_lr=config.min_lr
        )
        
        best_val_acc_pseudo = 0.0
        for epoch in range(config.epochs):
            current_lr = scheduler_pseudo.step(epoch)
            train_loss, train_acc = train_one_epoch(
                model_pseudo, train_loader_pseudo, criterion_pseudo, optimizer_pseudo, config.device, epoch
            )
            val_loss, val_acc = validate(model_pseudo, valid_loader, criterion_pseudo, config.device)
            
            print(f'Stage 3 - Epoch {epoch+1}/{config.epochs}')
            print(f'LR: {current_lr:.6f}')
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Valid Loss: {val_loss:.4f}, Valid Acc: {val_acc:.2f}%')
            
            if config.use_wandb:
                wandb.log({
                    'stage3/train_loss': train_loss, 'stage3/train_acc': train_acc,
                    'stage3/val_loss': val_loss, 'stage3/val_acc': val_acc,
                    'stage3/learning_rate': current_lr, 'epoch_stage3': epoch + 1
                })
            
            if val_acc > best_val_acc_pseudo:
                best_val_acc_pseudo = val_acc
                torch.save(model_pseudo.state_dict(), 'best_model_stage2_pseudo.pth')
                print(f'✓ Best Stage 3 (Pseudo) model saved! (Val Acc: {val_acc:.2f}%)')
        
        print(f'\nStage 3 Best Validation Accuracy: {best_val_acc_pseudo:.2f}%')
        if config.use_wandb:
            wandb.log({'final_best_val_acc': best_val_acc_pseudo})

    else:
        print("Pseudo-labeling disabled. Using Stage 1 model for inference.")
        if config.use_wandb:
            wandb.log({'final_best_val_acc': best_val_acc})

In [None]:
# ===============================
# 15. 테스트 예측 (TTA + OCR 앙상블 적용, 수정됨)
# ===============================
def predict_with_tta_ocr_ensemble(ocr_reader, ocr_pipeline):
    """TTA(Vision)와 OCR(Text)을 앙상블하여 테스트 예측"""
    submission_df = pd.read_csv(config.submission_csv)
    
    # --- Track 1: Vision 모델 로드 ---
    tta_transforms = get_tta_transforms(config.img_size)
    
    vision_model = create_model(config.model_name, config.num_classes, pretrained=False)
    
    if config.use_pseudo_labeling and os.path.exists('best_model_stage2_pseudo.pth'):
        model_path = 'best_model_stage2_pseudo.pth'
        print("Loading Stage 3 (Pseudo-Label) model for Vision Track.")
    else:
        model_path = 'best_model_stage1.pth'
        print("Loading Stage 1 model for Vision Track.")
        
    vision_model.load_state_dict(torch.load(model_path))
    vision_model = vision_model.to(config.device)
    vision_model.eval()
    
    # --- Track 2: OCR 모델은 이미 'ocr_pipeline'으로 로드됨 ---
    print("Using pre-trained OCR model for Text Track.")
    
    all_predictions = []
    
    pbar = tqdm(submission_df['ID'], desc='Predicting with 2-Track Ensemble')
    for img_name in pbar:
        img_path = os.path.join(config.test_img_dir, img_name)
        
        # --- Track 1: Vision TTA 예측 ---
        image = Image.open(img_path).convert('RGB')
        image_np = np.array(image)
        
        tta_preds_for_image = []
        with torch.no_grad():
            for transform in tta_transforms:
                augmented = transform(image=image_np)
                img_tensor = augmented['image'].unsqueeze(0).to(config.device)
                
                outputs = vision_model(img_tensor)
                probs = torch.softmax(outputs, dim=1)
                tta_preds_for_image.append(probs.cpu().numpy())
        
        # TTA 평균 (Shape: (17,))
        vision_probs = np.mean(tta_preds_for_image, axis=0)
        vision_probs = np.squeeze(vision_probs) # (1, 17) -> (17,)
        
        # --- Track 2: OCR 예측 ---
        # TTA와 달리 원본 이미지 1장에 대해서만 OCR 수행
        ocr_text = get_ocr_text_from_path(ocr_reader, img_path)
        
        # .predict_proba는 (n_samples, n_classes) 형태의 배열 반환. [0]으로 첫 번째(유일한) 샘플 선택
        # (Shape: (17,))
        ocr_probs = ocr_pipeline.predict_proba([ocr_text])[0]
        
        # --- 앙상블 ---
        # 두 확률 분포를 가중 평균
        final_probs = (config.ensemble_vision_weight * vision_probs) + \
                      (config.ensemble_ocr_weight * ocr_probs)
        
        predicted_class = np.argmax(final_probs)
        all_predictions.append(predicted_class)
    
    # 결과 저장
    submission_df['target'] = all_predictions
    submission_filename = 'submission_ensemble.csv'
    submission_df.to_csv(submission_filename, index=False)
    print(f'Submission file saved as {submission_filename}!')
    
    if config.use_wandb:
        wandb.save(submission_filename)