<a href="https://colab.research.google.com/github/MeatHub/Attention-Is-All-You-Need-Review/blob/main/transformer_%EA%B5%AC%ED%98%842.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip -q install datasets transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader, TensorDataset
from datasets import load_dataset
from transformers import AutoTokenizer
from dataclasses import dataclass
from tqdm.auto import tqdm

In [None]:
dataset = load_dataset('imdb')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [None]:
dataset

In [None]:
def tokenize(batch):
    return tokenizer(batch['text'], padding='max_length', truncation=True, max_length=256)

train_data = dataset['train'].shuffle(seed=42)
test_data = dataset['test'].shuffle(seed=42)

train_tokenized = train_data.map(tokenize, batched=True)
test_tokenized = test_data.map(tokenize, batched=True)

def make_loader(data, batch_size=16):
    ids = torch.tensor(data['input_ids'])
    labels = torch.tensor(data['label'])
    return DataLoader(TensorDataset(ids, labels), batch_size=batch_size, shuffle=True)

train_loader = make_loader(train_tokenized, batch_size=16)
test_loader = make_loader(test_tokenized, batch_size=16)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # div_term = 1 / 10000^(2i/d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        # PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        # PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x = x + PE: Embedding Î≤°ÌÑ∞Ïóê ÏúÑÏπò Ï†ïÎ≥¥ Ìï©ÏÇ∞
        return self.dropout(x + self.pe[:, :x.size(1)])

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.fc = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v):
        batch_size = q.size(0)

        # Linear projection: Q, K, V ÏÉùÏÑ± (Query, Key, Value)
        # Split into heads: (batch, seq_len, num_heads, d_k) -> (batch, num_heads, seq_len, d_k)
        q = self.w_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.w_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.w_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Scaled Dot-Product Attention: Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = torch.softmax(scores, dim=-1)

        # MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
        context = torch.matmul(self.dropout(attn), v).transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.d_k)
        return self.fc(context)

In [None]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)

        # FFN(x) = max(0, xW1 + b1)W2 + b2
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Sublayer 1: Residual Connection & Layer Normalization -> LayerNorm(x + Sublayer(x))
        attn_out = self.mha(x, x, x)
        x = self.norm1(x + self.dropout(attn_out))

        # Sublayer 2: Position-wise Feed-Forward Network & LayerNorm
        ffn_out = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_out))
        return x

In [None]:
class TransformerSentimentModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers=2, num_classes=2, dropout=0.1):
        super().__init__()
        # Input Embedding: Î¨∏ÏûêÎ•º d_model Ï∞®ÏõêÏùò Î≤°ÌÑ∞Î°ú Î≥ÄÌôò
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)

        # N x Encoder Layers: ÎÖºÎ¨∏Ïùò Nx Íµ¨Ï°∞ (Ï§ëÏ≤©Îêú Ïù∏ÏΩîÎçî Ï∏µ)
        self.layers = nn.ModuleList([TransformerEncoderLayer(d_model, num_heads, dropout) for _ in range(num_layers)])

        # Final Linear Layer: Î∂ÑÎ•òÎ•º ÏúÑÌïú Ï∂úÎ†•Ï∏µ
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for layer in self.layers:
            x = layer(x)

        # Global Average Pooling: (batch, seq_len, d_model) -> (batch, d_model)
        x = x.mean(dim=1)
        return self.classifier(x)

In [None]:
class Trainer:
    def __init__(self, model, config, train_loader, test_loader):
        self.model = model.to(config.device)
        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=config.lr)
        self.criterion = nn.CrossEntropyLoss()

    def train(self):
        for epoch in range(self.config.epochs):
            self.model.train()
            total_loss = 0
            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch+1}"):
                ids, labels = [b.to(self.config.device) for b in batch]
                self.optimizer.zero_grad()
                loss = self.criterion(self.model(ids), labels)
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            acc = self.evaluate()
            print(f"Loss: {total_loss/len(self.train_loader):.4f} | Test Acc: {acc:.2f}%")

    def evaluate(self):
        self.model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for batch in self.test_loader:
                ids, labels = [b.to(self.config.device) for b in batch]
                out = self.model(ids)
                correct += (out.argmax(1) == labels).sum().item()
                total += labels.size(0)
        return 100 * correct / total

def predict_sentiment(text, model, tokenizer, device):
    model.eval()
    with torch.no_grad():
        inputs = tokenizer(text, return_tensors="pt", padding='max_length',
                          truncation=True, max_length=256).to(device)
        output = model(inputs['input_ids'])
        prob = torch.softmax(output, dim=-1)
        pred = output.argmax(1).item()
        label = "Í∏çÏ†ï üòä" if pred == 1 else "Î∂ÄÏ†ï üò°"
        print(f"Î¶¨Î∑∞: {text}\nÍ≤∞Í≥º: {label} ({prob[0][pred].item()*100:.2f}%)")
        print("-" * 50)

In [None]:
@dataclass
class Config:
    vocab_size: int = tokenizer.vocab_size
    d_model: int = 128
    num_heads: int = 8
    epochs: int = 12
    lr: float = 1e-4
    device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()
model = TransformerSentimentModel(config.vocab_size, config.d_model, config.num_heads)

In [None]:
trainer = Trainer(model, config, train_loader, test_loader)
trainer.train()

In [None]:
test_texts = [
    "This movie was a masterpiece. The depth of the characters was incredible.",
    "I hated this film. It was way too long and very boring.",
    "It was okay, but the ending was a bit disappointing."
]

for text in test_texts:
    predict_sentiment(text, model, tokenizer, config.device)