In [None]:
dataset_train, dataset_val, alpha_interval_split, tau_interval_split = load_from_path("data")
X, X_tau, t_values, tau_values, alpha_values = dataset_train.X, dataset_train.X_tau, dataset_train.t_values, dataset_train.tau_values, dataset_train.alpha_values
print(X.shape)

In [None]:
## Just an autoencoder:
import torch.nn as nn
# Normalization Layer for Conv2D
class Norm(nn.Module):
    def __init__(self, num_channels, num_groups=4):
        super(Norm, self).__init__()
        self.norm = nn.GroupNorm(num_groups, num_channels)

    def forward(self, x):
        return self.norm(x)

# Encoder using Conv2D
class Encoder(nn.Module):
    def __init__(self, latent_dim=3):
        super(Encoder, self).__init__()
        self.conv_layers = nn.Sequential(
            # Input: (batch_size, 1, 256, 256)
            nn.Conv2d(1, 32, kernel_size=2, stride=2, padding=0),  # (batch_size, 64, 128, 128)
            nn.GELU(),
            Norm(32),
            nn.Conv2d(32, 64, kernel_size=2, stride=2, padding=0),  # (batch_size, 128, 64, 64)
            nn.GELU(),
            Norm(64),
            nn.Conv2d(64, 128, kernel_size=2, stride=2, padding=0),  # (batch_size, 256, 32, 32)
            nn.GELU(),
            Norm(128),
            nn.Conv2d(128, 256, kernel_size=2, stride=2, padding=0),  # (batch_size, 512, 16, 16)
            nn.GELU(),
            Norm(256),
            nn.Conv2d(256, 512, kernel_size=2, stride=2, padding=0),  # (batch_size, 512, 8, 8)
            nn.GELU(),
            Norm(512),
        )
        self.flatten = nn.Flatten()
        self.fc_mean = nn.Linear(512 * 4 * 4, latent_dim)
        self.fc_log_var = nn.Linear(512 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.flatten(x)
        mean = self.fc_mean(x)
        log_var = self.fc_log_var(x)
        return mean, log_var

In [None]:
x_1 = torch.rand(size = (32, 1, 128, 128), dtype = torch.float32)
print(x_1.shape)

encoder = Encoder(latent_dim = 3)
mean, log_var = encoder(x_1)

print("mean shape",mean.shape)

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim=3):
        super(Decoder, self).__init__()
        # Fully connected layer to transform the latent vector back to the shape (batch_size, 512, 8, 8)
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)

        self.deconv_layers = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(512, 256, kernel_size=1),
            nn.GELU(),
            Norm(256),


            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, kernel_size=1),
            nn.GELU(),
            Norm(128),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, kernel_size=1),
            nn.GELU(),
            Norm(64),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=1),
            nn.GELU(),
            Norm(32),

            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(32, 1, kernel_size=1),
            nn.ReLU()
        )

    def forward(self, z):
        # Transform the latent vector to match the shape of the feature maps
        x = self.fc(z)
        x = x.view(-1, 512, 4, 4)  # Reshape to (batch_size, 512, 4, 4)
        x = self.deconv_layers(x)
        return x


class Model(nn.Module):
    def __init__(self, encoder, decoder):
        super(Model, self).__init__()
        self.encoder = encoder
        self.decoder = decoder # decoder for x(t)

    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var)
        z = mean + var * epsilon
        return z

    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var))

        # Reconstruction
        x_hat = self.decoder(z)  # Reconstruction of x(t)
        return x_hat.squeeze(), mean, log_var

In [None]:
x = torch.rand(size = (32, 1, 128, 128), dtype = torch.float32)
print(x.shape)

encoder = Encoder(latent_dim = 3)
decoder = Decoder(latent_dim = 3)
model = Model(encoder, decoder)

In [None]:
x_hat, mean, log_var = model(x)
print(x_hat.shape, mean.shape, log_var.shape)

In [None]:
print(X.shape)

