<a href="https://colab.research.google.com/github/Jih00nJung/GAN_AI/blob/main/GAN_FMNIST_sepa1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. config.py (설정 및 환경 준비)

In [4]:
%%writefile config.py
import os
import argparse
from torchvision import transforms
from google.colab import drive # Colab 환경 전용 import

def get_config():
    """학습에 필요한 모든 하이퍼파라미터를 정의합니다."""
    parser = argparse.ArgumentParser()

    # 데이터 및 경로 설정
    parser.add_argument('--project_name', type=str, default='GAN_FMNIST_sepa')
    # 사용자 요청에 따라 경로 수정
    parser.add_argument('--save_root', type=str, default='/content/drive/MyDrive/Colab Notebooks/GAN_assignment')
    parser.add_argument('--img_size', type=int, default=64, help='이미지 크기 (FMNIST 기본 28 -> 64 리사이즈)')
    parser.add_argument('--batch_size', type=int, default=2)
    # @@@ 16

    # 모델 하이퍼파라미터
    parser.add_argument('--style_dim', type=int, default=64, help='스타일 코드 차원')
    parser.add_argument('--latent_dim', type=int, default=16, help='랜덤 노이즈 차원')
    parser.add_argument('--num_domains', type=int, default=10, help='Fashion MNIST 클래스 개수')
    parser.add_argument('--hidden_dim', type=int, default=256, help='Mapping Network 히든 차원')

    # 학습 설정
    parser.add_argument('--total_iters', type=int, default=1000)
    # @@@ 100000
    parser.add_argument('--resume_iter', type=int, default=0)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--w_hpf', type=float, default=1, help='High-pass filtering 가중치 (논문 구현용)')

    # Loss 가중치
    parser.add_argument('--lambda_sty', type=float, default=1.0)
    parser.add_argument('--lambda_ds', type=float, default=1.0)
    parser.add_argument('--lambda_cyc', type=float, default=1.0)

    # 로깅 주기
    parser.add_argument('--sample_freq', type=int, default=100)
    # @@@ 1000
    parser.add_argument('--save_freq', type=int, default=500)
    # @@@ 5000

    args, _ = parser.parse_known_args()
    return args

def prepare_environment(args):
    """Google Drive 마운트 및 체크포인트/샘플 디렉토리를 생성합니다."""
    print("--- 환경 설정 중 ---")
    save_path = os.path.join(args.save_root, args.project_name)

    # Google Drive 마운트
    if not os.path.exists('/content/drive'):
        print("Google Drive를 마운트합니다...")
        drive.mount('/content/drive')
        print("Google Drive mounted.")

    os.makedirs(os.path.join(save_path, 'checkpoints'), exist_ok=True)
    os.makedirs(os.path.join(save_path, 'samples'), exist_ok=True)
    print(f"저장 경로: {save_path}")
    return save_path

def get_data_transform(img_size):
    """Fashion MNIST 데이터 전처리를 정의합니다."""
    return transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,), std=(0.5,)) # [0,1] -> [-1,1]
    ])

Writing config.py


2. model.py (네트워크 아키텍처)

In [5]:
%%writefile model.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdaIN(nn.Module):
    """Adaptive Instance Normalization"""
    def __init__(self, style_dim, num_features):
        super().__init__()
        self.norm = nn.InstanceNorm2d(num_features, affine=False)
        self.fc = nn.Linear(style_dim, num_features * 2)

    def forward(self, x, s):
        h = self.fc(s)
        h = h.view(h.size(0), h.size(1), 1, 1)
        gamma, beta = torch.chunk(h, chunks=2, dim=1)
        return (1 + gamma) * self.norm(x) + beta

class ResBlock(nn.Module):
    """기본 ResBlock (다운샘플링 블록에서 사용)"""
    def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2)):
        super().__init__()
        self.main = nn.Sequential(
            actv,
            nn.Conv2d(dim_in, dim_out, 3, 1, 1),
            nn.InstanceNorm2d(dim_out, affine=True),
            actv,
            nn.Conv2d(dim_out, dim_out, 3, 1, 1),
            nn.InstanceNorm2d(dim_out, affine=True)
        )
        self.shortcut = nn.Sequential()
        if dim_in != dim_out:
            self.shortcut = nn.Conv2d(dim_in, dim_out, 1, 1, 0)

    def forward(self, x):
        return self.main(x) + self.shortcut(x)

