#### Originally tried to scale 32x32->128x128, but it seems like that is too difficult of a task to do on free GPUs. The GAN optimizes to a local minima and fails to find the HD image after multiple epochs. Instead we attempt 64x64->128x128

In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import InterpolationMode

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import os
from tqdm.notebook import tqdm

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [5]:
torch.manual_seed(42)
np.random.seed(42)

## data pipeline

In [6]:
class TrainValDataset(Dataset):
    def __init__(self, hr_dir, file_list, scale_factor=2, crop_size=128):
        self.hr_dir = hr_dir
        self.file_list = file_list
        self.scale_factor = scale_factor
        self.crop_size = crop_size
        
        # validate files exist
        self.valid_files = []
        for file in file_list:
            full_path = os.path.join(hr_dir, file)
            if os.path.exists(full_path):
                self.valid_files.append(file)
        
        if len(self.valid_files) == 0:
            raise ValueError(f"No valid image files found in {hr_dir}")
        
        print(f"Found {len(self.valid_files)} valid images")
    
    def __len__(self):
        return len(self.valid_files)
    
    def __getitem__(self, idx):
        # load high-resolution image
        img_path = os.path.join(self.hr_dir, self.valid_files[idx])
        hr_img = Image.open(img_path).convert('RGB')
        
        # apply random crop to PIL image first
        transform_crop = transforms.RandomCrop(self.crop_size)
        hr_img_cropped = transform_crop(hr_img)
        
        # convert cropped HR image to tensor
        hr_tensor = transforms.ToTensor()(hr_img_cropped)
        
        # create low-resolution version from the cropped image
        lr_size = self.crop_size // self.scale_factor
        lr_tensor = transforms.Resize((lr_size, lr_size), 
                                    interpolation=InterpolationMode.BICUBIC)(hr_tensor.unsqueeze(0)).squeeze(0)
        
        return lr_tensor, hr_tensor

def create_dataloaders(hr_dir, batch_size=16, scale_factor=2, crop_size=128, num_workers=2):
    """Create training and validation dataloaders"""
    # check if directory exists
    if not os.path.exists(hr_dir):
        raise ValueError(f"Directory not found: {hr_dir}")
    
    # get list of image files
    image_files = []
    for file in os.listdir(hr_dir):
        if file.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_files.append(file)
    
    if not image_files:
        raise ValueError(f"No image files found in {hr_dir}")
    
    print(f"Total images found: {len(image_files)}")
    
    # split files into train and validation
    train_size = int(0.8 * len(image_files))
    train_files = image_files[:train_size]
    val_files = image_files[train_size:]
    
    print(f"Training images: {len(train_files)}")
    print(f"Validation images: {len(val_files)}")
    
    # create datasets
    train_dataset = TrainValDataset(hr_dir, train_files, scale_factor, crop_size)
    val_dataset = TrainValDataset(hr_dir, val_files, scale_factor, crop_size)
    
    # create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

def plot_samples(loader, num_samples=3):
    """Plot sample low-res and high-res image pairs"""
    plt.figure(figsize=(15, 5))
    
    # get a batch of samples
    samples = next(iter(loader))
    lr_batch, hr_batch = samples
    
    for i in range(min(num_samples, lr_batch.size(0))):
        # convert tensors to numpy arrays for plotting
        lr_img = lr_batch[i].permute(1, 2, 0).cpu().numpy()
        hr_img = hr_batch[i].permute(1, 2, 0).cpu().numpy()
        
        # clip values to valid range and adjust for display
        lr_img = np.clip(lr_img, 0, 1)
        hr_img = np.clip(hr_img, 0, 1)
        
        plt.subplot(2, num_samples, i + 1)
        plt.imshow(lr_img)
        plt.title(f'Low Res {lr_img.shape[:2]}')
        plt.axis('off')
        
        plt.subplot(2, num_samples, i + num_samples + 1)
        plt.imshow(hr_img)
        plt.title(f'High Res {hr_img.shape[:2]}')
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

## Used CNN architecture

In [7]:
class UpsamplingCNN(nn.Module):
    def __init__(self):
        super(UpsamplingCNN, self).__init__()
        
        # encoder part: downsample the input
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Down to 32x32
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)   # Down to 16x16
        )
        
        # decoder part: upsample the image to 128x128
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # Up to 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # Up to 64x64
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),     # Up to 128x128
            nn.Sigmoid()  # output range [0, 1]
        )


    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# instantiate the model and print it
