In [None]:
import os
import time
import random
import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# 시드 고정
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

# 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 224
model_name = 'efficientnet_b3'
BATCH_SIZE = 32

# Dataset 클래스 정의
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image=img)['image']
        return img, target

def get_tta_transforms():
    """
    TTA를 위한 여러 transform들을 반환합니다.
    """
    transforms_list = []
    
    # Original transform
    transforms_list.append(
        A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    )
    
    # Horizontal flip
    transforms_list.append(
        A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.HorizontalFlip(p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    )
    
    # Rotation
    transforms_list.append(
        A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.Rotate(limit=10, p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    )
    
    # Brightness & Contrast
    transforms_list.append(
        A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.RandomBrightnessContrast(p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    )
    
    # Gaussian Noise
    transforms_list.append(
        A.Compose([
            A.Resize(height=img_size, width=img_size),
            A.GaussNoise(p=1.0),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    )
    
    return transforms_list

def tta_inference(model, test_dataset, device, num_classes=17):
    """
    TTA를 적용하여 예측을 수행합니다.
    """
    model.eval()
    transforms_list = get_tta_transforms()
    predictions = []
    
    # 원본 transform 백업
    original_transform = test_dataset.transform
    
    with torch.no_grad():
        for idx in tqdm(range(len(test_dataset))):
            # 각 이미지에 대한 예측 확률 저장
            pred_probas = []
            
            # 이미지 가져오기 (transform 적용 전)
            image, _ = test_dataset[idx]
            if isinstance(image, torch.Tensor):
                image = image.numpy()
                
            # 각 transform에 대해 예측 수행
            for transform in transforms_list:
                test_dataset.transform = transform
                augmented_image = transform(image=image)['image']
                augmented_image = augmented_image.unsqueeze(0).to(device)
                
                # 예측
                output = model(augmented_image)
                probas = torch.softmax(output, dim=1)
                pred_probas.append(probas.cpu().numpy())
            
            # 모든 예측 결과의 평균 계산
            avg_probas = np.mean(pred_probas, axis=0)
            predicted_class = np.argmax(avg_probas)
            predictions.append(predicted_class)
    
    # 원본 transform 복구
    test_dataset.transform = original_transform
    
    return predictions

In [None]:
test_img_path = '/root/CV_PJT/CV_PJT/data/data/test'
    sub_path = '/root/CV_PJT/CV_PJT/data/data/sample_submission.csv'
    model_path = "/root/CV_PJT/CV_PJT/model/entire_model_resnext_(2).pth"  # 저장된 모델 경로
    
    # 기본 transform 설정
    tst_transform = A.Compose([
        A.Resize(height=img_size, width=img_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])
    
    # 테스트 데이터셋 생성
    tst_dataset = ImageDataset(
        sub_path,
        test_img_path,
        transform=tst_transform
    )
    
    # 모델 생성
    model = timm.create_model(
        model_name,
        pretrained=False,
        num_classes=17
    ).to(device)
    
    # 저장된 모델 로드
    print("Loading model from:", model_path)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Model loaded successfully!")
    
    # TTA 적용하여 예측 수행
    print("Performing TTA inference...")
    preds_list = tta_inference(model, tst_dataset, device)
    
    # 예측 결과를 DataFrame으로 변환
    pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
    pred_df['target'] = preds_list
    
    # 결과 저장
    output_path = "pred_tta_ensemble.csv"
    pred_df.to_csv(output_path, index=False)
    print(f"Predictions saved to {output_path}")
    
    # 예측 분포 시각화
    plt.figure(figsize=(8, 6))
    sns.countplot(x='target', data=pred_df, palette='Set2')
    plt.title('Distribution of Predicted Classes (with TTA)')
    plt.xlabel('Class Labels')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.show()