In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm  # 학습 진행 상황을 보여주기 위한 라이브러리

# 하이퍼파라미터 설정
BATCH_SIZE = 64
EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_CLASSES = 2
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4  # 학습 안정성을 위해 학습률을 낮춤
MAX_VOCAB_SIZE = 20000  # 어휘 사전 크기 증가
MAX_SEQ_LEN = 512  # 시퀀스 길이 조정

# 디바이스 설정 (GPU 사용 가능 시 사용)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 데이터 준비
tokenizer = get_tokenizer('basic_english')

def yield_tokens(data_iter):
    for label, text in data_iter:
        yield tokenizer(text)

# 어휘 사전 구축
train_iter, test_iter = IMDB()
vocab = build_vocab_from_iterator(yield_tokens(train_iter), max_tokens=MAX_VOCAB_SIZE, specials=["<pad>", "<unk>"])
vocab.set_default_index(vocab["<unk>"])

# 데이터셋을 다시 로드하여 사용 (중요)
train_iter, test_iter = IMDB()

def text_pipeline(text):
    tokens = tokenizer(text)
    token_ids = [vocab[token] for token in tokens]
    # 시퀀스 길이 조정
    if len(token_ids) > MAX_SEQ_LEN:
        token_ids = token_ids[:MAX_SEQ_LEN]
    else:
        token_ids += [vocab["<pad>"]] * (MAX_SEQ_LEN - len(token_ids))
    return torch.tensor(token_ids, dtype=torch.long)

def label_pipeline(label):
    return torch.tensor(1 if label == 'pos' else 0, dtype=torch.long)

# 데이터셋 생성
class IMDBDataset(torch.utils.data.Dataset):
    def __init__(self, data_iter):
        self.data = []
        for label, text in data_iter:
            try:
                text_tensor = text_pipeline(text)
                label_tensor = label_pipeline(label)
                self.data.append((text_tensor, label_tensor))
            except Exception as e:
                # 데이터 처리 중 발생하는 오류 처리
                print(f"데이터 처리 오류: {e}")
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

train_dataset = IMDBDataset(train_iter)
test_dataset = IMDBDataset(test_iter)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# 모델 정의
class CNNTransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes):
        super(CNNTransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab["<pad>"])
        self.position_embedding = nn.Embedding(MAX_SEQ_LEN, embedding_dim)

        # CNN 인코더
        self.cnn_encoder = nn.Conv1d(in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=3, padding=1, stride=2)
        self.cnn_encoder_residual = nn.Conv1d(embedding_dim, embedding_dim, kernel_size=1, stride=2)

        # 트랜스포머 인코더 레이어
        encoder_layer = nn.TransformerEncoderLayer(d_model=embedding_dim, nhead=8, dropout=0.1, activation='relu')
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2, norm=nn.LayerNorm(embedding_dim))

        # CNN 디코더
        self.cnn_decoder = nn.ConvTranspose1d(in_channels=embedding_dim, out_channels=embedding_dim, kernel_size=3, padding=1, stride=2, output_padding=1)
        self.cnn_decoder_residual = nn.ConvTranspose1d(embedding_dim, embedding_dim, kernel_size=1, stride=2, output_padding=1)

        # 출력 레이어
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        batch_size, seq_len = x.size()
        # 임베딩 및 포지셔널 인코딩 추가
        x = self.embedding(x)  # [batch_size, seq_len, embedding_dim]
        positions = torch.arange(0, seq_len).unsqueeze(0).expand(batch_size, seq_len).to(device)
        x = x + self.position_embedding(positions)
        x = x.permute(0, 2, 1)  # [batch_size, embedding_dim, seq_len]

        # CNN 인코더와 잔차 연결
        residual = self.cnn_encoder_residual(x)
        x = self.cnn_encoder(x)
        x = nn.ReLU()(x + residual)

        x = x.permute(2, 0, 1)  # [seq_len', batch_size, embedding_dim]

        # 패딩 마스크 생성
        src_key_padding_mask = (x.abs().sum(dim=2) == 0).transpose(0, 1)

        # 트랜스포머 인코더
        x = self.transformer_encoder(x, src_key_padding_mask=src_key_padding_mask)

        x = x.permute(1, 2, 0)  # [batch_size, embedding_dim, seq_len']

        # CNN 디코더와 잔차 연결
        residual = self.cnn_decoder_residual(x)
        x = self.cnn_decoder(x)
        x = nn.ReLU()(x + residual)

        # 글로벌 평균 풀링
        x = x.mean(dim=2)  # [batch_size, embedding_dim]

        x = self.dropout(x)
        logits = self.fc(x)  # [batch_size, num_classes]
        return logits

# 모델 초기화
model = CNNTransformerModel(len(vocab), EMBEDDING_DIM, NUM_CLASSES).to(device)

# 손실 함수와 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 학습 루프
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"에포크 {epoch+1}/{NUM_EPOCHS}")
    for texts, labels in progress_bar:
        texts, labels = texts.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    avg_loss = total_loss / len(train_loader)
    print(f"에포크 [{epoch+1}/{NUM_EPOCHS}], 평균 손실: {avg_loss:.4f}")

    # 매 에포크 후 평가
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for texts, labels in test_loader:
            texts, labels = texts.to(device), labels.to(device)
            outputs = model(texts)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    print(f"에포크 {epoch+1} 후 테스트 정확도: {accuracy * 100:.2f}%\n")

# 모델 저장
torch.save(model.state_dict(), 'cnn_transformer_model.pth')


################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

에포크 1/5:   8%|▊         | 30/391 [00:43<08:22,  1.39s/it, loss=0.0267]