Skip to content

Commit

Permalink
Update models.py
Browse files Browse the repository at this point in the history
change mlp layer name
  • Loading branch information
ChrisZonghaoLi authored Aug 23, 2023
1 parent d90f891 commit f4977bb
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,10 +287,10 @@ def forward(self, state):
actions = torch.tensor(()).to(device)
for i in range(batch_size):
x = state[i]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.mlp1(x))
x = F.relu(self.mlp2(x))
x = F.relu(self.mlp3(x))
x = F.relu(self.mlp4(x))
x = self.lin1(torch.flatten(x))
x = torch.tanh(x).reshape(1, -1)
actions = torch.cat((actions, x), axis=0)
Expand Down Expand Up @@ -325,10 +325,10 @@ def forward(self, state, action):
values = torch.tensor(()).to(device)
for i in range(batch_size):
x = data[i]
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.relu(self.mlp1(x))
x = F.relu(self.mlp2(x))
x = F.relu(self.mlp3(x))
x = F.relu(self.mlp4(x))
x = self.lin1(torch.flatten(x)).reshape(1, -1)
values = torch.cat((values, x), axis=0)

Expand Down

0 comments on commit f4977bb

Please sign in to comment.