# Setting

In [79]:
# !pip install datasets
# !pip install transformers

In [80]:
import glob
import os
import pandas as pd
import numpy as np
import torch
import math

from PIL import Image
from sklearn.preprocessing import LabelEncoder

from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import random_split
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

# Load Data

In [81]:
H, W = 224, 224
y_coord = torch.linspace(0, 1, steps=H)
x_coord = torch.linspace(0, 1, steps=W)
grid_y, grid_x = torch.meshgrid(y_coord, x_coord)  # shape: (H, W)

# 좌표를 (H, W, 2) 형태로 합치기
coords = torch.stack((grid_x, grid_y), dim=-1)

In [82]:
def fourier_encode(coords, num_bands=4, max_freq=2.0): #원래는 num_bands=16이 일반적
    """
    coords: (H, W, 2)  # 2는 (x, y)
    num_bands: 몇 개의 주파수 대역으로 인코딩할지
    max_freq: 주파수 범위(가장 큰 주파수)
    """

    freq_bands = torch.linspace(1.0, max_freq, steps=num_bands)
    coords_expanded = coords.unsqueeze(-2)  # (H, W, 1, 2)

    per_pos_sin = torch.sin(
        2.0 * 3.1415926535 * coords_expanded * freq_bands.unsqueeze(-1)
    )
    per_pos_cos = torch.cos(
        2.0 * 3.1415926535 * coords_expanded * freq_bands.unsqueeze(-1)
    )

    fourier_features = torch.cat(
        [coords_expanded, per_pos_sin, per_pos_cos], dim=-2
    )

    H_, W_, tmp1, tmp2 = fourier_features.shape
    fourier_features = fourier_features.reshape(H_, W_, tmp1 * tmp2)

    return fourier_features  # (H, W, 2 + 2*num_bands*2)

In [83]:
class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.label_encoder = LabelEncoder()

        # 데이터 및 레이블 추출
        labels = []
        for folder in os.listdir(root_dir):
            folder_path = os.path.join(root_dir, folder)
            if os.path.isdir(folder_path):
                image_files = glob.glob(os.path.join(folder_path, "*.jpg"))
                if len(image_files) == 1:
                    image_path = image_files[0]
                else:
                    raise ValueError(f"폴더 {folder}에 JPG 파일이 하나가 아닙니다.")

                label_path = os.path.join(folder_path, "label.txt")
                if os.path.exists(label_path):
                    with open(label_path, "r") as f:
                        label = f.read().strip()
                        labels.append(label)
                        self.data.append((image_path, label))
                else:
                    raise FileNotFoundError(f"폴더 {folder}에 label.txt가 없습니다.")

        # LabelEncoder 학습
        self.label_encoder.fit(labels)

        # 레이블을 인코딩된 정수로 변환
        self.data = [(image_path, self.label_encoder.transform([label])[0]) for image_path, label in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image_path, label = self.data[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        _, H, W = image.shape

        # 2D coords 생성 후 Fourier 인코딩
        y_coord = torch.linspace(0, 1, steps=H)
        x_coord = torch.linspace(0, 1, steps=W)
        grid_y, grid_x = torch.meshgrid(y_coord, x_coord)
        coords = torch.stack((grid_x, grid_y), dim=-1)  # (H, W, 2)

        fourier_features = fourier_encode(coords, num_bands=4, max_freq=2.0)
        # (H, W, 2 + 2*num_bands*2)

        # image.shape => (3, H, W), fourier_features => (H, W, some_dim)
        # 둘을 concat하려면 차원을 맞춰야 하므로,
        # image (3, H, W) => (H, W, 3) 로 permute
        image_2d = image.permute(1, 2, 0)  # (H, W, 3)

        # 이어붙이기 (channel(dim)은 마지막에)
        # -> (H, W, 3 + (2 + 2*num_bands*2))
        combined = torch.cat([image_2d, fourier_features], dim=-1)

        # 이제 combined를 모델에 그대로 넘길지, 혹은 (H*W, channels)로 펼쳐 넘길지 결정
        combined = combined.view(-1, combined.shape[-1])  # (H*W, channels)
        # (224*224, 3 + ~)

        label = torch.tensor(label, dtype=torch.long)
        return combined, label

In [84]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 크기 조정
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  # 정규화
])

In [85]:
data_dir = '/home/Minju/Perceiver/data/n24news/image'
dataset = ImageDataset(root_dir=data_dir, transform=transform)

