In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [None]:
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import matplotlib.pyplot as plt

# Load the pretrained Stable Diffusion model
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id)
pipe = pipe.to("cuda")

# Assume encoded_image1 and encoded_image2 are of shape (1, 197, 768)
encoded_image1 = torch.randn(1, 197, 768).cuda()
encoded_image2 = torch.randn(1, 197, 768).cuda()

# Concatenate the encoded images
combined_encoded_images = torch.cat((encoded_image1, encoded_image2), dim=1)  # shape: (1, 394, 768)

# Generate a new image from the combined latent representation
scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
scheduler.set_timesteps(50)
timesteps = scheduler.timesteps

# Use text embeddings as encoder_hidden_states
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

prompt = "A futuristic cityscape"
text_inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
text_embeddings = text_encoder(text_inputs.input_ids.to("cuda"))[0]

with torch.no_grad():
    for t in reversed(timesteps):
        # Predict the noise to be removed
        model_input = combined_encoded_images
        noise_pred = pipe.unet(model_input, t, encoder_hidden_states=text_embeddings).sample
        
        # Update the combined encoded images using the DDIM update rule
        alpha_t = scheduler.alphas_cumprod[t]
        alpha_prev_t = scheduler.alphas_cumprod[t-1] if t > 0 else scheduler.alphas_cumprod[0]
        beta_t = 1 - alpha_t
        beta_prev_t = 1 - alpha_prev_t
        combined_encoded_images = (
            torch.sqrt(alpha_prev_t) * (combined_encoded_images - torch.sqrt(beta_t) * noise_pred / torch.sqrt(beta_prev_t)) 
            + torch.sqrt(1 - alpha_prev_t) * torch.randn_like(combined_encoded_images)
        )

# Decode the latent representation to get the final image
with torch.no_grad():
    final_image = pipe.vae.decode(combined_encoded_images / 0.18215).sample()

# Convert the image to a format suitable for visualization
final_image = (final_image / 2 + 0.5).clamp(0, 1).detach().cpu().numpy().transpose(0, 2, 3, 1)[0]

# Save or display the image
plt.imshow(final_image)
plt.axis("off")
plt.show()


In [None]:
# Example usage
input_ch_before = 128
input_ch_after = 128
W = 64
D = 4
num_heads = 4
num_layers = 2
skips = [0]  # Adding skip connection at the first layer
num_components = 5  # Number of Gaussian components

model = CustomModelWithGMM(input_ch_before, input_ch_after, W, D, num_heads, num_layers, skips, num_components)
x_before = torch.randn(32, input_ch_before)
x_after = torch.randn(32, input_ch_after)
weights, means, covariances = model(x_before, x_after)

# Sample Gaussian means from the predicted GMM parameters
sampled_means = model.sample_gmm(weights, means, covariances, num_samples=1)
print(sampled_means.shape)  # Should be (batch_size, num_samples, 3)

In [None]:
sampled_means

In [None]:
class CustomModelWithGMM(nn.Module):
    def __init__(self, input_ch_before, input_ch_after, W, num_heads, num_layers, skips, num_components):
        super(CustomModelWithGMM, self).__init__()
        self.input_ch_before = input_ch_before
        self.input_ch_after = input_ch_after
        self.W = W
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.skips = skips
        self.num_components = num_components  # Number of Gaussian components in the GMM

        # Define the encoder layers
        encoder_layer = nn.TransformerEncoderLayer(d_model=W, nhead=num_heads)
        self.transformer_encoder_past = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.transformer_encoder_future = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        # Initial linear layers to match input dimensions with encoder dimensions
        self.linear_before_to_W = nn.Linear(input_ch_before, W)
        self.linear_after_to_W = nn.Linear(input_ch_after, W)

        # Layers to predict the GMM parameters
        self.fc1 = nn.Linear(2 * W, 512)
        self.fc_weights = nn.Linear(512, num_components)  # Weights of the GMM
        self.fc_means = nn.Linear(512, num_components * 3)  # Means of the GMM (3D)
        self.fc_covariances = nn.Linear(512, num_components * 3 * 3)  # Covariances of the GMM (3x3)

    def forward(self, x_before, x_after):
        # Transform inputs to match the dimension of the encoder
        x_before_transformed = self.linear_before_to_W(x_before)  # (batch_size, W)
        x_after_transformed = self.linear_after_to_W(x_after)     # (batch_size, W)

        # Add a sequence dimension expected by the Transformer encoder
        x_before = x_before_transformed.unsqueeze(1)  # (batch_size, seq_len=1, W)
        x_after = x_after_transformed.unsqueeze(1)    # (batch_size, seq_len=1, W)

        # Pass through the Transformer encoders with skip connections
        for i, layer in enumerate(self.transformer_encoder_past.layers):
            if i in self.skips:
                x_before = x_before + x_before_transformed.unsqueeze(1)  # Add skip connection
            x_before = layer(x_before)
        
        for i, layer in enumerate(self.transformer_encoder_future.layers):
            if i in self.skips:
                x_after = x_after + x_after_transformed.unsqueeze(1)  # Add skip connection
            x_after = layer(x_after)

        # Remove the sequence dimension
        encoded_before = x_before.squeeze(1)  # (batch_size, W)
        encoded_after = x_after.squeeze(1)    # (batch_size, W)

        # Concatenate encoded features
        combined_features = torch.cat((encoded_before, encoded_after), dim=1)  # (batch_size, 2 * W)

        # Predict the GMM parameters
        x = F.relu(self.fc1(combined_features))
        
        weights = F.softmax(self.fc_weights(x), dim=1)  # Ensure weights sum to 1
        means = self.fc_means(x).view(-1, self.num_components, 3)  # Reshape to (batch_size, num_components, 3)
        covariances = self.fc_covariances(x).view(-1, self.num_components, 3, 3)  # Reshape to (batch_size, num_components, 3, 3)

        # Ensure positive definiteness of covariances (simplest way: diagonal covariances)
        covariances = torch.exp(covariances)  # Exponentiate to ensure positive values

        return weights, means, covariances

    def sample_gmm(self, weights, means, covariances, num_samples=1):
        batch_size, num_components, _ = means.size()
        sampled_points = []

        for b in range(batch_size):
            component = torch.multinomial(weights[b], num_samples, replacement=True)
            chosen_means = means[b, component]
            chosen_covariances = covariances[b, component]
            
            sampled_point = torch.randn(num_samples, 3).to(means.device)
            for i in range(num_samples):
                sampled_point[i] = torch.matmul(chosen_covariances[i], sampled_point[i]) + chosen_means[i]
            
            sampled_points.append(sampled_point)
        
        sampled_points = torch.stack(sampled_points)
        return sampled_points

