In [14]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2
import math

In [109]:
class MascotDataset(Dataset):
	def __init__(self, spritesheet_path, num_characters, num_poses):
		"""
		Dataset class for mascot sprites with automatic dimension detection
		
		Args:
			spritesheet_path: Path to master spritesheet
			num_characters: Number of character rows in spritesheet
			num_poses: Number of pose columns in spritesheet
		"""
		self.spritesheet = cv2.imread(spritesheet_path)
		if self.spritesheet is None:
			raise ValueError(f"Could not load spritesheet from {
							 spritesheet_path}")

		self.num_characters = num_characters
		self.num_poses = num_poses

		# Calculate sprite dimensions from spritesheet
		sheet_height, sheet_width = self.spritesheet.shape[:2]
		self.sprite_height = sheet_height // num_characters
		self.sprite_width = sheet_width // num_poses

		if sheet_height % num_characters != 0 or sheet_width % num_poses != 0:
			raise ValueError(
				f"Spritesheet dimensions ({sheet_width}x{
					sheet_height}) are not "
				f"evenly divisible by grid size ({num_poses}x{num_characters})"
			)

		print(f"Detected sprite dimensions: {
			  self.sprite_width}x{self.sprite_height}")

		# Calculate padding to make dimensions divisible by 8 (for 3 downsample operations)
		self.pad_height = math.ceil(
			self.sprite_height / 8) * 8 - self.sprite_height
		self.pad_width = math.ceil(
			self.sprite_width / 8) * 8 - self.sprite_width

		self.padded_height = self.sprite_height + self.pad_height
		self.padded_width = self.sprite_width + self.pad_width

		if self.pad_height > 0 or self.pad_width > 0:
			print(f"Padding sprites to {self.padded_width}x{self.padded_height} "
				  f"(added {self.pad_width}x{self.pad_height} pixels)")

		# Extract and pad individual sprites
		self.sprites = []
		for i in range(self.num_characters):
			character_poses = []
			for j in range(self.num_poses):
				if j<5:
					continue
				# Extract sprite
				y_start = i * self.sprite_height
				x_start = j * self.sprite_width
				sprite = self.spritesheet[
					y_start:y_start + self.sprite_height,
					x_start:x_start + self.sprite_width
				]

				# Pad if necessary
				if self.pad_height > 0 or self.pad_width > 0:
					sprite = np.pad(
						sprite,
						((0, self.pad_height), (0, self.pad_width), (0, 0)),
						mode='constant',
						constant_values=0
					)
				character_poses.append(sprite)
			self.sprites.append(character_poses)

		self.sprites = np.array(self.sprites)

	def __len__(self):
		return self.num_characters * (self.num_poses -5 - 1)

	def __getitem__(self, idx):
		character_idx = idx // ((self.num_poses-5) - 1)
		pose_idx = (idx % ((self.num_poses-5) - 1)) + 1

		input_pose = self.sprites[character_idx, 0]
		target_pose = self.sprites[character_idx, pose_idx]

		input_pose = torch.FloatTensor(input_pose).permute(2, 0, 1) / 255.0
		target_pose = torch.FloatTensor(target_pose).permute(2, 0, 1) / 255.0

		return input_pose, target_pose

	def get_sprite_dimensions(self):
		"""Return the original (unpadded) sprite dimensions"""
		return self.sprite_width, self.sprite_height

