In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import os

def load_images(folder_path, target_res):
    images = []
    resolutions = []
    for filename in os.listdir(folder_path):
        if filename.endswith(('.png', '.jpg', '.jpeg')):
            img_path = os.path.join(folder_path, filename)
            img = Image.open(img_path)
            current_res = min(img.size)  # Use the smaller dimension as resolution
            resolutions.append(current_res)

            # Resize the image to target_res x target_res
            img = img.resize((target_res, target_res))

            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
            img_tensor = transform(img)
            images.append(img_tensor)

    return torch.stack(images), resolutions

# Generator
class Generator(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels + 2, 64, kernel_size=9, padding=4)  # +2 for resolution info
        self.conv2 = nn.Conv2d(64, 32, kernel_size=1)
        self.conv3 = nn.Conv2d(32, output_channels, kernel_size=5, padding=2)

    def forward(self, x, current_res, target_res):
        # Create resolution channels
        batch_size, _, height, width = x.shape
        current_res_channel = torch.full((batch_size, 1, height, width), current_res, device=x.device) / 255.0
        target_res_channel = torch.full((batch_size, 1, height, width), target_res, device=x.device) / 255.0

        # Concatenate input with resolution information
        x = torch.cat([x, current_res_channel, target_res_channel], dim=1)

        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.conv3(x)
        return torch.tanh(x)



In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1)
        # Calculate the correct input size for the fully connected layer
        self.fc = nn.Linear(128 * (target_res // 4) * (target_res // 4), 1) # Changed this line to calculate the input size dynamically

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.conv2(x), 0.2)
        x = F.leaky_relu(self.conv3(x), 0.2)
        x = F.leaky_relu(self.conv4(x), 0.2)
        x = x.view(x.size(0), -1)
        x = torch.sigmoid(self.fc(x))
        return x



In [None]:
# Function to calculate scale factor
def calculate_scale_factor(current_res, target_res):
    return target_res // current_res


# Training function
def train_gan(generator, discriminator, dataloader, resolutions, target_res, num_epochs, device):
    criterion = nn.BCELoss()
    optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

    for epoch in range(num_epochs):
        for i, (real_images, current_res) in enumerate(zip(dataloader, resolutions)):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)
            current_res = current_res.to(device)

            # Train Discriminator
            optimizer_D.zero_grad()
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            outputs = discriminator(real_images)
            d_loss_real = criterion(outputs, real_labels)

            fake_images = generator(real_images, current_res, target_res)
            outputs = discriminator(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)

            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            optimizer_D.step()

            # Train Generator
            optimizer_G.zero_grad()
            outputs = discriminator(fake_images)
            g_loss = criterion(outputs, real_labels)
            g_loss.backward()
            optimizer_G.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}")

In [None]:
# Main execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Get user input
    folder_path = input("Enter the folder path containing images: ")
    target_res = int(input("Enter the target resolution: "))

    # Load images and get resolutions
    images, resolutions = load_images(folder_path, target_res)
    dataloader = torch.utils.data.DataLoader(images, batch_size=4, shuffle=True)
    resolutions = torch.tensor(resolutions)

    # Initialize models
    generator = Generator(3, 3).to(device)
    discriminator = Discriminator(3).to(device)

    # Train the GAN
    num_epochs = 50
    train_gan(generator, discriminator, dataloader, resolutions, target_res, num_epochs, device)

    # Save the trained generator
    torch.save(generator.state_dict(), 'generator.pth')

    print("Training complete. Generator saved as 'generator.pth'")

Enter the folder path containing images: /content/drive/MyDrive/archive/samples/0001e96803--621ea6adb419e86592d408c5
Enter the target resolution: 700
Epoch [1/50], D_loss: 1.3866, G_loss: 0.3985
Epoch [2/50], D_loss: 1.7406, G_loss: 4.8033
Epoch [3/50], D_loss: 0.0388, G_loss: 9.0634
Epoch [4/50], D_loss: 0.2367, G_loss: 9.2527
Epoch [5/50], D_loss: 0.0026, G_loss: 8.4763
Epoch [6/50], D_loss: 0.0006, G_loss: 7.2654
Epoch [7/50], D_loss: 0.0021, G_loss: 5.8649
Epoch [8/50], D_loss: 0.0100, G_loss: 5.2946
Epoch [9/50], D_loss: 0.0258, G_loss: 6.5123
Epoch [10/50], D_loss: 0.0145, G_loss: 8.4254
Epoch [11/50], D_loss: 0.0035, G_loss: 10.1285
Epoch [12/50], D_loss: 0.0013, G_loss: 10.8733
Epoch [13/50], D_loss: 0.0041, G_loss: 10.7036
Epoch [14/50], D_loss: 0.0154, G_loss: 13.4407
Epoch [15/50], D_loss: 0.0007, G_loss: 16.8453
Epoch [16/50], D_loss: 0.0000, G_loss: 20.5662
Epoch [17/50], D_loss: 0.0000, G_loss: 25.0547
Epoch [18/50], D_loss: 0.0000, G_loss: 30.5726
Epoch [19/50], D_loss: 