model = UpsamplingCNN()
print(model)

UpsamplingCNN(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): Sigmoid()
  )
)


In [8]:
## training process

train_loader, val_loader = create_dataloaders(
    '/kaggle/input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR',
    batch_size=16,
    scale_factor=2,
    crop_size=128
)

## define the model, loss function, and optimizer
model = UpsamplingCNN().to(device)  # Move model to GPU if available
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Total images found: 800
Training images: 640
Validation images: 160
Found 640 valid images
Found 160 valid images


In [9]:
model = UpsamplingCNN()
dummy_input = torch.randn(16, 3, 64, 64)  # Batch of 16, 64x64 images
output = model(dummy_input)
print(output.shape)  # Should print torch.Size([16, 3, 128, 128])

torch.Size([16, 3, 128, 128])


In [10]:
count = 0
for root, folders, filenames in os.walk('/kaggle/working/.virtual_documents'):
   print(root, folders)

/kaggle/working/.virtual_documents []


In [14]:
## function to load the checkpoint
def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename,map_location=torch.device('cuda'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss

model = UpsamplingCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# model, optimizer, start_epoch, start_loss = load_checkpoint(model, optimizer, filename="/kaggle/input/checkpoint7/pytorch/default/1/checkpoint_7.pth")
# print(f"Resuming training from epoch {start_epoch+1} with loss {start_loss}")

Resuming training from epoch 60 with loss 0.009660433186218143


  checkpoint = torch.load(filename,map_location=torch.device('cuda'))


In [10]:
# print("Model device:", next(model.parameters()).device)

# for low_res, high_res in train_loader:
#     low_res, high_res = low_res.to(device), high_res.to(device)
#     print("Low resolution image device:", low_res.device)
#     print("High resolution image device:", high_res.device)

print("CUDA available:", torch.cuda.is_available())
!nvidia-smi

CUDA available: True
Tue Dec 24 23:29:03 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.03              Driver Version: 560.35.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   39C    P0             26W /   70W |     105MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla T4            

In [9]:
def save_checkpoint(model, optimizer, epoch, loss, filename):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }
    torch.save(checkpoint, filename)

In [None]:
## set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## move model to the gpu once at the start
model = model.to(device)



## training function
def train(model, dataloader, criterion, optimizer, device):
    model.to(device).train()
    running_loss = 0.0
    for low_res, high_res in tqdm(dataloader):
        low_res, high_res = low_res.to(device), high_res.to(device)

        optimizer.zero_grad()
        outputs = model(low_res)
        loss = criterion(outputs, high_res)
        loss.backward()
        optimizer.step()
#         torch.cuda.synchronize()

        running_loss += loss.item()
    return running_loss / len(dataloader)

## validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for low_res, high_res in dataloader:
            low_res, high_res = low_res.to(device), high_res.to(device)

            outputs = model(low_res)
            loss = criterion(outputs, high_res)
            val_loss += loss.item()
    return val_loss / len(dataloader)

## training loop
num_epochs = 20
checkpoint = 1
for epoch in range(num_epochs):
    train_loss = train(model.to(device), train_loader, criterion, optimizer, device)
    val_loss = validate(model.to(device), val_loader, criterion, device)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    
    if epoch % 20 == 0 and epoch != 0:
        save_checkpoint(model, optimizer, epoch, val_loss, filename="checkpoint_{}.pth".format(checkpoint))
        checkpoint += 1
        print('checkpoint saved')

In [18]:
  save_checkpoint(model, optimizer, epoch, val_loss, filename="checkpoint_12410.pth")

In [9]:
import os
print(os.getcwd()) 

/kaggle/working


## Run inference on a specific image

In [None]:
## function to load and preprocess a single image for inference
def load_image(path, input_size=(64, 64), target_size=(128, 128)):
    img = Image.open(path).convert('RGB')
    
    ## transformations: downscale to 32x32 (or 64x64) for input, keep original as high-resolution
    transform_low_res = transforms.Compose([
        transforms.Resize(input_size),
        transforms.ToTensor()
    ])
    transform_high_res = transforms.Compose([
        transforms.Resize(target_size),
        transforms.ToTensor()
    ])
    
    low_res_img = transform_low_res(img).unsqueeze(0)  ## add batch dimension
    high_res_img = transform_high_res(img)  ## for comparison only (not used in inference)
    
    return low_res_img, high_res_img