In [None]:
batch_size = 64
X_data = X[: len(X) - len(X) % batch_size]
train_loader = DataLoader(X_data, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
index = np.random.randint(0, 32)
for data in train_loader:
    print(data.shape)
    plt.imshow(data[index, :, :], cmap = "jet")
    break

In [None]:
def LF(x, x_hat, mean, log_var):
    RL_1 = nn.MSELoss()(x, x_hat)
    KLD = torch.mean(-0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp(), dim=1))
    return RL_1, KLD

In [None]:
model = model.train().cuda()
for data in train_loader:
    x = data.unsqueeze(1).float().cuda()
    print("Input data shape: ", x.shape)
    x_hat, mean, log_var = model(x)
    print("Output data shape: ", x_hat.shape)
    print("Mean shape: ", mean.shape)
    print("Logvar shape: ", log_var.shape)
    
    RL_1, KLD = LF(x.squeeze(), x_hat, mean, log_var)
    print(RL_1, KLD)
    overall_loss = RL_1 + 2*KLD
    print("Overall Loss: ", overall_loss)
    break

In [None]:
num_epochs = 50
model = model.train().cuda()  # Putting model on GPU and setting it to train mode
losses = []  # To track the training loss after each epoch

optimizer = Adam(model.parameters(), 1e-3)
for epoch in range(num_epochs):
    epoch_loss = 0.0  # Accumulate loss for the epoch
    for data in train_loader:
        x = data.unsqueeze(1).float().cuda()  # Ensure the input tensor is on GPU and has the correct shape
        
        optimizer.zero_grad()  # Clear gradients from the previous step
        
        # Forward pass
        x_hat, mean, log_var = model(x)
        
        # Compute loss
        RL_1, KLD = LF(x.squeeze(), x_hat, mean, log_var)  # Custom loss function
        overall_loss = RL_1 + 0.00001 * KLD
        
        # Backward pass
        overall_loss.backward()
        optimizer.step()
        
        # Accumulate batch loss
        epoch_loss += overall_loss.item()

    # Average loss for the epoch
    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    
    # Print loss for the current epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.6f}")

# After training, plot the loss curve if needed
import matplotlib.pyplot as plt
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.show()

In [None]:
x_hat, mean, log_var = model(x)
print(x_hat.shape)

index = np.random.randint(0, len(x_hat))
plt.figure(figsize = (12, 8))

plt.subplot(1, 3, 1)
plt.imshow(x[index, :, :, :].squeeze().cpu().numpy(), cmap = "jet")
plt.title("Original Field")

plt.subplot(1, 3, 2)
plt.imshow(x_hat[index, :, :].squeeze().cpu().detach().numpy(), cmap = "jet")
plt.title("Reconstruction")

plt.subplot(1, 3, 3)
error = (x[index, :, :, :].squeeze().cpu().numpy() - x_hat[index, :, :].squeeze().cpu().detach().numpy())**2
plt.imshow(error, cmap="jet", vmin=0, vmax=1e-3)  # Set colorbar limits
plt.title("Error")
plt.colorbar(fraction=0.04)
plt.show()


## Validation Data

In [None]:
X_val, X_tau_val, t_values_val, tau_values_val, alpha_values_val = dataset_val.X, dataset_val.X_tau, dataset_val.t_values, dataset_val.tau_values, dataset_val.alpha_values
print(X_val.shape)

In [None]:
index = np.random.randint(0, 9000)
x = torch.tensor(X_val[index, :, :][None, None, :, :], dtype = torch.float32).cuda() #Batch_size, input_dim needs to added as dimensions
print(x.shape)

x_hat, mean, log_var = model(x)
print(x_hat.shape)

plt.figure(figsize = (12, 8))
plt.subplot(1, 3, 1)
plt.imshow(x.squeeze().cpu().numpy(), cmap = "jet")
plt.title("Original Field")

plt.subplot(1, 3, 2)
plt.imshow(x_hat.squeeze().cpu().detach().numpy(), cmap = "jet")
plt.title("Reconstruction")

plt.subplot(1, 3, 3)
error = (x.squeeze().cpu().numpy() - x_hat.squeeze().cpu().detach().numpy())**2
plt.imshow(error, cmap="jet", vmin=0, vmax=1e-3)  # Set colorbar limits
plt.title("Error")
plt.colorbar(fraction=0.04)
plt.show()