In [1]:
import os

from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader

from models.xswin import XNetSwinTransformer
from models.xswin_diffusion import XNetSwinTransformerDiffusion
from models.modules import SwinResidualCrossAttention

from models.modules.normal.residual_cross_attention import _extract_windows, _unfold_padding_prep, _fold_unpadding_prep

from torchinfo import summary


In [2]:
x = torch.tensor([
    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
    [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4],
])

x = torch.tensor([
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
    [1, 2, 3, 4, 5],
])

# print(torch.cat((x, x), dim=-1))


# x = torch.vstack((torch.hstack((x, x*2)), torch.hstack((x*3, x*4))))
x = torch.stack((x, x*2, x*3, x*4))
x = torch.vstack((torch.hstack((x, x*2)), torch.hstack((x*3, x*4))))


# x = x.reshape(2, 2, -1)
x = x.unsqueeze(0)

print(x.shape)

x = _extract_windows(x, 4, 4)

# print(x.shape)


torch.Size([1, 8, 8, 5])


In [3]:
H_DIM = 16

r1 = torch.ones((5, 10, 22, H_DIM))
r2 = torch.ones(r1.shape)

f1, pinfo = _unfold_padding_prep(r1, 4, 4)
print(f1.shape)
f2 = _fold_unpadding_prep(f1, pinfo)
print(f2.shape)

SRCA = SwinResidualCrossAttention(
    [4, 4],
    H_DIM,
    4,
)

SRCA(r1, r2).shape


torch.Size([5, 12, 24, 16])
torch.Size([5, 10, 22, 16])


torch.Size([5, 10, 22, 16])

In [4]:
patch_size = [4, 4]
embed_dim = 64
depths = [1, 1, 1]
num_heads = [4, 8, 16]
window_size = [4, 4]
num_classes = 10

IMG_H, IMG_W = 151, 309
LATENT_DIM = 64
B = 10

global_stages = 1
input_size = [IMG_H, IMG_W]
final_downsample = True
residual_cross_attention = True
class_dropout = 0.1

diffusion = XNetSwinTransformerDiffusion(patch_size, embed_dim, depths, 
                           num_heads, window_size, num_classes=num_classes,
                           global_stages=global_stages, input_size=input_size,
                           final_downsample=final_downsample, residual_cross_attention=residual_cross_attention,
                           class_dropout_prob=class_dropout, latent_dimensions=LATENT_DIM,
                           )

x = torch.randn((B, LATENT_DIM, IMG_H, IMG_W))
t = torch.arange(0, B) # (B, )
y = torch.arange(0, B) # (B, )

diffusion(x, t, y).shape


torch.Size([10, 64, 151, 309]) torch.Size([10, 64])
torch.Size([10, 64, 151, 309]) torch.Size([10, 64])
torch.Size([10, 64, 151, 309]) torch.Size([10, 64])
torch.Size([10, 64, 151, 309]) torch.Size([10, 64])
torch.Size([10, 37, 77, 64]) torch.Size([10, 64])
torch.Size([10, 37, 77, 64]) torch.Size([10, 64])
torch.Size([10, 37, 77, 64]) torch.Size([10, 64])
torch.Size([10, 19, 39, 128]) torch.Size([10, 64])
torch.Size([10, 19, 39, 128]) torch.Size([10, 64])
torch.Size([10, 19, 39, 128]) torch.Size([10, 64])
torch.Size([10, 10, 20, 256]) torch.Size([10, 64])
torch.Size([10, 10, 20, 256]) torch.Size([10, 64])
torch.Size([10, 10, 20, 256]) torch.Size([10, 64])
torch.Size([10, 50, 512]) torch.Size([10, 64])
torch.Size([10, 50, 512]) torch.Size([10, 64])
torch.Size([10, 5, 10, 512]) torch.Size([10, 64])
torch.Size([10, 10, 20, 256]) torch.Size([10, 64])
torch.Size([10, 10, 20, 256]) torch.Size([10, 64])
torch.Size([10, 10, 20, 512]) torch.Size([10, 64])
torch.Size([10, 10, 20, 512]) torch.Siz

