In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNEncoder(nn.Module):
    def __init__(self, input_shape, feature_dim):
        super(CNNEncoder, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU()
        )
        self._get_conv_output(input_shape)
        self.fc = nn.Linear(self.num_flat_features, feature_dim)

    def _get_conv_output(self, shape):
        bs = 1
        input_tensor = torch.randn(bs, *shape)
        with torch.no_grad():
            output_feat = self.conv(input_tensor)
        self.conv_shape = output_feat.shape[1:]
        self.num_flat_features = output_feat.view(bs, -1).size(1)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


class CNNDecoder(nn.Module):
    def __init__(self, feature_dim, conv_shape):
        super(CNNDecoder, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(feature_dim, 128),
            nn.ELU(),
            nn.Linear(128, int(torch.prod(torch.tensor(conv_shape))))
        )
        self.conv_shape = conv_shape
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1)
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), *self.conv_shape)
        return self.deconv(x)


class CNN_LSTM_Autoencoder(nn.Module):
    def __init__(self, input_shape, latent_dim=2, lstm_hidden=64):
        super(CNN_LSTM_Autoencoder, self).__init__()
        self.encoder = CNNEncoder(input_shape, feature_dim=128)
        self.lstm = nn.LSTM(input_size=128, hidden_size=latent_dim, batch_first=True)
        self.decoder = None  # initialized later after conv shape is known

        # Init dummy pass to extract decoder shape
        dummy_input = torch.randn(1, *input_shape)
        dummy_encoded = self.encoder(dummy_input)
        self.decoder = CNNDecoder(latent_dim, self.encoder.conv_shape)

    def forward(self, x_seq):
        # x_seq: (B, T, C, H, W)
        B, T, C, H, W = x_seq.shape
        encoded_seq = []
        for t in range(T):
            frame = x_seq[:, t]  # (B, C, H, W)
            encoded = self.encoder(frame)  # (B, 128)
            encoded_seq.append(encoded)
        encoded_seq = torch.stack(encoded_seq, dim=1)  # (B, T, 128)

        latent_seq, _ = self.lstm(encoded_seq)  # (B, T, latent_dim)

        decoded_seq = []
        for t in range(T):
            decoded = self.decoder(latent_seq[:, t])  # (B, C, H, W)
            decoded_seq.append(decoded)
        decoded_seq = torch.stack(decoded_seq, dim=1)  # (B, T, C, H, W)

        return decoded_seq, latent_seq


In [None]:
def train_cnn_lstm_autoencoder(model, images_tensor, num_epochs=50, lr=1e-3):
    import torch.optim as optim
    import torch.nn as nn
    import matplotlib.pyplot as plt

    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    model.train()
    losses = []

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        recon, latent_seq = model(images_tensor)
        loss = loss_fn(recon, images_tensor)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item():.6f}")
    
    return losses, latent_seq.detach()

def plot_losses(losses):
    import matplotlib.pyplot as plt
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Reconstruction Loss')
    plt.title('Training Loss over Time')
    plt.grid(True)
    plt.show()

In [None]:
import matplotlib.pyplot as plt
model = CNN_LSTM_Autoencoder(input_shape=(1, 556, 200), latent_dim=2)
images_tensor = torch.load('/Users/karim/desktop/eece499/TCN_SINDy/data_processing/image_tensors.pt')  # shape (T, 1, H, W)
losses, latent_seq = train_cnn_lstm_autoencoder(model, images_tensor, num_epochs=50, lr=1e-3)
plot_losses(losses)
latent_seq_np = latent_seq.numpy()
plt.plot(latent_seq_np[:, 0], label='Latent x')
plt.plot(latent_seq_np[:, 1], label='Latent v')
plt.title("Learned Latent Trajectories")
plt.xlabel("Time Step")
plt.ylabel("Latent Value")
plt.legend()
plt.grid(True)
plt.show()