In [110]:
class PoseGenerator(nn.Module):
    def __init__(self, input_height, input_width, base_channels=64, num_poses=25):
        """
        Generator network that automatically configures its architecture based on input dimensions
        
        Args:
            input_height: Height of input sprites (after padding)
            input_width: Width of input sprites (after padding)
            base_channels: Number of channels in first conv layer (doubles in subsequent layers)
            num_poses: Number of poses to generate (excluding input pose)
        """
        super(PoseGenerator, self).__init__()

        if not (input_height % 8 == 0 and input_width % 8 == 0):
            raise ValueError(
                "Input dimensions must be divisible by 8 for 3 downsample operations")

        # Calculate dimensions at each layer
        dims = self._calculate_dimensions(input_height, input_width)

        # Calculate channel progression (e.g., 3 -> 64 -> 128 -> 256)
        channels = [3] + [base_channels * (2**i) for i in range(3)]

        print(f"PoseGenerator Architecture:")
        print(f"Input size: {input_width}x{input_height}x{channels[0]}")

        # Build encoder layers dynamically
        encoder_layers = []
        for i in range(len(channels)-1):
            in_ch, out_ch = channels[i], channels[i+1]
            layer = [
                nn.Conv2d(in_ch, out_ch, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(0.2) if i == 0 else nn.BatchNorm2d(out_ch),
                nn.LeakyReLU(0.2)
            ]
            encoder_layers.extend(layer)
            print(f"Encoder layer {
                  i+1}: {dims[i+1][1]}x{dims[i+1][0]}x{out_ch}")

        self.encoder = nn.Sequential(*encoder_layers)

        # Pose embedding matches the final encoder channel count
        self.pose_embedding = nn.Parameter(
            torch.randn(num_poses, channels[-1]))
        print(f"Pose embedding size: {num_poses}x{channels[-1]}")

        # Build decoder layers dynamically (reverse channel progression)
        decoder_layers = []
        channels = channels[::-1]  # Reverse channel list for decoder
        for i in range(len(channels)-1):
            # Double first input for concatenation
            in_ch = channels[i] * 2 if i == 0 else channels[i]
            out_ch = channels[i+1]
            layer = [
                nn.ConvTranspose2d(
                    in_ch, out_ch, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(out_ch) if i < len(
                    channels)-2 else nn.Identity(),
                nn.ReLU() if i < len(channels)-2 else nn.Tanh()
            ]
            decoder_layers.extend(layer)
            print(f"Decoder layer {
                  i+1}: {dims[-(i+2)][1]}x{dims[-(i+2)][0]}x{out_ch}")

        self.decoder = nn.Sequential(*decoder_layers)

        # Save dimensions for forward pass
        self.bottleneck_shape = dims[3]  # Shape after encoder

    def _calculate_dimensions(self, height, width):
        """Calculate dimensions at each layer of the network"""
        dims = [(height, width)]  # Input dimensions

        # Calculate encoder dimensions (3 downsampling layers)
        for _ in range(3):
            height = height // 2
            width = width // 2
            dims.append((height, width))

        # Calculate decoder dimensions (3 upsampling layers)
        for _ in range(3):
            height = height * 2
            width = width * 2
            dims.append((height, width))

        return dims

    def forward(self, x, pose_idx):
        """
        Forward pass through the generator
        
        Args:
            x: Input pose tensor (B, 3, H, W)
            pose_idx: Indices of target poses to generate (B,)
        """
        # Encode input
        encoded = self.encoder(x)
        batch_size = x.size(0)

        # Get pose embeddings for requested indices
        pose_embed = self.pose_embedding[pose_idx]

        # Reshape pose embedding to match encoded dimensions
        pose_embed = pose_embed.view(batch_size, -1, 1, 1)
        pose_embed = pose_embed.expand(-1, -1,
                                       self.bottleneck_shape[0], self.bottleneck_shape[1])

        # Combine encoded image and pose embedding
        combined = torch.cat([encoded, pose_embed], dim=1)

        # Generate output pose
        output = self.decoder(combined)
        return output

In [111]:
def create_model_for_dataset(dataset):
    """
    Create a PoseGenerator model with appropriate dimensions for the dataset
    """
    return PoseGenerator(
        input_height=dataset.padded_height,
        input_width=dataset.padded_width
    )

In [121]:
def train_model(model, dataset, num_epochs=1000, batch_size=32*6, learning_rate=0.00002):
    """
    Training function for the pose generator
    """
    # Check if GPU is available and set device accordingly
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        for i, (input_poses, target_poses) in enumerate(dataloader):
            if torch.cuda.is_available():
                input_poses = input_poses.to(device)
                target_poses = target_poses.to(device)
            # Generate random pose indices
            pose_indices = torch.randint(0, 19, (input_poses.size(0),))

            # Generate new poses
            generated_poses = model(input_poses, pose_indices)

            # Calculate loss
            loss = criterion(generated_poses, target_poses)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 100 == 0:
                print(f'Epoch [{
                      epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}')

In [122]:
def generate_spritesheet(model, input_pose, num_poses, original_width, original_height):
	"""
	Generate a complete spritesheet for a new mascot
	
	Args:
		model: Trained PoseGenerator model
		input_pose: First pose of the new character
		num_poses: Number of poses to generate
		original_width: Width of the original (unpadded) sprites
		original_height: Height of the original (unpadded) sprites
	"""
	model.eval()

	# Create empty spritesheet
	spritesheet = np.zeros(
		(original_height, original_width * num_poses, 3),
		dtype=np.uint8
	)

	# Calculate padding needed to match model's expected input
	pad_height = math.ceil(original_height / 8) * 8 - original_height
	pad_width = math.ceil(original_width / 8) * 8 - original_width

	# Pad input pose if necessary
	if pad_height > 0 or pad_width > 0:
		padded_input = np.pad(
			input_pose,
			((0, pad_height), (0, pad_width), (0, 0)),
			mode='constant',
			constant_values=0
		)
	else:
		padded_input = input_pose

	# Place the input pose first
	spritesheet[:, :original_width] = input_pose

	# Generate all other poses
	with torch.no_grad():
		input_tensor = torch.FloatTensor(padded_input).permute(2, 0, 1) / 255.0
		input_tensor = input_tensor.unsqueeze(0).to(next(model.parameters()).device)

		for i in range(num_poses - 1):
			# Generate padded pose
			generated_pose = model(input_tensor, torch.tensor([i]).to(input_tensor.device))
			generated_pose = generated_pose.squeeze(0).permute(1, 2, 0)
			generated_pose = (generated_pose.cpu().numpy() * 255).astype(np.uint8)

			# Remove padding
			generated_pose = generated_pose[:original_height, :original_width]

			# Place in spritesheet
			x_offset = (i + 1) * original_width
			spritesheet[:, x_offset:x_offset+original_width] = generated_pose

	return spritesheet

In [123]:
# Example usage:
if True:
    # Load dataset
    dataset = MascotDataset('main.png',
                            num_characters=24, num_poses=26)
    
	# save first sprite to local to preview
    cv2.imwrite('first.png', dataset.sprites[0, 0])

    # Create and initialize model
    model = create_model_for_dataset(dataset)
    # print("Dataset: ", dataset[0])
    
	# Train model
    train_model(model, dataset)

Detected sprite dimensions: 45x50
Padding sprites to 48x56 (added 3x6 pixels)
PoseGenerator Architecture:
Input size: 48x56x3
Encoder layer 1: 24x28x64
Encoder layer 2: 12x14x128
Encoder layer 3: 6x7x256
Pose embedding size: 25x256
Decoder layer 1: 24x28x128
Decoder layer 2: 12x14x64
Decoder layer 3: 6x7x3
Epoch [0/1000], Step [0/3], Loss: 0.3655
Epoch [1/1000], Step [0/3], Loss: 0.3398
Epoch [2/1000], Step [0/3], Loss: 0.3199
Epoch [3/1000], Step [0/3], Loss: 0.2990
Epoch [4/1000], Step [0/3], Loss: 0.2794
Epoch [5/1000], Step [0/3], Loss: 0.2601
Epoch [6/1000], Step [0/3], Loss: 0.2449
Epoch [7/1000], Step [0/3], Loss: 0.2320
Epoch [8/1000], Step [0/3], Loss: 0.2163
Epoch [9/1000], Step [0/3], Loss: 0.2025
Epoch [10/1000], Step [0/3], Loss: 0.1895
Epoch [11/1000], Step [0/3], Loss: 0.1789
Epoch [12/1000], Step [0/3], Loss: 0.1654
Epoch [13/1000], Step [0/3], Loss: 0.1560
Epoch [14/1000], Step [0/3], Loss: 0.1445
Epoch [15/1000], Step [0/3], Loss: 0.1349
Epoch [16/1000], Step [0/3], L

In [125]:
# Generate new spritesheet
sprite_width, sprite_height = dataset.get_sprite_dimensions()
# get the first pose of first mascot from main.png
new_mascot_pose = dataset.sprites[0, 0]

# Resize input pose to match original dimensions
new_mascot_pose = cv2.resize(new_mascot_pose, (sprite_width, sprite_height))

spritesheet = generate_spritesheet(
	model,
	new_mascot_pose,
	num_poses=26,
	original_width=sprite_width,
	original_height=sprite_height
)

cv2.imwrite('new_mascot_spritesheet.png', spritesheet)

True