In [None]:
import torch

def normalize_coordinates(coords, grid_size):
    """
    Normalize xyz coordinates to the range [0, grid_size-1].
    """
    coords = coords - coords.min(0, keepdim=True)[0]  # Shift to start from zero
    coords = coords / coords.max(0, keepdim=True)[0]  # Normalize to [0, 1]
    coords = coords * (grid_size - 1)  # Scale to [0, grid_size-1]
    return coords

def coords_to_grid(coords, grid_size):
    """
    Convert normalized xyz coordinates to a 3D grid.
    """
    batch_size = coords.size(0)
    grid = torch.zeros(batch_size, 1, grid_size, grid_size, grid_size)
    indices = coords.long()
    for i in range(batch_size):
        x, y, z = indices[i]
        grid[i, 0, x, y, z] = 1
    return grid

# Example usage
batch_size = 32
grid_size = 32
xyz_coords = torch.randn(batch_size, 3)  # Input data of shape [batch_size, 3]

# Normalize and map to grid
normalized_coords = normalize_coordinates(xyz_coords, grid_size)
spatial_data = coords_to_grid(normalized_coords, grid_size)
print(spatial_data.shape)  # Should be [batch_size, 1, grid_size, grid_size, grid_size]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def normalize_coordinates(coords):
    """
    Normalize xyz coordinates to the range [0, 1].
    """
    coords = coords - coords.min(0, keepdim=True)[0]  # Shift to start from zero
    coords = coords / coords.max(0, keepdim=True)[0]  # Normalize to [0, 1]
    return coords

def coords_to_grid(coords, grid_size):
    """
    Convert normalized xyz coordinates to a 3D grid.
    """
    batch_size = coords.size(0)
    grid = torch.zeros(batch_size, 1, grid_size, grid_size, grid_size)
    indices = (coords * (grid_size - 1)).long()
    for i in range(batch_size):
        x, y, z = indices[i]
        grid[i, 0, x, y, z] = 1
    return grid

