In [109]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
# from torchvision.models import resnet18, resnet50
from utils.resnet import resnet18, resnet50
from itertools import permutations

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda'
device

'cuda'

In [167]:

class DecoderNN_1input(nn.Module):
    def __init__(
            self,
            transforms,
            num_discrete_magnitude,
            device
        ):
        super().__init__()
        
        #save the model param
        self.encoder_dim = 2028
        self.decoder_dim = 512
        self.embed_size = 128

        self.transforms = transforms
        num_transforms = len(transforms)
        self.num_transforms = num_transforms
        self.num_discrete_magnitude = num_discrete_magnitude
        self.seq_length = 3

        self.transform_embedding = nn.Embedding(num_transforms+1, self.embed_size)
        self.magnitude_embedding = nn.Embedding(num_discrete_magnitude+1, self.embed_size)
        self.branch_id_embedding = nn.Embedding(2, self.embed_size)
        self.action_id_embedding = nn.Embedding(2, self.embed_size)

        self.rnn = nn.LSTMCell(self.embed_size * self.seq_length * 2 * 2, self.decoder_dim, bias=True)
        
        self.transform_fc = nn.Linear(self.decoder_dim,num_transforms)
        self.magnitude_fc = nn.Linear(self.decoder_dim,num_discrete_magnitude)
        
        self.device = device

    

    def init_hidden_state(self, batch_size):
        h = torch.zeros(batch_size, self.decoder_dim, device=device)
        c = torch.zeros(batch_size, self.decoder_dim, device=device)
        return h, c
    

    def lstm_forward(self, transform_history, magnitude_history, h_t, c_t):
        
        batch_size = transform_history.shape[0]
        
        transform_history_embd = self.transform_embedding(transform_history)
        magnitude_history_embd = self.magnitude_embedding(magnitude_history)
        input = torch.concat(
            (transform_history_embd, magnitude_history_embd),
            dim=-1
        ).reshape(batch_size, -1)
        h_t, c_t = self.rnn(input, (h_t, c_t))
        transform_logits = self.transform_fc(h_t)
        magnitude_logits = self.magnitude_fc(h_t)
        return h_t, c_t, transform_logits, magnitude_logits


    def forward(self, batch_size, old_action=None):
        
        device = self.device
        
        if old_action is not None:
            old_transform_actions_index = torch.zeros((batch_size, 2, self.seq_length), dtype=torch.long).to(device)
            old_magnitude_actions_index = torch.zeros((batch_size, 2, self.seq_length), dtype=torch.long).to(device)
            for i in range(len(old_action)):
                for b in range(2):
                    for s in range(self.seq_length):
                        transform_id = self.transforms.index(old_action[i][b][s][0])
                        level = old_action[i][b][s][2]
                        magnitude_id = round(level * self.num_discrete_magnitude)
                        old_transform_actions_index[i, b, s] = transform_id
                        old_magnitude_actions_index[i, b, s] = magnitude_id
            

        
        log_p =  torch.zeros(batch_size, 2, self.seq_length).to(device)
        
        transform_history = torch.full((batch_size, 2, self.seq_length), self.num_transforms, dtype=torch.long).to(device)
        magnitude_history = torch.full((batch_size, 2, self.seq_length), self.num_discrete_magnitude, dtype=torch.long).to(device)

        transform_entropy = 0
        magnitude_entropy = 0
        
        # Initialize LSTM state
        h_t, c_t = self.init_hidden_state(batch_size)  # (batch_size, decoder_dim)
        
        for branch in range(2):
            
            for step in range(self.seq_length):

                h_t, c_t, transform_logits, magnitude_logits = self.lstm_forward(
                    transform_history=transform_history,
                    magnitude_history=magnitude_history,
                    h_t=h_t,
                    c_t=c_t,
                )
                if old_action is None:
                    transform_action_index = Categorical(logits=transform_logits).sample()
                    magnitude_action_index = Categorical(logits=magnitude_logits).sample()
                else:
                    transform_action_index = old_transform_actions_index[:, branch, step]
                    magnitude_action_index = old_magnitude_actions_index[:, branch, step]
                                
                
                transform_log_p = F.log_softmax(transform_logits, dim=-1).gather(-1,transform_action_index.unsqueeze(-1))
                magnitude_log_p = F.log_softmax(magnitude_logits, dim=-1).gather(-1,magnitude_action_index.unsqueeze(-1))
                
                log_p[:, branch, step] = transform_log_p.squeeze(-1) + magnitude_log_p.squeeze(-1)
                
                transform_entropy += Categorical(logits=transform_logits).entropy().mean()
                magnitude_entropy += Categorical(logits=transform_logits).entropy().mean()
                
                transform_history[:, branch, step] = transform_action_index
                magnitude_history[:, branch, step] = magnitude_action_index



        transform_entropy /= (2*self.seq_length)
        magnitude_entropy /= (2*self.seq_length)
        entropy = transform_entropy + magnitude_entropy
        
        log_p = log_p.reshape(batch_size, -1).sum(-1) 
        log_p = log_p.unsqueeze(-1)
        
        action = []

        for i in range(batch_size):
            action.append([])
            action[-1].append([])
            action[-1].append([])
            for b in range(2):
                for s in range(self.seq_length):
                    level = (magnitude_history[i, b, s] / self.num_discrete_magnitude).item()
                    # level = magnitude_history[i, b, s].item()
                    
                    action[-1][b].append((
                        self.transforms[transform_history[i, b, s]],
                        0.8,
                        level
                    ))
        
        return (
            log_p,
            action,
            entropy
        )
  