## function to run inference and display results
def run_inference(model, image_path):
    ## load and preprocess the image
    low_res_img, high_res_img = load_image(image_path)
    low_res_img = low_res_img.cuda() if torch.cuda.is_available() else low_res_img
    
    ## set model to evaluation mode and run inference
    model.eval()
    with torch.no_grad():
        generated_img = model(low_res_img).squeeze(0).cpu()  # Remove batch dimension
    
    ## convert tensors to PIL images for display
    low_res_img = transforms.ToPILImage()(low_res_img.squeeze(0))
    generated_img = transforms.ToPILImage()(generated_img)
    high_res_img = transforms.ToPILImage()(high_res_img)
    
    ## display the original, low-resolution, and generated high-resolution images
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(high_res_img)
    axes[0].set_title("Original High-Resolution")
    axes[1].imshow(low_res_img)
    axes[1].set_title("Low-Resolution Input (64x64)")
    axes[2].imshow(generated_img)
    axes[2].set_title("Generated High-Resolution (128x128)")
    
    for ax in axes:
        ax.axis("off")
    plt.show()


## assume `model` is the trained UpsamplingCNN model
image_path = '/kaggle/input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR/0009.png'
run_inference(model, image_path)


## discriminator network

In [27]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main = nn.Sequential(
            # Input: (3, 128, 128)
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # (64, 64, 64)
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # (128, 32, 32)
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # (256, 16, 16)
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # (512, 8, 8)
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            # Output: (1, 5, 5)
            nn.Sigmoid()  # For binary classification (real/fake)
        )
        
    def forward(self, x):
        return self.main(x)

## instantiate the discriminator
discriminator = Discriminator()
print(discriminator)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (12): Sigmoid()
  )
)


In [31]:
## loss function
criterion = nn.BCELoss()

## optimizer for the discriminator
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

## training loop for the discriminator
def train_discriminator(discriminator, generator, data_loader, device, num_epochs):
    discriminator.train()
    generator.eval()  ## ensure generator is in evaluation mode during D training

    real_label = 1.0
    fake_label = 0.0
    
    for epoch in range(num_epochs):
        for i, (low_res, high_res) in enumerate(data_loader):
            ## move data to the correct device
            low_res = low_res.to(device)
            high_res = high_res.to(device)

            ## train with real images
            optimizer_D.zero_grad()
            output_real = discriminator(high_res).view(-1)
            labels_real = torch.full((output_real.size(0),), real_label, dtype=torch.float, device=device)
            loss_real = criterion(output_real, labels_real)
            loss_real.backward()

            ## generate fake images using the generator
            with torch.no_grad():  # No need to calculate gradients for generator here
                fake_images = generator(low_res)

            ## train with fake images
            output_fake = discriminator(fake_images.detach()).view(-1)
            labels_fake = torch.full((output_fake.size(0),), fake_label, dtype=torch.float, device=device)
            loss_fake = criterion(output_fake, labels_fake)
            loss_fake.backward()

            ## update the discriminator
            optimizer_D.step()

            ## calculate total loss for logging
            loss_D = loss_real + loss_fake
            # save_checkpoint(discriminator, optimizer_D, i, loss_D.item(), filename="discriminator_12420.pth".format(checkpoint))
            return

            if i % 1 == 0:  ## print progress every n batches
                print(f'Batch {i}/{len(data_loader)}: Loss_D: {loss_D.item():.4f}')
    
    
    return loss_D.item()

train_discriminator(discriminator.to(device), model.to(device), train_loader, device, 10)

## full GAN network