class CustomModelWithGMM(nn.Module):
    def __init__(self, input_ch_before, input_ch_after, grid_size, hidden_dim, num_layers, num_components):
        super(CustomModelWithGMM, self).__init__()
        self.input_ch_before = input_ch_before
        self.input_ch_after = input_ch_after
        self.grid_size = grid_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_components = num_components  # Number of Gaussian components in the GMM

        # Define 3D convolutional layers to capture spatial features
        self.conv1 = nn.Conv3d(2, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(128, 256, kernel_size=3, padding=1)

        # Define GRU layer to capture temporal dependencies
        self.gru = nn.GRU(256 * grid_size * grid_size * grid_size, hidden_dim, num_layers, batch_first=True)

        # Layers to predict the GMM parameters
        self.fc1 = nn.Linear(2 * hidden_dim, hidden_dim)
        self.fc_weights = nn.Linear(hidden_dim, num_components)  # Weights of the GMM
        self.fc_means = nn.Linear(hidden_dim, num_components * 3)  # Means of the GMM (3D)
        self.fc_covariances = nn.Linear(hidden_dim, num_components * 3 * 3)  # Covariances of the GMM (3x3)

    def forward(self, x_before, x_after):
        # Normalize and map to grid
        x_before_normalized = normalize_coordinates(x_before)
        x_after_normalized = normalize_coordinates(x_after)
        x_before_grid = coords_to_grid(x_before_normalized, self.grid_size)
        x_after_grid = coords_to_grid(x_after_normalized, self.grid_size)

        # Concatenate the key frames along the channel dimension
        x = torch.cat([x_before_grid, x_after_grid], dim=1)  # (batch_size, 2, grid_size, grid_size, grid_size)

        # Apply 3D convolutional layers
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Flatten and prepare for GRU
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # (batch_size, 256 * grid_size * grid_size * grid_size)

        # Apply GRU
        x_combined = torch.stack((x, x), dim=1)  # (batch_size, 2, flattened_size)
        gru_output, _ = self.gru(x_combined)  # (batch_size, 2, hidden_dim)

        # Flatten the GRU output
        gru_output_flattened = gru_output.reshape(gru_output.size(0), -1)  # (batch_size, 2 * hidden_dim)

        # Predict the GMM parameters
        x = F.relu(self.fc1(gru_output_flattened))
        
        weights = F.softmax(self.fc_weights(x), dim=1)  # Ensure weights sum to 1
        means = self.fc_means(x).view(-1, self.num_components, 3)  # Reshape to (batch_size, num_components, 3)
        covariances = self.fc_covariances(x).view(-1, self.num_components, 3, 3)  # Reshape to (batch_size, num_components, 3, 3)

        # Ensure positive definiteness of covariances (simplest way: diagonal covariances)
        covariances = torch.exp(covariances)  # Exponentiate to ensure positive values

        return weights, means, covariances

    def sample_gmm(self, weights, means, covariances, num_samples=1):
        batch_size, num_components, _ = means.size()
        sampled_points = []

        for b in range(batch_size):
            component = torch.multinomial(weights[b], num_samples, replacement=True)
            chosen_means = means[b, component]
            chosen_covariances = covariances[b, component]
            
            sampled_point = torch.randn(num_samples, 3).to(means.device)
            for i in range(num_samples):
                sampled_point[i] = torch.matmul(chosen_covariances[i], sampled_point[i]) + chosen_means[i]
            
            sampled_points.append(sampled_point)
        
        sampled_points = torch.stack(sampled_points)
        return sampled_points

# Example usage
input_ch_before = 3
input_ch_after = 3
grid_size = 32
hidden_dim = 64
num_layers = 2
num_components = 5  # Number of Gaussian components

model = CustomModelWithGMM(input_ch_before, input_ch_after, grid_size, hidden_dim, num_layers, num_components)
x_before = torch.randn(32, input_ch_before)  # (batch_size, input_ch_before)
x_after = torch.randn(32, input_ch_after)    # (batch_size, input_ch_after)

weights, means, covariances = model(x_before, x_after)

# Sample Gaussian means from the predicted GMM parameters
sampled_means = model.sample_gmm(weights, means, covariances, num_samples=10)
print(sampled_means.shape)  # Should be (batch_size, num_samples, 3)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def normalize_coordinates(coords):
    """
    Normalize xyz coordinates to the range [0, 1].
    """
    coords = coords - coords.min(0, keepdim=True)[0]  # Shift to start from zero
    coords = coords / coords.max(0, keepdim=True)[0]  # Normalize to [0, 1]
    return coords

def coords_to_grid(coords, grid_size):
    """
    Convert normalized xyz coordinates to a 3D grid.
    """
    batch_size = coords.size(0)
    grid = torch.zeros(batch_size, 1, grid_size, grid_size, grid_size)
    indices = (coords * (grid_size - 1)).long()
    for i in range(batch_size):
        x, y, z = indices[i]
        grid[i, 0, x, y, z] = 1
    return grid

class CustomModelWithGMM(nn.Module):
    def __init__(self, input_ch_before, input_ch_after, grid_size, hidden_dim, num_layers, num_components):
        super(CustomModelWithGMM, self).__init__()
        self.input_ch_before = input_ch_before
        self.input_ch_after = input_ch_after
        self.grid_size = grid_size
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.num_components = num_components  # Number of Gaussian components in the GMM

        # Define 3D convolutional layers to capture spatial features
        self.conv1 = nn.Conv3d(2, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(128, 256, kernel_size=3, padding=1)

        # Define GRU layer to capture temporal dependencies
        self.gru = nn.GRU(256 * grid_size * grid_size * grid_size, hidden_dim, num_layers, batch_first=True)

        # Layers to predict the GMM parameters
        self.fc1 = nn.Linear(2 * hidden_dim, hidden_dim)
        self.fc_weights = nn.Linear(hidden_dim, num_components)  # Weights of the GMM
        self.fc_means = nn.Linear(hidden_dim, num_components * 3)  # Means of the GMM (3D)
        self.fc_covariances = nn.Linear(hidden_dim, num_components * 3 * 3)  # Covariances of the GMM (3x3)

    def forward(self, x_before, x_after):
        # Normalize and map to grid
        x_before_normalized = normalize_coordinates(x_before)
        x_after_normalized = normalize_coordinates(x_after)
        x_before_grid = coords_to_grid(x_before_normalized, self.grid_size)
        x_after_grid = coords_to_grid(x_after_normalized, self.grid_size)

        # Concatenate the key frames along the channel dimension
        x = torch.cat([x_before_grid, x_after_grid], dim=1)  # (batch_size, 2, grid_size, grid_size, grid_size)

        # Apply 3D convolutional layers
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        # Flatten and prepare for GRU
        batch_size = x.size(0)
        x = x.view(batch_size, -1)  # (batch_size, 256 * grid_size * grid_size * grid_size)

        # Apply GRU
        x_combined = torch.stack((x, x), dim=1)  # (batch_size, 2, flattened_size)
        gru_output, _ = self.gru(x_combined)  # (batch_size, 2, hidden_dim)

        # Flatten the GRU output
        gru_output_flattened = gru_output.reshape(gru_output.size(0), -1)  # (batch_size, 2 * hidden_dim)

        # Predict the GMM parameters
        x = F.relu(self.fc1(gru_output_flattened))
        
        weights = F.softmax(self.fc_weights(x), dim=1)  # Ensure weights sum to 1
        means = self.fc_means(x).view(-1, self.num_components, 3)  # Reshape to (batch_size, num_components, 3)
        covariances = self.fc_covariances(x).view(-1, self.num_components, 3, 3)  # Reshape to (batch_size, num_components, 3, 3)

        # Ensure positive definiteness of covariances (simplest way: diagonal covariances)
        covariances = torch.exp(covariances)  # Exponentiate to ensure positive values

        return weights, means, covariances

    def sample_gmm(self, weights, means, covariances, num_samples=1):
        batch_size, num_components, _ = means.size()
        sampled_points = []

        for b in range(batch_size):
            component = torch.multinomial(weights[b], num_samples, replacement=True)
            chosen_means = means[b, component]
            chosen_covariances = covariances[b, component]
            
            sampled_point = torch.randn(num_samples, 3).to(means.device)
            for i in range(num_samples):
                sampled_point[i] = torch.matmul(chosen_covariances[i], sampled_point[i]) + chosen_means[i]
            
            sampled_points.append(sampled_point)
        
        sampled_points = torch.stack(sampled_points)
        return sampled_points

# Example usage
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

input_ch_before = 3
input_ch_after = 3
grid_size = 32
hidden_dim = 64
num_layers = 2
num_components = 5  # Number of Gaussian components

model = CustomModelWithGMM(input_ch_before, input_ch_after, grid_size, hidden_dim, num_layers, num_components).to(device)
x_before = torch.randn(32, input_ch_before).to(device)  # (batch_size, input_ch_before)
x_after = torch.randn(32, input_ch_after).to(device)    # (batch_size, input_ch_after)

weights, means, covariances = model(x_before, x_after)

# Sample Gaussian means from the predicted GMM parameters
sampled_means = model.sample_gmm(weights, means, covariances, num_samples=10)
print(sampled_means.shape)  # Should be (batch_size, num_samples, 3)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention3D(nn.Module):
    def __init__(self, latent_dim, num_heads):
        super(SelfAttention3D, self).__init__()
        self.num_heads = num_heads
        self.head_dim = latent_dim // num_heads
        assert self.head_dim * num_heads == latent_dim, "latent_dim must be divisible by num_heads"
        
        self.query = nn.Linear(latent_dim, latent_dim)
        self.key = nn.Linear(latent_dim, latent_dim)
        self.value = nn.Linear(latent_dim, latent_dim)
        self.fc_out = nn.Linear(latent_dim, latent_dim)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def forward(self, x):
        batch_size = x.shape[0]
        
        # x shape: [batch_size, 2, latent_dim]
        # Split into query, key, value projections
        Q = self.query(x)  # [batch_size, 2, latent_dim]
        K = self.key(x)  # [batch_size, 2, latent_dim]
        V = self.value(x)  # [batch_size, 2, latent_dim]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        K = K.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        V = V.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        
        # Compute attention scores
        energy = torch.einsum("bnqd,bnkd->bnqk", [Q, K]) / self.scale  # [batch_size, num_heads, 2, 2]
        attention = torch.softmax(energy, dim=-1)  # [batch_size, num_heads, 2, 2]
        
        # Compute the attended values
        out = torch.einsum("bnqk,bnvd->bnqd", [attention, V])  # [batch_size, num_heads, 2, head_dim]
        
        # Reshape and combine heads
        out = out.permute(0, 2, 1, 3).contiguous()  # [batch_size, 2, num_heads, head_dim]
        out = out.view(batch_size, 2, self.num_heads * self.head_dim)  # [batch_size, 2, latent_dim]
        
        # Final linear layer
        out = self.fc_out(out)  # [batch_size, 2, latent_dim]
        
        return out
    
class Encoder(nn.Module):
    def __init__(self, input_channels, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_dim, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim*2, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(hidden_dim*2, hidden_dim*4, kernel_size=4, stride=2, padding=1)
        self.pool = nn.AdaptiveAvgPool2d((4, 4))
        self.fc_mu = nn.Linear(hidden_dim*4 * 4 * 4, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim*4 * 4 * 4, latent_dim)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_channels):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(2 * latent_dim, latent_dim * 16 * 16)  # Flatten the input and expand
        self.deconv1 = nn.ConvTranspose2d(latent_dim, hidden_dim*2, kernel_size=4, stride=2, padding=1)  # Output: [8, 64, 32, 32]
        self.deconv2 = nn.ConvTranspose2d(hidden_dim*2, hidden_dim, kernel_size=4, stride=2, padding=1)   # Output: [8, 32, 64, 64]
        self.deconv3 = nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, kernel_size=4, stride=2, padding=1)   # Output: [8, 16, 128, 128]
        self.deconv4 = nn.ConvTranspose2d(hidden_dim // 2, output_channels, kernel_size=1)                        # Output: [8, 3, 128, 128]
        
    def forward(self, x,z_w,z_h):
        x = x.view(x.size(0), -1)  # Flatten: [8, 2, 128] -> [8, 2 * 128]
        x = F.relu(self.fc(x))  # Fully connected: [8, 2 * 128] -> [8, 128 * 16 * 16]
        x = x.view(x.size(0), 128, 16, 16)  # Reshape to 2D: [8, 128 * 16 * 16] -> [8, 128, 16, 16]
        x = F.relu(self.deconv1(x))  # Transposed conv: [8, 128, 16, 16] -> [8, 64, 32, 32]
        x = F.relu(self.deconv2(x))  # Transposed conv: [8, 64, 32, 32] -> [8, 32, 64, 64]
        x = F.relu(self.deconv3(x))  # Transposed conv: [8, 32, 64, 64] -> [8, 16, 128, 128]
        x = torch.sigmoid(self.deconv4(x))  # Transposed conv: [8, 16, 128, 128] -> [8, 3, 128, 128]
        x = F.interpolate(x, size=(z_w, z_h), mode='bilinear', align_corners=False)
        return x

class VAE(nn.Module):
    def __init__(self, input_channels, hidden_dim, latent_dim, output_channels, num_heads):
        super(VAE, self).__init__()
        self.encoder_before = Encoder(input_channels, hidden_dim, latent_dim)
        self.encoder_after = Encoder(input_channels, hidden_dim, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dim, output_channels)
        self.SelfAttention3D = SelfAttention3D(latent_dim, num_heads)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x_before, x_after):
        z_w = x_before.size()[2]
        z_h = x_before.size()[3]
        mu_before, logvar_before = self.encoder_before(x_before)
        z_before = self.reparameterize(mu_before, logvar_before) # -----> [batch size, latent dim]
        mu_after, logvar_after = self.encoder_after(x_after)
        z_after = self.reparameterize(mu_after, logvar_after) # -----> [batch size, latent dim]
        # add frame dim
        z_before = z_before.unsqueeze(1)
        z_after = z_after.unsqueeze(1)
        print(z_before.shape)
        # concat
        z = torch.cat([z_before,z_after],dim=1)
        z = self.SelfAttention3D(z)
        #print(z.size())
        #print(z.size())
        return self.decoder(z,z_w,z_h)

# Example usage
input_channels = 3
hidden_dim = 32
latent_dim = 128
output_channels = 3
num_heads = 8

model = VAE(input_channels, hidden_dim, latent_dim, output_channels, num_heads).cuda()

# Dummy data with different sizes
input_data1 = torch.randn(8, 3, 240, 480).cuda()  # batch of 8, 64x64 RGB images
input_data2 = torch.randn(8, 3, 240, 480).cuda()  # batch of 8, 128x128 RGB images

output1 = model(input_data1, input_data2)

print(output1.shape)  # should match input_data1 shape (8, 3, 64, 64)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention3D(nn.Module):
    def __init__(self, latent_dim, num_heads):
        super(SelfAttention3D, self).__init__()
        self.num_heads = num_heads
        self.head_dim = latent_dim // num_heads
        assert self.head_dim * num_heads == latent_dim, "latent_dim must be divisible by num_heads"
        
        self.query = nn.Linear(latent_dim, latent_dim)
        self.key = nn.Linear(latent_dim, latent_dim)
        self.value = nn.Linear(latent_dim, latent_dim)
        self.fc_out = nn.Linear(latent_dim, latent_dim)
        
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])) #.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    def forward(self, x):
        batch_size = x.shape[0]
        
        # x shape: [batch_size, 2, latent_dim]
        # Split into query, key, value projections
        Q = self.query(x)  # [batch_size, 2, latent_dim]
        K = self.key(x)  # [batch_size, 2, latent_dim]
        V = self.value(x)  # [batch_size, 2, latent_dim]
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        K = K.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        V = V.view(batch_size, 2, self.num_heads, self.head_dim).permute(0, 2, 1, 3)  # [batch_size, num_heads, 2, head_dim]
        
        # Compute attention scores
        energy = torch.einsum("bnqd,bnkd->bnqk", [Q, K]) / self.scale  # [batch_size, num_heads, 2, 2]
        attention = torch.softmax(energy, dim=-1)  # [batch_size, num_heads, 2, 2]
        
        # Compute the attended values
        out = torch.einsum("bnqk,bnvd->bnqd", [attention, V])  # [batch_size, num_heads, 2, head_dim]
        
        # Reshape and combine heads
        out = out.permute(0, 2, 1, 3).contiguous()  # [batch_size, 2, num_heads, head_dim]
        out = out.view(batch_size, 2, self.num_heads * self.head_dim)  # [batch_size, 2, latent_dim]
        
        # Final linear layer
        out = self.fc_out(out)  # [batch_size, 2, latent_dim]
        
        return out

