In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from gan_artifact_removal.models import UNetGenerator, PatchGANDiscriminator
from gan_artifact_removal.dataset import ArtifactDataset
from gan_artifact_removal.utils import save_sample
from gan_artifact_removal.config import *

def train():
    dataset = ArtifactDataset(ARTIFACT_DIR, CLEAN_DIR)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gen = UNetGenerator().to(device)
    disc = PatchGANDiscriminator().to(device)

    l1_loss = nn.L1Loss()
    bce_loss = nn.BCEWithLogitsLoss()

    opt_g = torch.optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))

    for epoch in range(EPOCHS):
        for i, (x, y) in enumerate(loader):
            x, y = x.to(device), y.to(device)

            # Train Discriminator
            fake_y = gen(x)
            d_real = disc(x, y)
            d_fake = disc(x, fake_y.detach())
            loss_d = 0.5 * (bce_loss(d_real, torch.ones_like(d_real)) +
                            bce_loss(d_fake, torch.zeros_like(d_fake)))

            opt_d.zero_grad()
            loss_d.backward()
            opt_d.step()

            # Train Generator
            d_fake = disc(x, fake_y)
            loss_g = bce_loss(d_fake, torch.ones_like(d_fake)) + LAMBDA_L1 * l1_loss(fake_y, y)

            opt_g.zero_grad()
            loss_g.backward()
            opt_g.step()

        print(f"[Epoch {epoch+1}/{EPOCHS}] Loss_D: {loss_d.item():.4f} | Loss_G: {loss_g.item():.4f}")
        if (epoch+1) % SAVE_INTERVAL == 0:
            save_sample(fake_y, f"results/gan_epoch_{epoch+1}.png")

if __name__ == "__main__":
    train()
