In [6]:
!nvidia-smi

Fri Jul 12 12:31:09 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 546.92                 Driver Version: 546.92       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4080 ...  WDDM  | 00000000:01:00.0  On |                  N/A |
| N/A   47C    P8               4W / 130W |    844MiB / 12282MiB |      8%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [7]:
import torch
import torch.nn as nn
import numpy as np
import scipy as sp 
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
from torch.nn import init
import pytorch_lightning as pL
import torch.optim as optim

In [8]:
# Define the Device Being Used:
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Define the relevant helper functions for the attention UNet architecture
def conv1x1(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)

def conv3x3(in_channels, out_channels, stride = 1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)

def Swish():
    return nn.SiLU()

def GeLU():
    return nn.GELU()

def Mish():
    return nn.Mish()

def Sigmoid():
    return nn.Sigmoid()

In [9]:
# Define the transformation for converting psnr into timesteps.
def psnr_to_timestep(psnr, minpsnr, maxpsnr, num_timesteps):
    """
    psnr - the psnr value to convert to a timestep
    minpsnr - the minimum psnr value that can be achieved
    maxpsnr - the maximum psnr value that can be achieved
    num_timesteps - the number of timesteps in the simulation
    This function takes a psnr as import, performs min-max normalisation and 
    then converts it into a timestep.
    """
    psnr = torch.FloatTensor([psnr])
    minpsnr = torch.FloatTensor([minpsnr])
    maxpsnr = torch.FloatTensor([maxpsnr])
    
    psnr = torch.maximum(torch.minimum(psnr, maxpsnr), minpsnr)
    normalised_psnr = ((psnr - minpsnr) / (maxpsnr - minpsnr))
    timestep = (num_timesteps-1)*normalised_psnr
    return torch.Tensor([timestep]).to(torch.float32)