# Example usage
batch_size = 8
latent_dim = 128
num_heads = 8

self_attention = SelfAttention3D(latent_dim, num_heads)
input_tensor = torch.randn(batch_size, 2, latent_dim)

output_tensor = self_attention(input_tensor)
print(output_tensor.shape)  # should be [batch_size, 2, latent_dim]


In [None]:
latent_data = torch.randn(1, 128)
latent_data.size()

In [None]:
num_repeats = 1000

In [None]:
expanded_features_before = latent_data.repeat(num_repeats + 1, 1)

In [None]:
expanded_features_before.size()

In [None]:
import torch
from PIL import Image
from torchvision import transforms

# Load the image
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe = pipe.to("cuda")
init_image = torch.randn(3, 270, 480)
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image)

image = image.images[0]

# Define a transform to convert the image to a tensor
transform = transforms.ToTensor()

# Apply the transform to the image
image_tensor = transform(image)

# Print the shape of the tensor and the tensor itself
print(f"Image tensor shape: {image_tensor.shape}")
print(image_tensor)


In [None]:
image

In [None]:
inputs.pixel_values.size()

In [None]:
image_embeddings.size()

In [None]:
image_embeddings

In [None]:
init_image = Image.open("/home/ad/20813716/Deformable-3D-Gaussians/output/exp_nerfds_press-2/train/ours_40000/renders/00000.png")
init_image

