In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import os
import re
import torch
from torch.utils.data import Dataset, random_split, DataLoader
import numpy as np
import os



In [None]:
path = '/gpfs/data/ssa/users/d602145/Workspace/scratch/Porosity/ETH/'
os.chdir(path)

In [None]:
from Lib.Data import PorosityDistribution, extract_microstructures
from Lib.Datasets import  PorosityDataset

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

In [None]:
sample_path = os.getcwd()+'/Job_Assignment_Data/Job_Assignment_Data/'

In [None]:
# Create train, validation, and test datasets
train_dataset = PorosityDataset(sample_path, train=True, val=False, test=False,keep_doubles=False,device=device)
val_dataset = PorosityDataset(sample_path, train=False, val=True, test=False,keep_doubles=False,device=device)
test_dataset = PorosityDataset(sample_path, train=False, val=False, test=True,keep_doubles=False,device=device)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=1280, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=1280, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1280, shuffle=False)

In [None]:
print(train_dataset.__len__())

In [None]:
import torch.nn as nn

In [None]:
class ResidualModuleBlock(nn.Module):
    def __init__(self, dim, steps, dropout=0.1, residual=False, batch_norm=True):
        super(ResidualModuleBlock,self).__init__()
        self.layers = nn.ModuleList()
        self.residual = residual
        
        for i in range(steps):
            self.layers.append(nn.Linear(dim, dim))
            if batch_norm:
                self.layers.append(nn.BatchNorm1d(dim))
            self.layers.append(nn.SiLU())
            self.layers.append(nn.Dropout(dropout))
            
    def forward(self,x):
        residual = x
        for layer in self.layers:
            x = layer(x)
        if self.residual:
            x += residual
        return x
        
class LinearModuleBlock(nn.Module):
    def __init__(self, dims, dropout=0.1, batch_norm=True):
        super(LinearModuleBlock,self).__init__()
        self.layers = nn.ModuleList()
        
        for i in range(len(dims)-1):
            self.layers.append(nn.Linear(dims[i], dims[i+1]))
            if batch_norm:
                self.layers.append(nn.BatchNorm1d(dims[i+1]))
            self.layers.append(nn.SiLU())
            self.layers.append(nn.Dropout(dropout))
            
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self,scale=1,condition_dim=0):
        super(Encoder, self).__init__()
        self.scale = scale

        # Linear Layers with Dropout
        
        self.input_block = LinearModuleBlock([3+condition_dim,scale*8,scale*16])
        self.deep_block = ResidualModuleBlock(scale*16,2,residual=True)


    def forward(self, x):
        x = self.input_block(x)
        x = self.deep_block(x)

        return x

class Decoder(nn.Module):
    def __init__(self, scale=1):
        super(Decoder, self).__init__()
        self.scale = scale
        # Linear Layers with Dropout
        
        self.linproj = nn.Linear(scale*8,3)
        self.output_block = LinearModuleBlock([scale*16,scale*8])
        self.deep_block = ResidualModuleBlock(scale*16,4,residual=False)


    def forward(self, x):
        
        x = self.deep_block(x)
        x = self.output_block(x)
        
        return self.linproj(x)

In [None]:
(X,y) = next(iter(train_dataloader))
print(X.shape,y.shape)

In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)

In [None]:
X.shape

In [None]:
hidden = encoder(X)
out = decoder(hidden)
out.shape

In [None]:
import torch.nn as nn

class ConditionedVAE(nn.Module):
    def __init__(self, scale=1):
        super(ConditionedVAE, self).__init__()
        self.scale = scale
        self.encoder = Encoder(scale=scale)
        self.decoder = Decoder(scale=scale)
        
        self.condition_encoder = LinearModuleBlock([1,scale*8,scale*16])
        # Add layers for mean and variance of the latent space
        self.fc_mu =nn.Linear(scale*16,scale*16)
        self.fc_logvar = nn.Linear(scale*16,scale*16)  # Output dimension for log variance


    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from the latent space."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, x, y):

        # Encode the input
        h = self.encoder(x)

        # Get mean and log variance
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)

        # Sample from the latent space
        z = self.reparameterize(mu, logvar)
        
        # Add Conditioning
        
        
        z += z + self.condition_encoder(y.view(-1,1))
        
        # Decode the latent representation
        x_recon = self.decoder(z)

        return x_recon, mu, logvar
    
    def sample(self,num_samples,density,device):
        
        z = torch.randn(num_samples,16*self.scale).to(device)
        y = density*torch.ones(num_samples,1).to(device)
        
        z += z + self.condition_encoder(y)
        samples = self.decoder(z)
        return samples

In [None]:
import torch.nn as nn

