In [7]:
import torch
from torch import nn
from torch.nn import functional as F

In [8]:
# 1. Train an Encoder for Environment States
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc = nn.Linear(64*4*4, 256)  

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  
        x = self.fc(x)
        return x

In [9]:
# 2. Train an Inverse Model (MLP)
class InverseModel(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(InverseModel, self).__init__()
        self.fc1 = nn.Linear(state_dim * 2, 256)
        self.fc2 = nn.Linear(256, action_dim)

    def forward(self, state, next_state):
        x = torch.cat([state, next_state], dim=1)  
        x = F.relu(self.fc1(x))
        action_pred = self.fc2(x)
        return action_pred

In [10]:
# 3. Modify the Forward Model
class ForwardModel(nn.Module):
    def __init__(self, state_dim, action_dim, num_frames):
        super(ForwardModel, self).__init__()
        self.fc1 = nn.Linear(state_dim * num_frames + action_dim, 256)
        self.fc2 = nn.Linear(256, state_dim)

    def forward(self, state, action):
        x = torch.cat([state, action], dim=1) 
        x = F.relu(self.fc1(x))
        next_state_pred = self.fc2(x)
        return next_state_pred

In [23]:
# 4. Integrate ICM with DQN Algorithm
class ICM(nn.Module):
    def __init__(self, state_dim, action_dim, num_frames):
        super(ICM, self).__init__()
        self.encoder = Encoder()
        self.inverse_model = InverseModel(state_dim, action_dim)
        self.forward_model = ForwardModel(state_dim, action_dim, num_frames)

    def forward(self, state, action, next_state):
        state = self.encoder(state)
        next_state = self.encoder(next_state)
        action_pred = self.inverse_model(state, next_state)
        next_state_pred = self.forward_model(state, action)
        return action_pred, next_state_pred

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, num_frames):
        super(DQN, self).__init__()
        self.icm = ICM(state_dim, action_dim, num_frames) 
        self.q_network = nn.Linear(state_dim, action_dim) 

    def forward(self, state, action): 
        action_pred, next_state_pred = self.icm(state, action, state) 
        q_values = self.q_network(state)
        return action_pred, next_state_pred, q_values

In [26]:
def train_icm_dqn(dqn, icm, optimizer, criterion, dataloader, epochs):
    for epoch in range(epochs):
        for state, action, next_state in dataloader:
            action_pred, next_state_pred, q_values = dqn(state, action)

            inverse_loss = criterion(action_pred, action)
            forward_loss = criterion(next_state_pred, next_state)
            icm_loss = inverse_loss + forward_loss

            action_taken = action.argmax(dim=1)
            q_value = q_values.gather(1, action_taken.unsqueeze(1)).squeeze(1)
            dqn_loss = criterion(q_value, action)

            loss = icm_loss + dqn_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')


Tried testing the model but a few unfixed errors while reshaping the data

In [27]:
import torch
from torch.utils.data import DataLoader, TensorDataset

states = torch.randn(100, 1)  
actions = torch.randn(100, 1)  
next_states = torch.randn(100, 1) 

# Create a DataLoader
dataset = TensorDataset(states, actions, next_states)
dataloader = DataLoader(dataset, batch_size=32)

state_dim = 1
action_dim = 1
num_frames = 4  
encoder = Encoder()
inverse_model = InverseModel(state_dim, action_dim)
forward_model = ForwardModel(state_dim, action_dim, num_frames)
icm = ICM(state_dim, action_dim, num_frames)
dqn = DQN(state_dim, action_dim, num_frames)

optimizer = torch.optim.Adam(dqn.parameters())
criterion = torch.nn.MSELoss()

train_icm_dqn(dqn, icm, optimizer, criterion, dataloader, epochs=10)


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [32, 1]