#### Pipeline

In [9]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import torch

# Load the CLIP processor and model
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

# Open an image file
image_ref_1 = torch.randn(3,244) #Image.open('/home/ad/20813716/Deformable-3D-Gaussians/output/exp_nerfds_press-2/train/ours_40000/renders/00000.png')
image_ref_2 = torch.randn(3,244,244) #Image.open('/home/ad/20813716/Deformable-3D-Gaussians/output/exp_nerfds_press-2/train/ours_40000/renders/00005.png')
image_ref_1 = image_ref_1.unsqueeze(0)
image_ref_2 = image_ref_2.unsqueeze(0)
# Process the image
#input_1 = processor(images=image_ref_1, return_tensors="pt")
#input_2 = processor(images=image_ref_2, return_tensors="pt")
#print(input_1)
#print(input_2)

# Get image embeddings and pooled output
with torch.no_grad():
    # Image 1
    output_1 = model.vision_model(image_ref_1) #(**input_1)
    image_embedding_1 = output_1.last_hidden_state  # Token-level embeddings
    pooled_image_embed_1 = output_1.pooler_output  # Pooled embedding
    # Image 2
    #output_2 = model.vision_model(image_ref_2) #(**input_2)
    #image_embedding_2 = output_2.last_hidden_state  # Token-level embeddings
    #pooled_image_embed_2 = output_2.pooler_output  # Pooled embedding

