In [88]:
import numpy as np
import torch
import torch.nn as nn

In [89]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [90]:
class Encoder(nn.Module):
    def __init__(self, hidden_size):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(4)
        self.conv2 = nn.Conv2d(4, 8, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(8)
        self.conv3 = nn.Conv2d(8, 16, kernel_size=4)
        self.bn3 = nn.BatchNorm2d(16)
        self.fc = nn.Linear(16, hidden_size)

    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.relu(self.bn3(self.conv3(x)))
        x = self.fc(x)
        return x


class Decoder(nn.Module):
    def __init__(self, hidden_size):
        super(Decoder, self).__init__()        
        self.fc = nn.Linear(hidden_size, 16)
        self.conv1 = nn.ConvTranspose2d(16, 8, kernel_size=4)
        self.bn1 = nn.BatchNorm2d(8)
        self.conv2 = nn.ConvTranspose2d(8, 4, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(4)
        self.conv3 = nn.ConvTranspose2d(4, 1, kernel_size=3)

    def forward(self, x):
        x = self.fc(x)
        x = torch.relu(self.bn1(self.conv1(x)))
        x = torch.relu(self.bn2(self.conv2(x)))
        x = torch.sigmoid(self.conv3(x))
        return x


class ConvAutoencoder(nn.Module):
    def __init__(self, hidden_size):
        super(ConvAutoencoder, self).__init__()
        self.encoder = Encoder(hidden_size)
        self.decoder = Decoder(hidden_size)

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def save_encoder(self, path):
        torch.save(self.encoder.state_dict(), path)

In [91]:
num_samples = 1000
width = 212
height = 129

In [92]:
x_data = torch.randn(num_samples, width, height)
y_data = torch.ones(num_samples)

x_data = x_data.view(-1, 1, width, height)
y_data = y_data.view(-1, 1)

train_loader = torch.utils.data.DataLoader(dataset=list(zip(x_data, y_data)), batch_size=128, shuffle=True)

In [93]:
autoencoder = ConvAutoencoder(hidden_size=20).to(device)

optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3, weight_decay=1e-5)
criterion = nn.MSELoss()

In [94]:
num_epochs = 50

autoencoder.train()
for epoch in range(num_epochs):
    for data in train_loader:
        img, _ = data
        img = img.to(device)
        print(img.shape)
        output = autoencoder(img)
        loss = criterion(output, img)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'epoch [{epoch + 1}/{num_epochs}], loss:{loss.item()}')

autoencoder.save_encoder('./models/encoder.pth')

torch.Size([128, 1, 212, 129])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (419840x122 and 16x20)