## Attention based Variational Autoencoder

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

class FeatureProcessor:
    def __init__(self):
        self.scalers = []
        
    def normalize_and_concatenate(self, feature_vectors):
        normalized_vectors = []
        
        for vector in feature_vectors:
            scaler = StandardScaler()
            normalized = scaler.fit_transform(vector)
            self.scalers.append(scaler)
            normalized_vectors.append(normalized)
        
        return np.concatenate(normalized_vectors, axis=1)

class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.BatchNorm1d(in_features),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(in_features, in_features),
            nn.BatchNorm1d(in_features)
        )
        
    def forward(self, x):
        return F.relu(x + self.block(x))

def loss_function(recon_x, x, mu, log_var, kld_weight):
    # Use mean reduction instead of sum
    reconstruction_loss = F.mse_loss(recon_x, x, reduction='mean')
    
    # Normalized KLD loss
    kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
    
    return reconstruction_loss + kld_weight * kld_loss, reconstruction_loss, kld_loss

class WarmUpScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_epochs, total_epochs):
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        super(WarmUpScheduler, self).__init__(optimizer)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            return [base_lr * (self.last_epoch + 1) / self.warmup_epochs 
                    for base_lr in self.base_lrs]
        return [base_lr * (1 - (self.last_epoch - self.warmup_epochs) / 
                (self.total_epochs - self.warmup_epochs))
                for base_lr in self.base_lrs]

class GroupedLinearAttention(nn.Module):
    def __init__(self, feature_dim, group_size=64):
        super().__init__()
        self.group_size = group_size
        self.num_groups = feature_dim // group_size
        if feature_dim % group_size != 0:
            self.num_groups += 1
        
        # Initialize with normal distribution
        self.feature_weights = nn.Parameter(torch.randn(feature_dim) * 0.02)
        
        # Separate projections for each group
        self.group_projections = nn.ModuleList([
            nn.Sequential(
                nn.Linear(group_size, group_size // 2),
                nn.ReLU(),
                nn.Linear(group_size // 2, group_size)
            ) for _ in range(self.num_groups)
        ])
        
        # Global feature context
        self.global_context = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, feature_dim)
        )
        
    def forward(self, x):
        batch_size = x.size(0)
        original_size = x.size(1)
        
        # Apply global context
        global_weights = torch.sigmoid(self.global_context(x))
        x = x * global_weights
        
        # Pad if necessary
        padded_len = self.num_groups * self.group_size
        if padded_len > original_size:
            padding = torch.zeros(batch_size, padded_len - original_size, device=x.device)
            x = torch.cat([x, padding], dim=1)
        
        # Process each group
        grouped_output = []
        for i in range(self.num_groups):
            start_idx = i * self.group_size
            end_idx = start_idx + self.group_size
            group = x[:, start_idx:end_idx]
            
            # Apply group-specific attention
            group_weights = F.softmax(self.feature_weights[start_idx:end_idx], dim=0)
            weighted_group = group * group_weights
            
            # Apply group transformation
            transformed_group = self.group_projections[i](weighted_group)
            grouped_output.append(transformed_group)
        
        # Concatenate results
        output = torch.cat(grouped_output, dim=1)
        
        # Return only original size
        return output[:, :original_size]
    
    def get_feature_importance(self):
        # Apply softmax per group to get more distinct importance values
        importances = []
        for i in range(self.num_groups):
            start_idx = i * self.group_size
            end_idx = start_idx + self.group_size
            group_weights = F.softmax(self.feature_weights[start_idx:end_idx], dim=0)
            importances.append(group_weights)
        return torch.cat(importances)