# Print the shapes of the embeddings
print(f"Token-level embeddings shape: {image_embedding_1.shape}")
print(f"Pooled image embeddings shape: {pooled_image_embed_1.shape}")
#print(f"Token-level embeddings shape: {image_embedding_2.shape}")
#print(f"Pooled image embeddings shape: {pooled_image_embed_2.shape}")



  from .autonotebook import tqdm as notebook_tqdm


Token-level embeddings shape: torch.Size([1, 50, 768])
Pooled image embeddings shape: torch.Size([1, 768])


In [None]:
#input_1['pixel_values'].size()

In [None]:
### CONDITIONING ON 2 IMAGES

import torch.nn.functional as F

image_embedding = torch.cat([image_embedding_1,image_embedding_2],dim=1)
print(image_embedding.shape)
interpolated_tensor = F.interpolate(image_embedding.permute(0, 2, 1), size=77, mode='linear', align_corners=True).permute(0, 2, 1)  # Shape: [1, 77, 768]
print(interpolated_tensor.shape)
# Step 3: Linear layer to transform the feature dimension from 768 to 1280
linear_layer = torch.nn.Linear(768, 1280)
final_image_embedding = linear_layer(interpolated_tensor)  # Shape: [1, 77, 1280]

# Check the shape of the final tensor
print(final_image_embedding.shape)  # Output should be torch.Size([1, 77, 1280])


pooled_image_embedding = torch.cat([pooled_image_embed_1,pooled_image_embed_2],dim=1)

linear_pooled_prompt = torch.nn.Linear(1536, 1280)

final_pooled_image_embedding = linear_pooled_prompt(pooled_image_embedding)

print(final_pooled_image_embedding.shape) 

In [None]:
import torch.nn as nn 
### ONLY CONDITIONING ON 1 IMAGE
## final_image_embedding
final_image_embedding = torch.nn.functional.pad(image_embedding_1, (0, 0, 0, 27))  
linear_layer = nn.Linear(768, 1280)
# Apply the linear layer to transform the feature dimension
final_image_embedding = linear_layer(final_image_embedding)

