In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn, einsum
import torch.nn.functional as F
#from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, PreNorm, FeeodForward, Attention

class CrossAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v):
        super(CrossAttention, self).__init__()
        self.key_proj = nn.Linear(d_in, d_out_kq)
        self.query_proj = nn.Linear(d_in, d_out_kq)
        self.value_proj = nn.Linear(d_in, d_out_v)
        self.softmax = nn.Softmax(dim=-1)           # 이게 뭐지

    def forward(self, x, latent):
        keys = self.key_proj(x)
        queries = self.query_proj(latent)
        values = self.value_proj(x)

        attention_scores = torch.matmul(queries, keys.transpose(-2, -1))
        attention_probs = self.softmax(attention_scores)

        attended_values = torch.matmul(attention_probs, values)
        return attended_values
    
class LatentTransformer(nn.Module):
    def __init__(self, latent_dim, num_heads, num_layers, embed_dim):
        super(LatentTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads) # trasformer 로 latent array 반복적으로 update
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, latent):
        latent = latent.permute(1,0,2)  # Transformer는 (seq_len, batch_size, latent_dim) 형식으로 데이터 받음.
        latent = self.transformer(latent)
        return latent.permute(1,0,2)    # 이걸 다시 (batch_size, latent_len, latent_dim으로 바꿈)
    
class Averaging(nn.Module):
    def forward(self, latent):
        return latent.mean(dim=1)   # latent vector를 평균내서 최종 logits 계산
    
class Perceiver(nn.Module):
    def __init__(self, input_dim, latent_dim, embed_dim, num_heads, num_layers, num_classes):
        super(Perceiver, self).__init__()
        self.input_proj = nn.Linear(input_dim, embed_dim)

        self.latents = nn.Parameter(torch.randn(1, latent_dim, embed_dim))

        self.cross_attention = CrossAttention(d_in=embed_dim, d_out_kq=embed_dim, d_out_v=embed_dim)
        self.latent_transformer = LatentTransformer(latent_dim=latent_dim, num_heads=num_heads, 
                                                    num_layers=num_layers, embed_dim=embed_dim)
        
        self.averaging = Averaging()
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        batch_size = x.size(0)
        x = self.input_proj(x)
        latent = self.latents.repeat(batch_size, 1, 1)    # batch 학습 시, batch 내 각각 샘플에 서로 다른 독립적인 latent값을 제공
        latent = self.cross_attention(x, latent)
        latent = self.latent_transformer(latent)
        latent_avg = self.averaging(latent)
        logits = self.classifier(latent_avg)
        return logits

In [4]:
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import BertTokenizer
import pandas as pd
import numpy as np

# 1. 데이터 로드
data_path = "/home/youlee/perceiver/perceiver/code/data/MNIST/raw/SMSSpamCollection"
data = pd.read_csv(data_path, sep="\t", header=None, names=["label", "text"])
data.head()

Unnamed: 0,label,text
0,ham,"Go until jurong point, crazy.. Available only ..."
1,ham,Ok lar... Joking wif u oni...
2,spam,Free entry in 2 a wkly comp to win FA Cup fina...
3,ham,U dun say so early hor... U c already then say...
4,ham,"Nah I don't think he goes to usf, he lives aro..."


In [5]:
# 2. 데이터 전처리
# 라벨 변환
label_encoder = LabelEncoder()
data['label'] = label_encoder.fit_transform(data['label'])  # spam: 1, ham: 0

# 토큰화
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
max_seq_len = 128

def tokenize_text(text):
    tokens = tokenizer(text, padding='max_length', max_length=max_seq_len, truncation=True, return_tensors="pt")
    return tokens['input_ids'].squeeze(0), tokens['attention_mask'].squeeze(0)

data['tokenized'] = data['text'].apply(tokenize_text)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [6]:
# 데이터셋 준비
class SMSDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_ids, attention_mask = self.data.iloc[idx]['tokenized']
        label = self.data.iloc[idx]['label']
        return input_ids, attention_mask, label

In [7]:
# Train/Test Split
train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
train_dataset = SMSDataset(train_data)
test_dataset = SMSDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [8]:
# 3. 모델 초기화
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_dim = max_seq_len
latent_dim = 16
embed_dim = 64
num_heads = 4
num_layers = 2
num_classes = 2

model = Perceiver(input_dim=input_dim, latent_dim=latent_dim, embed_dim=embed_dim,
                  num_heads=num_heads, num_layers=num_layers, num_classes=num_classes).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)



In [9]:
# 4. 학습 루프
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for input_ids, attention_mask, labels in loader:
        input_ids, labels = input_ids.to(device, dtype=torch.float32), labels.to(device)
        optimizer.zero_grad()
        outputs = model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)


In [10]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    with torch.no_grad():
        for input_ids, attention_mask, labels in loader:
            input_ids, labels = input_ids.to(device, dtype=torch.float32), labels.to(device)
            outputs = model(input_ids)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            preds = torch.argmax(outputs, dim=1)
            correct += (preds == labels).sum().item()
    accuracy = correct / len(loader.dataset)
    return total_loss / len(loader), accuracy

In [11]:
# Train the model
epochs = 10
for epoch in range(epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}")

Epoch 1/10, Train Loss: 0.4113, Test Loss: 0.3961, Test Accuracy: 0.8664
Epoch 2/10, Train Loss: 0.3999, Test Loss: 0.3945, Test Accuracy: 0.8664
Epoch 3/10, Train Loss: 0.3975, Test Loss: 0.3956, Test Accuracy: 0.8664
Epoch 4/10, Train Loss: 0.3955, Test Loss: 0.3948, Test Accuracy: 0.8664
Epoch 5/10, Train Loss: 0.3970, Test Loss: 0.3938, Test Accuracy: 0.8664
Epoch 6/10, Train Loss: 0.3963, Test Loss: 0.3953, Test Accuracy: 0.8664
Epoch 7/10, Train Loss: 0.3960, Test Loss: 0.3944, Test Accuracy: 0.8664
Epoch 8/10, Train Loss: 0.3957, Test Loss: 0.3946, Test Accuracy: 0.8664
Epoch 9/10, Train Loss: 0.3939, Test Loss: 0.3932, Test Accuracy: 0.8664
Epoch 10/10, Train Loss: 0.3944, Test Loss: 0.3941, Test Accuracy: 0.8664
