In [1]:
import torch
from torch.utils.data import DataLoader
import os
import torchvision
import torch.nn as nn
import torch.optim as optim

# Set the directory path to the folders containing _preview.png and asset files
data_dir = 'preview_data'

In [2]:
# Check if CUDA is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [8]:
from PIL import Image
import torchvision.transforms as transforms

class CompositionDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.folders = [f for f in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, f))]

    def __len__(self):
        return len(self.folders)

    def __getitem__(self, idx):
        folder = self.folders[idx]
        preview_path = os.path.join(self.data_dir, folder, '_preview.png')
        asset_paths = [os.path.join(self.data_dir, folder, f) for f in os.listdir(os.path.join(self.data_dir, folder)) if f != '_preview.png']

        # Load the _preview.png image
        preview_img = self.load_preview_image(preview_path)
        
        # Load the asset images
        asset_imgs = []
        for path in asset_paths:
            try:
                img = self.load_image(path)
                asset_imgs.append(img)
            except Exception as e:
                print(f"Error loading asset image {path}: {str(e)}")

        if not asset_imgs:
            # Raise an exception to stop the DataLoader
            raise ValueError(f"No asset images found for folder {folder}.") 

        # Concatenate asset_imgs along the channel dimension
        asset_tensor = torch.cat(asset_imgs, dim=0)

        # Pad or truncate asset_tensor to fixed number of channels
        target_channels = 48  # or whatever maximum number of channels you expect
        if asset_tensor.size(0) < target_channels:
            padding = torch.zeros(target_channels - asset_tensor.size(0), 256, 256)
            asset_tensor = torch.cat([asset_tensor, padding], dim=0)
        elif asset_tensor.size(0) > target_channels:
            asset_tensor = asset_tensor[:target_channels]

        #print(f"Folder: {folder}")
        #print(f"Number of asset images: {len(asset_imgs)}")
        #print("Preview image shape:", preview_img.shape)
        #print("Asset tensor shape:", asset_tensor.shape)

        return preview_img, asset_tensor
    
    def load_preview_image(self, path):
        img = Image.open(path)
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        img_tensor = transform(img)
        
        # Convert grayscale to RGB if necessary
        if img_tensor.shape[0] == 1:
            img_tensor = img_tensor.repeat(3, 1, 1)
        
        # Normalize the image
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        img_tensor = normalize(img_tensor)
        
        return img_tensor  # Shape: [3, 256, 256]

    def load_image(self, path):
        img = Image.open(path)
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
        ])
        img_tensor = transform(img)
        
        # Convert grayscale to RGB if necessary
        if img_tensor.shape[0] == 1:
            img_tensor = img_tensor.repeat(3, 1, 1)
        
        # Normalize the image
        normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        img_tensor = normalize(img_tensor)
        
        return img_tensor  # Shape: [3, 256, 256]

# Create a data loader from the custom dataset
dataset = CompositionDataset(data_dir)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [4]:
# Define the CycleGAN model
class CycleGAN(nn.Module):
    def __init__(self, in_channels_A, in_channels_B, out_channels_A, out_channels_B, num_residual_blocks):
        super(CycleGAN, self).__init__()
        self.generator_A2B = Generator(in_channels_A, out_channels_B, num_residual_blocks)
        self.generator_B2A = Generator(in_channels_B, out_channels_A, num_residual_blocks)
        self.discriminator_A = Discriminator(in_channels_A)
        self.discriminator_B = Discriminator(in_channels_B)

    def forward(self, x_A, x_B):
        out_A2B = self.generator_A2B(x_A)
        out_B2A = self.generator_B2A(x_B)
        return out_A2B, out_B2A

    def discriminator_forward(self, x_A, x_B):
        out_A = self.discriminator_A(x_A)
        out_B = self.discriminator_B(x_B)
        return out_A, out_B

class Generator(nn.Module):
    def __init__(self, in_channels, out_channels, num_residual_blocks):
        super(Generator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Down-sampling
        self.down_blocks = nn.Sequential(
            self._block(64, 128, 3, 2, 1),
            self._block(128, 256, 3, 2, 1)
        )
        
        # Residual blocks
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(num_residual_blocks)]
        )
        
        # Up-sampling
        self.up_blocks = nn.Sequential(
            self._block(256, 128, 3, 2, 1, upsample=True),
            self._block(128, 64, 3, 2, 1, upsample=True)
        )
        
        self.last = nn.Conv2d(64, out_channels, kernel_size=7, padding=3)
        
    def _block(self, in_channels, out_channels, kernel_size, stride, padding, upsample=False):
        layers = []
        if upsample:
            layers.append(nn.Upsample(scale_factor=2, mode='nearest'))
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=padding))
        else:
            layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding))
        layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.ReLU(inplace=True))
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.initial(x)
        x = self.down_blocks(x)
        x = self.res_blocks(x)
        x = self.up_blocks(x)
        return self.last(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.instance_norm = nn.InstanceNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        out = self.relu(self.instance_norm(self.conv1(x)))
        out = self.instance_norm(self.conv2(out))
        out += residual
        return self.relu(out)

class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1),
            nn.Sigmoid()  # Add sigmoid activation here
        )

    def forward(self, img):
        return self.model(img)