In [168]:
len(old_action), len(old_action[0]), len(old_action[0][0])

(16, 2, 3)

In [169]:
old_action[3]

[[('nazim', 0.8, 0.800000011920929),
  ('kelba', 0.8, 0.20000000298023224),
  ('pipipopo', 0.8, 0.10000000149011612)],
 [('nazim', 0.8, 0.0),
  ('pipipopo', 0.8, 0.10000000149011612),
  ('nazim', 0.8, 0.10000000149011612)]]

In [174]:
new_action[3]

[[('nazim', 0.8, 0.30000001192092896),
  ('pipipopo', 0.8, 0.6000000238418579),
  ('kelba', 0.8, 0.800000011920929)],
 [('kelba', 0.8, 0.6000000238418579),
  ('pipipopo', 0.8, 0.4000000059604645),
  ('nazim', 0.8, 0.20000000298023224)]]

In [173]:
transforms = ['nazim', 'pipipopo', 'kelba']
net = DecoderNN_1input(transforms, 10, device).to(device)

for _ in range(100):
    old_log_p, old_action, old_entropy = net(batch_size=16)
    new_log_p, new_action, new_entropy = net(batch_size=16, old_action=old_action)
    
    assert (old_log_p == new_log_p).all()
    assert (old_entropy == new_entropy).all()

In [133]:
old_action

[[[('kelba', 0.8, 0.699999988079071), ('pipipopo', 0.8, 0.20000000298023224)],
  [('nazim', 0.8, 0.6000000238418579), ('kelba', 0.8, 0.10000000149011612)]],
 [[('nazim', 0.8, 0.800000011920929), ('pipipopo', 0.8, 0.30000001192092896)],
  [('kelba', 0.8, 0.30000001192092896), ('kelba', 0.8, 0.699999988079071)]],
 [[('kelba', 0.8, 0.6000000238418579), ('kelba', 0.8, 0.800000011920929)],
  [('pipipopo', 0.8, 0.699999988079071),
   ('pipipopo', 0.8, 0.6000000238418579)]],
 [[('kelba', 0.8, 0.10000000149011612), ('nazim', 0.8, 0.4000000059604645)],
  [('nazim', 0.8, 0.699999988079071), ('nazim', 0.8, 0.9000000357627869)]],
 [[('nazim', 0.8, 0.800000011920929), ('nazim', 0.8, 0.30000001192092896)],
  [('nazim', 0.8, 0.6000000238418579), ('pipipopo', 0.8, 0.6000000238418579)]],
 [[('nazim', 0.8, 0.9000000357627869), ('kelba', 0.8, 0.5)],
  [('kelba', 0.8, 0.10000000149011612), ('kelba', 0.8, 0.10000000149011612)]],
 [[('nazim', 0.8, 0.20000000298023224), ('kelba', 0.8, 0.0)],
  [('nazim', 0.8

In [51]:
actions_index

tensor([[[1, 2],
         [1, 0]],

        [[0, 0],
         [0, 2]],

        [[0, 2],
         [1, 2]],

        [[2, 1],
         [0, 2]],

        [[1, 2],
         [1, 2]],

        [[1, 0],
         [2, 1]],

        [[2, 0],
         [0, 1]],

        [[1, 1],
         [0, 2]],

        [[1, 0],
         [2, 2]],

        [[0, 2],
         [0, 2]],

        [[2, 1],
         [2, 1]],

        [[2, 2],
         [1, 0]],

        [[1, 1],
         [0, 0]],

        [[0, 2],
         [2, 1]],

        [[1, 1],
         [1, 0]],

        [[0, 2],
         [2, 0]]])