In [161]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

In [162]:
torch.__version__

'2.0.0'

In [163]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [164]:
# mnist image
image_shape = (1,28,28)
BATCH_SIZE = 128

# 데이터 셋 구성
transform = transforms.Compose([transforms.Resize(image_shape[-2:]),
                                transforms.ToTensor(),
                                transforms.Normalize([0.5],[0.5])]
                                )
dataset = datasets.MNIST(root='./data',train=True,download=True,transform=transform)
dataloader = DataLoader(dataset,batch_size=BATCH_SIZE, shuffle=True)


In [165]:
# Model 생성
# Generator
class Generator(nn.Module):
    def __init__(self,latent_dim, num_classes, image_shape):
        super().__init__()
        
        self.latent_dim = latent_dim
        self.label_emb = nn.Embedding(num_classes,latent_dim)

        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, int(image_shape[0]*image_shape[1]*image_shape[2])),
            nn.Tanh() 
        )
    def forward(self, noise, labels):
        x = torch.mul(self.label_emb(labels),noise)
        x = self.model(x)
        return x.view(x.size(0), *image_shape)
        # return x.reshape(x.size(0), *image_shape)

In [166]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self,num_classes,image_shape):
        super().__init__()

        self.label_emb = nn.Embedding(num_classes,int(image_shape[0]*image_shape[1]*image_shape[2]))
        
        self.model = nn.Sequential(
            # *2: 원래 들어간 정보와 생성한 이미지 두개이기 때문에 2를 곱해준다.
            nn.Linear(int(image_shape[0]*image_shape[1]*image_shape[2])*2,128),
            nn.ReLU(),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self,images,labels):
        # view는 reshape와 비슷한 기능을 한다.
        # -1: 가로로 합치자
        x = torch.cat((images.view(images.size(0),-1),self.label_emb(labels)),-1)
        x = self.model(x)
        return x

In [167]:
# hyperparameter
latent_dim = 100    # 고객의 요구사항
num_classes = 10    # 클래스 개수 정하는법?
num_epochs = 100     # 학습 횟수
learning_rate = 2e-3
device = 'mps'

In [168]:
# model
generator = Generator(latent_dim,num_classes,image_shape)
discriminator = Discriminator(num_classes,image_shape)

In [169]:
# lossfunction, optimizer
from torch import optim
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(),lr=learning_rate)
optimizer_d = optim.Adam(discriminator.parameters(),lr=learning_rate)

In [170]:
# 학습
for epoch in range(num_epochs):
    for i,(imgs,labels) in enumerate(dataloader):
        real_imgs = imgs.to(device)
        real_labels = labels.to(device)

        # real data loss
        real_pred = discriminator(real_imgs,real_labels)
        # 진짜 이미지만 학습시켰기 때문에 뒤에는 y_labels가 아닌 1을 넣어준다.(진짜는 1 가짜는 0으로 판별하기로 함.)
        real_loss = criterion(real_pred,torch.ones_like(real_pred)) # 1(진짜)을 넣는 요령 -> real_pred모양으로

        # fake data gen
        noise = torch.randn(BATCH_SIZE,latent_dim).to(device)
        gen_labels=torch.randint(0,num_classes,(BATCH_SIZE,)).to(device)
        gen_imgs = generator(noise,gen_labels)

        # fake data loss
        fake_pred = discriminator(gen_imgs,gen_labels)
        fake_loss = criterion(fake_pred,torch.zeros_like(fake_pred))

        # update
        d_loss = real_loss + fake_loss
        optimizer_d.zero_grad()
        d_loss.backward()
        optimizer_d.step()

        # generator update
        fake_pred = discriminator(gen_imgs,gen_labels)
        fake_loss = criterion(fake_pred,torch.ones_like(fake_pred))
        optimizer_g.zero_grad()
        fake_loss.backward()
        optimizer_g.step()
    clear_output()
    # squeeze ?
    plt.imshow(gen_imgs.detach().numpy()[0].squeeze(0))
    plt.title(str(epoch)+'회')
    plt.show()

RuntimeError: Placeholder storage has not been allocated on MPS device!