In [4]:
import os
import re
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import torch
import torch.nn as nn
import torch.nn.functional as F

vision_output_dim = 3136
num_words = 44  # Number of unique words in the vocabulary
language_output_dim = 128
embedding_dim = 128
mixing_dim = 256
lstm_hidden_dim = 256
num_actions = 4

# (3,128,128) --> (64,7,7) = 3136 (3-layer CNN)
class VisualModule(nn.Module): 
    def __init__(self):
        super(VisualModule, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=3, padding=0),
            nn.ReLU()
        )
        # self.conv = nn.Sequential(
        #     nn.Conv2d(3, 32, kernel_size=5, stride=2, padding=2),
        #     nn.ReLU(),
        #     nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2),
        #     nn.ReLU(),
        #     nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
        #     nn.ReLU(),
        #     nn.Conv2d(128, 64, kernel_size=5, stride=2, padding=1),
        #     nn.ReLU(),
        #     nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU(),
        #     nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
        #     nn.ReLU(),
        # )

    def forward(self, vt):
        encoded_vt = self.conv(vt)
        return encoded_vt.view(vt.size(0), -1).squeeze()

# one-hot encoding [0 0 1 0 0] --> 128 dimensional embedding (FF)
# S1:5 S2:5 S3:11 S4:9 --> 30 + 5 (noun) = 35 in total
class LanguageModule(nn.Module): 
    def __init__(self, num_words, embedding_dim):
        super(LanguageModule, self).__init__()
        self.embedding = nn.Linear(num_words, embedding_dim)

    def forward(self, lt):
        embedded_lt = self.embedding(lt)
        return embedded_lt

# 3136(vision) + 128 (language) --> 256 dimensional embedding (FF)
class MixingModule(nn.Module):
    def __init__(self, vision_output_dim, language_output_dim, mixing_dim):
        super(MixingModule, self).__init__()
        self.linear = nn.Linear(vision_output_dim + language_output_dim, mixing_dim)

    def forward(self, vision_output, language_output):
        combined_output = torch.cat((vision_output, language_output), dim=0)
        mixed_output = self.linear(combined_output)
        return mixed_output

class LSTMModule(nn.Module):
    def __init__(self,mixing_dim,lstm_hidden_dim):
        super(LSTMModule, self).__init__()
        self.lstm = nn.LSTMCell(mixing_dim, lstm_hidden_dim)
    
    def forward(self,mixed_output,lstm_hidden_state):
        lstm_hidden_state = self.lstm(mixed_output, lstm_hidden_state) 
        # lstm_output = lstm_hidden_state[0] # output is (hidden_state,cell_state), we need hidden state, shape (1,256)
        return lstm_hidden_state

class Agent(nn.Module):
    def __init__(self, num_words, embedding_dim, vision_output_dim, language_output_dim, mixing_dim, lstm_hidden_dim,num_actions):
        super(Agent, self).__init__()
        self.language_module = LanguageModule(num_words, embedding_dim)
        self.visual_module = VisualModule()
        self.mixing_module = MixingModule(vision_output_dim, language_output_dim, mixing_dim)
        self.lstm_module = LSTMModule(mixing_dim, lstm_hidden_dim)
        self.action_predictor = nn.Linear(lstm_hidden_dim, num_actions)
        self.value_estimator = nn.Linear(lstm_hidden_dim, 1)

    def forward(self, vt, lt, lstm_hidden_state):
        vision_output = self.visual_module(vt)
        language_output = self.language_module(lt)
        mixed_output = self.mixing_module(vision_output, language_output).unsqueeze(0)
        lstm_output = self.lstm_module(mixed_output,lstm_hidden_state)
        action_probs = self.action_predictor(lstm_output[0]) 
        value_estimate = self.value_estimator(lstm_output[0])
        return action_probs,value_estimate,lstm_output
        
        
    def save(self, episode, ALG_NAME, ENV_ID):
        path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save(self.state_dict(), os.path.join(path, f'agent_{episode}.pt'))

    def load(self, episode, ALG_NAME, ENV_ID):
        path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
        self.load_state_dict(torch.load(os.path.join(path, f'agent_{episode}.pt')))     

    def count_trainable_parameters(self):
        print("Trainable parameters in each module:")
        for name, module in self.named_modules():
            total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
            print(f"{name}: {total_params}")
    # def load(self, episode, ALG_NAME, ENV_ID):
    #     path = os.path.join('model', '_'.join([ALG_NAME, ENV_ID]))
    #     saved_state_dict = torch.load(os.path.join(path, f'agent_{episode}.pt'))

    #     # Create a new state_dict for the model and only copy parameters except 'language_module'
    #     new_state_dict = {}
    #     for key, value in saved_state_dict.items():
    #         if 'language_module' not in key:
    #             new_state_dict[key] = value

    #     # Load the modified state_dict into the agent
    #     self.load_state_dict(new_state_dict, strict=False)