class ConditionedVAE2(nn.Module):
    def __init__(self, scale=1):
        super(ConditionedVAE2, self).__init__()
        self.scale = scale
        self.encoder = Encoder(scale=scale)
        self.decoder = Decoder(scale=2*scale)
        
        self.condition_encoder = LinearModuleBlock([1,8,8])
        self.condition_decoder = LinearModuleBlock([1,scale*8,scale*16])
        # Add layers for mean and variance of the latent space
        self.fc_mu =nn.Linear(scale*16,scale*16)
        self.fc_logvar = nn.Linear(scale*16,scale*16)


    def reparameterize(self, mu, logvar):
        """Reparameterization trick to sample from the latent space."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, x, y):

        # Encode the input
        
        h = self.encoder(x)

        # Get mean and log variance
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)

        # Sample from the latent space
        z = self.reparameterize(mu, logvar)
        
        # Add Conditioning
        
        
        z = torch.cat((z,self.condition_decoder(y.view(-1,1))),dim=-1)
        
        # Decode the latent representation
        x_recon = self.decoder(z)

        return x_recon, mu, logvar
    
    def sample(self,num_samples,density,device):
        
        z = torch.randn(num_samples,16*self.scale).to(device)
        y = density*torch.ones(num_samples,1).to(device)
        
        z  = torch.cat((z,self.condition_decoder(y)),dim=-1)
        samples = self.decoder(z)
        return samples

In [None]:
(X,y) = next(iter(train_dataloader))
print(X.shape,y.shape)

In [None]:
torch.cat((X,y.view(-1,1)),dim=1).shape

In [None]:
print(X[0],y[0])

In [None]:
(X,y) = next(iter(train_dataloader))
model = ConditionedVAE2(scale=4)
model.to(device)
model(X,y)

In [None]:
model(X,y)

In [None]:
import torch.optim as optim

# Define the optimizer

optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
# Define the loss function
criterion_reconstruction = nn.L1Loss()

In [None]:
# Training loop
num_epochs = 1

train_losses = []
train_recon_losses = []
train_cond_losses = []
train_kl_losses = []
val_losses = []
val_recon_losses = []
val_cond_losses = []
val_kl_losses = []
beta = 1

for epoch in range(num_epochs):
    # Training
    model.train()
    running_train_loss = 0.0
    running_train_recon_loss = 0.0
    running_train_kl_loss = 0.0

    for i, (inputs, conditions) in enumerate(train_dataloader):
        optimizer.zero_grad()
        
        outputs, mu, logvar = model(inputs, conditions)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        # Calculate individual losses
        loss_reconstruction = criterion_reconstruction(outputs, inputs)

        # Combine losses with weights (adjust as needed)
        loss = loss_reconstruction + beta*kl_loss # Example: 0.1 weight for condition loss

        loss.backward()
        optimizer.step()

        running_train_loss += loss.item()
        running_train_recon_loss += loss_reconstruction.item()
        running_train_kl_loss += kl_loss.item()


    epoch_train_loss = running_train_loss / len(train_dataloader)
    epoch_train_recon_loss = running_train_recon_loss / len(train_dataloader)
    epoch_train_kl_loss = running_train_kl_loss / len(train_dataloader)

    # Validation
    model.eval()
    running_val_loss = 0.0
    running_val_recon_loss = 0.0
    running_val_kl_loss = 0.0

    with torch.no_grad():
        for i, (inputs, conditions) in enumerate(val_dataloader):
            
            outputs, mu, logvar = model(inputs, conditions)

            # Calculate individual losses
            loss_reconstruction = criterion_reconstruction(outputs, inputs)
            
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            # Combine losses with weights (adjust as needed)
            loss = loss_reconstruction + beta*kl_loss # Example: 0.1 weight for condition loss

            running_val_loss += loss.item()
            running_val_recon_loss += loss_reconstruction.item()
            running_val_kl_loss += kl_loss.item()

    epoch_val_loss = running_val_loss / len(val_dataloader)
    epoch_val_recon_loss = running_val_recon_loss / len(val_dataloader)

    epoch_val_kl_loss = running_val_kl_loss / len(val_dataloader)

    print(f"Epoch [{epoch + 1}/{num_epochs}] "
          f"Train Loss: {epoch_train_loss:.4f} "
          f"Train Reconstruction Loss: {epoch_train_recon_loss:.4f} "
          f"Train KL Loss: {epoch_train_kl_loss:.4f} "
          f"Val Loss: {epoch_val_loss:.4f} "
          f"Val Reconstruction Loss: {epoch_val_recon_loss:.4f} "
          f"Val KL Loss: {epoch_val_kl_loss:.4f} ")

    train_losses.append(epoch_train_loss)
    train_recon_losses.append(epoch_train_recon_loss)
    train_kl_losses.append(epoch_train_kl_loss)
    val_losses.append(epoch_val_loss)
    val_recon_losses.append(epoch_val_recon_loss)
    val_kl_losses.append(epoch_val_kl_loss)


print("Finished Training")

In [None]:
# Plotting (after the training loop)
plt.figure(figsize=(12, 6))

# Total loss
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Total Loss')

# Individual losses
plt.subplot(1, 2, 2)
plt.plot(train_recon_losses, label='Train Reconstruction Loss')
plt.plot(train_cond_losses, label='Train Condition Loss')
plt.plot(val_recon_losses, label='Val Reconstruction Loss')
plt.plot(val_cond_losses, label='Val Condition Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Individual Losses')

plt.tight_layout()
plt.show()

In [None]:
samples = 500
density = 0.
model.eval()

In [None]:
y

In [None]:

samples = 500
density = 0.683
model.eval()

rec = model.sample(samples,density,device=device)
df = pd.DataFrame(rec.detach().to('cpu').numpy(),columns=['x','y','z'])
fig = px.scatter_3d(df,x='x',y='y',z='z')
fig.show()
fig = px.histogram(df,facet_col='variable',histnorm='probability',nbins=100)
fig.show()


In [None]:

samples = 500
density = 0.5
model.eval()

rec = model.sample(samples,density,device=device)
df = pd.DataFrame(rec.detach().to('cpu').numpy(),columns=['x','y','z'])
fig = px.scatter_3d(df,x='x',y='y',z='z')
fig.show()
fig = px.histogram(df,facet_col='variable',histnorm='probability',nbins=100)
fig.show()

In [None]:
samples = 500
density = 0.8
model.eval()

rec = model.sample(samples,density,device=device)
df = pd.DataFrame(rec.detach().to('cpu').numpy(),columns=['x','y','z'])
fig = px.scatter_3d(df,x='x',y='y',z='z')
fig.show()
fig = px.histogram(df,facet_col='variable',histnorm='probability',nbins=100)
fig.show()