In [1]:
!pip install basicsr

Collecting basicsr
  Downloading basicsr-1.4.2.tar.gz (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.5/172.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting addict (from basicsr)
  Downloading addict-2.4.0-py3-none-any.whl (3.8 kB)
Collecting lmdb (from basicsr)
  Downloading lmdb-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (299 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m299.2/299.2 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
Collecting tb-nightly (from basicsr)
  Downloading tb_nightly-2.16.0a20231026-py3-none-any.whl (5.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
Collecting yapf (from basicsr)
  Downloading yapf-0.40.2-py3-none-any.whl (254 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.7/254.7 kB[0m [31m27.7 MB/s[0m eta [36m0:00:0

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import transforms
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.losses.basic_loss import L1Loss
from drive.MyDrive.ImgPro.RealESRGAN.models.arcs.discriminator_arch import UNetDiscriminatorSN
from PIL import Image
from drive.MyDrive.ImgPro.RealESRGAN.datasets.dataset import LoadDataset



In [4]:
# Check if CUDA (GPU support) is available
cuda_available = torch.cuda.is_available()
# cuda_available = False
if cuda_available:
    print("CUDA (GPU support) is available.")
else:
    print("CUDA (GPU support) is not available. Running on CPU.")

CUDA (GPU support) is available.


In [5]:
# Load generator
generator_path = '/content/drive/MyDrive/ImgPro/RealESRGAN/models/pretrained/RealESRGAN_x4plus.pth'
generator = RRDBNet(3, 3)
generator.load_state_dict(torch.load(generator_path)['params_ema'])

# Move the model to GPU if available
if cuda_available:
  generator = generator.cuda()

generator.eval()

RRDBNet(
  (conv_first): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (body): Sequential(
    (0): RRDB(
      (rdb1): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv5): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (lrelu): LeakyReLU(negative_slope=0.2, inplace=True)
      )
      (rdb2): ResidualDenseBlock(
        (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv2): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv3): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (conv4): Conv2d(160, 32, kernel_size=(3, 3), stride=(1, 1), 

In [6]:
# Load discriminator
discriminator_path = '/content/drive/MyDrive/ImgPro/RealESRGAN/models/pretrained/RealESRGAN_x4plus_netD.pth'
discriminator = UNetDiscriminatorSN(3)
discriminator.load_state_dict(torch.load(discriminator_path)['params'])

# Move the model to GPU if available
if cuda_available:
  discriminator = discriminator.cuda()

discriminator.eval()

UNetDiscriminatorSN(
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv2): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv3): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (conv4): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv5): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [7]:
# Define tranforms function
transform = transforms.Compose([
    transforms.ToTensor()
])

In [8]:
num_epoch = 5
batch_size = 2

In [9]:
# Create data loader for training
train_dataset = LoadDataset(low_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/lq_train',
                            high_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/hq_train',
                            transform=transform)
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Create data loader for testing
test_dataset = LoadDataset(low_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/lq_test',
                           high_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/hq_test',
                           transform=transform)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# Create data loader for validation
val_dataset = LoadDataset(low_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/lq_val',
                          high_quality_folder='/content/drive/MyDrive/ImgPro/RealESRGAN/datasets/hq_val',
                          transform=transform)
val_data_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [10]:
# Define training optimizer
optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=[0.9, 0.99])
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0001, betas=[0.9, 0.99])

# Loss function
criterion = L1Loss(loss_weight=1.0)
criterion_d = nn.BCELoss()

# Move the criterion to GPU if available
if cuda_available:
  criterion = criterion.cuda()
  criterion_d = criterion_d.cuda()

In [11]:
# Define the number of accumulation steps
accumulation_steps = 8
# Train and validate model
for epoch in range(num_epoch):
    # Training
    # Set to train mode
    generator.train()
    discriminator.train()
    train_loss = 0
    d_train_loss = 0
    total_loss = 0
    for i, (lq_images, hq_images) in enumerate(train_data_loader):
        # Move data to GPU if available
        if cuda_available:
            lq_images = lq_images.cuda()
            hq_images = hq_images.cuda()

        # Train generator
        lq_outputs = generator(lq_images)  # Forward pass

        # Calculate generator loss
        loss = criterion(lq_outputs, hq_images) # Calculate the loss
        train_loss += loss.item() * lq_images.size(0)
        total_loss += loss



        # Train discriminator
        real_predictions = discriminator(hq_images) # Forward pass
        real_predictions = torch.sigmoid(real_predictions) # Cap to 0-1
        fake_predictions = discriminator(lq_outputs) # Forward pass
        fake_predictions = torch.sigmoid(fake_predictions) # Cap to 0-1
        real_targets = torch.ones_like(real_predictions) # Create a tensor of ones with the same shape as real_predictions
        fake_targets = torch.zeros_like(fake_predictions) # Create a tensor of zeros with the same shape as fake_predictions
        if cuda_available:
            real_targets = real_targets.cuda()
            fake_targets = fake_targets.cuda()

        # Calculate discriminator loss
        d_loss_real = criterion_d(real_predictions, real_targets)
        d_loss_fake = criterion_d(fake_predictions, fake_targets)
        d_loss = (d_loss_real + d_loss_fake) / 2
        total_loss += d_loss

        d_train_loss += d_loss.item() * lq_images.size(0)

        # every accumulation_steps batches
        if (i + 1) % accumulation_steps == 0:
          # Calculate the average loss
          averaged_loss = total_loss / accumulation_steps

          # Zero the gradients
          optimizer.zero_grad()
          optimizer_d.zero_grad()

          # Backward pass
          averaged_loss.backward()

          # Update weights
          optimizer_d.step()
          optimizer.step()

          # Reset loss
          total_loss = 0

    average_train_loss = train_loss / len(train_data_loader.dataset)
    average_d_train_loss = d_train_loss / len(train_data_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epoch}, Average Training Loss: {average_train_loss:.4f}')
    print(f'Epoch {epoch+1}/{num_epoch}, Average Discriminator Training Loss: {average_d_train_loss:.4f}')

    # Validation
    generator.eval()  # Set the model to evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for lq_images, hq_images in val_data_loader:  # Assuming val_data_loader is your validation DataLoader instance
          # Move data to GPU if available
          if cuda_available:
            lq_images = lq_images.cuda()
            hq_images = hq_images.cuda()

          val_outputs = generator(lq_images)  # Forward pass for validation data
          val_loss += criterion(val_outputs, hq_images).item() * lq_images.size(0)

    average_val_loss = val_loss / len(val_data_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epoch}, Average Validation Loss: {average_val_loss:.4f}')

Epoch 1/5, Average Training Loss: 0.2641
Epoch 1/5, Average Discriminator Training Loss: 0.0247
Epoch 1/5, Average Validation Loss: 0.2677
Epoch 2/5, Average Training Loss: 0.2440
Epoch 2/5, Average Discriminator Training Loss: 0.0001
Epoch 2/5, Average Validation Loss: 0.2708
Epoch 3/5, Average Training Loss: 0.2411
Epoch 3/5, Average Discriminator Training Loss: 0.0001
Epoch 3/5, Average Validation Loss: 0.2373
Epoch 4/5, Average Training Loss: 0.2401
Epoch 4/5, Average Discriminator Training Loss: 0.0001
Epoch 4/5, Average Validation Loss: 0.2409
Epoch 5/5, Average Training Loss: 0.2380
Epoch 5/5, Average Discriminator Training Loss: 0.0001
Epoch 5/5, Average Validation Loss: 0.2464


In [12]:
# Save model
torch.save(generator, '/content/drive/MyDrive/ImgPro/RealESRGAN/models/trained/generator.pth')
torch.save(discriminator, '/content/drive/MyDrive/ImgPro/RealESRGAN/models/trained/discriminator.pth')