In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:


class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        
        # Encoding path
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        # Decoding path
        self.dec1 = self.conv_block(512, 256)
        self.dec2 = self.conv_block(256, 128)
        self.dec3 = self.conv_block(128, 64)
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        # Encoding path
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.maxpool(enc1))
        enc3 = self.enc3(self.maxpool(enc2))
        enc4 = self.enc4(self.maxpool(enc3))

        # Decoding path
        dec1 = self.upsample(enc4)
        dec1 = self.dec1(torch.cat([dec1, enc3], dim=1))
        dec2 = self.upsample(dec1)
        dec2 = self.dec2(torch.cat([dec2, enc2], dim=1))
        dec3 = self.upsample(dec2)
        dec3 = self.dec3(torch.cat([dec3, enc1], dim=1))
        
        return self.final_conv(dec3)


In [None]:
# U-Net 모델 정의 (위에서 제공한 코드)

class DiffusionProcess(nn.Module):
    def __init__(self, unet):
        super(DiffusionProcess, self).__init__()
        self.unet = unet

    def forward(self, x, num_steps=5, noise_scale=0.1):
        # 여러 단계를 거쳐 이미지에 노이즈 추가
        for _ in range(num_steps):
            noise = torch.randn_like(x) * noise_scale
            x = x + noise
            x = self.unet(x)  # denoising 시도
        return x

# 학습 과정
unet = UNet(in_channels=3, out_channels=3)  # 예: RGB 이미지에 대한 U-Net
diffusion_model = DiffusionProcess(unet)

optimizer = optim.Adam(diffusion_model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# 가정: train_loader는 학습 데이터를 제공
for epoch in range(epochs):
    for batch in train_loader:
        images = batch[0]
        noisy_images = images + torch.randn_like(images) * 0.1  # 초기 노이즈 추가
        denoised_images = diffusion_model(noisy_images)
        loss = criterion(denoised_images, images)  # denoised 이미지와 원본 이미지 간의 손실

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