class AttentionVAE(nn.Module):
    def __init__(self, input_dim, encoding_dim=1024, group_size=64):
        super().__init__()
        
        self.attention = GroupedLinearAttention(input_dim, group_size)
        
        # Revised encoder architecture
        self.encoder_layers = nn.ModuleList([
            # First layer: input_dim -> 3072
            nn.Sequential(
                nn.Linear(input_dim, 3072),
                nn.LayerNorm(3072),
                nn.ReLU(),
                nn.Dropout(0.2)
            ),
            # Second layer: 3072 -> 3072 (residual)
            nn.Sequential(
                nn.Linear(3072, 3072),
                nn.LayerNorm(3072),
                nn.ReLU(),
                nn.Dropout(0.2)
            ),
            # Third layer: 3072 -> 2048
            nn.Sequential(
                nn.Linear(3072, 2048),
                nn.LayerNorm(2048),
                nn.ReLU(),
                nn.Dropout(0.2)
            )
        ])
        
        # Feature projections for residual connections
        self.input_projection = nn.Linear(input_dim, 3072)
        
        # VAE specific layers
        self.fc_mu = nn.Linear(2048, encoding_dim)
        self.fc_var = nn.Linear(2048, encoding_dim)
        
        # Revised decoder architecture
        self.decoder = nn.ModuleList([
            # First layer: encoding_dim -> 2048
            nn.Sequential(
                nn.Linear(encoding_dim, 2048),
                nn.LayerNorm(2048),
                nn.ReLU(),
                nn.Dropout(0.2)
            ),
            # Second layer: 2048 -> 3072
            nn.Sequential(
                nn.Linear(2048, 3072),
                nn.LayerNorm(3072),
                nn.ReLU(),
                nn.Dropout(0.2)
            ),
            # Third layer: 3072 -> 3072 (residual)
            nn.Sequential(
                nn.Linear(3072, 3072),
                nn.LayerNorm(3072),
                nn.ReLU(),
                nn.Dropout(0.2)
            ),
            # Output layer: 3072 -> input_dim
            nn.Linear(3072, input_dim)
        ])
        
        # Initialize weights
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
    
    def encode(self, x):
        # Apply attention
        x = self.attention(x)
        
        # Project input for residual connection
        identity = self.input_projection(x)
        
        # Apply encoder layers
        for i, layer in enumerate(self.encoder_layers):
            x = layer(x)
            if i == 1:  # Add residual connection after second layer
                x = x + identity
        
        return self.fc_mu(x), self.fc_var(x)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        # Apply decoder layers with residual connections
        for i, layer in enumerate(self.decoder[:-1]):  # Exclude final layer
            z = layer(z)
            if i == 2:  # Add residual connection at the third layer
                identity = z
                z = z + identity
        
        return self.decoder[-1](z)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var
    
