In [7]:
from typing import Annotated, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.functional import F
from dl_solver import HyperParameters


class PatchNet(nn.Module):
    hparams: HyperParameters

    def __init__(self, hparams: HyperParameters):
        super().__init__()
        self.hparams = hparams

        # self.temperatures = LearnableTemperatures(hparams)

    def _soft_forward_step(
        self,
        x: Tensor,
        pos_seq: Tensor,
        encoder_memory: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor, Tensor]:
        """

        Args:
            x: Tensor[B, num_pieces, num_features]
            pos_seq: Tensor[B, num_pieces, 3]
            encoder_memory: Optional[Tensor[B, num_pieces, num_features]]

        Returns:
            Tuple[Tensor, Tuple[Tensor, Tensor, Tensor], Tensor]: (pos_seq, (row, col, rot)_logits, encoder_memory)
        """
        x, encoder_memory = self.transformer(x, pos_seq, encoder_memory)
        logits = self.classifier(x, *self.hparams.puzzle_shape)

        num_rows, num_cols = self.hparams.puzzle_shape

        # Enhance logits using the unique selection strategy for 2D structure
        enhanced_probs = self._enhance_unique_selection(
            logits[0], logits[1], self.hparams.softmax_temperature
        )  # This should return [B, L, num_rows, num_cols]

        # Applying Gumbel-Softmax on the enhanced probabilities
        # Flatten the enhanced probabilities for Gumbel-Softmax if needed
        flat_enhanced_probs = enhanced_probs.view(
            enhanced_probs.shape[0], enhanced_probs.shape[1], -1
        )
        probabilities = F.gumbel_softmax(
            flat_enhanced_probs, tau=self.hparams.gumbel_temperature, hard=True
        )

        # Get indices from the probabilities
        indices = probabilities.argmax(dim=-1)
        row_indices = indices // num_cols  # Integer division to find row index
        col_indices = indices % num_cols  # Modulo to find column index

        # Rotation logits processed separately
        rotation_indices = torch.argmax(logits[2], dim=-1)

        # Stack the indices to form the final position sequence tensor
        pos_seq = torch.stack([row_indices, col_indices, rotation_indices], dim=-1).to(
            torch.float32
        )

        return pos_seq, logits, encoder_memory

    def _enhance_unique_selection(
        self, row_logits: Tensor, col_logits: Tensor
    ):
        """Apply self-competition to enhance unique selections

        Args:
            row_logits: Tensor["B, L, num_rows", torch.float32]
            col_logits: Tensor["B, L, num_cols", torch.float32]
            temperature (float): Temperature for softmax

        Returns:
            Tensor [B, L, num_rows, num_cols]: Joint probabilities that have been modified as follows:
            - In Case of a Clash, where multiple token have the same highest probability:
                - The token that has the highest probability for idx j - the coordinates (row_idx, col_idx)_j -
                  it's probability is not modified.
                - All other tokens that predicted the same coordinates (row_idx, col_idx)_j are penalized by reducing
                  P_j by a factor of hparams.not_unique_penalty = 2; All other logits are increased accordingly by applying softmax again!
                  This is applied iteratively until no more clashes are exist!
            - All Operations should be differentiable!
        """
        num_rows, num_cols = self.hparams.puzzle_shape

        # Compute joint probabilities
        joint_probs = self._compute_joint_probabilities(row_logits, col_logits)

        # Flatten to [B, num_pieces, num_rows * num_cols]
        flat_probs = joint_probs.view(joint_probs.shape[0], joint_probs.shape[1], -1)

        # Identify token clashes: check if the maximum probability per class is chosen by more than one token
        max_probs, max_indices = flat_probs.max(dim=2, keepdim=True)  # Max probability per token
        token_clashes = (flat_probs == max_probs).sum(dim=1) > 1  # More than one token per max probability

        # Identify class clashes: check if any class is being overly selected
        class_clashes = flat_probs.max(dim=1, keepdim=True)  # Max probability per class across all tokens

        # Apply penalties for token clashes
        penalties = token_clashes.float() * self.hparams.non_unique_penalty
        adjusted_logits = flat_probs - penalties

        # Normalize again using softmax
        return F.softmax(adjusted_logits, dim=-1).view_as(joint_probs)


    def _compute_joint_probabilities(
        self, row_logits: Tensor, col_logits: Tensor
    ) -> Tensor:
        """_summary_

        Args:
            row_logits (Tensor[B, num_pieces, num_rows])
            col_logits (Tensor[B, num_pieces, num_cols])
            temperature (float, optional): Defaults to 1.0.

        Returns:
            Tensor[B, num_pieces, num_rows, num_cols]: Joint probabilities
        """
        # Compute probabilities within each token over all classes
        row_probs = F.softmax(row_logits, dim=-1)
        col_probs = F.softmax(col_logits, dim=-1)

        joint_probs = row_probs[:, :, :, None] * col_probs[:, :, None, :]

        return joint_probs

    def apply_penalties(self, joint_probs: Tensor) -> Tensor:
        # Flatten to [B, num_pieces, num_rows * num_cols]
        flat_probs = joint_probs.view(*joint_probs.shape[:2], -1)

        max_probs_per_token, _ = flat_probs.max(
            dim=1, keepdim=True
        )  # [B, 1, num_rows * num_cols]
        max_probs_per_class, _ = flat_probs.max(
            dim=-1, keepdim=True
        )  # [B, num_pieces, 1]

        # Masks that identify the maximum probability per token / class / globally
        max_per_token = (
            flat_probs == max_probs_per_token
        )  # [B, num_pieces, num_rows * num_cols]
        max_per_class = flat_probs == max_probs_per_class

        penalty_scale = (
            torch.abs(flat_probs - max_probs_per_token)
            / max_probs_per_token
            * self.hparams.non_unique_penalty
        )

        # Apply penalties using softmax for differentiability
        flat_probs = torch.where(
            ~max_per_class & max_per_token,
            torch.softmax(flat_probs * penalty_scale, dim=-1),
            flat_probs,
        )
        flat_probs = torch.where(
            ~max_per_class & ~max_per_token,
            torch.softmax(flat_probs / penalty_scale, dim=-1),
            flat_probs,
        )
        flat_probs = torch.where(
            max_per_class & ~max_per_token,
            torch.softmax(flat_probs / penalty_scale, dim=-1),
            flat_probs,
        )
        flat_probs = torch.where(
            max_per_class & max_per_token,
            torch.softmax(flat_probs / penalty_scale, dim=-1),
            flat_probs,
        )

        return (
            (flat_probs + torch.finfo(torch.float32).eps).log().softmax(-1)
        ).view_as(joint_probs)




    @staticmethod
    def _check_unique_indices(spatial_indices: Tensor) -> Tensor:
        """
        Check uniqueness of spatial indices within each batch.
        Args:
            spatial_indices: Tensor[torch.int64] - (B, num_pieces, 2) [row_idx, col_idx]
        Returns:
            is_unique: Tensor[torch.bool] - (B, num_pieces)
        """
        batch_size, num_pieces = spatial_indices.size(0), spatial_indices.size(1)
        unique_mask = torch.ones(
            (batch_size, num_pieces), dtype=torch.bool, device=spatial_indices.device
        )

        # Check each batch independently
        for i in range(batch_size):
            _, inverse_indices, counts = torch.unique(
                spatial_indices[i], dim=0, return_inverse=True, return_counts=True
            )
            unique_mask[i] = counts[inverse_indices] == 1

        return unique_mask

