## 라이브러리 & Cuda 세팅

import torch
import torch.nn as nn
from torchinfo import summary
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn.functional as F

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
import torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)

In [None]:
device = 'cuda:1'
torch.manual_seed(seed)
if device == 'cuda:1':
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

if torch.cuda.is_available():
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print("using cuda:", torch.cuda.get_device_name(3))
    pass

## 데이터 전처리

path_y = 'Untitled Folder/AI_Proj/mask_y'
path_n = 'Untitled Folder/AI_Proj/mask_n'
image_size = 128
transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])

data_my = datasets.ImageFolder(path_y, transform = transform)
data_mn = datasets.ImageFolder(root = path_n, transform = transform)

y = []
for img, label in data_my:
    y.append(img.numpy())

n = []
for img, label in data_mn:
    n.append(img.numpy())
    
img_y = torch.Tensor(np.array(y)).to(device)
img_n = torch.Tensor(np.array(n)).to(device)

In [None]:
print(img_y.shape)
print(img_n.shape)

In [None]:
dataset = TensorDataset(img_y, img_n)
train_size = int(0.8*len(dataset))
test_size = len(dataset) - train_size
train, test = torch.utils.data.random_split(dataset, [train_size, test_size],generator=torch.Generator(device='cuda'))

In [None]:
batch_size= 64
dataloader = DataLoader(train, batch_size = batch_size, shuffle=False)
dataloader_test = DataLoader(test, batch_size = batch_size, shuffle=False)

## 모델