In [10]:
# Define the transformation for converting the timestep into a positional embedding.
class Temporal_Embedder(nn.Module):
    def __init__(self, n_channels):
        super(Temporal_Embedder, self).__init__()
        self.n_channels = n_channels
        self.Linear_1 = nn.Linear(self.n_channels//4, self.n_channels)
        self.Linear_2 = nn.Linear(self.n_channels, self.n_channels)
        self.Mish_1 = Mish()
        self.Mish_2 = Mish()

    def forward(self, t):
        half_dim = self.n_channels//8
        constant = torch.FloatTensor([10000.0])
        emb = torch.log(constant) / (half_dim -1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim = 1)
        emb = emb.to(device)
        
        emb = self.Mish_1(self.Linear_1(emb))
        emb = self.Mish_2(self.Linear_2(emb))
        return emb

In [11]:
# Define the Squeeze and Extraction Block used to recaliberate feature maps from a convolution.
class SEblock(nn.Module):
    def __init__(self, units, bottlenecks, dropout_rate):
        super(SEblock, self).__init__()
        
        self.units = units
        self.bottlenecks = bottlenecks
        self.dropout_rate = dropout_rate
        
        # Define the SE Block layers
        self.Dense = nn.LazyLinear(units)
        self.Dropout = nn.Dropout2d(dropout_rate)
        self.GlobalPool = nn.AdaptiveAvgPool2d((1,1))
        
    def forward(self, x):
        x = self.GlobalPool(x)
        x = x.view(x.size(0), -1)
        x = self.Dense(x)
        x = F.leaky_relu(x)
        x = self.Dense_2(x)
        x = F.sigmoid(x)
        x = x.view(-1, self.units, 1, 1)
        return x

In [12]:
# Define the Different Convolution Blocks for the Encoder and Decoder
class Residual_Block(nn.Module):
    """
    in_channels = Number of Incoming Channels
    out_channels = Number of Outgoing Channels
    time_channels = Number of Time Channels
    dropout_rate = Dropout Rate
    n_groups = Number of Groups
    stride = Stride
    This Block is the traditional residual block used in standard diffusion models.
    """
    def __init__(self, in_channels, out_channels, time_channels,
                dropout_rate, n_groups = 8, stride=1):
        
        super(Residual_Block, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout_rate = dropout_rate
        self.stride = stride
        self.n_groups = n_groups
        self.time_channels = time_channels

        # Define the residual block layers

        self.time_embedding = nn.Linear(time_channels, out_channels)
        self.conv1 = conv3x3(in_channels, out_channels)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.conv3 = conv1x1(in_channels, out_channels)

        self.GroupNorm_1 = nn.GroupNorm(n_groups, in_channels)
        self.GroupNorm_2 = nn.GroupNorm(n_groups, out_channels)
        self.Batch_Norm_1 = nn.BatchNorm2d(out_channels)

        self.dropout_1 = nn.Dropout(dropout_rate)
        self.dropout_2 = nn.Dropout(dropout_rate)

        self.Mish_1 = Mish()
        self.Mish_2 = Mish()
        self.Mish_3 = Mish()

    def forward(self, x, t):
        x_skip = self.Batch_Norm_1(self.conv3(x))
        
        x = self.conv1(self.Mish_1(self.GroupNorm_1(x)))
        x = self.dropout_1(x)

        h = self.Mish_3(self.time_embedding(t)[:,:,None,None])
        x += h

        x = self.conv2(self.Mish_2(self.GroupNorm_2(x)))
        x = self.dropout_2(x)
        x += x_skip
    
        return x

In [13]:
class Efficient_Residual_Block(nn.Module):
    """
    in_channels = Number of incoming channels
    out_channels = Number of outgoing channels
    time_channels = Number of time channels
    dropout_rate = Dropout Rate
    n_groups = Number of Groups
    bottleneck_channels = Number of channels in the Convolution Bottleneck
    stride = Stride
    This block is a computational efficient variant of the residual block used 
    in standard diffusion models. It increases the computational efficiency, by performing
    the 3v3 convolution in a reduced channel space before restoring the dimensionality of the 
    feature representation.
    """
    def __init__(self, in_channels, bottleneck_channels, time_channels, 
                 out_channels, dropout_rate, n_groups = 8, stride=1):
        
        super(Efficient_Residual_Block, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dropout_rate = dropout_rate
        self.stride = stride
        self.n_groups =  n_groups
        self.time_channels = time_channels

        # Define the residual block layers

        self.time_embedding = nn.Linear(time_channels, bottleneck_channels)
        self.conv1 = conv1x1(in_channels, bottleneck_channels)
        self.conv2 = conv3x3(bottleneck_channels, bottleneck_channels)
        self.conv3 = conv1x1(bottleneck_channels, out_channels)

        self.GroupNorm_1 = nn.GroupNorm(self.n_groups, in_channels)
        self.GroupNorm_2 = nn.GroupNorm(self.n_groups, bottleneck_channels)
        self.GroupNorm_2 = nn.GroupNorm(self.n_groups, out_channels)
        self.Batch_Norm_1 = nn.BatchNorm2d(in_channels)

        self.dropout_1 = nn.Dropout2d(dropout_rate)
        self.dropout_2 = nn.Dropout2d(dropout_rate)
        self.dropout_3 = nn.Dropout2d(dropout_rate)

        self.Mish_1 = Mish()
        self.Mish_2 = Mish()
        self.Mish_3 = Mish()
        self.Mish_4 = Mish()

    def forward(self, x, t):
        x_skip = self.Batch_Norm_1(self.conv3(x))
        
        x = self.conv1(self.Mish_1(self.GroupNorm_1(x)))
        x = self.dropout_1(x)

        x = self.conv2(self.Mish_2(self.GroupNorm_2(x)))
        x = self.dropout_2(x)

        h = self.Mish_4(self.time_embedding(t)[:,:, None, None])
        x += h

        x = self.conv3(self.Mish_3(self.GroupNorm_3(x)))
        x = self.dropout_3(x)

        x += x_skip

        return x

In [14]:
class SqueezeExtraction_Block(nn.Module):
    """
    filters = The number of outgoing channels
    units = The number of units in the SE block
    dropout_rate = The dropout rate
    time_channels = The number of time channels
    units_bottleneck = The number of units in the bottleneck of the SE Block
    n_groups = The number of groups in the Group Normalisation
    This block is a variant of the residual block used in standard diffusion models.
    It incorporates the squeeze and extraction blocks to recaliberate the feature maps 
    coming from the convolutions.
    """
    def __init__(self, filters, units, dropout_rate,
                 time_channels, units_bottleneck, n_groups = 8):
        super(SqueezeExtraction_Block, self).__init__()
        
        self.filters = filters
        self.units = units
        self.dropout_rate = dropout_rate
        self.units_bottleneck = units_bottleneck
        self.time_channels = time_channels
        self.n_groups = n_groups
        
        # Define the SqueezeExtraction_Block Layers
        
        self.time_embedding = nn.Linear(time_channels, filters)
        self.Conv_1 = nn.LazyConv2d(filters, 3, padding = 1, stride = 1)
        self.Conv_2 = nn.LazyConv2d(filters, 3, padding = 1, stride = 1)
        self.Conv_Bypass = nn.LazyConv2d(filters, 1)

        self.GroupNorm_1 = nn.GroupNorm(8, filters)
        self.GroupNorm_2 = nn.GroupNorm(8, filters)

        self.dropout_1 = nn.Dropout(dropout_rate)
        self.dropout_2 = nn.Dropout(dropout_rate)

        self.Mish_1 = Mish()
        self.Mish_2 = Mish()
        self.Mish_3 = Mish()
        self.SE_Block = SEblock(units, units_bottleneck, dropout_rate)

    def forward(self, x, t):
        x_skip = self.Conv_Bypass(x)

        x = self.Conv_1(self.Mish_1(self.GroupNorm_1(x)))
        x = self.dropout_1(x)

        h = self.Mish_3(self.time_embedding(x)[:,:,None,None])
        y += h

        x = self.Conv_2(self.Mish_2(self.GroupNorm_2(t)))
        x = self.dropout_2(x)

        y = self.SE_Block(x)
        y *= x
        x = y + x_skip

        x = F.leaky_relu(x)
        
        return x

In [15]:
class ResidualDense_Block(nn.Module):
    """
    in_channels = Number of incoming channels
    out_channels = Number of Outgoing channels
    dropout_rate = dropout_rate
    time_channels = Number of time channels
    n_groups = Number of groups in the group normalisation layer
    This block is a variant of the residual block used in standard diffusion models.
    The block incorporates dense connections between the convolutional layers to increase
    computational efficiency by incentising feature reuse. Skip connections continue to be 
    used to create a hybrid between a residual an dense connection design.
    """
    def __init__(self, in_channels, out_channels, dropout_rate,
                time_channels, n_groups = 8):
        super(ResidualDense_Block, self).__init__()
        
        self.in_channels = in_channels
        self.in_out_channels = out_channels
        self.dropout_rate = dropout_rate
        self.n_groups = n_groups
        
        # Define the Fully_Dense_Encoder layers

        self.time_embedding = nn.Linear(time_channels, out_channels)

        self.conv1 = nn.LazyConv2d(out_channels, kernel_size=3, padding=1, stride=1)
        self.conv2 = nn.LazyConv2d(out_channels, kernel_size=3, padding=1, stride=1)
        self.conv3 = nn.LazyConv2d(out_channels, kernel_size=3, padding=1, stride=1)
        self.conv4 = nn.LazyConv2d(out_channels, kernel_size=1, padding = 0, stride=1)
        self.conv5 = nn.LazyConv2d(out_channels, kernel_size=1, padding = 0, stride=1)


        self.GroupNorm_1 = nn.GroupNorm(self.n_groups, in_channels)
        self.GroupNorm_2 = nn.GroupNorm(self.n_groups, in_channels + out_channels)
        self.GroupNorm_3 = nn.GroupNorm(self.n_groups, in_channels + 2*out_channels)
        self.BatchNorm_1 = nn.BatchNorm2d(in_channels)

        self.Mish_1 = Mish()
        self.Mish_2 = Mish()
        self.Mish_3 = Mish()
        self.Mish_4 = Mish()
        self.Mish_5 = Mish()

        self.dropout_1 = nn.Dropout2d(dropout_rate)
        self.dropout_2 = nn.Dropout2d(dropout_rate)
        self.dropout_3 = nn.Dropout2d(dropout_rate)

    def forward(self, x, t):
        x_cat, x_skip = x, self.BatchNorm_1(self.conv4(x))

        x = self.conv1(self.Mish_1(self.GroupNorm_1(x)))
        x = self.dropout_1(x)
        x = torch.cat((x, x_cat), 1)

        x_cat = x
        x = self.conv2(self.Mish_2(self.GroupNorm_2(x)))
        x = self.dropout_2(x)
        x = torch.cat((x, x_cat), 1)

        x_cat = x
        h = self.Mish_5(self.time_embedding(t)[:,:, None, None])
        x += h
        x = self.conv3(self.Mish_3(self.GroupNorm_3(x)))
        x = self.dropout_3(x)
        x = torch.cat((x, x_cat), 1)


        x = self.Mish_4(self.conv5(x))
        x += x_skip

        return x      

In [16]:
class AttentionBlock(nn.Module):
    def __init__(self, n_channels, n_heads, dim_k, n_groups, dropout_rate):
        super(AttentionBlock, self).__init__()

        self.n_channels = n_channels
        self.n_heads = n_heads
        self.dim_k = dim_k
        self.n_groups = n_groups
        self.dropout_rate = dropout_rate
        self.scale = dim_k ** -0.5

        self.qkv = nn.Linear(n_channels, n_heads * dim_k * 3)
        self.output = nn.Linear(n_heads*dim_k, n_channels)
        self.norm = nn.GroupNorm(n_groups, n_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1)
        x_skip = x
        x= x.permute(0,2,1)

        qkv = self.qkv(x).view(batch_size, -1, self.n_heads, 3*self.dim_k)

        q, k, v = qkv.chunk(3, dim=-1)

        attn = torch.einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
        attn = attn.softmax(dim = -1)
        attn_output = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
         
        attn_output = attn_output.view(batch_size, -1, self.n_heads * self.dim_k)
        attn_output = self.dropout(self.output(attn_output))
        #print(attn_output.shape, x_skip.shape)
        
        attn_output = attn_output.permute(0,2,1)
        attn  = self.norm(attn_output + x_skip)
        attn_output = attn.view(batch_size, n_channels, height, width)
        return attn_output

In [48]:
class GroupQueryAttentionBlock(nn.Module):
    def __init__(self, n_channels, n_heads, dim_k, n_groups, dropout_rate, group_size):
        super(GroupQueryAttentionBlock, self).__init__()

        self.n_channels = n_channels
        self.n_heads = n_heads
        self.dim_k = dim_k
        self.n_groups = n_groups
        self.dropout_rate = dropout_rate
        self.scale = dim_k ** -0.5
        self.group_size = group_size

        self.qkv = nn.Linear(n_channels, n_heads * dim_k * 3)
        self.output = nn.Linear(n_heads*dim_k, n_channels)
        self.norm = nn.GroupNorm(n_groups, n_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1)
        x_skip = x
        x= x.permute(0,2,1)

        qkv = self.qkv(x).view(batch_size, -1, self.n_heads, 3*self.dim_k)

        q, k, v = qkv.chunk(3, dim=-1)

        num_groups = max(1, q.shape[1]//self.group_size)
        #print(num_groups, q.shape[1], self.group_size)

        q_groups = q.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)
        k_groups = k.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)
        v_groups = v.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)

        attn_weights = torch.einsum("bgnhd, bgnhd -> bgnh", q_groups, k_groups) * self.scale
        attn_weights = F.softmax(attn_weights, dim = -1)
        attn_output = torch.einsum("bgnh, bgnhd -> bgnhd", attn_weights, v_groups)
        
        attn_output = attn_output.view(batch_size, -1, self.n_heads * self.dim_k)
        attn_output = self.dropout(self.output(attn_output))
        #print(attn_output.shape, x_skip.shape)
        attn_output = attn_output.permute(0,2,1)
        attn  = self.norm(attn_output + x_skip)
        attn_output = attn.view(batch_size, n_channels, height, width)
        return attn_output

In [49]:
class ConvGroupQueryAttentionBlock(nn.Module):
    def __init__(self, n_channels, n_heads, dim_k, n_groups, dropout_rate, group_size, reduction_factor):
        super(ConvGroupQueryAttentionBlock, self).__init__()

        self.original_n_channels = n_channels
        self.n_heads = n_heads
        self.dim_k = dim_k
        self.n_groups = n_groups
        self.dropout_rate = dropout_rate
        self.scale = dim_k ** -0.5
        self.group_size = group_size
        self.reduction_factor = reduction_factor
        self.n_channels = n_channels//reduction_factor

        self.bottleneck = nn.LazyConv2d(self.original_n_channels//reduction_factor, kernel_size= 1)
        self.unbottleneck = nn.LazyConv2d(self.original_n_channels, kernel_size = 1)

        self.qkv = nn.LazyLinear(n_heads * dim_k * 3)
        self.output = nn.Linear(n_heads * dim_k, self.n_channels)
        self.norm = nn.GroupNorm(n_groups, self.n_channels)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        
        x = self.bottleneck(x)
        
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, -1)
        x_skip = x

        x = x.permute(0,2,1)

        qkv = self.qkv(x).view(batch_size, -1, self.n_heads, 3*self.dim_k)

        q, k, v = qkv.chunk(3, dim=-1)

        num_groups = max(1, q.shape[1]//self.group_size)
        #print(num_groups, q.shape[1], self.group_size)

        q_groups = q.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)
        k_groups = k.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)
        v_groups = v.view(batch_size, num_groups, self.group_size, self.n_heads, self.dim_k)

        attn_weights = torch.einsum("bgnhd, bgnhd -> bgnh", q_groups, k_groups) * self.scale
        attn_weights = F.softmax(attn_weights, dim = -1)
        attn_output = torch.einsum("bgnh, bgnhd -> bgnhd", attn_weights, v_groups)
        
        attn_output = attn_output.view(batch_size, -1, self.n_heads * self.dim_k)
        attn_output = self.dropout(self.output(attn_output))
        #print(attn_output.shape, x_skip.shape)
        attn_output = attn_output.permute(0,2,1)
        attn  = self.norm(attn_output + x_skip)
        
        attn_output = attn.view(batch_size, n_channels, height, width)
        
        attn_output = self.unbottleneck(attn_output)
        return attn_output