class AdaINResBlock(nn.Module):
    """Generator용 AdaIN ResBlock (Bottleneck 및 Up-sampling 블록에서 사용)"""
    def __init__(self, dim_in, dim_out, style_dim, actv=nn.LeakyReLU(0.2)):
        super().__init__()
        self.actv = actv
        self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1)
        self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1)
        self.norm1 = AdaIN(style_dim, dim_in)
        self.norm2 = AdaIN(style_dim, dim_out)

        self.shortcut = nn.Sequential()
        if dim_in != dim_out:
            self.shortcut = nn.Conv2d(dim_in, dim_out, 1, 1, 0)

    def forward(self, x, s):
        out = self.norm1(x, s)
        out = self.actv(out)
        out = self.conv1(out)
        out = self.norm2(out, s)
        out = self.actv(out)
        out = self.conv2(out)
        return out + self.shortcut(x)

# --- (1) Generator (G) ---
class Generator(nn.Module):
    def __init__(self, img_size=64, style_dim=64, max_conv_dim=256):
        super().__init__()
        dim_in = 32  # 경량화를 위한 시작 필터 수
        self.img_size = img_size

        self.from_rgb = nn.Conv2d(1, dim_in, 3, 1, 1) # Grayscale 1채널 입력

        # Down-sampling blocks (64 -> 32 -> 16 -> 8)
        self.encode = nn.ModuleList()
        curr_dim = dim_in
        for _ in range(3): # 3번 다운샘플링하여 8x8 병목 생성
            self.encode.append(ResBlock(curr_dim, curr_dim * 2))
            self.encode.append(nn.AvgPool2d(2))
            curr_dim = curr_dim * 2

        # Bottleneck (8x8 유지, AdaIN 적용)
        self.decode = nn.ModuleList()
        curr_dim = min(curr_dim, max_conv_dim)
        for _ in range(2):
            self.decode.append(AdaINResBlock(curr_dim, curr_dim, style_dim))

        # Up-sampling blocks (8 -> 16 -> 32 -> 64)
        for _ in range(3):
            self.decode.append(nn.Upsample(scale_factor=2, mode='nearest'))
            self.decode.append(AdaINResBlock(curr_dim, curr_dim // 2, style_dim))
            curr_dim = curr_dim // 2

        # Final Conv
        self.to_rgb = nn.Sequential(
            nn.InstanceNorm2d(curr_dim, affine=True),
            nn.LeakyReLU(0.2),
            nn.Conv2d(curr_dim, 1, 1, 1, 0) # Grayscale 1채널 출력
        )

    def forward(self, x, s):
        x = self.from_rgb(x)
        for block in self.encode:
            x = block(x)

        for block in self.decode:
            if isinstance(block, AdaINResBlock):
                x = block(x, s)
            else:
                x = block(x)

        return self.to_rgb(x)

# --- (2) Mapping Network (F) ---
class MappingNetwork(nn.Module):
    def __init__(self, latent_dim=16, style_dim=64, num_domains=10, hidden_dim=256):
        super().__init__()
        layers = []
        # 공유 레이어 (Shared)
        for _ in range(3):
            layers += [nn.Linear(latent_dim if not layers else hidden_dim, hidden_dim)]
            layers += [nn.ReLU()]
        self.shared = nn.Sequential(*layers)

        # 도메인별 출력 레이어 (Unshared)
        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared.append(nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, style_dim)
            ))

    def forward(self, z, y):
        h = self.shared(z)
        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1) # (batch, num_domains, style_dim)

        idx = torch.arange(y.size(0)).to(y.device)
        s = out[idx, y] # 해당 도메인의 스타일 코드만 선택
        return s

