In [1]:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader


class MinecraftImageDataset(Dataset):
    def __init__(self, folder_path, img_size=64):
        """
        Custom Dataset to load valid images from a directory.
        Args:
            folder_path (str): Path to the folder containing images.
            img_size (int): Target size to resize images (img_size x img_size).
        """
        self.folder_path = folder_path
        self.image_paths = self._filter_valid_images(folder_path)
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),  # Resize images
            transforms.ToTensor(),                    # Convert to tensor
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # Normalize to [-1, 1]
        ])

        print(f"Total valid images: {len(self.image_paths)}")  # Print the number of valid images

    def _filter_valid_images(self, folder_path):
        """
        Filters out invalid or corrupted images from the directory.
        Args:
            folder_path (str): Path to the folder containing images.

        Returns:
            List[str]: List of valid image paths.
        """
        valid_images = []
        for img in os.listdir(folder_path):
            if img.endswith(('.png', '.jpg', '.jpeg')):  # Check for valid extensions
                img_path = os.path.join(folder_path, img)
                try:
                    # Attempt to open and verify the image
                    with Image.open(img_path) as image:
                        image.verify()  # Check if the image is valid
                    valid_images.append(img_path)
                except Exception as e:
                    print(f"Invalid image: {img_path} - {e}")  # Log invalid image details
        return valid_images

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Ensure 3-channel RGB
        return self.transform(image)  # Apply transformations


def load_minecraft_data(folder_path, img_size=64, batch_size=32, shuffle=True):
    """
    Creates a DataLoader for the Minecraft images with validation.
    Args:
        folder_path (str): Path to the folder containing images.
        img_size (int): Target size to resize images (img_size x img_size).
        batch_size (int): Number of images per batch.
        shuffle (bool): Whether to shuffle the dataset.

    Returns:
        DataLoader: Torch DataLoader for the dataset.
    """
    dataset = MinecraftImageDataset(folder_path, img_size=img_size)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)


In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import torch.nn as nn
import torch
import torch.optim as optim
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.nn.functional as F
import numpy as np
from skimage.metrics import structural_similarity as ssim_sklearn


# Generator Model
class SelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.key = nn.Conv2d(in_channels, in_channels // 8, 1)
        self.value = nn.Conv2d(in_channels, in_channels, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        
        proj_query = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, height * width)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        
        proj_value = self.value(x).view(batch_size, -1, height * width)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, height, width)
        
        return self.gamma * out + x

class AdaptiveResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_spectral=True):
        super(AdaptiveResidualBlock, self).__init__()
        self.activation = nn.LeakyReLU(0.2, inplace=True)
        
        # Conditional batch norm or instance norm based on input size
        self.norm1 = nn.BatchNorm2d(out_channels)
        self.norm2 = nn.BatchNorm2d(out_channels)
        
        # Optional spectral normalization
        if use_spectral:
            self.conv1 = nn.utils.spectral_norm(
                nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1, bias=False)
            )
            self.conv2 = nn.utils.spectral_norm(
                nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
            )
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1, bias=False)
            self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        
        # Squeeze-and-Excitation block
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(out_channels, out_channels // 16, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 16, out_channels, 1),
            nn.Sigmoid()
        )
        
        # Residual connection
        self.shortcut = None
        if in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.activation(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        
        # Apply squeeze-and-excitation
        se_weight = self.se(out)
        out = out * se_weight
        
        if self.shortcut is not None:
            residual = self.shortcut(residual)
            
        out += residual
        out = self.activation(out)
        
        return out

class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, features_g=64):
        super(Generator, self).__init__()
        self.init_size = 4
        self.latent_dim = latent_dim
        
        # Calculate proper size for initial dense layer
        # features_g * 16 represents the number of feature maps in the first layer
        self.initial = nn.Sequential(
            # Changed dimension to match the expected size
            nn.Linear(latent_dim, features_g * 16 * self.init_size * self.init_size),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Main generation blocks
        self.main = nn.ModuleList([
            # 4x4 -> 8x8
            nn.Sequential(
                AdaptiveResidualBlock(features_g * 16, features_g * 8),
                nn.Upsample(scale_factor=2),
                SelfAttention(features_g * 8)
            ),
            # 8x8 -> 16x16
            nn.Sequential(
                AdaptiveResidualBlock(features_g * 8, features_g * 4),
                nn.Upsample(scale_factor=2),
                SelfAttention(features_g * 4)
            ),
            # 16x16 -> 32x32
            nn.Sequential(
                AdaptiveResidualBlock(features_g * 4, features_g * 2),
                nn.Upsample(scale_factor=2)
            ),
            # 32x32 -> 64x64
            nn.Sequential(
                AdaptiveResidualBlock(features_g * 2, features_g),
                nn.Upsample(scale_factor=2)
            )
        ])
        
        # Final layers
        self.final = nn.Sequential(
            nn.Conv2d(features_g, img_channels, 3, stride=1, padding=1),
            nn.Tanh()
        )
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
                
    def forward(self, z):
        # Initial dense layer
        out = self.initial(z)
        # Reshape using the correct dimensions
        out = out.view(-1, 64 * 16, self.init_size, self.init_size)
        
        # Main generation blocks
        for block in self.main:
            out = block(out)
        
        # Final convolution
        img = self.final(out)
        return img


# Critic Model (Discriminator)
class Critic(nn.Module):
    def __init__(self, img_channels, features_d=64):
        super(Critic, self).__init__()
        
        def critic_block(in_channels, out_channels, normalize=True):
            layers = [nn.utils.spectral_norm(
                nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1)
            )]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_channels, affine=True))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.features_d = features_d
        
        self.initial = nn.Sequential(
            *critic_block(img_channels, features_d, normalize=False),
            SelfAttention(features_d)
        )
        
        self.main = nn.Sequential(
            *critic_block(features_d, features_d * 2),
            SelfAttention(features_d * 2),
            *critic_block(features_d * 2, features_d * 4),
            SelfAttention(features_d * 4),
            *critic_block(features_d * 4, features_d * 8),
        )
        
        # Calculate the correct input size for the final linear layer
        # After 4 downsampling layers (stride=2), the spatial dimensions are reduced by factor of 16
        # For 64x64 input: 64/16 = 4, so final spatial dimensions are 4x4
        final_spatial_size = 4
        flattened_size = features_d * 8 * final_spatial_size * final_spatial_size
        
        self.final = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flattened_size, 1)
        )
        
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def forward(self, img):
        features = self.initial(img)
        features = self.main(features)
        return self.final(features)