In [50]:
# Example usage
n_channels = 64
n_heads = 16
dim_k = 16
n_groups = 8
dropout_rate = 0.1
reduction_factor = 4
group_size = (128**2)//8
reduction_factor = 4
x = torch.randn(32, n_channels, 128, 128)
x.shape

torch.Size([32, 64, 128, 128])

In [51]:
Atten_Block = AttentionBlock(n_channels=n_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                         dropout_rate = dropout_rate)

output = Atten_Block(x)
print(output.shape)

torch.Size([32, 64, 128, 128])


In [52]:
Query_Atten_Block = GroupQueryAttentionBlock(n_channels=n_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                         dropout_rate = dropout_rate, group_size = group_size)

output = Query_Atten_Block(x)
print(output.shape)

torch.Size([32, 64, 128, 128])


In [53]:
Conv_Query_Atten_Block = ConvGroupQueryAttentionBlock(n_channels=n_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                         dropout_rate = dropout_rate, group_size = group_size, reduction_factor= reduction_factor)

output = Conv_Query_Atten_Block(x)
print(output.shape)



torch.Size([32, 64, 128, 128])


In [37]:
# Define A Down_Sample Block
class Down_Block(nn.Module):
    """
    in_channels = Number of incoming channels
    out_channels = Number of outgoing channels
    time_channels = Number of time channels
    has_attn = Whether the block has an attention mechanism
    n_groups = Number of groups in the group normalisation layer
    dropout_rate = Dropout Rate
    conv_type = The type of convolution block to use
    attn_type = The type of attention block to use
    downsample = Whether to downsample the image
    Pooling = Whether to use pooling or strided convolutions for downsampling
    bottleneck_channels = Number of channels in the bottleneck of the Efficient Residual Block
    units = Number of units in the SE Block
    bottleneck_units = Number of units in the bottleneck of the SE Block
    This block is the downsample block used in the attention UNet architecture. It is used to
    downsample the image and increase the number of channels in the image. It can incorporate
    attention mechanisms and different types of convolution blocks to increase the computational
    efficiency of the model.
    """
    def __init__(self, in_channels, out_channels, time_channels, n_groups,
                dropout_rate, conv_type = "Residual_Block", attn_type = "Attention",
                n_heads = None, dim_k = None, group_size = None, downsample = True, 
                Pooling = True, bottleneck_channels = None, units = None, 
                reduction_factor = None, bottleneck_units = None):
        
        super(Down_Block, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.time_channels = time_channels
        self.dropout_rate = dropout_rate
        self.n_groups = n_groups
        self.bottleneck_channels = bottleneck_channels
        self.units = units
        self.bottleneck_units = bottleneck_units
        self.conv_type = conv_type
        self.attn_type = attn_type
        self.Pooling = Pooling
        self.downsample = downsample
        self.dim_k = dim_k
        self.group_size = group_size
        self.n_heads = n_heads
        self.reduction_factor = reduction_factor

        # Define the Down_Block layers
        if "Residual_Block" in self.conv_type:
            self.Conv_Block = Residual_Block(in_channels = in_channels, out_channels = out_channels, 
                                             time_channels=time_channels, dropout_rate=dropout_rate, 
                                             n_groups=n_groups)
            
        if "Efficient_Residual_Block" in self.conv_type:
            self.Conv_Block = Efficient_Residual_Block(in_channels=in_channels, bottleneck_channels=bottleneck_channels, 
                                                       time_channels=time_channels, out_channels=out_channels, 
                                                       dropout_rate=dropout_rate, n_groups=n_groups)
            
        if "SqueezeExtraction_Block" in self.conv_type:
            self.Conv_Block = SqueezeExtraction_Block(filters = out_channels, units = units, dropout_rate= dropout_rate, 
                                                      time_channels= time_channels, units_bottleneck= bottleneck_units, n_groups= n_groups)
        if "ResidualDense_Block" in self.conv_type:
            self.Conv_Block = ResidualDense_Block(in_channels = in_channels, out_channels = out_channels, 
                                                 time_channels = time_channels, dropout_rate = dropout_rate, n_groups = n_groups)
        
        if attn_type is not None:
            self.has_attn = True
            if "Attention" in attn_type:
                self.Attention = AttentionBlock(n_channels= out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                dropout_rate = dropout_rate)
            if "GroupQueryAttention" in attn_type:
                self.Attention = GroupQueryAttentionBlock(n_channels=out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                          dropout_rate = dropout_rate, group_size = group_size)
            if "ConvGroupQueryAttention" in attn_type:
                self.Attention = ConvGroupQueryAttentionBlock(n_channels=out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                             dropout_rate = dropout_rate, group_size = group_size, reduction_factor = reduction_factor)
        else:
            self.has_attn = False
            self.Attention = None

        if self.downsample:
            if self.Pooling:
                self.Pool = nn.MaxPool2d(kernel_size=2, stride=2)
            else:
                self.stride = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding = 1, stride=2) 

    def forward(self, x, t):
        x = self.Conv_Block(x, t)
        if self.has_attn:
            x = self.Attention(x)
        
        if self.downsample:
            before_pool = x
            if self.Pooling:
                before_pool = x
                x = self.Pool(x)
            else:
                before_pool = x
                x = self.stride(x)
            
            return x, before_pool
        else:
            return x

