In [None]:
import torch
import torch.nn as nn
from CircuitSimulation.CircuitSimulator import *
import numpy as np
from diffusers import UNet2DModel
import torch.optim as optim
import torch.nn.functional as F

In [45]:
#Circuit shape, with input coloumn left, and output column right
# [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
# [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],

InitialDataset = torch.tensor([
    [
        [1, 1, 1, 1, 1, 2, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ],
    [
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 1, 0],
        [1, 1, 0, 1, 0, 0, 0, 1, 1, 0],
        [0, 1, 0, 1, 0, 0, 1, 1, 0, 0],
        [0, 1, 1, 2, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    ]
])

InitialLabels = torch.tensor([
    [
        [1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2 ],
        [1, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
        [0, 1, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
        [0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
    ],
    [
        [1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2 ],
        [1, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
        [0, 1, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
        [0, 0, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2 ],
    ],
])

In [46]:
# Test that circuits work

testCircuit = InitialDataset[1].numpy()

socket1 = Socket("inp0", True)
socket2 = Socket("inp1", True)
socket3 = Socket("out0", False)

socketList = [
    (socket1, (0, 0)), 
    (socket2, (0, 3)), 
    (socket3, (9, 0)),  
    ]

socketMap = GetSocketMap(testCircuit, socketList)
connectionMap = GetConnectionMap(socketMap)

print(socketMap)
print(connectionMap)
testOrder1 = [socket1, socket2]
Simulate(connectionMap, socketMap, testOrder1)

{inp0: {1}, inp1: {3}, out0: {2}, AND0base: {3}, AND0collector: {1}, AND0emitter: {2}}
{1: [inp0, AND0collector], 3: [inp1, AND0base], 2: [out0, AND0emitter]}


[(inp0, True), (inp1, True), (out0, True)]

In [None]:
class RowEmbedder(nn.Module):
    def __init__(self, num_categories, vector_length, embedding_dim):
        super().__init__()
        self.shared_embed = nn.Embedding(num_categories, embedding_dim)
        self.position_weights = nn.Parameter(torch.ones(vector_length, embedding_dim))
        self.position_bias = nn.Parameter(torch.zeros(vector_length, embedding_dim))
        
    def forward(self, x):
        # x shape: [batch_size, vector_length]
        shared = self.shared_embed(x)  # [batch_size, vector_length, emb_dim]
        # Apply position-specific scaling and shifting
        return shared * self.position_weights + self.position_bias
    

class TabularTransformer(nn.Module):
    def __init__(self, num_categories, num_features, d_model):
        super().__init__()
        self.d_model = d_model

        self.row_embedding = RowEmbedder(num_categories, num_features, d_model) #num_categories, vector_length, embedding_dim

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=6,
            dim_feedforward=4*d_model,
            batch_first=True,
        )

        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=6,
        )
    
    def forward(self, x):
        #batch_size, num_rows, num_cols = x.shape

        rows = self.row_embedding(x) #bs, rows, columns, embedding
        rows = rows.mean(dim=2)
    
        transformed = self.transformer(rows)

        return transformed

transformer = TabularTransformer(3, 12, 16)

transformer(InitialLabels)[0]

tensor([[ 1.4526e+00,  4.6408e-01, -2.2007e-02, -5.0593e-02, -1.8623e+00,
          9.8007e-01,  8.1641e-02,  3.1947e-02,  1.0452e-01,  7.8182e-01,
          1.2370e+00, -1.2544e+00, -1.0294e+00, -4.0757e-01, -1.6950e+00,
          1.1876e+00],
        [ 8.4392e-01, -2.5650e-01, -1.5884e-01,  2.4626e-01, -2.4257e+00,
          1.2766e+00,  5.2654e-01,  7.1567e-01,  1.9325e-01,  7.9750e-01,
          8.2815e-01, -7.9608e-01, -1.0337e+00, -8.1231e-01, -1.2194e+00,
          1.2746e+00],
        [ 9.5773e-01, -1.5920e-01, -4.0058e-02,  9.0168e-02, -1.6174e+00,
          1.5712e+00,  8.1230e-01,  1.2431e+00, -2.3086e-01,  3.5523e-01,
          1.1842e-01, -2.2396e-01, -1.0856e+00, -1.3408e+00, -1.7145e+00,
          1.2643e+00],
        [ 1.0282e+00,  1.8311e-01, -1.9624e-03, -1.6950e-01, -2.0915e+00,
          9.5756e-01,  4.9463e-01,  1.1232e+00,  1.7706e-01,  2.4930e-01,
          1.0046e+00, -1.3411e+00, -8.0515e-01, -8.6900e-01, -1.3438e+00,
          1.4043e+00]], grad_fn=<SelectBack

In [None]:
# betas = torch.linspace(0.0001, 0.014, 200)
# alphas = 1.0 - betas
# alphaCumprod = torch.cumprod(alphas, dim=0) 
# alphaCumprod[-1] # this should be roughly 1/4

class CategoricalScheduler:
    def __init__(self, TrainSteps = 200, numCategories = 4, betaStart = 0.0001, betaEnd = 0.014):
        self.TrainSteps = TrainSteps
        self.noiseDevice = 'cpu'
        self.numCategories = numCategories

        self.betas = torch.linspace(betaStart, betaEnd, TrainSteps, device=self.noiseDevice)
        self.alphas = 1.0 - self.betas
        self.alpha_cumprod = torch.cumprod(self.alphas, dim=0)
         # The last value of alpha_cumprod should be close to 1/numclasses

    def addNoise(self, imageBatch, time):
        # imagebatch shape is (128, 11, 32, 32), where the 11 is one-hot vector of categories
        bs, ch, w, h = imageBatch.shape

        with torch.no_grad():
            alpha_t = self.alpha_cumprod[time].view(-1, 1, 1, 1) # Translates shape (1,) -> (1, 1, 1, 1)

            # the per pixel probability distribution of the categories
            currentProbabilities = imageBatch

            # The chance of each state per pixel when noised            
            updatedProbabilities = currentProbabilities * alpha_t + (1 - alpha_t) / self.numCategories 
            updatedProbabilities = updatedProbabilities.permute(0, 2, 3, 1) # reshape such that it is flattened correctly below
            updatedProbabilities = updatedProbabilities.reshape(bs*w*h, self.numCategories)  # Shape: [bs * w * h, 11]
            

            # 1 Sample per value
            categoricalNoise = torch.multinomial(updatedProbabilities, 1, replacement=True)
            categoricalNoise = categoricalNoise.view(bs, w, h) # Shape: [bs, w, h]

            noisedImages = F.one_hot(categoricalNoise, num_classes=self.numCategories)
            noisedImages = noisedImages.permute(0, 3, 1, 2) # [bs, num_classes, w, h]

            return noisedImages

scheduler = CategoricalScheduler()


device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

class CategoricalDiffusionModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.guidance_prob=0.1

    embeddingSize = ... # how much?

    self.model = UNet2DModel(
        sample_size=28,           # the target image resolution
        in_channels=1 + embeddingSize, # Additional input channels for class cond.
        out_channels=1 + embeddingSize,           # the number of output channels
        layers_per_block=2,       # how many ResNet layers to use per UNet block
        block_out_channels=(64, 64, 64), 
        down_block_types=( 
            "DownBlock2D",        # a regular ResNet downsampling block
            "DownBlock2D",    # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ), 
        up_block_types=(
            "AttnUpBlock2D", 
            "UpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
          ),
    )
  # Our forward method now takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    # Shape of x:
    bs, ch, w, h = x.shape
    
    # class conditioning in right shape to add as additional input channels
    class_cond = self.class_emb(class_labels) # Map to embedding dimension
    batch_size, embed_dim = class_cond.shape
    mask = torch.rand(batch_size, device=class_cond.device) < self.guidance_prob
    class_cond[mask] = torch.zeros(embed_dim, device=class_cond.device)
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
    
    # Net input is now x and class cond concatenated together along dimension 1
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Feed this to the UNet alongside the timestep and return the prediction
    return self.model(net_input, t).sample # (bs, 11, 28, 28)
  
model = CategoricalDiffusionModel().to(device)