In [None]:
## function to load the checkpoint for generator/discriminator
def load_checkpoint(model, optimizer, filename):
    checkpoint = torch.load(filename,map_location=torch.device('cuda'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    return model, optimizer, epoch, loss


generator = UpsamplingCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


# generator, optimizer, start_epoch, start_loss = load_checkpoint(generator, optimizer, filename="/kaggle/input/generator_100/pytorch/default/1/checkpoint_12410.pth")
# print(f"Resuming generator training from epoch {start_epoch+1} with loss {start_loss}")

# discriminator, optimizer, start_epoch, start_loss = load_checkpoint(discriminator, optimizer, filename="/kaggle/input/discriminator_10/pytorch/default/1/discriminator_12420.pth")
# print(f"Resuming discriminator training from epoch {start_epoch+1} with loss {start_loss}")


In [33]:
def compute_gradient_penalty(discriminator, real_samples, fake_samples, device):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = discriminator(interpolates)
    fake = torch.ones(d_interpolates.size(), device=device, requires_grad=False)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,  # Keep graph for further backward computation
        only_inputs=True
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
## full training process for GAN

## worked for about 5-10 epochs
## need to make discriminator network better in some way


## instantiate the generator and discriminator models
# generator = UpsamplingCNN().to(device)

generator = UpsamplingCNN()  # Or your GAN generator
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
discriminator = Discriminator()

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = generator.to(device)
discriminator = discriminator.to(device)

# loss function and optimizers
criterion = nn.BCELoss()
## different learning rates!!
optimizer_G = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

num_generator_updates = 1
lambda_gp = 10
num_epochs = 50 ## change this value
## label smoothing!!
real_label = 1.0
fake_label = 0.0

image_path = '/kaggle/input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR/0009.png'

# training loop
for epoch in range(num_epochs):
    generator.train()  # Set generator to training mode
    discriminator.train()  # Set discriminator to training mode
    
    for i, (low_res, high_res) in enumerate(tqdm(train_loader)):
        # move data to the correct device
        low_res = low_res.to(device)
        high_res = high_res.to(device)
        
        # add noise to the real and fake images
        noise_std_dev = 0.00  # standard deviation for the noise; adjust as needed
        noisy_real_images = high_res + noise_std_dev * torch.randn_like(high_res).to(device)
        noisy_real_images = torch.clamp(noisy_real_images, 0.0, 1.0)  # Keep values within valid range [0, 1]

        # generate fake images and add noise
        fake_images = generator(low_res)
        noisy_fake_images = fake_images + noise_std_dev * torch.randn_like(fake_images).to(device)
        noisy_fake_images = torch.clamp(noisy_fake_images, 0.0, 1.0)  # Keep values within valid range [0, 1]
        

        # real and fake labels
        output_real = discriminator(noisy_real_images).view(-1)
        labels_real = torch.full((output_real.size(0),), real_label, dtype=torch.float, device=device)
        loss_real = criterion(output_real, labels_real)

        # train with noisy fake images
        output_fake = discriminator(noisy_fake_images.detach()).view(-1)
        labels_fake = torch.full((output_fake.size(0),), fake_label, dtype=torch.float, device=device)
        loss_fake = criterion(output_fake, labels_fake)

        # compute gradient penalty (if applicable)
#         gradient_penalty = compute_gradient_penalty(discriminator, high_res, fake_images, device)
        gradient_penalty = 1
    
        # total discriminator loss with gradient penalty
        loss_D = loss_real + loss_fake + lambda_gp * gradient_penalty
        loss_D.backward()

        # update discriminator
        optimizer_D.step()


        # train generator
        # multiple generator updates!!
        for _ in range(num_generator_updates):
            optimizer_G.zero_grad()

            # generate fake images and calculate loss for the generator
            fake_images = generator(low_res)  # Forward pass
            output_fake_for_G = discriminator(fake_images).view(-1)
            labels_fake_for_G = torch.full((output_fake_for_G.size(0),), real_label, dtype=torch.float, device=device)
            loss_G = criterion(output_fake_for_G, labels_fake_for_G)
            loss_G.backward()

            # update generator
            optimizer_G.step()
        
        # print progress
        if i % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i}/{len(train_loader)}], '
                  f'Loss_D: {loss_real.item() + loss_fake.item():.4f}, Loss_G: {loss_G.item():.4f}')

    # save models
    torch.save(generator.state_dict(), f'generator_epoch_{epoch+1}.pth')
    torch.save(discriminator.state_dict(), f'discriminator_epoch_{epoch+1}.pth')
    
    run_inference(generator, image_path)