In [38]:
# Define the Up_Sample Block
class Up_Block(nn.Module):
    """
    in_channels = Number of incoming channels
    out_channels = Number of outgoing channels
    time_channels = Number of time channels
    has_attn = Whether the block has an attention mechanism
    n_groups = Number of groups in the group normalisation layer
    dropout_rate = Dropout Rate
    conv_type = The type of convolution block to use
    attn_type = The type of attention block to use
    up_sample = Whether to upsample the image.
    transpose = Whether to use transpose convolutions or upsample layers for upsampling.
    merge_type = The type of merge operation to use for the skip connections.
    bottleneck_channels = Number of channels in the bottleneck of the Efficient Residual Block
    units = Number of units in the SE Block
    bottleneck_units = Number of units in the bottleneck of the SE Block
    This block is the upsample block used in the attention UNet architecture. It is used to
    upsample the image and decrease the number of channels in the image. It can incorporate
    attention mechanisms and different types of convolution blocks to increase the computational
    efficiency of the model.
    """ 
    def __init__(self, in_channels, out_channels, time_channels, n_groups,
                dropout_rate, conv_type = "Residual_Block", attn_type = "Attention",
                up_sample = True, transpose = True, merge_type = "concat", n_heads = None,
                dim_k = None, group_size = None, reduction_factor = None, 
                bottleneck_channels = None, units = None, bottleneck_units = None):
        super(Up_Block, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.time_channels = time_channels
        self.dropout_rate = dropout_rate
        self.n_groups = n_groups
        self.bottleneck_channels = bottleneck_channels
        self.units = units
        self.bottleneck_units = bottleneck_units
        self.conv_type = conv_type
        self.attn_type = attn_type
        self.transpose = transpose
        self.merge_type = merge_type
        self.up_sample = up_sample
        self.dim_k = dim_k
        self.group_size = group_size
        self.reduction_factor = reduction_factor
        self.n_heads = n_heads

        # Define the Up_Block layers
        if up_sample:
            if "Residual_Block" in self.conv_type:
                self.Conv_Block = Residual_Block(2*in_channels, out_channels=out_channels, time_channels=time_channels, 
                                                dropout_rate=dropout_rate, n_groups= n_groups)
                
            if "Efficient_Residual_Block" in self.conv_type:
                self.Conv_Block = Efficient_Residual_Block(2*in_channels, bottleneck_channels=bottleneck_channels, time_channels= time_channels, 
                                                       out_channels= out_channels, dropout_rate=dropout_rate, n_groups=n_groups)
            
            if "SqueezeExtraction_Block" in self.conv_type:
                self.Conv_Block = SqueezeExtraction_Block(2*in_channels, units = units, dropout_rate= dropout_rate, 
                                                      time_channels= time_channels, units_bottleneck= bottleneck_units) 
            
            if "ResidualDense_Block" in self.conv_type:
                self.Conv_Block = ResidualDense_Block(2*in_channels, out_channels= out_channels, 
                                                  dropout_rate= dropout_rate, time_channels= time_channels, 
                                                  n_groups= n_groups)
                
        else:
            if "Residual_Block" in self.conv_type:
                self.Conv_Block = Residual_Block(in_channels, out_channels=out_channels, time_channels=time_channels, 
                                                dropout_rate=dropout_rate, n_groups= n_groups)
                
            if "Efficient_Residual_Block" in self.conv_type:
                self.Conv_Block = Efficient_Residual_Block(in_channels, bottleneck_channels=bottleneck_channels, time_channels= time_channels, 
                                                       out_channels= out_channels, dropout_rate=dropout_rate, n_groups=n_groups)
            
            if "SqueezeExtraction_Block" in self.conv_type:
                self.Conv_Block = SqueezeExtraction_Block(in_channels, units = units, dropout_rate= dropout_rate, 
                                                      time_channels= time_channels, units_bottleneck= bottleneck_units) 
            
            if "ResidualDense_Block" in self.conv_type:
                self.Conv_Block = ResidualDense_Block(in_channels, out_channels= out_channels, 
                                                  dropout_rate= dropout_rate, time_channels= time_channels, 
                                                  n_groups= n_groups)

        
        if attn_type is not None:
            self.has_attn = True
            if "Attention" in attn_type:
                self.Attention = AttentionBlock(n_channels= out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                dropout_rate = dropout_rate)
            if "GroupQueryAttention" in attn_type:
                self.Attention = GroupQueryAttentionBlock(n_channels=out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                          dropout_rate = dropout_rate, group_size = group_size)
            if "ConvGroupQueryAttention" in attn_type:
                self.Attention = ConvGroupQueryAttentionBlock(n_channels=out_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                             dropout_rate = dropout_rate, group_size = group_size, reduction_factor = reduction_factor)
        else:
            self.has_attn = False
            self.Attention = None
        
        if up_sample:
            if self.transpose:
                self.Upsample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
            else:
                self.Upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
    def forward(self, x, before_pool, t):
        if self.up_sample:
            x = self.Upsample(x)
            if "concat" in self.merge_type:
                x = torch.cat((x, before_pool), 1)
            else: 
                x += before_pool

        x = self.Conv_Block(x, t)

        if self.has_attn:
            x = self.Attention(x)

        return x

In [39]:
class MiddleBlock(nn.Module):
    """
    in_channels = Number of incoming channels
    time_channels = Number of time channels
    has_attn = Whether the block has an attention mechanism
    n_groups = Number of groups in the group normalisation layer
    dropout_rate = Dropout Rate
    conv_type = The type of convolution block to use
    attn_type = The type of attention block to use
    bottleneck_channels = Number of channels in the bottleneck of the Efficient Residual Block
    units = Number of units in the SE Block
    bottleneck_units = Number of units in the bottleneck of the SE Block
    This block is the middle block used in the attention UNet architecture. It is used to
    process the image and incorporate attention mechanisms to capture long-range dependencies
    in the image. It can incorporate different types of convolution blocks to increase the computational
    efficiency of the model.
    """
    def __init__(self, in_channels, time_channels, n_groups,
                dropout_rate, conv_type = "Residual_Block", attn_type = "Attention",
                bottleneck_channels = None, bottleneck_units = None, units = None,
                n_heads = None, dim_k = None, group_size = None, reduction_factor = None):
        super(MiddleBlock, self).__init__()

        self.in_channels = in_channels
        self.time_channels = time_channels
        self.dropout_rate = dropout_rate
        self.n_groups = n_groups
        self.bottleneck_channels = bottleneck_channels
        self.units = units
        self.bottleneck_units = bottleneck_units
        self.conv_type = conv_type
        self.attn_type = attn_type
        self.n_heads = n_heads
        self.dim_k = dim_k
        self.group_size = group_size
        self.reduction_factor = reduction_factor
        
        # Define the MiddleBlock layers
        if "Residual_Block" in self.conv_type:
            self.Conv_Block_1 = Residual_Block(in_channels = in_channels, out_channels = in_channels, 
                                             time_channels=time_channels, dropout_rate=dropout_rate, 
                                             n_groups=n_groups)
            
            self.Conv_Block_2 = Residual_Block(in_channels = in_channels, out_channels = in_channels, 
                                             time_channels=time_channels, dropout_rate=dropout_rate, 
                                             n_groups=n_groups)
            
        if "Efficient_Residual_Block" in self.conv_type:
            self.Conv_Block_1 = Efficient_Residual_Block(in_channels=in_channels, bottleneck_channels=bottleneck_channels, 
                                                       time_channels=time_channels, out_channels=in_channels, 
                                                       dropout_rate=dropout_rate, n_groups=n_groups)
            
            self.Conv_Block_2 = Efficient_Residual_Block(in_channels=in_channels, bottleneck_channels=bottleneck_channels, 
                                                       time_channels=time_channels, out_channels=in_channels, 
                                                       dropout_rate=dropout_rate, n_groups=n_groups)
    
        if "SqueezeExtraction_Block" in self.conv_type:
            self.Conv_Block_1 = SqueezeExtraction_Block(filters = in_channels, units = units, dropout_rate= dropout_rate, 
                                                      time_channels= time_channels, units_bottleneck= bottleneck_units, n_groups= n_groups)
            
            self.Conv_Block_2 = SqueezeExtraction_Block(filters = in_channels, units = units, dropout_rate= dropout_rate, 
                                                      time_channels= time_channels, units_bottleneck= bottleneck_units, n_groups= n_groups)
        if "ResidualDense_Block" in self.conv_type:
            self.Conv_Block_1 = ResidualDense_Block(in_channels = in_channels, out_channels = in_channels, 
                                                 time_channels = time_channels, dropout_rate = dropout_rate, n_groups = n_groups)
            
            self.Conv_Block_2 = ResidualDense_Block(in_channels = in_channels, out_channels = in_channels, 
                                                 time_channels = time_channels, dropout_rate = dropout_rate, n_groups = n_groups)
            
        if attn_type is not None:
            self.has_attn = True
            if "Attention" in attn_type:
                self.Attention = AttentionBlock(n_channels= in_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                dropout_rate = dropout_rate)
            if "GroupQueryAttention" in attn_type:
                self.Attention = GroupQueryAttentionBlock(n_channels=in_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                          dropout_rate = dropout_rate, group_size = group_size)
            if "ConvGroupQueryAttention" in attn_type:
                self.Attention = ConvGroupQueryAttentionBlock(n_channels=in_channels, n_heads = n_heads, dim_k = dim_k, n_groups = n_groups, 
                                                             dropout_rate = dropout_rate, group_size = group_size, reduction_factor = reduction_factor)
        else:
            self.has_attn = False
            self.Attention = None

    def forward(self, x, t):
        x = self.Conv_Block_1(x, t)
        if self.has_attn:
            x = self.Attention(x)
        x = self.Conv_Block_2(x, t)
        return x

In [44]:
class AttentionUNet(pL.LightningModule):
    """
    initial_channels = Number of channels to initially project the image into.
    channels_list = Number of channels in each layer of the UNet.
    blocks_per_channel = Number of blocks per channel in the UNet.
    n_groups = Number of groups in the group normalisation layer.
    n_heads = Number of heads to use for the multiheaded attention.
    dim_k  = The desired dimensionality of the target key and query vectors.
    dropout_rate = Dropout Rate.
    time_channels = Number of time channels.
    bottleneck_channels = Number of channels in the bottleneck of the Efficient Residual Block.
    units = Number of units in the SE Block.
    bottleneck_units = Number of units in the bottleneck of the SE Block.
    has_attn = Whether the block has an attention mechanism.
    conv_type = The type of convolution block to use.
    attn_type_list = The type of attention block to use.
    merge_type = The type of merge operation to use for the skip connections.
    maxpsnr = The maximum psnr value in the dataset.
    minpsnr = The minimum psnr value in the dataset.
    num_timesteps = The number of timesteps in the dataset.
    image_channels = The number of channels in the image.
    This class defines the Attention UNet architecture. It is a variant of the standard UNet
    architecture that incorporates attention mechanisms to capture long-range dependencies 
    in the image. This architecture is designed to be used in conjunction with the GAP framework
    it takes the psnr level of the image as input and uses it to generate the temporal embedding
    and then a positional encoding which when combined with the photon count distribution 
    of the image is used to predict the photon arrival distribution.
    """
    def __init__(self, initial_channels, channels_list, blocks_per_channel, n_groups, n_heads, dim_k, reduction_factor, group_size, dropout_rate, time_channels, 
                 bottleneck_channels, units, bottleneck_units, conv_type, attn_type_list, merge_type, upsample_type, maxpsnr, minpsnr, 
                 num_timesteps, middle_attn_type, image_channels = 1):
        super(AttentionUNet, self).__init__()

        self.initial_channels = initial_channels
        self.channels_list = channels_list
        self.n_groups = n_groups
        self.n_heads = n_heads
        self.reduction_factor = reduction_factor
        self.group_size = group_size
        self.middle_attn_type = middle_attn_type
        self.dim_k = dim_k
        self.dropout_rate = dropout_rate
        self.time_channels = time_channels
        self.bottleneck_channels = bottleneck_channels
        self.bottleneck_units = bottleneck_units
        self.blocks_per_channels = blocks_per_channel
        self.units = units
        self.transpose = upsample_type
        self.conv_type = conv_type
        self.attn_type_list = attn_type_list
        self.merge_type = merge_type
        self.depth = len(self.channels_list)
        self.maxpsnr = maxpsnr
        self.minpsnr = minpsnr
        self.num_timesteps = num_timesteps
        self.image_channels = image_channels
        self.save_hyperparameters()

        # Define the Psnr_to_Timestep function
        self.Psnr_Converter = lambda psnr: psnr_to_timestep(psnr, self.maxpsnr, self.minpsnr, self.num_timesteps)

        # Define the AttentionUNet layers
        self.DownBlocks = []
        self.MiddleBlocks = []
        self.UpBlocks = []
    

        self.Image_Projection = nn.LazyConv2d(initial_channels, kernel_size=3, padding=1, stride=1)
        self.input_norm = nn.GroupNorm(n_groups, self.initial_channels)
        self.input_Mish = Mish()

        self.Temporal_Embedding = Temporal_Embedder(initial_channels *4)

        for index in range(self.depth):
            in_channels = self.initial_channels if index == 0 else self.channels_list[index-1]
            out_channels = self.channels_list[index] 
            attn_type = attn_type_list[index]

            for index in range(self.blocks_per_channels):

                if index != self.blocks_per_channels -1:
                    self.DownBlocks.append(Down_Block(in_channels = in_channels,
                        out_channels = in_channels, 
                        time_channels = initial_channels*4,
                        n_groups = n_groups,
                        dropout_rate = dropout_rate,
                        conv_type = conv_type,
                        attn_type = attn_type,
                        downsample = False,
                        Pooling = True,
                        bottleneck_channels = bottleneck_channels,
                        units = units,
                        bottleneck_units = bottleneck_units,
                        n_heads= n_heads,
                        group_size = group_size,
                        dim_k = dim_k,
                        reduction_factor = reduction_factor
                    ))

                elif index == self.blocks_per_channels - 1:
                  self.DownBlocks.append(Down_Block(in_channels = in_channels,
                    out_channels = out_channels, 
                    time_channels = initial_channels*4,
                    n_groups = n_groups,
                    dropout_rate = dropout_rate,
                    conv_type = conv_type,
                    attn_type = attn_type,
                    downsample = True,
                    Pooling = True,
                    bottleneck_channels = bottleneck_channels,
                    units = units,
                    bottleneck_units = bottleneck_units,
                    n_heads= n_heads,
                    group_size = group_size,
                    dim_k = dim_k,
                    reduction_factor = reduction_factor
                )) 
                  
                     

            
        self.MiddleBlocks.append(MiddleBlock(in_channels= self.channels_list[-1],
                                            time_channels= initial_channels*4,
                                            n_groups= n_groups,
                                            dropout_rate= dropout_rate,
                                            conv_type= conv_type,
                                            attn_type= middle_attn_type,
                                            bottleneck_channels= bottleneck_channels,
                                            units= units,
                                            bottleneck_units= bottleneck_units,
                                            n_heads= n_heads,
                                            group_size = group_size,
                                            dim_k = dim_k,
                                            reduction_factor = reduction_factor
                                            )
                )

        for index in range(self.depth-2, -1, -1):
            in_channels = self.channels_list[index+1]
            out_channels = self.channels_list[index]
            attn_type = attn_type_list[index]
            #print(in_channels, out_channels, attn_type)
    
            for index in range(self.blocks_per_channels):

                if index == 0:
                    self.UpBlocks.append(Up_Block(in_channels = in_channels, 
                                            out_channels= out_channels,
                                            time_channels= initial_channels*4,
                                            n_groups=n_groups,
                                            dropout_rate=dropout_rate,
                                            conv_type=conv_type,
                                            attn_type=attn_type,
                                            up_sample=True,
                                            transpose=self.transpose,
                                            merge_type=merge_type,
                                            bottleneck_channels=bottleneck_channels,
                                            units=units,
                                            bottleneck_units=bottleneck_units,
                                            n_heads= n_heads,
                                            group_size = group_size,
                                            dim_k = dim_k,
                                            reduction_factor = reduction_factor
                ))
                    
                elif index != 0:
                    self.UpBlocks.append(Up_Block(in_channels = out_channels, 
                                            out_channels= out_channels,
                                            time_channels= initial_channels*4,
                                            n_groups=n_groups,
                                            dropout_rate=dropout_rate,
                                            conv_type=conv_type,
                                            attn_type=attn_type,
                                            up_sample=False,
                                            transpose=self.transpose,
                                            merge_type=merge_type,
                                            bottleneck_channels=bottleneck_channels,
                                            units=units,
                                            bottleneck_units=bottleneck_units,
                                            n_heads= n_heads,
                                            group_size = group_size,
                                            dim_k = dim_k,
                                            reduction_factor = reduction_factor
                ))
            
        self.Final_Upsample = Up_Block(
            in_channels = out_channels,
            out_channels = initial_channels,
            time_channels= initial_channels*4,
            n_groups=n_groups,
            dropout_rate=dropout_rate,
            conv_type=conv_type,
            attn_type=attn_type,
            up_sample=True,
            transpose=self.transpose,
            merge_type=merge_type,
            bottleneck_channels=bottleneck_channels,
            units=units,
            bottleneck_units=bottleneck_units,
            n_heads= n_heads,
            group_size = group_size,
            dim_k = dim_k,
            reduction_factor = reduction_factor
            )    
            

        self.output_norm = nn.GroupNorm(n_groups, self.initial_channels)
        self.output_Mish = Mish()
        self.Output_Layer = nn.Conv2d(self.initial_channels, self.image_channels, kernel_size=1, stride=1)
    
        self.DownBlocks = nn.ModuleList(self.DownBlocks)
        self.MiddleBlocks = nn.ModuleList(self.MiddleBlocks)
        self.UpBlocks = nn.ModuleList(self.UpBlocks)

    def forward(self, x):
        
        psnr = torch.FloatTensor([-40.0])
        t = self.Psnr_Converter(psnr)
        t = self.Temporal_Embedding(t)
        
        stack = None
        
        factor = 10.0
        for i in range(self.depth):
            scale = x.clone()*(factor**(-i))
            scale = torch.sin(scale)
            if stack is None:
                stack = scale
            else:
                stack = torch.cat((stack,scale),1)
        
        x = stack

        x = self.Image_Projection(x)
        x = self.input_norm(x)
        x = self.input_Mish(x)
        
        Encoder_Skip_Connections = []
        for block in self.DownBlocks:
            #print(x.shape, "Downsampling")
            if block.downsample:
                x, before_pool = block(x, t)
                Encoder_Skip_Connections.append(before_pool)
            else:
                x = block(x, t)
        

        for block in self.MiddleBlocks:
            x = block(x, t)
        
        ptr = 0
        slow_ptr = 0
        Reversed_Encoder_Skip_Connections = Encoder_Skip_Connections[::-1]
        while ptr < len(self.UpBlocks):
            #print(x.shape, ptr, slow_ptr, before_pool.shape, "Upsampling")
            before_pool = Reversed_Encoder_Skip_Connections[slow_ptr]
            if ptr % self.blocks_per_channels == 0:
                x = self.UpBlocks[ptr](x, before_pool, t)
                slow_ptr += 1
            else:
                x = self.UpBlocks[ptr](x, before_pool, t)
            ptr += 1
        
        before_pool = Encoder_Skip_Connections[0]
        x = self.Final_Upsample(x, before_pool, t)

        x = self.Output_Layer(self.output_Mish(self.output_norm(x)))
        return x

    @staticmethod
    def weight_init(m):
        if isinstance(m, nn.Conv2d):
            init.xavier_normal(m.weight)
            init.constant(m.bias, 0)
    
    def reset_params(self):
        for i, m in enumerate(self.modules()):
            self.weight_init(m)
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-4)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
        
        return {
           'optimizer': optimizer,
           'lr_scheduler': scheduler, 
           'monitor': 'val_loss'
        }
    
    def photonLoss(self,result, target):
        expEnergy = torch.exp(result)
        perImage =  -torch.mean(result*target, dim =(-1,-2,-3), keepdims = True )
        perImage += torch.log(torch.mean(expEnergy, dim =(-1,-2,-3), keepdims = True ))*torch.mean(target, dim =(-1,-2,-3), keepdims = True )
        return torch.mean(perImage)
    
    def MSELoss(self,result, target):
        expEnergy = torch.exp(result)
        expEnergy /= (torch.mean(expEnergy, dim =(-1,-2,-3), keepdims = True ))
        target = target / (torch.mean(target, dim =(-1,-2,-3), keepdims = True ))
        return torch.mean((expEnergy-target)**2)
    
    def training_step(self, batch, batch_idx = None):
        img_input, psnr_image, target_img  = batch
        psnr = psnr_image.min()
        predicted = self.forward(img_input)
        train_loss = self.photonLoss(predicted, target_img)
        self.log("train_loss", train_loss, on_step = False, on_epoch = True, prog_bar = True, logger = True)
        return train_loss
    
    def validation_step(self, batch, batch_idx = None):
        img_input, psnr_image, target_img = batch
        psnr = psnr_image.min()
        predicted = self.forward(img_input)
        valid_loss = self.photonLoss(predicted, target_img)
        self.log("val_loss", valid_loss, on_step = False, on_epoch = True, prog_bar = True, logger = True)
        return valid_loss
    
    def test_step(self, batch, batch_idx = None):
        img_input, psnr_image, target_img = batch
        psnr = psnr_image.min()
        predicted = self.forward(img_input)
        test_loss = self.photonLoss(predicted, target_img)
        self.log("test_loss", test_loss, on_step = False, on_epoch = True, prog_bar = True, logger = True)
        return test_loss

    def predict(self, x, psnr):
        return self.forward(x, psnr)