# --- (3) Style Encoder (E) ---
class StyleEncoder(nn.Module):
    def __init__(self, img_size=64, style_dim=64, num_domains=10):
        super().__init__()
        dim_in = 32
        blocks = []
        blocks += [nn.Conv2d(1, dim_in, 3, 1, 1)] # 1ch input

        curr_dim = dim_in
        # Downsample to small size (64 -> 4)
        for _ in range(4): # 64 -> 32 -> 16 -> 8 -> 4
            blocks += [nn.LeakyReLU(0.2)]
            blocks += [nn.Conv2d(curr_dim, curr_dim * 2, 3, 2, 1)]
            curr_dim = curr_dim * 2

        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(curr_dim, curr_dim, 4, 1, 0)] # 4x4 -> 1x1
        self.shared = nn.Sequential(*blocks)

        # 도메인별 Style Code 출력
        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared.append(nn.Linear(curr_dim, style_dim))

    def forward(self, x, y):
        h = self.shared(x) # (batch, curr_dim, 1, 1)
        h = h.view(h.size(0), -1)

        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1)

        idx = torch.arange(y.size(0)).to(y.device)
        s = out[idx, y]
        return s

# --- (4) Discriminator (D) ---
class Discriminator(nn.Module):
    def __init__(self, img_size=64, num_domains=10):
        super().__init__()
        dim_in = 32
        blocks = []
        blocks += [nn.Conv2d(1, dim_in, 3, 1, 1)] # 1ch input

        curr_dim = dim_in
        # Downsample to small size (64 -> 4)
        for _ in range(4): # 64 -> 32 -> 16 -> 8 -> 4
            blocks += [nn.LeakyReLU(0.2)]
            blocks += [nn.Conv2d(curr_dim, curr_dim * 2, 3, 2, 1)]
            curr_dim = curr_dim * 2

        blocks += [nn.LeakyReLU(0.2)]
        blocks += [nn.Conv2d(curr_dim, curr_dim, 4, 1, 0)] # 4x4 -> 1x1
        self.shared = nn.Sequential(*blocks)

        # 도메인별 진위 판별 헤드
        self.unshared = nn.ModuleList()
        for _ in range(num_domains):
            self.unshared.append(nn.Linear(curr_dim, 1))

    def forward(self, x, y):
        h = self.shared(x)
        h = h.view(h.size(0), -1)

        out = []
        for layer in self.unshared:
            out += [layer(h)]
        out = torch.stack(out, dim=1) # (batch, num_domains, 1)

        idx = torch.arange(y.size(0)).to(y.device)
        score = out[idx, y]
        return score

Writing model.py


3. solver.py (메인 실행 및 학습 루프)

In [6]:
import os
import time
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.utils import save_image

# 분리된 파일에서 모듈 가져오기
from config import get_config, prepare_environment, get_data_transform
from model import Generator, MappingNetwork, StyleEncoder, Discriminator