## final_pooled_image_embedding
pooled_linear_layer = nn.Linear(768, 1280)
final_pooled_image_embedding = pooled_linear_layer(pooled_image_embed_1)

print(final_image_embedding.size())
print(final_pooled_image_embedding.size())

In [None]:
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image


pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
final_image_embedding = final_image_embedding.to("cuda")
final_pooled_image_embedding = final_pooled_image_embedding.to("cuda")

init_image = Image.open('/home/ad/20813716/Deformable-3D-Gaussians/output/exp_nerfds_press-2/train/ours_40000/renders/00000.png').convert("RGB")

image = pipe(prompt_embeds = final_image_embedding, pooled_prompt_embeds = final_pooled_image_embedding,  image=init_image, output_type="latent",num_inference_steps=10)

In [None]:
image.images.size()

In [None]:
image.images

In [None]:
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"

init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image, output_type="latent")

In [None]:
image.images.size()

In [None]:
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image
from PIL import Image


pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

init_image = torch.randn(3, 270, 480) #Image.open('/home/ad/20813716/Deformable-3D-Gaussians/output/exp_nerfds_press-2/train/ours_40000/renders/00000.png').convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image, output_type="latent",num_inference_steps=2)

In [None]:
image.images.size()

In [None]:
image.images[0].size()

In [None]:
latent_data = image.images[0]

In [None]:
latent_data.size()

In [None]:
# Flatten the tensor to 1D
flattened_tensor = latent_data.flatten()

# Calculate the number of elements needed for the target shape
num_elements_needed = 50000 * 64

# If the flattened tensor has fewer elements than needed, pad with zeros (or replicate)
if flattened_tensor.numel() < num_elements_needed:
    repeats = (num_elements_needed // flattened_tensor.numel()) + 1
    padded_tensor = flattened_tensor.repeat(repeats)[:num_elements_needed]
else:
    padded_tensor = flattened_tensor[:num_elements_needed]

# Reshape to the desired shape [50000, 64]
reshaped_tensor = padded_tensor.reshape(50000, 64)

In [None]:
reshaped_tensor.size()

In [None]:
### Saving tensor

In [None]:
import torch

# Initialize a dictionary to store tensors
tensor_dict = {}

# Iterate and store tensors
for i in range(10):  # Example loop for 10 iterations
    # Generate a tensor (e.g., a random tensor for this example)
    tensor = torch.rand(3, 3)  # 3x3 tensor with random values
    
    # Save the tensor in the dictionary with a unique key
    tensor_dict[f'tensor_{i}'] = tensor

# Accessing stored tensors
for key, tensor in tensor_dict.items():
    print(f"{key}: {tensor}")


In [None]:
tensor_dict['tensor_0']

In [None]:
import torch
import torch.nn as nn

class VAEDecoder(nn.Module):
    def __init__(self):
        super(VAEDecoder, self).__init__()
        self.fc = nn.Linear(4 * 33 * 60, 1024 * 8 * 8)  # Reduced size
        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=4, stride=2, padding=1),  # (512, 16, 16)
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),   # (256, 32, 32)
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),   # (128, 64, 64)
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),    # (64, 128, 128)
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),     # (32, 256, 256)
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),     # (16, 512, 512)
            nn.ReLU(True),
            nn.ConvTranspose2d(16, 8, kernel_size=4, stride=2, padding=1),      # (8, 1024, 1024)
            nn.ReLU(True),
            nn.ConvTranspose2d(8, 1, kernel_size=4, stride=2, padding=1),       # (1, 2048, 2048)
        )
        self.final_fc = nn.Linear(2048 * 2048, 50000 * 64)  # Adjusted size

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the latent vector
        x = self.fc(x)             # Fully connected layer
        x = x.view(x.size(0), 1024, 8, 8)  # Reshape to (batch_size, 1024, 8, 8)
        x = self.deconv_layers(x)  # Pass through the transposed convolutional layers
        x = x.view(x.size(0), -1)  # Flatten again
        x = self.final_fc(x)       # Final fully connected layer to output size
        x = x.view(-1, 50000, 64)  # Reshape to desired output size
        return x