In [45]:
# Define the Attention Unet Model

# Define the Hyperparameters of the Model:
initial_channels = 8
image_channels = 1
blocks_per_channel = 2
n_groups = 8
n_heads = 8
channels_list = [32, 64, 128, 256, 512]
dim_k = 64
group_size = 128
reduction_factor = 4
dropout_rate = 0.1
time_channels = 512
bottleneck_channels = 64
units = 128
bottleneck_units = 64
has_attn = True
conv_type = "Residual_Block"
attn_type_list = [None, None, None, None, None]
middle_attn_type = "Attention"
merge_type = "concat"
maxpsnr = -5.0
minpsnr = -40.0
num_timesteps = 1024

# Define the Attention Network:
AttnUNet = AttentionUNet(
    initial_channels = initial_channels,
    channels_list = channels_list,
    blocks_per_channel = blocks_per_channel,
    n_groups = n_groups,
    n_heads = n_heads,
    dim_k = dim_k,
    dropout_rate = dropout_rate,
    time_channels = time_channels,
    bottleneck_channels = bottleneck_channels,
    units = units,
    bottleneck_units= bottleneck_units,
    conv_type = conv_type,
    attn_type_list = attn_type_list,
    merge_type = merge_type,
    maxpsnr = maxpsnr,
    minpsnr = minpsnr,
    num_timesteps = num_timesteps,
    image_channels = image_channels,
    reduction_factor= reduction_factor,
    group_size= group_size,
    middle_attn_type= middle_attn_type,
    upsample_type = "transpose"
)

In [47]:
AttnUNet

AttentionUNet(
  (Image_Projection): LazyConv2d(0, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (input_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
  (input_Mish): Mish()
  (Temporal_Embedding): Temporal_Embedder(
    (Linear_1): Linear(in_features=8, out_features=32, bias=True)
    (Linear_2): Linear(in_features=32, out_features=32, bias=True)
    (Mish_1): Mish()
    (Mish_2): Mish()
  )
  (Final_Upsample): Up_Block(
    (Conv_Block): Residual_Block(
      (time_embedding): Linear(in_features=32, out_features=8, bias=True)
      (conv1): Conv2d(64, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv3): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (GroupNorm_1): GroupNorm(8, 64, eps=1e-05, affine=True)
      (GroupNorm_2): GroupNorm(8, 8, eps=1e-05, affine=True)
      (Batch_Norm_1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True,