# Setting

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.manifold import TSNE
import pandas as pd
from transformers import BertTokenizer, BertModel
import os
from PIL import Image

# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"사용 중인 디바이스: {device}")

# 라벨 번호를 텍스트로 매핑
label_map = {
    0: "T-shirt/top",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle boot"
}

num_classes = len(label_map)

# 텍스트 라벨을 숫자 인덱스로 매핑
text_to_idx = {v: k for k, v in label_map.items()}

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

# Fashion MNIST 데이터셋 다운로드
train_dataset = datasets.FashionMNIST(root='./Fashion_MNIST_dataset', train=True, transform=transform, download=True)
test_dataset = datasets.FashionMNIST(root='./Fashion_MNIST_dataset', train=False, transform=transform, download=True)

# 데이터 로더 설정
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


사용 중인 디바이스: cuda


# txt2img Dataset

In [24]:
class TextImageDataset(Dataset):
    def __init__(self, dataset, label_map, tokenizer, max_length=32):
        """
        dataset: PyTorch의 Dataset 객체 (FashionMNIST)
        label_map: 라벨 번호를 텍스트로 매핑하는 딕셔너리
        tokenizer: 텍스트 인코딩을 위한 토크나이저 (BERT)
        max_length: 텍스트 캡션의 최대 길이
        """
        self.dataset = dataset
        self.label_map = label_map
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        caption = self.label_map[label]
        
        # BERT 토크나이저를 사용하여 텍스트 인코딩
        encoding = self.tokenizer.encode_plus(
            caption,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        input_ids = encoding['input_ids'].squeeze()  # [max_length]
        attention_mask = encoding['attention_mask'].squeeze()  # [max_length]
        
        return img, input_ids, attention_mask, label
    
# BERT 토크나이저 초기화
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# 커스텀 데이터셋 생성
train_dataset_text = TextImageDataset(train_dataset, label_map, tokenizer, max_length=32)
test_dataset_text = TextImageDataset(test_dataset, label_map, tokenizer, max_length=32)

# 데이터로더 설정
train_loader_text = DataLoader(train_dataset_text, batch_size=batch_size, shuffle=True)
test_loader_text = DataLoader(test_dataset_text, batch_size=batch_size, shuffle=False)


# Models

In [25]:
class TextEncoder(nn.Module):
    def __init__(self, embedding_dim):
        super(TextEncoder, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(self.bert.config.hidden_size, embedding_dim)
        self.relu = nn.ReLU()
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output  # [batch_size, hidden_size]
        embedding = self.fc(pooled_output)  # [batch_size, embedding_dim]
        embedding = self.relu(embedding)
        return embedding

class Generator(nn.Module):
    def __init__(self, latent_dim, embedding_dim, img_channels, img_size):
        super(Generator, self).__init__()
        self.init_size = img_size // 4  # 초기 크기 설정 (업샘플링 전)
        self.l1 = nn.Sequential(nn.Linear(latent_dim + embedding_dim, 128 * self.init_size ** 2))
        
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_channels, 3, stride=1, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z, text_embedding):
        # 노이즈 벡터와 텍스트 임베딩 결합
        gen_input = torch.cat((z, text_embedding), dim=1)
        out = self.l1(gen_input)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_channels, img_size, embedding_dim):
        super(Discriminator, self).__init__()
        
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [
                nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout2d(0.25)
            ]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        
        self.model = nn.Sequential(
            *discriminator_block(img_channels, 16, bn=False),  # 28x28 -> 14x14
            *discriminator_block(16, 32),                      # 14x14 -> 7x7
            *discriminator_block(32, 64),                      # 7x7 -> 4x4 (소수점 아래는 버려집니다)
            *discriminator_block(64, 128),                     # 4x4 -> 2x2
        )
        
        # 텍스트 임베딩을 이미지의 특징 공간에 맞게 변환
        self.label_embedding = nn.Sequential(
            nn.Linear(embedding_dim, 128 * 2 * 2),  # 128 * 2 * 2 = 512
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        self.adv_layer = nn.Sequential(
            nn.Linear(128 * 2 * 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, img, text_embedding):
        out = self.model(img)  # [batch_size, 128, 2, 2]
        # 텍스트 임베딩을 이미지의 특징 공간에 맞게 변환
        label_emb = self.label_embedding(text_embedding)  # [batch_size, 512]
        label_emb = label_emb.view(out.shape[0], 128, 2, 2)  # [batch_size, 128, 2, 2]
        # 이미지와 텍스트 임베딩 결합 (element-wise multiplication)
        combined = out * label_emb  # [batch_size, 128, 2, 2]
        combined = combined.view(combined.size(0), -1)  # [batch_size, 512]
        validity = self.adv_layer(combined)  # [batch_size,1]
        return validity


# Training

In [26]:
# 하이퍼파라미터 설정
latent_dim = 100
embedding_dim = 256  # 텍스트 임베딩 벡터의 차원
img_size = 28
img_channels = 1  # Fashion MNIST는 흑백 이미지

# 텍스트 인코더 초기화
text_encoder = TextEncoder(embedding_dim).to(device)

# 생성자 및 판별자 초기화
generator = Generator(latent_dim, embedding_dim, img_channels, img_size).to(device)
discriminator = Discriminator(img_channels, img_size, embedding_dim).to(device)

# 손실 함수 정의
adversarial_loss = nn.BCELoss()

# 최적화 기법 설정
lr = 0.0002
optimizer_G = optim.Adam(list(generator.parameters()) + list(text_encoder.parameters()), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))


In [27]:
import torchvision
import torchvision.utils as vutils

# 학습 파라미터
epochs = 100
sample_interval = 10  # 몇 에폭마다 이미지 샘플링할지

# 손실 기록용 리스트
G_losses = []
D_losses = []

# 학습 루프
for epoch in range(1, epochs + 1):
    for batch_idx, (real_imgs, input_ids, attention_mask, labels) in enumerate(train_loader_text):
        batch_size_current = real_imgs.size(0)
        
        real_imgs = real_imgs.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)
        
        # 진짜와 가짜 레이블
        valid = torch.ones(batch_size_current, 1).to(device)
        fake = torch.zeros(batch_size_current, 1).to(device)
        
        # ---------------------
        #  판별자 학습
        # ---------------------
        discriminator.zero_grad()
        
        # 텍스트 임베딩
        text_embedding_real = text_encoder(input_ids, attention_mask)
        
        # 진짜 이미지 판별
        validity_real = discriminator(real_imgs, text_embedding_real)
        d_real_loss = adversarial_loss(validity_real, valid)
        
        # 가짜 이미지 생성
        z = torch.randn(batch_size_current, latent_dim).to(device)
        gen_imgs = generator(z, text_embedding_real)
        
        # 가짜 이미지 판별
        validity_fake = discriminator(gen_imgs.detach(), text_embedding_real)
        d_fake_loss = adversarial_loss(validity_fake, fake)
        
        # 총 판별자 손실
        d_loss = (d_real_loss + d_fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()
        
        # ---------------------
        #  생성자 학습
        # ---------------------
        generator.zero_grad()
        text_encoder.zero_grad()
        
        # 텍스트 임베딩
        text_embedding = text_encoder(input_ids, attention_mask)
        
        # 가짜 이미지 재생성
        gen_imgs = generator(z, text_embedding)
        
        # 판별자가 가짜 이미지를 진짜로 인식하도록 유도
        validity = discriminator(gen_imgs, text_embedding)
        g_loss = adversarial_loss(validity, valid)
        
        g_loss.backward()
        optimizer_G.step()
        
    # 손실 기록
    G_losses.append(g_loss.item())
    D_losses.append(d_loss.item())
    
    print(f"Epoch [{epoch}/{epochs}]  Loss_D: {d_loss.item():.4f}  Loss_G: {g_loss.item():.4f}")
    
    # 이미지 샘플링
    if epoch % sample_interval == 0 or epoch == 1:
        with torch.no_grad():
            # 샘플링할 라벨 생성 (예: 0~9 반복)
            sample_labels = torch.tensor([i % 10 for i in range(16)], dtype=torch.long).to(device)
            sample_input_ids = sample_labels  # 라벨 인덱스 사용
            sample_attention_mask = torch.ones(sample_labels.size(0), 32).to(device)  # max_length=32
            
            # 텍스트 임베딩
            sample_text_embedding = text_encoder(sample_input_ids, sample_attention_mask)
            
            # 노이즈 벡터 생성
            z = torch.randn(16, latent_dim).to(device)
            
            # 이미지 생성
            gen_imgs = generator(z, sample_text_embedding).view(-1, img_channels, img_size, img_size)
            gen_imgs = gen_imgs.cpu()
            
            # 이미지 시각화
            grid = vutils.make_grid(gen_imgs, nrow=4, normalize=True)
            plt.figure(figsize=(8,8))
            plt.imshow(np.transpose(grid, (1, 2, 0)))
            plt.title(f"Epoch {epoch}")
            plt.axis('off')
            plt.show()


KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator Loss")
plt.plot(D_losses, label="Discriminator Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Test

In [None]:
def visualize_tsne(generator, text_encoder, dataloader, latent_dim, device, num_samples=1000):
    """
    t-SNE를 사용하여 실 데이터와 생성 데이터를 시각화
    """
    generator.eval()
    text_encoder.eval()
    real_samples = []
    fake_samples = []
    
    # 실 데이터 수집
    for i, (imgs, input_ids, attention_mask, labels) in enumerate(dataloader):
        real_samples.append(imgs.view(imgs.size(0), -1).cpu())
        if len(real_samples) * imgs.size(0) >= num_samples:
            break
    real_samples = torch.cat(real_samples)[:num_samples].numpy()
    
    # 가짜 데이터 수집
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim).to(device)
        # 임의의 라벨 사용 (여기서는 무작위 라벨)
        sample_labels_indices = torch.randint(0, num_classes, (num_samples,)).to(device)
        sample_attention_mask = torch.ones(sample_labels_indices.size(0), 32).to(device)  # max_length=32
        sample_text_embedding = text_encoder(sample_labels_indices, sample_attention_mask)
        fake_imgs = generator(z, sample_text_embedding).cpu().numpy()
    fake_samples = fake_imgs[:num_samples]
    
    # t-SNE 적용
    tsne = TSNE(n_components=2, random_state=42)
    real_tsne = tsne.fit_transform(real_samples)
    fake_tsne = tsne.fit_transform(fake_samples)
    
    # 데이터프레임 생성
    df_real = pd.DataFrame(real_tsne, columns=['x', 'y'])
    df_real['type'] = 'Real'
    df_fake = pd.DataFrame(fake_tsne, columns=['x', 'y'])
    df_fake['type'] = 'Fake'
    df = pd.concat([df_real, df_fake])
    
    # 시각화
    plt.figure(figsize=(8,6))
    sns.scatterplot(data=df, x='x', y='y', hue='type', alpha=0.5)
    plt.title('t-SNE Visualization of Real vs. Fake Data')
    plt.show()

In [None]:
# t-SNE 시각화
visualize_tsne(generator, text_encoder, train_loader_text, latent_dim, device)