# Example usage
latent = torch.randn(1, 4, 33, 60)  # Batch size of 1 for simplicity
decoder = VAEDecoder()
output = decoder(latent)
print(output.shape)  # Should print torch.Size([1, 50000, 64])


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention3D(nn.Module):
    def __init__(self, in_channels):
        super(SelfAttention3D, self).__init__()
        self.query_conv = nn.Conv3d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.key_conv = nn.Conv3d(in_channels, max(1, in_channels // 8), kernel_size=1)
        self.value_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, D, H, W = x.size()
        query = self.query_conv(x).view(batch_size, -1, D * H * W).permute(0, 2, 1)  # B, D*H*W, C//8
        key = self.key_conv(x).view(batch_size, -1, D * H * W)  # B, C//8, D*H*W
        value = self.value_conv(x).view(batch_size, -1, D * H * W)  # B, C, D*H*W

        attention = torch.bmm(query, key)  # B, D*H*W, D*H*W
        attention = F.softmax(attention, dim=-1)  # B, D*H*W, D*H*W

        out = torch.bmm(value, attention.permute(0, 2, 1))  # B, C, D*H*W
        out = out.view(batch_size, C, D, H, W)

        out = self.gamma * out + x
        return out

class ConvDecoderSelfAttention(nn.Module):
    def __init__(self, in_channels):
        super(ConvDecoderSelfAttention, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(True),
            nn.Conv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
            nn.ReLU(True)
        )
        self.self_attention = SelfAttention3D(in_channels)
        self.final_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.global_pool = nn.AdaptiveAvgPool3d((1, 33, 60))

    def forward(self, x):
        x = x.unsqueeze(2)  # Add an additional dimension for depth: [B, C, 1, H, W]
        x = self.conv_layers(x)
        x = self.self_attention(x)
        x = self.final_conv(x)
        x = self.global_pool(x)  # Aggregate along the batch dimension
        return x.mean(dim=0, keepdim=True).squeeze(2).squeeze(0) # Reduce batch dimension to 1 and remove depth dimension

# Example tensor with size [100, 4, 33, 60]
input_tensor = torch.randn(100, 4, 33, 60)

# Instantiate the model
model = ConvDecoderSelfAttention(in_channels=4)

# Apply the model to the input tensor
output = model(input_tensor)

print(output.shape)  # Should print torch.Size([1, 4, 33, 60])


torch.Size([4, 33, 60])


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

input_tensor = torch.randn(100, 4, 33, 60)
input_tensor.ndimension()

4

In [8]:
input_tensor = torch.randn(10000,1)
input_tensor[1,:].size()

torch.Size([1])

In [6]:
import torch
import torch.nn as nn

class CameraPoseEmbedding(nn.Module):
    def __init__(self, input_dim=37, output_dim=1280):
        super(CameraPoseEmbedding, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, camera_center, world_view_transform, full_proj_transform, fovx, fovy):
        # Flatten the [4, 4] tensors to [16]
        world_view_flat = world_view_transform.view(-1)  # [16]
        proj_transform_flat = full_proj_transform.view(-1)  # [16]

        # Concatenate all the components
        pose_vector = torch.cat([
            camera_center,              # [3]
            world_view_flat,            # [16]
            proj_transform_flat,        # [16]
            fovx.view(-1),              # [1]
            fovy.view(-1)               # [1]
        ], dim=-1)  # [37]

        # Map to the desired size [1280]
        embedded_pose = self.fc(pose_vector)  # [1280]

        # Repeat or reshape to match [1, 77, 1280]
        embedded_pose = embedded_pose.unsqueeze(0).repeat(77, 1)  # [77, 1280]

        # Add batch dimension [1, 77, 1280]
        embedded_pose = embedded_pose.unsqueeze(0)  # [1, 77, 1280]

        return embedded_pose

# Example usage
camera_center = torch.randn(3)
world_view_transform = torch.randn(4, 4)
full_proj_transform = torch.randn(4, 4)
fovx = 1.0 #torch.tensor(1.0)
fovy = 2.0 #torch.tensor(1.0)
fovx = torch.tensor(fovx)
fovy = torch.tensor(fovy)


model = CameraPoseEmbedding()
output = model(camera_center, world_view_transform, full_proj_transform, fovx, fovy)

print(output.shape)  # Should print torch.Size([1, 77, 1280])


torch.Size([1, 77, 1280])


In [7]:
import torch
import torch.nn as nn

class CameraPoseEmbedding(nn.Module):
    def __init__(self, input_dim=37, output_dim=1280):
        super(CameraPoseEmbedding, self).__init__()
        # Linear layer to map from input_dim to output_dim
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, camera_center, world_view_transform, full_proj_transform, fovx, fovy):
        # Flatten the [4, 4] tensors to [16]
        world_view_flat = world_view_transform.view(-1)  # [16]
        proj_transform_flat = full_proj_transform.view(-1)  # [16]

        # Convert FoVx and FoVy to tensors if they are floats
        fovx_tensor = torch.tensor([fovx], dtype=torch.float32)
        fovy_tensor = torch.tensor([fovy], dtype=torch.float32)

        # Concatenate all components into a single vector
        pose_vector = torch.cat([
            camera_center,              # [3]
            world_view_flat,            # [16]
            proj_transform_flat,        # [16]
            fovx_tensor,                # [1]
            fovy_tensor                 # [1]
        ], dim=-1)  # [37]

        # Map to the desired size [1280]
        embedded_pose = self.fc(pose_vector)  # [1280]

        # Add batch dimension [1, 1280]
        embedded_pose = embedded_pose.unsqueeze(0)  # [1, 1280]

        return embedded_pose

# Example usage
camera_center = torch.randn(3)
world_view_transform = torch.randn(4, 4)
full_proj_transform = torch.randn(4, 4)
fovx = 1.0  # Example FoVx as a float
fovy = 1.0  # Example FoVy as a float

print(world_view_transform.size())
print(full_proj_transform.size())

model = CameraPoseEmbedding()
output = model(camera_center, world_view_transform, full_proj_transform, fovx, fovy)

print(output.shape)  # Should print torch.Size([1, 1280])


torch.Size([1, 1280])
