<a href="https://colab.research.google.com/github/SurajWijewickrama/auto-mesh/blob/main/ml/V_E_new_Edition.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
!pip install dgl torch




In [7]:
!pip install torch torch-geometric




In [8]:
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from torch_geometric.data import DataLoader  # Import the correct DataLoader
from torch_geometric.data import Batch
import torch.nn as nn


In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [9]:
import torch
import json
import os
from torch_geometric.data import Data, DataLoader

# Function to load JSON 3D objects as graphs
def load_graph_data(folder_path):
    data_list = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.json'):
            with open(os.path.join(folder_path, filename)) as f:
                graph_data = json.load(f)
                x = torch.tensor(graph_data['v'], dtype=torch.float)  # Nodes' coordinates
                edge_index = torch.tensor(graph_data['e'], dtype=torch.long).t().contiguous()  # Edge indices
                data_list.append(Data(x=x, edge_index=edge_index))
    return data_list

# Specify the path where JSON files are located
data_list = load_graph_data('/content/json')
loader = DataLoader(data_list, batch_size=32, shuffle=True)


FileNotFoundError: [Errno 2] No such file or directory: '/content/json'

In [None]:

from transformers import BertTokenizer, BertModel

# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Function to convert text to embeddings
def get_text_embedding(prompt):
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = bert_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1)  # Get mean-pooled embeddings


In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, VGAE

class GraphEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GraphEncoder, self).__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        return self.conv2(x, edge_index)

class VGAEWithEdgePrediction(VGAE):
    def __init__(self, encoder, text_dim, latent_dim):
        super(VGAEWithEdgePrediction, self).__init__(encoder)
        self.fc1 = nn.Linear(text_dim, latent_dim)

    def forward(self, x, edge_index, text_embedding):
        z = self.encoder(x, edge_index)
        z += self.fc1(text_embedding)  # Condition on text embeddings
        return self.decode(z, edge_index), z  # Decode edges

encoder = GraphEncoder(in_channels=3, out_channels=16)  # 3 for 3D coordinates
model = VGAEWithEdgePrediction(encoder, text_dim=768, latent_dim=16)


In [None]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.01)
epochs = 50

def train(loader, prompt):
    model.train()
    total_loss = 0
    text_embedding = get_text_embedding(prompt).to(device)

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon, z = model(data.x, data.edge_index, text_embedding)
        loss = model.recon_loss(recon, data.edge_index) + (1 / data.num_nodes) * model.kl_loss()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs

    return total_loss / len(loader.dataset)

# Example usage with text prompt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
bert_model = bert_model.to(device)

prompt = "Generate a low-poly tree model."
for epoch in range(epochs):
    loss = train(loader, prompt)
    print(f'Epoch {epoch+1}, Loss: {loss:.4f}')