tensor([[[[ 2.7529e-01,  4.9275e-01,  1.8895e-01,  ...,  2.5477e-01,
            3.4588e-01,  4.9383e-01],
          [ 4.0597e-01,  3.2908e-01,  4.5913e-01,  ...,  3.9881e-01,
            5.7380e-01,  4.9717e-01],
          [ 3.7970e-02, -3.4107e-01, -1.8008e-01,  ...,  3.8295e-01,
            3.0138e-01,  4.8041e-01],
          ...,
          [ 2.9419e-01,  3.5661e-01,  2.5188e-01,  ...,  7.9012e-03,
            2.6368e-01,  7.3264e-03],
          [ 2.7711e-01,  5.0808e-02,  3.4199e-02,  ..., -3.0745e-01,
           -7.9126e-03,  3.2583e-01],
          [ 5.5517e-01,  4.3856e-01,  8.2591e-02,  ...,  2.4407e-01,
            1.6944e-01,  5.0870e-01]],

         [[ 4.0396e-02,  2.5708e-01,  4.4687e-02,  ...,  5.0010e-01,
            1.9652e-01,  2.9233e-01],
          [ 1.5042e-01,  1.9112e-01,  2.3647e-01,  ...,  6.7613e-01,
           -7.1661e-02,  6.1299e-01],
          [ 7.1949e-02,  4.5296e-01,  4.1781e-01,  ...,  4.7238e-01,
            1.4058e-01,  3.8443e-01],
          ...,
     

In [5]:
patch_size = [4, 4]
embed_dim = 64
depths = [1, 1, 1]
num_heads = [4, 8, 16]
window_size = [4, 4]
num_classes = 1

IMG_H, IMG_W = 151, 309

global_stages = 1
input_size = [IMG_H, IMG_W]
final_downsample = True
residual_cross_attention = True

swin = XNetSwinTransformer(patch_size, embed_dim, depths, 
                           num_heads, window_size, num_classes=num_classes,
                           global_stages=global_stages, input_size=input_size,
                           final_downsample=final_downsample, residual_cross_attention=residual_cross_attention,
                           )


x = torch.randn((5, 3, IMG_H, IMG_W))
print(x.shape)

# print(swin)

y = swin(x)
print(y.shape)

summary(swin, input_size=[1, 3, IMG_H, IMG_W])
# 512 25
# Parameter containing:
# tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
#          [0., 0., 0.,  ..., 0., 0., 0.],
#          [0., 0., 0.,  ..., 0., 0., 0.],
#          ...,
#          [0., 0., 0.,  ..., 0., 0., 0.],
#          [0., 0., 0.,  ..., 0., 0., 0.],
#          [0., 0., 0.,  ..., 0., 0., 0.]]])
# (625, 512)

# 20 39
# 10 19
# 5 9


torch.Size([5, 3, 151, 309])
torch.Size([5, 151, 309])


  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                                  Output Shape              Param #
XNetSwinTransformer                                     [1, 151, 309]             25,600
├─ConvolutionTriplet: 1-1                               [1, 64, 151, 309]         --
│    └─Sequential: 2-1                                  [1, 64, 151, 309]         --
│    │    └─Conv2d: 3-1                                 [1, 64, 151, 309]         1,792
│    │    └─BatchNorm2d: 3-2                            [1, 64, 151, 309]         128
│    │    └─LeakyReLU: 3-3                              [1, 64, 151, 309]         --
│    │    └─Conv2d: 3-4                                 [1, 64, 151, 309]         36,928
│    │    └─BatchNorm2d: 3-5                            [1, 64, 151, 309]         128
│    │    └─LeakyReLU: 3-6                              [1, 64, 151, 309]         --
│    │    └─Conv2d: 3-7                                 [1, 64, 151, 309]         36,928
│    │    └─BatchNorm2d: 3-8               

In [6]:
print(swin)


XNetSwinTransformer(
  (smooth_conv_in): ConvolutionTriplet(
    (layers): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): LeakyReLU(negative_slope=0.01)
      (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (8): LeakyReLU(negative_slope=0.01)
    )
  )
  (patching): Patching(
    (patching): Sequential(
      (0): Conv2d(64, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
  )
  (encoder): ModuleList(
    (0): Sequential(
      (0): SwinTrans