train_ratio = 0.8
train_size = int(len(dataset) * train_ratio)
valid_size = len(dataset) - train_size

train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

print(f'train:{train_size}, valid:{valid_size}')

train:39190, valid:9798


In [86]:
len(train_dataset), len(train_dataset[0]), len(train_dataset[0][0]), len(train_dataset[0][0][0])

(39190, 2, 50176, 21)

In [87]:
NUM_CLASSES = len(dataset.label_encoder.classes_)
print("NUM_CLASSES:", NUM_CLASSES)
print("classes:", dataset.label_encoder.classes_)

NUM_CLASSES: 24
classes: ['Art & Design' 'Automobiles' 'Books' 'Dance' 'Economy' 'Education'
 'Fashion & Style' 'Food' 'Global Business' 'Health' 'Media' 'Movies'
 'Music' 'Opinion' 'Real Estate' 'Science' 'Sports' 'Style' 'Technology'
 'Television' 'Theater' 'Travel' 'Well' 'Your Money']


# Prepare Data

## Fourier Positional Encoding

# Model

In [88]:
class Perceiver(nn.Module):
    def __init__(self, input_dim, latent_dim, latent_size, num_classes, num_latent_blocks):
        super(Perceiver, self).__init__()
        self.latents = nn.Parameter(torch.randn(latent_size, latent_dim))
        self.input_projection = nn.Linear(input_dim, latent_dim)  # 입력을 latent 차원으로 변환
        self.cross_attention = nn.MultiheadAttention(embed_dim=latent_dim, num_heads=8)
        self.self_attention = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model=latent_dim, nhead=8) for _ in range(num_latent_blocks)
        ])
        self.output_layer = nn.Linear(latent_dim, num_classes)

    def forward(self, x):
        B, T, F = x.size()  # B: Batch, T: Sequence length, F: Feature dim
        x = self.input_projection(x)  # Feature dim 변환 (F -> latent_dim)
        latents = self.latents.expand(B, -1, -1)  # Latent (batch_size, latent_size, latent_dim)
        x = x.permute(1, 0, 2)  # (T, B, latent_dim) for MultiheadAttention
        latents = latents.permute(1, 0, 2)  # (latent_size, B, latent_dim)
        latents, _ = self.cross_attention(latents, x, x)  # Cross Attention
        
        for layer in self.self_attention:
            latents = layer(latents)  # Latent Self-Attention
        
        latents = latents.permute(1, 0, 2).mean(dim=1)  # (B, latent_dim), Global Average Pooling
        return self.output_layer(latents)


In [89]:
model = Perceiver(input_dim=21, latent_dim=128, latent_size=64, num_classes=NUM_CLASSES, num_latent_blocks=4)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Train

In [90]:
# Train and Test 수정
def train_model(model, train_loader, valid_loader, criterion, optimizer, epochs):
    model.train()
    train_losses = []
    val_accuracies = []

    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_losses.append(total_loss / len(train_loader))
        accuracy = evaluate_model(model, valid_loader, log_results=False)
        val_accuracies.append(accuracy)

        print(f"Epoch {epoch+1}/{epochs}, Loss: {train_losses[-1]:.4f}, Val Accuracy: {val_accuracies[-1]:.2f}%")
    
    return train_losses, val_accuracies

def evaluate_model(model, valid_loader, log_results=True):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in valid_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    if log_results:
        print(f"Test Accuracy: {accuracy:.2f}%")
    return accuracy


In [91]:
# 모델 학습
epochs = 10
train_losses, val_accuracies = train_model(model, train_loader, valid_loader,
                                           criterion, optimizer, epochs)

Epoch 1/10, Loss: 3.1082, Val Accuracy: 5.37%
Epoch 2/10, Loss: 3.1109, Val Accuracy: 5.35%
Epoch 3/10, Loss: 3.1151, Val Accuracy: 5.01%


KeyboardInterrupt: 

In [None]:
def plot_learning_curve(train_losses, val_accuracies):
    fig, ax1 = plt.subplots()

    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Train Loss', color='tab:blue')
    ax1.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss', color='tab:blue')
    ax1.tick_params(axis='y', labelcolor='tab:blue')

    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation Accuracy (%)', color='tab:orange')
    ax2.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy', color='tab:orange')
    ax2.tick_params(axis='y', labelcolor='tab:orange')

    fig.tight_layout()
    plt.title('Learning Curve')
    plt.show()

plot_learning_curve(train_losses, val_accuracies)