In [None]:
class generator(nn.Module):
    def __init__(self):
        super().__init__()
        gen_filt_num = 128
        
        self.conv1 = nn.Conv2d(3,gen_filt_num*2, 4,2,1,bias=False)
        self.act1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(gen_filt_num*2, gen_filt_num*4, 4,2,1,bias=False)
        self.bn2 = nn.BatchNorm2d(gen_filt_num * 4)
        self.act2 = nn.LeakyReLU(0.2)
        self.conv3 = nn.Conv2d(gen_filt_num*4, gen_filt_num * 8, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(gen_filt_num * 8)
        self.act3 = nn.LeakyReLU(0.2)
        
        self.conv4 =  nn.Conv2d(gen_filt_num*8, gen_filt_num*8, 3, 1, 1, bias=False)
        self.act4 = nn.LeakyReLU(0.2)
        self.conv5 = nn.Conv2d(gen_filt_num*8,gen_filt_num*8, 3, 1, 1, bias=False)
        self.leaky = nn.LeakyReLU(0.2)
        
        self.dconv1 = nn.ConvTranspose2d(gen_filt_num*8,gen_filt_num*4, kernel_size=4, stride=2, padding=1, bias=False)
        self.dbn1 = nn.BatchNorm2d(gen_filt_num*4)
        self.dact1 = nn.ReLU()
        self.dconv2 = nn.ConvTranspose2d(gen_filt_num*4,gen_filt_num*2, kernel_size=4, stride=2, padding=1, bias=False)
        self.dbn2 = nn.BatchNorm2d(gen_filt_num*2)
        self.dact2 = nn.ReLU()
        self.dconv3 =  nn.ConvTranspose2d(gen_filt_num*2,gen_filt_num, kernel_size=4, stride=2, padding=1, bias=False)
        self.dbn3 =  nn.BatchNorm2d(gen_filt_num)
        self.dact3 = nn.ReLU()
        self.dconv4 = nn.ConvTranspose2d(gen_filt_num,3, kernel_size=3, stride=1, padding=1, bias=False)
        self.dact4 = nn.Tanh()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        
        x = self.conv4(x)
        x = self.act4(x)
        x = x + self.conv5(x)
        x = self.leaky(x)
        x = self.conv4(x)
        x = self.act4(x)
        x = x +self.conv5(x)
        x = self.leaky(x)
        
        x = self.dconv1(x)
        x = self.dbn1(x)
        x = self.dact1(x)
        x = self.dconv2(x)
        x = self.dbn2(x)
        x = self.dact2(x)
        x = self.dconv3(x)
        x = self.dbn3(x)
        x = self.dact3(x)
        x = self.dconv4(x)
        x = self.dact4(x)
        
        return x

In [None]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        dis_filt_num = 128
        
        self.conv1 = nn.Conv2d(3, dis_filt_num, 4, 2, 1,bias=False)
        self.act1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(dis_filt_num, dis_filt_num * 2, 4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(dis_filt_num * 2)
        self.act2 = nn.LeakyReLU(0.2)
        self.conv3 = nn.Conv2d(dis_filt_num*2, dis_filt_num * 4, 4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(dis_filt_num * 4)
        self.act3 = nn.LeakyReLU(0.2)
        self.conv4 = nn.Conv2d(dis_filt_num*4, dis_filt_num * 8, 4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(dis_filt_num * 8)
        self.act4 = nn.LeakyReLU(0.2)
        self.conv5 = nn.Conv2d(dis_filt_num*8, dis_filt_num * 8, 4, 2, 1, bias=False)
        self.bn5 = nn.BatchNorm2d(dis_filt_num * 8)
        self.act5 = nn.LeakyReLU(0.2)
        self.conv6 = nn.Conv2d(dis_filt_num * 8, 1, 4, 1, 0, bias=False)
        self.act6 = nn.Sigmoid()
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.act1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act2(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act3(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.act4(x)
        x = self.conv5(x)
        x = self.bn5(x)
        x = self.act5(x)
        x = self.conv6(x)
        x = self.act6(x)
        
        return x.view(-1,1)

## 훈련

gen = generator().to(device)
dis = discriminator().to(device)

loss_fun = nn.BCELoss()
from torch import optim

optim_g = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
optim_d = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
start_time = time.time()

In [None]:
iterators = 0
epochs = 20
loss_hist = {'dis':[], 'gen':[]}

for epoch in range(epochs):
    for x, y in dataloader:

        batch_size = x.shape[0]
        
        y_real = torch.Tensor(batch_size, 1).fill_(1.0).to(device)
        y_fake = torch.Tensor(batch_size, 1).fill_(0.0).to(device)
        
        gen.zero_grad()
        out_g = gen(x)
        out_d = dis(out_g)
        loss_g = loss_fun(out_d, y_real)
        loss_g.backward()
        optim_g.step()
        
        dis.zero_grad()
        out_d = dis(y)
        loss_real = loss_fun(out_d, y_real)
        out_d = dis(out_g.detach())
        loss_fake = loss_fun(out_d, y_fake)
        
        loss_d = (loss_real + loss_fake) /2
        loss_d.backward()
        optim_d.step()
        
        loss_hist['gen'].append(loss_g.item())
        loss_hist['dis'].append(loss_d.item())
        
        iterators += 1
        print('>', end=' ')
        
        if iterators % 100 == 0:
            print('\n Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, loss_g.item(), loss_d.item(), (time.time()-start_time)/60))

In [None]:
plt.figure(figsize=(10,5))
plt.title('Loss Progress')
plt.plot(loss_hist['gen'])
plt.plot(loss_hist['dis'])
plt.xlabel('Batch count')
plt.ylabel('Loss value')
plt.legend(['generator','discriminator'])

## 테스트

In [None]:
with torch.no_grad():
    for x, y in dataloader_test:
        mask = x.detach().permute(0,2,3,1).cpu().numpy()
        face = y.detach().permute(0,2,3,1).cpu().numpy()
        img_fake = gen(x)

In [None]:
for i in range(16):
    img = img_fake.detach().permute(0,2,3,1).cpu().numpy()
    plt.subplot(4,4,i+1)
    plt.imshow(img[i]*0.5 + 0.5)
    plt.axis('off')

In [None]:
for i in range(16):
    plt.subplot(4,4,i+1)
    plt.imshow(face[i]*0.5 + 0.5)
    plt.axis('off')

## 성능 평가

In [None]:
fid = FrechetInceptionDistance(feature=64)

In [None]:
img = torch.Tensor(img)
img = img.type(torch.uint8)

In [None]:
img = img.permute(0,3,1,2)
img.shape

In [None]:
img.type()

In [None]:
face = torch.Tensor(face)
face = face.type(torch.uint8)

In [None]:
face = face.permute(0,3,1,2)
face.shape

In [None]:
face.type()

In [None]:
fid.update(face, real = True)
fid.update(img, real = False)
fid.compute()