model = PatchNet(HyperParameters(num_post_iters=20, non_unique_penalty=0.5))

row_logits = torch.rand(2, 12, 3)
col_logits = torch.rand(2, 12, 4)

joint_probs = model._compute_joint_probabilities(row_logits, col_logits)
print(joint_probs.shape)
print(joint_probs.sum(1))
penalized_probs = model.apply_penalties(joint_probs=joint_probs)
print(joint_probs.shape)

torch.Size([2, 12, 3, 4])
tensor([[[1.0850, 1.1302, 1.1053, 0.8410],
         [0.9896, 0.9749, 0.9435, 0.7341],
         [1.1135, 1.1562, 1.0542, 0.8725]],

        [[0.9771, 0.8424, 1.0311, 0.9762],
         [0.9647, 0.8057, 1.0239, 0.9300],
         [1.1297, 0.9856, 1.2481, 1.0853]]])
torch.Size([2, 12, 3, 4])


In [4]:
import torch
rotation_probs = torch.rand((8, 12, 4)).softmax(dim=-1)
ret = torch.functional.F.gumbel_softmax(rotation_probs, 1, hard=True).squeeze(
            -1
        )
print(ret.shape)
print(ret)

torch.Size([8, 12, 4])
tensor([[[0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.]],

        [[0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [0., 1., 0., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.]],

        [[1., 0., 0., 0.],
         [0., 1., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 1., 0.],
         [0., 0., 0., 1.],
         [0., 0., 1., 0.],
         [1., 0., 0., 0.]],


In [32]:
t = torch.rand(8, 12, 3, 4)
rows, _, cols = torch.split(t, [1, 1, 1], dim=-1)
rows = rows.squeeze(-1)
cols = cols.squeeze(-1)
rows.shape, cols.shape

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 4 (input tensor's size at dimension -1), but got split_sizes=[1, 1, 1]

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 4 (input tensor's size at dimension -1), but got split_sizes=[1, 1]

In [46]:
import numpy as np
import torch
import torch.nn as nn


def softargmax2d(input, beta=100):
    *_, h, w = input.shape

    input = input.reshape(*_, h * w)
    input = nn.functional.softmax(beta * input, dim=-1)

    indices_c, indices_r = torch.meshgrid(
        torch.linspace(0, 1, w, device=input.device),
        torch.linspace(0, 1, h, device=input.device),
    )

    indices_r = indices_r.reshape(-1, h * w)
    indices_c = indices_c.reshape(-1, h * w)

    result_r = torch.sum((h - 1) * input * indices_r, dim=-1)
    result_c = torch.sum((w - 1) * input * indices_c, dim=-1)

    result = torch.stack([result_r, result_c], dim=-1)

    return result


def softargmax1d(input, beta=100):
    *_, n = input.shape
    input = nn.functional.softmax(beta * input, dim=-1)
    indices = torch.linspace(0, 1, n)
    result = torch.sum((n - 1) * input * indices, dim=-1)
    return result

pos_probs = softargmax2d(torch.rand(8, 12, 3, 4))
rot_probs = softargmax1d(torch.rand(8, 12, 4))

pos_probs.shape, rot_probs.unsqueeze(-1).shape

(torch.Size([8, 12, 2]), torch.Size([8, 12, 1]))

In [47]:
pos_seq = torch.cat([pos_probs, rot_probs.unsqueeze(-1)], dim=-1)
pos_seq.shape

torch.Size([8, 12, 3])

In [19]:
flat_joint_probs = joint_probs.view(*joint_probs.shape[:2], -1)
token_to_class = flat_joint_probs.argmax(1)
class_to_token = flat_joint_probs.argmax(-1)

token_to_class, class_to_token

NameError: name 'joint_probs' is not defined

In [51]:
max=flat_joint_probs.argmax(dim=1, keepdim=True)

In [52]:
max_indices = flat_joint_probs.argmax(dim=1, keepdim=True)
is_max = torch.zeros_like(flat_joint_probs).scatter_(1, max_indices, 1)

In [53]:
def get_max(joint_probs: Tensor, dim: int) -> Tensor:
    max_indices = flat_joint_probs.argmax(dim=dim, keepdim=True)
    return torch.zeros_like(flat_joint_probs).scatter_(dim, max_indices, 1).bool()

In [60]:
old_flat_joint_prons = flat_joint_probs.clone()
flat_joint_probs[~get_max(flat_joint_probs, 1)] *= 0.5

In [63]:
_, indices = flat_joint_probs.sort(dim=-1, descending=True)
indices

tensor([[[ 4,  7,  0,  6,  5,  3,  1,  2,  8,  9, 11, 10],
         [ 3,  2, 11, 10,  0,  1,  8,  9,  7,  6,  5,  4],
         [ 5,  4,  1,  0,  9,  7,  8,  6,  3, 11,  2, 10],
         [ 8,  0, 11,  4,  9,  7,  3,  1, 10,  5,  2,  6],
         [ 1,  3,  9, 11,  2,  5,  7, 10,  0,  6,  8,  4],
         [ 3,  2,  7,  0,  1, 11,  6, 10,  9,  5,  8,  4],
         [ 8, 10,  9, 11,  4,  7,  5,  6,  1,  0,  2,  3],
         [ 9, 11,  7, 10,  5,  6,  8,  1,  3,  4,  2,  0],
         [ 3, 11,  0,  8,  1,  7,  9,  2, 10,  4,  5,  6],
         [10,  8,  2,  6,  0,  4,  9, 11,  1,  5,  7,  3],
         [ 6,  5, 10,  7,  4,  9,  2,  1, 11,  8,  3,  0],
         [ 0,  2,  6,  4,  1,  5, 10,  8,  9,  3,  7, 11]],

        [[ 4,  5,  7,  0,  8,  6,  1,  9,  3, 11, 10,  2],
         [ 9, 11, 10,  5,  7,  3,  8,  1,  6,  4,  2,  0],
         [ 7,  5,  6, 11,  4,  9,  3,  1, 10,  2,  8,  0],
         [ 1,  5,  3,  7,  9,  2,  0,  6, 11,  4,  8, 10],
         [ 2, 10,  1,  0,  8,  9,  6,  3,  5,  4, 11, 

In [72]:
def gumbel_softmax_argmax(logits, temperature=1.0):
    """
    Compute the Gumbel-Softmax (a.k.a Concrete) approximation of argmax.
    This function returns the index as a one-hot encoded vector.
    """
    y_soft = F.gumbel_softmax(logits, tau=temperature, hard=True, dim=-1)
    # _, index = y_soft.max(dim=-1)
    # y_hard = torch.zeros_like(logits).scatter_(-1, index.unsqueeze(-1), 1.0)
    # y = y_hard - y_soft.detach() + y_soft
    return y_soft

In [73]:
gumbel_softmax_argmax(flat_joint_probs)

tensor([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 

In [69]:
flat_joint_probs.argmax(-1)

tensor([[ 4,  3,  5,  8,  1,  3,  8,  9,  3, 10,  6,  0],
        [ 4,  9,  7,  1,  2,  7,  0,  9,  6, 10,  3,  1]])