class Solver:
    def __init__(self, args, device):
        self.args = args
        self.device = device
        self.save_dir = prepare_environment(args)

        # 데이터셋 준비 (FashionMNIST)
        transform = get_data_transform(args.img_size)

        # 학습용 데이터
        # Colab에서 데이터셋을 로드하므로 root 경로를 현재 위치로 지정
        dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)

        self.loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
        # @@@ num_workers=2

        # 모델 초기화
        self.nets = {
            'G': Generator(args.img_size, args.style_dim),
            'F': MappingNetwork(args.latent_dim, args.style_dim, args.num_domains, args.hidden_dim),
            'E': StyleEncoder(args.img_size, args.style_dim, args.num_domains),
            'D': Discriminator(args.img_size, args.num_domains)
        }

        for name, module in self.nets.items():
            module.to(self.device)
            # He Initialization (논문 권장)은 PyTorch 기본 설정과 유사하므로 생략하거나,
            # torch.nn.init.kaiming_normal_ 등을 사용할 수 있습니다. 여기서는 간단히 기본값 사용.
            module.train()

        # 옵티마이저
        self.optims = {
            'G': torch.optim.Adam(self.nets['G'].parameters(), lr=args.lr, betas=(0.0, 0.99)),
            'F': torch.optim.Adam(self.nets['F'].parameters(), lr=args.lr*0.01, betas=(0.0, 0.99)), # F는 낮은 LR 사용
            'E': torch.optim.Adam(self.nets['E'].parameters(), lr=args.lr, betas=(0.0, 0.99)),
            'D': torch.optim.Adam(self.nets['D'].parameters(), lr=args.lr, betas=(0.0, 0.99))
        }

        # 체크포인트 로드
        self.start_iter = 0
        if args.resume_iter > 0:
            self.load_checkpoint(args.resume_iter)
            self.start_iter = args.resume_iter

    def save_checkpoint(self, step):
        path = os.path.join(self.save_dir, 'checkpoints', f'{step:06d}.ckpt')
        state = {
            'nets': {name: net.state_dict() for name, net in self.nets.items()},
            'optims': {name: opt.state_dict() for name, opt in self.optims.items()},
            'step': step
        }
        torch.save(state, path)
        print(f"Saved checkpoint to {path}")

    def load_checkpoint(self, step):
        path = os.path.join(self.save_dir, 'checkpoints', f'{step:06d}.ckpt')
        if not os.path.exists(path):
            print("Checkpoint not found!")
            return

        ckpt = torch.load(path, map_location=self.device)
        for name, net in self.nets.items():
            net.load_state_dict(ckpt['nets'][name])
        for name, opt in self.optims.items():
            opt.load_state_dict(ckpt['optims'][name])
        print(f"Loaded checkpoint from {path}")

    def train(self):
        print("--- 학습 시작 ---")
        nets = self.nets
        optims = self.optims
        args = self.args

        data_iter = iter(self.loader)

        start_time = time.time()
        for i in range(self.start_iter, args.total_iters):

            # 1. 데이터 가져오기
            try:
                x_real, y_org = next(data_iter)
            except StopIteration:
                data_iter = iter(self.loader)
                x_real, y_org = next(data_iter)

            x_real = x_real.to(self.device)
            y_org = y_org.to(self.device)

            # 타겟 도메인 및 Latent 생성
            y_trg = torch.randint(0, args.num_domains, (x_real.size(0),)).to(self.device)
            z_trg = torch.randn(x_real.size(0), args.latent_dim).to(self.device)
            z_trg2 = torch.randn(x_real.size(0), args.latent_dim).to(self.device) # Diversity Loss용

            # =================================================================================== #
            #                               1. Discriminator 학습                                 #
            # =================================================================================== #

            # Real Loss (Hinge Loss)
            d_out_real = nets['D'](x_real, y_org)
            d_loss_real = torch.mean(F.relu(1.0 - d_out_real))

            # Fake Loss (Latent 기반 생성)
            with torch.no_grad():
                s_trg = nets['F'](z_trg, y_trg)
                x_fake = nets['G'](x_real, s_trg)

            d_out_fake = nets['D'](x_fake.detach(), y_trg)
            d_loss_fake = torch.mean(F.relu(1.0 + d_out_fake))

            d_loss = d_loss_real + d_loss_fake

            optims['D'].zero_grad()
            d_loss.backward()
            optims['D'].step()

            # =================================================================================== #
            #                     2. Generator, Mapping, Encoder 학습                             #
            # =================================================================================== #

            # Adversarial Loss
            s_trg = nets['F'](z_trg, y_trg)
            x_fake = nets['G'](x_real, s_trg)
            d_out_fake = nets['D'](x_fake, y_trg)
            g_loss_adv = -torch.mean(d_out_fake)

            # Style Reconstruction Loss
            # $\mathcal{L}_{sty}=\mathbb{E}_{x,\overline{y},z}[||\tilde{s}-E_{\overline{y}}(G(x,\tilde{s}))||_{1}]$
            s_pred = nets['E'](x_fake, y_trg)
            g_loss_sty = torch.mean(torch.abs(s_trg - s_pred))

            # Diversity Sensitive Loss
            # $\mathcal{L}_{ds}=\mathbb{\mathbb{E}}_{x,\overline{y},z_{1},z_{2}}[||G(x,\overline{s}_{1})-G(x,\overline{s}_{2})||_{1}]$
            s_trg2 = nets['F'](z_trg2, y_trg)
            x_fake2 = nets['G'](x_real, s_trg2)
            g_loss_ds = torch.mean(torch.abs(x_fake - x_fake2))

            # Cycle Consistency Loss
            # $\mathcal{L}_{cyc}=\mathbb{E}_{x,y,\overline{y},z}[||x-G(G(x,\tilde{s}),\hat{s})||_{1}]$
            s_org = nets['E'](x_real, y_org)
            x_rec = nets['G'](x_fake, s_org)
            g_loss_cyc = torch.mean(torch.abs(x_real - x_rec))

            # Total Loss
            g_loss = g_loss_adv \
                     + args.lambda_sty * g_loss_sty \
                     - args.lambda_ds * g_loss_ds \
                     + args.lambda_cyc * g_loss_cyc

            optims['G'].zero_grad()
            optims['F'].zero_grad()
            optims['E'].zero_grad()
            g_loss.backward()
            optims['G'].step()
            optims['F'].step()
            optims['E'].step()

            # =================================================================================== #
            #                                 3. 로깅 및 저장                                     #
            # =================================================================================== #

            if (i + 1) % 100 == 0:
                elapsed = time.time() - start_time
                print(f"Iter [{i+1}/{args.total_iters}] Time: {elapsed:.2f}s | "
                      f"D_loss: {d_loss.item():.4f} | G_adv: {g_loss_adv.item():.4f} | "
                      f"Sty: {g_loss_sty.item():.4f} | Cyc: {g_loss_cyc.item():.4f}")

            if (i + 1) % args.sample_freq == 0:
                self.save_samples(x_real, i + 1)

            if (i + 1) % args.save_freq == 0:
                self.save_checkpoint(i + 1)

    def save_samples(self, x_real, step):
        """학습 중간 결과 이미지 저장 (Latent-guided synthesis)"""
        nets = self.nets
        args = self.args

        with torch.no_grad():
            x_real_subset = x_real[:args.num_domains].to(self.device) # 각 행의 소스 이미지

            # 고정된 z와 다양한 y로 스타일 생성 (Style-Mixing 효과 확인을 위함)
            z_fix = torch.randn(1, args.latent_dim).repeat(args.num_domains, 1).to(self.device)
            y_fix = torch.arange(args.num_domains).to(self.device)
            s_fix = nets['F'](z_fix, y_fix) # (10, style_dim)

            # 이미지 생성 리스트
            # 첫 번째 행: 원본 이미지
            images = [x_real_subset.cpu()]

            # 나머지 행: 각 타겟 도메인의 스타일을 원본 이미지들에 적용
            for j in range(args.num_domains):
                s_curr = s_fix[j].unsqueeze(0).repeat(x_real_subset.size(0), 1)
                x_fake = nets['G'](x_real_subset, s_curr)
                images.append(x_fake.cpu())

            # 그리드로 합치기
            images = torch.cat(images, dim=0)
            path = os.path.join(self.save_dir, 'samples', f'{step:06d}.jpg')
            # nrow을 소스 이미지 개수로 설정하여 (원본 1행 + 변환 10행) x (소스 10열) 형태의 그리드 생성
            save_image(images, path, nrow=len(x_real_subset), padding=2, normalize=True)
            print(f"Sample image saved to {path}")

        # 다시 학습 모드
        for net in nets.values():
            net.train()