In [5]:
# Initialize the model and optimizers
in_channels_B = 48  # 16 assets * 3 channels each
model = CycleGAN(in_channels_A=3, in_channels_B=in_channels_B, out_channels_A=3, out_channels_B=in_channels_B, num_residual_blocks=6)
model = model.to(device)

optimizer_G = optim.Adam(list(model.generator_A2B.parameters()) + list(model.generator_B2A.parameters()), lr=0.001)
optimizer_D_A = optim.Adam(model.discriminator_A.parameters(), lr=0.001)
optimizer_D_B = optim.Adam(model.discriminator_B.parameters(), lr=0.001)

# Define the loss functions
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

In [9]:
# Train the model
for epoch in range(100):
    for i, (preview_img, asset_tensor) in enumerate(dataloader):
        # Move data to device (GPU or CPU)
        preview_img, asset_tensor = preview_img.to(device), asset_tensor.to(device)

        # Forward pass
        out_A2B, out_B2A = model(preview_img, asset_tensor)

        # Calculate losses
        # --- Discriminator A ---
        out_D_A_real = model.discriminator_A(preview_img)
        out_D_A_fake = model.discriminator_A(out_B2A.detach())  # Detach to avoid backpropagating through generator
        out_D_A_real = out_D_A_real.view(out_D_A_real.size(0), -1)  # Flatten
        out_D_A_fake = out_D_A_fake.view(out_D_A_fake.size(0), -1)
        
        # Reshape target labels to match flattened discriminator output
        target_real = torch.ones(out_D_A_real.size()).to(device)
        target_fake = torch.zeros(out_D_A_fake.size()).to(device)

        loss_D_A_real = criterion_GAN(out_D_A_real, target_real)
        loss_D_A_fake = criterion_GAN(out_D_A_fake, target_fake)
        loss_D_A = (loss_D_A_real + loss_D_A_fake) * 0.5

        # --- Discriminator B ---
        out_D_B_real = model.discriminator_B(asset_tensor)
        out_D_B_fake = model.discriminator_B(out_A2B.detach())
        out_D_B_real = out_D_B_real.view(out_D_B_real.size(0), -1)
        out_D_B_fake = out_D_B_fake.view(out_D_B_fake.size(0), -1)

        loss_D_B_real = criterion_GAN(out_D_B_real, target_real)
        loss_D_B_fake = criterion_GAN(out_D_B_fake, target_fake)
        loss_D_B = (loss_D_B_real + loss_D_B_fake) * 0.5
        # --- Generator ---
        out_D_B_A2B = model.discriminator_B(out_A2B)  # Get discriminator output
        out_D_B_A2B = out_D_B_A2B.view(out_D_B_A2B.size(0), -1)  # Flatten
        loss_GAN_A2B = criterion_GAN(out_D_B_A2B, target_real)

        out_D_A_B2A = model.discriminator_A(out_B2A)
        out_D_A_B2A = out_D_A_B2A.view(out_D_A_B2A.size(0), -1)
        loss_GAN_B2A = criterion_GAN(out_D_A_B2A, target_real)

        loss_cycle_A = criterion_cycle(out_B2A, preview_img)
        loss_cycle_B = criterion_cycle(out_A2B, asset_tensor)

        loss_identity_A = criterion_identity(out_B2A, preview_img)
        loss_identity_B = criterion_identity(out_A2B, asset_tensor)

        loss_G = loss_GAN_A2B + loss_GAN_B2A + 10 * (loss_cycle_A + loss_cycle_B) + 5 * (loss_identity_A + loss_identity_B)

        # Backward pass
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Train discriminators
        optimizer_D_A.zero_grad()
        loss_D_A.backward()
        optimizer_D_A.step()

        optimizer_D_B.zero_grad()
        loss_D_B.backward()
        optimizer_D_B.step()

        # Print lossesEpoch [1/100], Step [1/153], Loss_G: 21.1746, Loss_D_A: 0.7167, Loss_D_B: 0.7078
        if i % 10 == 0:  # Print every 10 iterations
            print(f'Epoch [{epoch+1}/100], Step [{i+1}/{len(dataloader)}], Loss_G: {loss_G.item():.4f}, Loss_D_A: {loss_D_A.item():.4f}, Loss_D_B: {loss_D_B.item():.4f}')

 
    # Optional: Save the model after each epoch
    # torch.save(model.state_dict(), f'cyclegan_model_epoch_{epoch+1}.pth')

In [None]:
# # Use the trained model to compose new images
# def compose_images(assets):
#     # Stack the asset images into a single tensor
#     asset_tensor = torch.stack(assets)

#     # Use the trained model to generate a composed image
#     output = model(asset_tensor, asset_tensor)
#     return output

# # Example usage:
# assets = [torchvision.load_image('asset1.png'), torchvision.load_image('asset2.png'), ...]
# composed_img = compose_images(assets)