def train_attention_vae(model, train_loader, num_epochs, learning_rate, device):
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    scaler = GradScaler()
    
    best_loss = float('inf')
    patience = 10
    patience_counter = 0
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        total_recon_loss = 0
        total_kld_loss = 0
        
        for batch_idx, data in enumerate(train_loader):
            inputs = data[0].to(device)
            
            # Adaptive KLD weight
            kld_weight = min(1.0, (epoch * len(train_loader) + batch_idx) / 
                           (10 * len(train_loader)))
            
            # 수정된 autocast 사용
            with autocast():
                reconstructed, mu, log_var = model(inputs)
                loss, recon_loss, kld_loss = loss_function(
                    reconstructed, inputs, mu, log_var, kld_weight)
            
            optimizer.zero_grad()
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            scaler.step(optimizer)
            scaler.update()
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_kld_loss += kld_loss.item()
        
        scheduler.step()
        
        # Calculate average losses
        avg_loss = total_loss / len(train_loader)
        avg_recon_loss = total_recon_loss / len(train_loader)
        avg_kld_loss = total_kld_loss / len(train_loader)
        
        # Early stopping check
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience_counter = 0
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch}")
            break
        
        if (epoch + 1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Average Loss: {avg_loss:.4f}')
            print(f'Reconstruction Loss: {avg_recon_loss:.4f}')
            print(f'KLD Loss: {avg_kld_loss:.4f}')
            
            # Get and print feature importance distribution
            importance = model.attention.get_feature_importance()
            sorted_importance, indices = torch.sort(importance, descending=True)
            
            print("\nTop 10 most important features:")
            for idx, value in zip(indices[:10], sorted_importance[:10]):
                print(f"Feature {idx.item()}: {value.item():.4f}")

In [None]:
torch.manual_seed(42)
np.random.seed(42)

# 데이터 준비
vector1 = np.load('saved/default_bagged_resnet_feature.npy')
vector2 = np.load('saved/default_bagged_vit_feature.npy')
vector3 = np.load('saved/default_bagged_dino_feature.npy')
vector4 = np.load('saved/default_bagged_deit_feature.npy')

processor = FeatureProcessor()
concatenated_features = processor.normalize_and_concatenate([vector1, vector2, vector3, vector4])

data_tensor = torch.FloatTensor(concatenated_features)
dataset = TensorDataset(data_tensor)
batch_size = 64  # Reduced batch size for better stability
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 모델 초기화
input_dim = concatenated_features.shape[1]
encoding_dim = 1024
group_size = 128  # Increased group size
model = AttentionVAE(input_dim, encoding_dim, group_size).to(device)

# 학습 파라미터
num_epochs = 100
learning_rate = 0.001

# 모델 학습
train_attention_vae(model, train_loader, num_epochs, learning_rate, device)

In [None]:
def predict(model, input_vectors, processor, device, batch_size=64):
    """
    새로운 데이터에 대한 예측 수행
    
    Parameters:
    - model: 학습된 VAE 모델
    - input_vectors: list of numpy arrays [vector1, vector2, vector3, vector4]
    - processor: 학습에 사용된 FeatureProcessor 인스턴스
    - device: 'cuda' or 'cpu'
    - batch_size: 배치 크기
    
    Returns:
    - reconstructed_data: 재구성된 데이터
    - encoded_features: 인코딩된 특징
    - feature_importance: 각 특징의 중요도
    """
    model.eval()  # 평가 모드로 설정
    
    # 입력 데이터 전처리
    concatenated_features = processor.normalize_and_concatenate(input_vectors)
    data_tensor = torch.FloatTensor(concatenated_features)
    dataset = TensorDataset(data_tensor)
    dataloader = DataLoader(dataset, batch_size=batch_size)
    
    reconstructed_chunks = []
    encoded_chunks = []
    feature_importance = None
    
    with torch.no_grad():  # 그래디언트 계산 비활성화
        for batch in dataloader:
            inputs = batch[0].to(device)
            
            # 모델 통과
            reconstructed, mu, _ = model(inputs)
            
            # 결과 저장
            reconstructed_chunks.append(reconstructed.cpu().numpy())
            encoded_chunks.append(mu.cpu().numpy())
        
        # Feature importance 계산
        feature_importance = model.attention.get_feature_importance().cpu().numpy()
    
    # 결과 합치기
    reconstructed_data = np.concatenate(reconstructed_chunks, axis=0)
    encoded_features = np.concatenate(encoded_chunks, axis=0)
    
    return reconstructed_data, encoded_features, feature_importance

def analyze_reconstruction_quality(original_data, reconstructed_data):
    """
    재구성 품질 분석
    """
    # MSE 계산
    mse = np.mean((original_data - reconstructed_data) ** 2)
    
    # Feature별 MSE 계산
    feature_mse = np.mean((original_data - reconstructed_data) ** 2, axis=0)
    
    # 상위/하위 10개 feature의 재구성 품질
    worst_features = np.argsort(feature_mse)[-10:]
    best_features = np.argsort(feature_mse)[:10]
    
    print("\nReconstruction Quality Analysis:")
    print(f"Overall MSE: {mse:.4f}")
    
    print("\nBest Reconstructed Features:")
    for idx in best_features:
        print(f"Feature {idx}: MSE = {feature_mse[idx]:.4f}")
    
    print("\nWorst Reconstructed Features:")
    for idx in worst_features:
        print(f"Feature {idx}: MSE = {feature_mse[idx]:.4f}")
    
    return mse, feature_mse

def visualize_feature_importance(feature_importance, top_k=20):
    """
    Feature importance 시각화
    """
    import matplotlib.pyplot as plt
    
    # Top-k 중요 특징 선택
    top_indices = np.argsort(feature_importance)[-top_k:]
    top_importance = feature_importance[top_indices]
    
    # 시각화
    plt.figure(figsize=(12, 6))
    plt.bar(range(top_k), top_importance[::-1])
    plt.title('Top Feature Importance')
    plt.xlabel('Feature Rank')
    plt.ylabel('Importance Score')
    plt.xticks(range(top_k), [f'Feature {idx}' for idx in top_indices[::-1]], rotation=45)
    plt.tight_layout()
    plt.show()

# 사용 예시:
def predict_example():
    # 모델과 프로세서 로드 (이미 학습된 상태라고 가정)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 새로운 데이터 준비 (예시)
    vector1 = np.load('saved/default_bagged_resnet_feature.npy')
    vector2 = np.load('saved/default_bagged_vit_feature.npy')
    vector3 = np.load('saved/default_bagged_dino_feature.npy')
    vector4 = np.load('saved/default_bagged_deit_feature.npy')
    
    # 예측 수행
    reconstructed_data, encoded_features, feature_importance = predict(
        model, 
        [vector1, vector2, vector3, vector4],
        processor,
        device
    )
    
    # 원본 데이터 준비 (concatenated)
    original_data = processor.normalize_and_concatenate(
        [vector1, vector2, vector3, vector4]
    )
    
    # 재구성 품질 분석
    mse, feature_mse = analyze_reconstruction_quality(original_data, reconstructed_data)
    
    # Feature importance 시각화
    visualize_feature_importance(feature_importance)
    
    print(f"\nEncoded feature shape: {encoded_features.shape}")
    print(f"Reconstructed data shape: {reconstructed_data.shape}")
    
    return reconstructed_data, encoded_features, feature_importance

reconstructed_data, encoded_features, importance = predict_example()