# Setting

In [None]:
import numpy as np
import pandas as pd
import torch
from torch import nn, einsum
import torch.nn.functional as F
from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeeodForward, Attention

# Model Architecture

## Fourier Feature Position Encoding

In [None]:
# class FourierFeatureEncoding(nn.Module):
#     '''MNIST에 테스트해보기 위해 임의로
#     chatGPT를 사용해 만들었습니다. 추후 수정 필요'''
#     def __init__(self, num_features, max_freq=10.0):
#         super(FourierFeatureEncoding, self).__init__()
#         self.num_features = num_features
#         self.max_freq = max_freq
#         # 주파수 벡터 생성 (로그 스케일링 권장)
#         self.register_buffer('freq_bands', torch.linspace(1, max_freq, num_features))

#     def forward(self, x, height, width):
#         """
#         x: 입력 텐서 (batch_size, input_len, input_dim)
#         height: 이미지의 높이
#         width: 이미지의 너비
#         """
#         batch_size, input_len, input_dim = x.shape
#         # (input_len,) -> (height, width)
#         assert input_len == height * width, "input_len must be height * width"

#         # 좌표 그리드 생성
#         y_coords = torch.arange(0, height).float() / height  # [0, 1)
#         x_coords = torch.arange(0, width).float() / width    # [0, 1)
#         y_grid, x_grid = torch.meshgrid(y_coords, x_coords)
#         y_grid = y_grid.flatten().to(x.device)  # (input_len,)
#         x_grid = x_grid.flatten().to(x.device)  # (input_len,)

#         # 좌표를 [batch_size, input_len, 2] 형태로 확장
#         coords = torch.stack([x_grid, y_grid], dim=1).unsqueeze(0).repeat(batch_size, 1, 1)  # (batch_size, input_len, 2)

#         # 주파수 변환
#         coords = coords * 2 * torch.pi  # 주기적 함수의 주기를 맞추기 위해 2π를 곱함
#         coords = coords.unsqueeze(-1) * self.freq_bands  # (batch_size, input_len, 2, num_features)
#         coords = coords.view(batch_size, input_len, -1)  # (batch_size, input_len, 2 * num_features)

#         # 사인과 코사인 적용
#         fourier_feats = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)  # (batch_size, input_len, 4 * num_features)

#         return fourier_feats

## Model Elements

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super(CrossAttention, self).__init__()
        self.key_proj = nn.Linear(d_in, d_out_kq)
        self.query_proj = nn.Linear(d_in, d_out_kq)
        self.value_proj = nn.Linear(d_in, d_out_v)
        self.softmax = nn.Softmax(dim=-1)           # 이게 뭐지

    def forward(self, x, latent):
        keys = self.key_proj(x)
        queries = self.query_proj(latent)
        values = self.value_proj(x)

        attention_scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention_probs = self.softmax(attention_scores)

        attended_values = torch.matmul(attention_probs, values)
        return attended_values

In [None]:
class LatentTransformer():
    def __init__(self, latent_dim, num_heads, num_layers):
        super(LatentTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=latent_dim, nhead=num_heads) # trasformer 로 latent array 반복적으로 update
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, latent):
        latent = latent.permute(1,0,2)  # Transformer는 (seq_len, batch_size, latent_dim) 형식으로 데이터 받음.
        latent = self.transformer(latent)
        return latent.permute(1,0,2)    # 이걸 다시 (batch_size, latent_len, latent_dim으로 바꿈)

In [None]:
class Averaging():
    def forward(self, latent):
        return latent.mean(dim=1)   # latent vector를 평균내서 최종 logits 계산

## Define Model

In [None]:
class Perceiver(nn.Module):
    def __init__(self, input_dim, latent_dim, embed_dim, num_heads, num_layers, num_classes):
        super(Perceiver, self).__init__()
        self.input_proj = nn.Linear(input_dim, embed_dim)

        self.latents = nn.Parameter(torch.randn(1, latent_dim, embed_dim))

        self.cross_attention = CrossAttention(d_in=embed_dim, d_out_kq=embed_dim, d_out_v=embed_dim)
        self.latent_transformer = LatentTransformer(latent_dim=latent_dim, num_heads=num_heads, num_layers=num_layers)
        
        self.averaging = nn.Averaging()
        self.classificer = nn.Linear(embed_dim, num_classes)

    def forward(self, x, latent, batch_size):
        x = self.input_proj(x)
        latent = self.latents.repeat(batch_size, 1, 1)    # batch 학습 시, batch 내 각각 샘플에 서로 다른 독립적인 latent값을 제공
        latent = self.cross_attention(x, latent)
        latent = self.latent_transformer(latent)
        latent_avg = self.averaging(latent)
        logits = self.classifier(latent_avg)
        return logits