In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import  DataLoader
from tqdm import tqdm
from itertools import permutations


from math import factorial

from utils.networks import DecoderRNN, build_resnet18

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device

'cuda'

In [2]:
class DecoderNoInput(nn.Module):
    def __init__(self,
            num_transforms,
            num_discrete_magnitude,
            device
            ):
        super().__init__()
    
    #save the model param

        self.num_transforms = num_transforms
        self.num_discrete_magnitude = num_discrete_magnitude
        self.seq_length = num_transforms
        self.device = device
        
        self.permutations = torch.tensor(
            list(permutations(range(4)))
            ).to(device)
        
        self.num_transforms_permutations = len(self.permutations)
        self.num_actions = num_transforms * num_discrete_magnitude
        
        self.model = nn.Sequential(
            nn.Linear(1, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 2 * self.num_actions + 2 * self.num_transforms_permutations),
        )
        
        
    def forward(self, batch_size, old_action_index=None):
        
        x = torch.zeros((batch_size,1), dtype=torch.float32).to(self.device)
        
        output = self.model(x)
        
        magnitude_logits = output[:, :2 * self.num_actions]
        permutations_logits = output[:, 2 * self.num_actions:]
        
        magnitude_logits = magnitude_logits.reshape(batch_size, 2, self.num_transforms, self.num_discrete_magnitude)
        permutations_logits = permutations_logits.reshape(batch_size, 2, self.num_transforms_permutations)
        
        magnitude_dist = torch.distributions.Categorical(logits=magnitude_logits)
        permutations_dist = torch.distributions.Categorical(logits=permutations_logits)
        
        if old_action_index is None:
            magnitude_actions_index = magnitude_dist.sample()
            permutations_index = permutations_dist.sample()
        else:
            transform_actions_index, magnitude_actions_index = old_action_index
            matches = torch.all(transform_actions_index.unsqueeze(0) == self.permutations.unsqueeze(1).unsqueeze(1), dim=-1) * 1
            permutations_index = torch.argmax(matches, dim=0)
            magnitude_actions_index = magnitude_actions_index
                
        magnitude_log_p = F.log_softmax(magnitude_logits, dim=-1).gather(-1, magnitude_actions_index.unsqueeze(-1)).reshape(batch_size, -1).sum(-1, keepdim=True)
        permutation_log_p = F.log_softmax(permutations_logits, dim=-1).gather(-1, permutations_index.unsqueeze(-1)).reshape(batch_size, -1).sum(-1, keepdim=True)
        
        log_p = magnitude_log_p + permutation_log_p
        transform_actions_index = self.permutations[permutations_index]
        magnitude_actions_index = magnitude_actions_index
        transform_entropy = permutations_dist.entropy().mean()
        magnitude_entropy = magnitude_dist.entropy().mean()
        
        # print(log_p.shape)
        # print(transform_actions_index.shape)
        # print(magnitude_actions_index.shape)
        # print(transform_entropy.shape)
        # print(magnitude_entropy.shape)
        
        return (
                log_p,
                (transform_actions_index, magnitude_actions_index),
                (transform_entropy, magnitude_entropy)
            )

In [5]:
model = DecoderNoInput(4, 11, device).to(device)
batch_size = 1024
log_p, (transform_actions_index, magnitude_actions_index), (transform_entropy, magnitude_entropy) = model(batch_size)
new_log_p, (new_transform_actions_index, new_magnitude_actions_index), (transform_entropy, magnitude_entropy) = model(
    batch_size, 
    old_action_index=(transform_actions_index, magnitude_actions_index)
)

assert torch.isclose(log_p, new_log_p).all()
assert torch.isclose(transform_actions_index, new_transform_actions_index).all()
assert torch.isclose(magnitude_actions_index, new_magnitude_actions_index).all()

In [148]:
p = torch.tensor(list(permutations(range(4)))).to(device)
x = transform_actions_index.reshape(-1, 4)

In [149]:
matches = torch.all(p.unsqueeze(1) == x.unsqueeze(0), dim=-1) * 1
indices = torch.argmax(matches, dim=0)
indices

tensor([ 6, 11, 11,  ..., 23,  4, 15], device='cuda:0')

In [151]:
x.unsqueeze(0).shape, p.unsqueeze(1).shape

(torch.Size([1, 2048, 4]), torch.Size([24, 1, 4]))