In [None]:
import torch.nn as nn
import torch.nn.functional as F

class block(nn.Module):
    def __init__(self, dim,kernel_size=3, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(dim,dim,kernel_size,stride,padding=kernel_size//2)
        self.bn1=nn.BatchNorm2d(dim)
        self.prelu = nn.PReLU()
        self.conv2=nn.Conv2d(dim,dim,kernel_size,stride,padding=kernel_size//2)
        self.bn2=nn.BatchNorm2d(dim)
        
    def forward(self,x):
        h1 = self.conv1(x)
        bn1 = self.bn1(h1)
        prelu1 = self.prelu(bn1)
        h2 = self.conv2(prelu1)
        bn2 = self.bn2(h2)
        out = x+bn2
        return out
s
class PixelShuffle(nn.Module):
    def __init__(self, dim, scale, kernel_size):
        super().__init__()
        self.scale = scale
        self.conv = nn.Conv2d(dim, dim * (self.scale**2), kernel_size, padding=kernel_size//2)
        self.shuffle = nn.PixelShuffle(self.scale)
    def forward(self, x):
        h = self.shuffle(self.conv(x))
        return h

class Generator(nn.Module):
    def __init__(self, dim, kernel_size=9, residual_kernel_size=3, stride=1, n_blocks=16):
        super().__init__()
        self.conv1 = nn.Conv2d(dim[0], dim[1], kernel_size, stride=stride, padding=kernel_size//2)
        self.prelu1 = nn.PReLU()
        
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.residual_blocks.append(block(dim[1], residual_kernel_size, stride))
        
        self.conv2 = nn.Conv2d(dim[1], dim[1], residual_kernel_size, stride, padding=residual_kernel_size//2)
        self.bn1 = nn.BatchNorm2d(dim[1])
        
        
        self.upsample = nn.Sequential(
           PixelShuffle(dim[1],2, kernel_size=3),
           nn.PReLU(),
           PixelShuffle(dim[1],2, kernel_size=3),
           nn.PReLU()
       )
     
        self.head = nn.Conv2d(dim[1], dim[0],kernel_size, stride, padding=kernel_size//2)
        
    def forward(self, x):
        shallow_feature = self.conv1(x)
        shallow_prelu = self.prelu1(shallow_feature)
        
        residual_out = shallow_prelu
        for res_block in self.residual_blocks:
            residual_out = res_block(residual_out)
        
        conv2_out = self.conv2(residual_out)
        bn_out = self.bn1(conv2_out)
        skip_out = shallow_prelu + bn_out
        
        upsample_out = self.upsample(skip_out)
        
        out = self.head(upsample_out)
        return out
        

In [None]:
class Discriminator(nn.Module):
    def __init__(self, dim=[3,64,128,256,512,1024], kernel_size=3, n_blocks=7):
        super().__init__()
        self.layers = nn.ModuleList()

        # 첫 두 레이어 (stride=1 → stride=2)
        self.layers.append(nn.Sequential(
            nn.Conv2d(dim[0], dim[1], kernel_size, padding=kernel_size//2),
            nn.LeakyReLU(0.2, inplace=True)
        ))
        self.layers.append(nn.Sequential(
            nn.Conv2d(dim[1], dim[1], kernel_size, stride=2, padding=kernel_size//2),
            nn.BatchNorm2d(dim[1]),
            nn.LeakyReLU(0.2, inplace=True)
        ))

        # 반복 블록
        for i in range(n_blocks-2):
            in_c, out_c = dim[i+1], dim[i+2]
            stride = 1 if i % 2 == 0 else 2
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size, stride=stride, padding=kernel_size//2),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2, inplace=True)
            ))

        # 전역 풀링 후 FC로 차원 축소
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(dim[-2], 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)

        x = self.global_pool(x)          # (B, C, 1, 1)
        x = x.view(x.size(0), -1)        # (B, C)
        x = self.fc(x)                   # (B, 1)
        return x