In [None]:
class VAEEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAEEncoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)

    def forward(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std


In [None]:
class TextConditionedGNN(nn.Module):

    def __init__(self):
        super(TextConditionedGNN, self).__init__()
        self.conv1 = GCNConv(800, 16)  # Concatenate text embedding
        self.conv2_v = GCNConv(16, 3)      # Vertex output
        self.conv2_e = GCNConv(16, 2)      # Edge presence output

    def forward(self, x, edge_index, text_embedding):

        num_nodes = x.size(0)
        text_embedding_repeated = text_embedding.repeat_interleave(num_nodes // text_embedding.size(0), dim=0)
        remainder = num_nodes % text_embedding.size(0)
        if remainder > 0:
            text_embedding_repeated = torch.cat([text_embedding_repeated, text_embedding[:remainder]], dim=0)
        x = torch.cat([x, text_embedding_repeated], dim=1)

        print("Shape of x before conv1:", x.shape)



        x = F.relu(self.conv1(x, edge_index))
        print("hiii" ,x.shape)
        vertices = self.conv2_v(x, edge_index)
        edges = torch.sigmoid(self.conv2_e(x, edge_index))
        print("Shape of vertices after concatenation:", vertices.shape)
        print("Shape of edges after concatenation:", edges.shape)


        return vertices, edges

In [None]:
# VAE-GNN Model
class VAENet(nn.Module):
    def __init__(self, encoder, gnn_decoder):
        super(VAENet, self).__init__()
        self.encoder = encoder
        self.gnn_decoder = gnn_decoder

    def forward(self, x, edge_index, text_embedding):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        vertices, edges = self.gnn_decoder(z, edge_index, text_embedding)
        return vertices, edges, mu, logvar

# Loss function for VAE
def loss_function(recon_vertices, recon_edges, vertices, edges, mu, logvar,edge_index):
    recon_edges = recon_edges.view(-1, recon_edges.shape[-1])
    recon_edges_for_actual_edges = recon_edges[edge_index[0], edge_index[1] % recon_edges.shape[1]]
     # Calculate loss using probabilities of only existing edges and ground truth labels
    recon_loss = F.mse_loss(recon_vertices, vertices) + \
                 F.binary_cross_entropy(recon_edges_for_actual_edges, edges.view(-1))

    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kld

In [None]:
# # Initialize Model
print(f"Vertices shape: {batch.x.shape}")
print(f"Edge index shape: {batch.edge_index.shape}")
print(f"Text embedding shape: {batch.name_embedding.shape}")

encoder = VAEEncoder(input_dim=3 , hidden_dim=64, latent_dim=32)
print(encoder.fc1)
print(encoder.fc_mu)
print(encoder.fc_logvar)
text_embedding_dim = batch.name_embedding.shape[1]
print(text_embedding_dim) # Get the actual dimension
gnn_decoder = TextConditionedGNN()
print(gnn_decoder.conv1)
print(gnn_decoder.conv2_v)
print(gnn_decoder.conv2_e)
model = VAENet(encoder, gnn_decoder)
print(model.encoder)
print(model.gnn_decoder)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
data_loader = DataLoader(data_list, batch_size=2, shuffle=True)


In [None]:
for epoch in range(100):
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()

        vertices, edges, mu, logvar = model(batch.x, batch.edge_index, batch.name_embedding)
        loss = loss_function(vertices, edges, batch.x, batch.edge_label, mu, logvar,batch.edge_index)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(data_loader)}')

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv  # Using Graph Convolutional Network as an example

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        # Encoder
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc_mu = nn.Linear(64, latent_dim)
        self.fc_logvar = nn.Linear(64, latent_dim)

        # Decoder (for vertices and edges prediction)
        self.decoder = nn.Linear(latent_dim, 128)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        h2 = F.relu(self.fc2(h1))
        return self.fc_mu(h2), self.fc_logvar(h2)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar




In [None]:
class GNNDecoder(nn.Module):
    def __init__(self, latent_dim, vertex_dim, edge_dim):
        super(GNNDecoder, self).__init__()
        self.conv1 = GCNConv(latent_dim, 64)
        self.conv2 = GCNConv(64, vertex_dim)  # Output for vertex predictions

        # Edge prediction
        self.edge_fc = nn.Linear(64 * 2, edge_dim)  # Output dimension for edge predictions

    def forward(self, x, edge_index):
        # Vertex prediction
        x = F.relu(self.conv1(x, edge_index))
        vertices_pred = self.conv2(x, edge_index)

        # Edge prediction
        row, col = edge_index
        edge_features = torch.cat([x[row], x[col]], dim=1)  # Concatenate features of node pairs
        edges_pred = self.edge_fc(edge_features)  # Predict edges between node pairs

        return vertices_pred, edges_pred

class VAE_GNN(nn.Module):
    def __init__(self, vae_input_dim, latent_dim, vertex_dim, edge_dim):
        super(VAE_GNN, self).__init__()
        self.vae = VAE(vae_input_dim, latent_dim)
        self.gnn = GNNDecoder(latent_dim, vertex_dim, edge_dim)

    def forward(self, x, edge_index):
        latent, mu, logvar = self.vae(x)
        vertices_pred, edges_pred = self.gnn(latent, edge_index)
        return vertices_pred, edges_pred, mu, logvar


In [None]:
def edge_loss(pred, target):
    return F.binary_cross_entropy_with_logits(pred, target)

def combined_loss(vertices_pred, vertices_true, edges_pred, edges_true, mu, logvar):
    vertex_loss = F.mse_loss(vertices_pred, vertices_true, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    edge_loss_value = edge_loss(edges_pred, edges_true)
    return vertex_loss + kl_divergence + edge_loss_value

def train(model, data_loader):
    model.train()
    for data in data_loader:  # Each data contains name embeddings, vertices, edges, and edge_index
        name_embeddings, vertices, edges, edge_index = data
        optimizer.zero_grad()
        vertices_pred, edges_pred, mu, logvar = model(name_embeddings, edge_index)
        loss = combined_loss(vertices_pred, vertices, edges_pred, edges, mu, logvar)
        loss.backward()
        optimizer.step()

def infer(model, name, edge_index):
    model.eval()
    with torch.no_grad():
        name_embedding = embed_name(name)
        vertices_pred, edges_pred, _, _ = model(name_embedding, edge_index)
        return vertices_pred, edges_pred