In [5]:
agent = Agent(num_words, embedding_dim, vision_output_dim, language_output_dim, mixing_dim, lstm_hidden_dim,num_actions)
agent.count_trainable_parameters()

Trainable parameters in each module:
: 1425541
language_module: 5760
language_module.embedding: 5760
visual_module: 56320
visual_module.conv: 56320
visual_module.conv.0: 896
visual_module.conv.1: 0
visual_module.conv.2: 18496
visual_module.conv.3: 0
visual_module.conv.4: 36928
visual_module.conv.5: 0
mixing_module: 835840
mixing_module.linear: 835840
lstm_module: 526336
lstm_module.lstm: 526336
action_predictor: 1028
value_estimator: 257


In [16]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F


num_words = 44  # Number of unique words in the vocabulary
lstm_hidden_dim = 256
num_actions = 4

# (3,128,128) --> (64,16,16)
class ImageEncoderCNN(nn.Module):
    def __init__(self):
        super(ImageEncoderCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=3, padding=0),
            nn.ReLU()
        )
        self.fc = nn.Linear(64 * 7* 7, 256)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc(x))
        return x

class LanguageEncoderMLP(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LanguageEncoderMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        
        
    def forward(self, x):
        x1 = self.fc1(x)
        return x1

class KVLinear(nn.Module):
    def __init__(self, hidden_size):
        super(KVLinear, self).__init__()
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        keys = self.key(x)
        values = self.value(x)
        return keys, values

class QLinear(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(QLinear, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)

    def forward(self, x):
        return self.fc(x).unsqueeze(0)

class Agent(nn.Module):
    def __init__(self, hidden_size):
        super(Agent, self).__init__()
        self.image_encoder = ImageEncoderCNN()
        self.lstm = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
        self.language_encoder_mlp = LanguageEncoderMLP(input_size=44, hidden_size=hidden_size)
        self.keyvalue_linear = KVLinear(hidden_size)
        self.query_linear = QLinear(input_size=hidden_size, hidden_size=hidden_size)
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=1)
        self.action_value_linear = nn.Linear(hidden_size, 4)
        self.value_linear = nn.Linear(hidden_size, 1)

    def forward(self, vt, lt,lstm_hiden_state):
        image_features = self.image_encoder(vt)
        # image_features = image_features.unsqueeze(0)  # Add sequence dimension
        lstm_output = self.lstm(image_features,lstm_hiden_state)
        language_features = self.language_encoder_mlp(lt)
        keys, values = self.keyvalue_linear(lstm_output[0])
        querys = self.query_linear(language_features)
        querys = querys.unsqueeze(0)
        keys = keys.unsqueeze(0)
        values = values.unsqueeze(0)
        context_vector,_= self.attention(querys, keys, values)
        context_vector = context_vector.squeeze(0)
        actions = self.action_value_linear(context_vector)
        value = self.value_linear(context_vector)
        return actions, value, lstm_output 
    
    def save(self, episode, best_score, ALG_NAME, ENV_ID):
        path = os.path.join(r'C:\Users\linzj\Desktop\model', '_'.join([ALG_NAME, ENV_ID]))
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save(self.state_dict(), os.path.join(path, f'agent_{episode}_{best_score}.pt'))

    def load(self, episode, ALG_NAME, ENV_ID):
        path = os.path.join(r'C:\Users\linzj\Desktop\model', '_'.join([ALG_NAME, ENV_ID]))
        self.load_state_dict(torch.load(os.path.join(path, f'agent_{episode}.pt')))    

    def count_trainable_parameters(self):
        print("Trainable parameters in each module:")
        for name, module in self.named_modules():
            total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
            print(f"{name}: {total_params}")

        


In [17]:
agent = Agent(256)
agent.count_trainable_parameters()

Trainable parameters in each module:
: 1859077
image_encoder: 859392
image_encoder.conv: 56320
image_encoder.conv.0: 896
image_encoder.conv.1: 0
image_encoder.conv.2: 18496
image_encoder.conv.3: 0
image_encoder.conv.4: 36928
image_encoder.conv.5: 0
image_encoder.fc: 803072
lstm: 526336
language_encoder_mlp: 11520
language_encoder_mlp.fc1: 11520
keyvalue_linear: 131584
keyvalue_linear.key: 65792
keyvalue_linear.value: 65792
query_linear: 65792
query_linear.fc: 65792
attention: 263168
attention.out_proj: 65792
action_value_linear: 1028
value_linear: 257


In [18]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F


num_words = 44  # Number of unique words in the vocabulary
lstm_hidden_dim = 256
num_actions = 4

# (3,128,128) --> (64,16,16)
class ImageEncoderCNN(nn.Module):
    def __init__(self):
        super(ImageEncoderCNN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=3, padding=0),
            nn.ReLU()
        )
        self.fc = nn.Linear(64 * 7* 7, 256)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc(x))
        return x

