In [None]:
# models.py
import torch
import torch.nn as nn
import torch.nn.functional as F

class Actor(nn.Module):
    def __init__(self, img_shape, action_dim):
        super().__init__()
        c, h, w = img_shape
        self.cnn = nn.Sequential(
            nn.Conv2d(c, 32, 3, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2), nn.ReLU(),
            nn.Flatten()
        )
        # compute final feature size
        dummy = torch.zeros(1, c, h, w)
        n_flat = self.cnn(dummy).shape[1]
        self.fc = nn.Sequential(
            nn.Linear(n_flat, 256), nn.ReLU(),
            nn.Linear(256, action_dim), nn.Tanh()  # action in [-1, 1]
        )

    def forward(self, x):
        x = self.cnn(x)
        return self.fc(x)

class Critic(nn.Module):
    def __init__(self, img_shape, action_dim):
        super().__init__()
        c, h, w = img_shape
        self.cnn = nn.Sequential(
            nn.Conv2d(c, 32, 3, stride=2), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2), nn.ReLU(),
            nn.Flatten()
        )
        dummy = torch.zeros(1, c, h, w)
        n_flat = self.cnn(dummy).shape[1]
        self.fc = nn.Sequential(
            nn.Linear(n_flat + action_dim, 256), nn.ReLU(),
            nn.Linear(256, 1)
        )

    def forward(self, x, a):
        x = self.cnn(x)
        x = torch.cat([x, a], dim=1)
        return self.fc(x)
