In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Encoder(nn.Module):
    def __init__(self, code_dim=32):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 4, stride=2, padding=1) # 60x60
        self.conv2 = nn.Conv2d(16, 32, 4, stride=2, padding=1) # 30x30
        self.conv3 = nn.Conv2d(32, 64, 4, stride=2, padding=1) # 15x15
        self.fc = nn.Linear(64 * 15 * 15, code_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        code = self.fc(x)
        return code

class Decoder(nn.Module):
    def __init__(self, code_dim=32):
        super().__init__()
        self.fc = nn.Linear(code_dim, 64 * 15 * 15)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1) # 30x30
        self.deconv2 = nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1) # 60x60
        self.deconv3 = nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1) # 120x120

    def forward(self, code):
        x = F.relu(self.fc(code))
        x = x.view(-1, 64, 15, 15)
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = torch.sigmoid(self.deconv3(x))  # [0, 1] dla obrazów
        return x

class Autoencoder(nn.Module):
    def __init__(self, code_dim=32):
        super().__init__()
        self.encoder = Encoder(code_dim)
        self.decoder = Decoder(code_dim)

    def forward(self, x):
        code = self.encoder(x)
        recon = self.decoder(code)
        return recon