if __name__ == '__main__':
    # 시드 고정 (재현성)
    torch.manual_seed(777)
    np.random.seed(777)

    # 설정 로드
    config = get_config()

    # 장치 설정
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Device: {device}")

    # Solver 시작
    solver = Solver(config, device)
    solver.train()

Device: cuda
--- 환경 설정 중 ---
Google Drive를 마운트합니다...
Mounted at /content/drive
Google Drive mounted.
저장 경로: /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa


100%|██████████| 26.4M/26.4M [00:01<00:00, 13.4MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 212kB/s]
100%|██████████| 4.42M/4.42M [00:01<00:00, 3.94MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 12.8MB/s]


--- 학습 시작 ---
Iter [100/1000] Time: 14.04s | D_loss: 1.8020 | G_adv: 0.3790 | Sty: 0.0121 | Cyc: 0.5254
Sample image saved to /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa/samples/000100.jpg
Iter [200/1000] Time: 23.66s | D_loss: 0.3737 | G_adv: 1.5299 | Sty: 0.0097 | Cyc: 0.4788
Sample image saved to /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa/samples/000200.jpg
Iter [300/1000] Time: 33.40s | D_loss: 0.0467 | G_adv: 1.5234 | Sty: 0.0107 | Cyc: 0.4606
Sample image saved to /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa/samples/000300.jpg
Iter [400/1000] Time: 43.32s | D_loss: 0.0000 | G_adv: 2.4058 | Sty: 0.0079 | Cyc: 0.4835
Sample image saved to /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa/samples/000400.jpg
Iter [500/1000] Time: 53.23s | D_loss: 0.0000 | G_adv: 1.2786 | Sty: 0.0089 | Cyc: 0.3631
Sample image saved to /content/drive/MyDrive/Colab Notebooks/GAN_assignment/GAN_FMNIST_sepa/