# Gradient Penalty Function
def compute_gradient_penalty(critic, real_imgs, fake_imgs, device):
    """Compute the gradient penalty for WGAN-GP."""
    alpha = torch.rand(real_imgs.size(0), 1, 1, 1).to(device)
    interpolates = (alpha * real_imgs + (1 - alpha) * fake_imgs).requires_grad_(True)
    d_interpolates = critic(interpolates)
    fake = torch.ones(real_imgs.size(0), 1).to(device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

# WGAN Training
def train_wgan(generator, critic, dataloader, latent_dim, epochs, sample_interval=1, lambda_gp=10):
    optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.0, 0.999))
    optimizer_C = optim.Adam(critic.parameters(), lr=4e-4, betas=(0.0, 0.999))

    g_losses, c_losses, ssim_scores = [], [], []

    for epoch in range(epochs):
        g_loss_epoch, c_loss_epoch = 0.0, 0.0
        epoch_ssim_scores = []

        for i, imgs in enumerate(tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")):
            real_imgs = imgs.to(device).float()
            batch_size = real_imgs.size(0)

            # Train Critic
            optimizer_C.zero_grad()
            z = torch.randn(batch_size, latent_dim).to(device)
            fake_imgs = generator(z)
            real_loss = torch.mean(critic(real_imgs))
            fake_loss = torch.mean(critic(fake_imgs))
            gradient_penalty = compute_gradient_penalty(critic, real_imgs, fake_imgs, device)
            c_loss = fake_loss - real_loss + lambda_gp * gradient_penalty
            c_loss.backward()
            optimizer_C.step()
            c_loss_epoch += c_loss.item()

            # Train Generator every n_critic steps
            if i % 5 == 0:
                optimizer_G.zero_grad()
                z = torch.randn(batch_size, latent_dim).to(device)
                gen_imgs = generator(z)
                g_loss = -torch.mean(critic(gen_imgs))
                g_loss.backward()
                optimizer_G.step()
                g_loss_epoch += g_loss.item()

                # Calculate SSIM periodically
                ssim_score = calculate_ssim(real_imgs, gen_imgs)
                epoch_ssim_scores.append(ssim_score)

        # Record losses and SSIM
        g_losses.append(g_loss_epoch / len(dataloader))
        c_losses.append(c_loss_epoch / len(dataloader))
        ssim_scores.append(np.mean(epoch_ssim_scores))

        # Save sample images
        if (epoch + 1) % sample_interval == 0:
            display_sample_images(generator, latent_dim)

        print(f"[Epoch {epoch + 1}] Generator Loss: {g_losses[-1]:.4f}, Critic Loss: {c_losses[-1]:.4f}, SSIM: {ssim_scores[-1]:.4f}")

    # Plot losses and SSIM
    plot_losses_and_ssim(g_losses, c_losses, ssim_scores)

# Display and Plot Functions
def display_sample_images(generator, latent_dim):
    generator.eval()
    z = torch.randn(16, latent_dim).to(device)
    gen_imgs = generator(z).detach().cpu()
    gen_imgs = (gen_imgs + 1) / 2  # Rescale to [0, 1]
    grid = make_grid(gen_imgs, nrow=4)
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.show()
    generator.train()

def plot_losses_and_ssim(g_losses, c_losses, ssim_scores):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
    
    # Loss plot
    ax1.plot(g_losses, label="Generator Loss")
    ax1.plot(c_losses, label="Critic Loss")
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.legend()
    
    # SSIM plot
    ax2.plot(ssim_scores, label="SSIM Score", color='green')
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("SSIM")
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
def calculate_ssim(real_imgs, gen_imgs, win_size=3):
    """
    Calculate SSIM between real and generated images
    
    Args:
    real_imgs (torch.Tensor): Real image batch
    gen_imgs (torch.Tensor): Generated image batch
    win_size (int): Window size for SSIM calculation (should be smaller than the image size)
    
    Returns:
    float: Average SSIM score
    """
    # Convert images to numpy for SSIM calculation
    real_np = real_imgs.detach().cpu().numpy().transpose(0, 2, 3, 1)
    gen_np = gen_imgs.detach().cpu().numpy().transpose(0, 2, 3, 1)
    
    # Normalize images to [0, 1] range
    real_np = (real_np + 1) / 2
    gen_np = (gen_np + 1) / 2
    
    # Compute SSIM for each image with custom window size and data_range specified
    ssim_scores = [ssim_sklearn(real_np[i], gen_np[i], multichannel=True, win_size=win_size, data_range=1.0) 
                   for i in range(len(real_np))]
    
    return np.mean(ssim_scores)

# Hyperparameters and Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
latent_dim = 128
img_channels = 3
img_size = 64
epochs = 500
batch_size = 64
lambda_gp = 10

generator = Generator(latent_dim, img_channels).to(device)
critic = Critic(img_channels).to(device)

print(f"Available GPUs: {torch.cuda.device_count()}")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    generator = nn.DataParallel(generator)
    critic = nn.DataParallel(critic)

data_path = "/kaggle/input/minecraft-10000-skins"  # Replace with your dataset directory
dataloader = load_minecraft_data(data_path, img_size=img_size, batch_size=batch_size)
    
    # Initialize and Train WGAN
train_wgan(generator, critic, dataloader, latent_dim, epochs, sample_interval=5, lambda_gp=lambda_gp)
torch.save(generator.state_dict(), f"generator_epoch_{epochs}.pth")
torch.save(critic.state_dict(), f"critic_epoch_{epochs}.pth")

Available GPUs: 2
Using 2 GPUs
Total valid images: 10000


Epoch 1/500:   0%|          | 0/157 [00:00<?, ?it/s]

In [None]:
"""import torch
from torchvision.utils import save_image
import os

# Load the saved model weights
generator.load_state_dict(torch.load("/kaggle/input/epoch-500-0.0001-2x-db-15-sample-10000/pytorch/v1/1/generator_epoch_500.pth", map_location=device))
generator.eval()  # Set the generator to evaluation mode

# Generate and save images individually
def generate_and_save_individual_images(generator, latent_dim, output_dir="generated_images", num_images=16):
    os.makedirs(output_dir, exist_ok=True)  # Create directory to save images
    for i in range(num_images):
        z = torch.randn(1, latent_dim).to(device)  # Generate a single latent vector
        with torch.no_grad():
            gen_img = generator(z)  # Generate a single image
        gen_img = (gen_img + 1) / 2  # Rescale to [0, 1]
        gen_img = gen_img.cpu()  # Move to CPU
        save_path = os.path.join(output_dir, f"generated_image_{i + 1}.png")
        save_image(gen_img, save_path)  # Save the individual image
        print(f"Generated image saved to {save_path}")


# Generate and save the images
generate_and_save_individual_images(generator, latent_dim)"""

In [None]:
"""!pip install skinpy"""

In [None]:
"""import torch
from skinpy import Skin, Perspective
from PIL import Image
import matplotlib.pyplot as plt

def generate_and_render_skin(generator, latent_dim):
    # Generate a single latent vector
    z = torch.randn(1, latent_dim).to(device)
    
    with torch.no_grad():
        gen_img = generator(z)  # Generate a single image
    
    # Rescale to [0, 1] and move to CPU
    gen_img = (gen_img + 1) / 2
    gen_img = gen_img.cpu().squeeze().permute(1, 2, 0)
    
    # Convert to PIL Image and add alpha channel
    gen_img_pil = Image.fromarray((gen_img.numpy() * 255).astype('uint8'))
    gen_img_rgba = gen_img_pil.convert('RGBA')
    gen_img_rgba.save('generated_skin.png')
    
    # Create skin and perspective for rendering
    skin = Skin.from_path("generated_skin.png")
    perspective = Perspective(
        x="left",
        y="front", 
        z="up",
        scaling_factor=5
    )
    
    # Save and render the isometric image
    render = skin.to_isometric_image(perspective)
    render.save("render.png")
    
    # Display images
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    
    # Display generated skin
    ax1.imshow(gen_img_rgba)
    ax1.set_title('Generated Skin')
    ax1.axis('off')
    
    # Display rendered skin
    ax2.imshow(render)
    ax2.set_title('Isometric Render')
    ax2.axis('off')
    plt.tight_layout()
    plt.show()

# Generate and render the skin
for i in range(20):
    generate_and_render_skin(generator, latent_dim)"""