In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.utils as utils
import torchvision.datasets as datasets
import torchvision.transforms as transforms

is_cuda = torch.cuda.is_available() #GPU 실행 (true, false)
device = torch.device('cuda' if is_cuda else 'cpu')

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import Transformer

class TransformerGenerator(nn.Module):
    def __init__(self, d_noise, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout):
        super(TransformerGenerator, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Linear(d_noise, d_model)
        self.transformer = Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.output_layer = nn.Linear(d_model, 28*28)
        self.tanh = nn.Tanh()
    
    def forward(self, z):
        z = self.embedding(z)
        z = z.unsqueeze(0)  # Add batch dimension
        transformer_out = self.transformer(z, z)
        transformer_out = transformer_out.squeeze(0)  # Remove batch dimension
        img = self.output_layer(transformer_out)
        img = self.tanh(img)
        img = img.view(-1, 1, 28, 28)
        return img

class TransformerDiscriminator(nn.Module):
    def __init__(self, d_model, nhead, num_encoder_layers, dim_feedforward, dropout):
        super(TransformerDiscriminator, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Linear(28*28, d_model)
        self.transformer = Transformer(d_model, nhead, num_encoder_layers, 0, dim_feedforward, dropout)
        self.output_layer = nn.Linear(d_model, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        img_emb = self.embedding(img_flat)
        img_emb = img_emb.unsqueeze(0)  # Add batch dimension
        transformer_out = self.transformer(img_emb, img_emb)
        transformer_out = transformer_out.squeeze(0)  # Remove batch dimension
        validity = self.output_layer(transformer_out)
        validity = self.sigmoid(validity)
        return validity


def sample_z(batch_size = 1, d_noise=100):
    return torch.randn(batch_size, d_noise, device=device)

In [None]:
# Hyperparameters
d_noise = 100
d_model = 256 #논문에서는 512차원
nhead = 8 #Q,K,V vector 256 / 8 = 32차원으로 축소, multihead attention 8번
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 512
dropout = 0.1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize models
generator = TransformerGenerator(d_noise, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout).to(device)
discriminator = TransformerDiscriminator(d_model, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# Loss function
adversarial_loss = nn.BCELoss()

z = sample_z()
img_fake = generator(z).view(-1,28,28)
#이미지 출력
imshow(img_fake.squ)

In [None]:
# Training loop
batch_size = 64
epochs = 10000

for epoch in range(epochs):
    # Train Discriminator
    optimizer_D.zero_grad()
    
    # Sample noise as generator input
    z = torch.randn(batch_size, d_noise).to(device)
    
    # Generate a batch of images
    gen_imgs = generator(z)
    
    # Generate a batch of real images
    real_imgs = torch.randn(batch_size, 1, 28, 28).to(device)  # Replace with real images from your dataset
    
    # Adversarial ground truths
    valid = torch.ones(batch_size, 1).to(device)
    fake = torch.zeros(batch_size, 1).to(device)
    
    # Loss for real images
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    
    # Loss for fake images
    fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
    
    # Total discriminator loss
    d_loss = (real_loss + fake_loss) / 2
    
    d_loss.backward()
    optimizer_D.step()
    
    # Train Generator
    optimizer_G.zero_grad()
    
    # Loss for fake images
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)
    
    g_loss.backward()
    optimizer_G.step()
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}/{epochs} - D loss: {d_loss.item()} - G loss: {g_loss.item()}")


In [None]:


# Initialize models
generator = Generator(d_noise, d_hidden).to(device)
discriminator = TransformerDiscriminator(d_model, nhead, num_encoder_layers, dim_feedforward, dropout).to(device)

# The rest of the training loop remains the same