class LanguageEncoderMLP(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LanguageEncoderMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        
        
    def forward(self, x):
        x1 = self.fc1(x)
        return x1

class KVLinear(nn.Module):
    def __init__(self, hidden_size):
        super(KVLinear, self).__init__()
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

    def forward(self, x):
        keys = self.key(x)
        values = self.value(x)
        return keys, values

class QLinear(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(QLinear, self).__init__()
        self.fc = nn.Linear(input_size, hidden_size)

    def forward(self, x):
        return self.fc(x).unsqueeze(0)

class Agent(nn.Module):
    def __init__(self, hidden_size):
        super(Agent, self).__init__()
        self.image_encoder = ImageEncoderCNN()
        self.language_encoder_mlp = LanguageEncoderMLP(input_size=44, hidden_size=hidden_size)
        self.keyvalue_linear = KVLinear(hidden_size)
        self.query_linear = QLinear(input_size=hidden_size, hidden_size=hidden_size)
        self.attention = nn.MultiheadAttention(hidden_size, num_heads=1)
        self.action_value_linear = nn.Linear(hidden_size, 4)
        self.value_linear = nn.Linear(hidden_size, 1)

    def forward(self, vt, lt):
        image_features = self.image_encoder(vt)
        # image_features = image_features.unsqueeze(0)  # Add sequence dimension
        # lstm_output = self.lstm(image_features,lstm_hiden_state)
        language_features = self.language_encoder_mlp(lt)
        keys, values = self.keyvalue_linear(image_features)
        querys = self.query_linear(language_features)
        querys = querys.unsqueeze(0)
        keys = keys.unsqueeze(0)
        values = values.unsqueeze(0)
        context_vector,_= self.attention(querys, keys, values)
        context_vector = context_vector.squeeze(0)
        actions = self.action_value_linear(context_vector)
        value = self.value_linear(context_vector)
        return actions, value 
    
    def save(self, episode, best_score, ALG_NAME, ENV_ID):
        path = os.path.join(r'C:\Users\linzj\Desktop\model', '_'.join([ALG_NAME, ENV_ID]))
        if not os.path.exists(path):
            os.makedirs(path)
        torch.save(self.state_dict(), os.path.join(path, f'agent_{episode}_{best_score}.pt'))

    def load(self, episode, ALG_NAME, ENV_ID):
        path = os.path.join(r'C:\Users\linzj\Desktop\model', '_'.join([ALG_NAME, ENV_ID]))
        self.load_state_dict(torch.load(os.path.join(path, f'agent_{episode}.pt')))    

    def count_trainable_parameters(self):
        print("Trainable parameters in each module:")
        for name, module in self.named_modules():
            total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
            print(f"{name}: {total_params}")

        


In [19]:
agent = Agent(256)
agent.count_trainable_parameters()

Trainable parameters in each module:
: 1332741
image_encoder: 859392
image_encoder.conv: 56320
image_encoder.conv.0: 896
image_encoder.conv.1: 0
image_encoder.conv.2: 18496
image_encoder.conv.3: 0
image_encoder.conv.4: 36928
image_encoder.conv.5: 0
image_encoder.fc: 803072
language_encoder_mlp: 11520
language_encoder_mlp.fc1: 11520
keyvalue_linear: 131584
keyvalue_linear.key: 65792
keyvalue_linear.value: 65792
query_linear: 65792
query_linear.fc: 65792
attention: 263168
attention.out_proj: 65792
action_value_linear: 1028
value_linear: 257
