In [1]:
from utils.swin_transformer import swin_t_encoder

In [2]:
import torch

In [3]:
b = torch.randn((1,3,256,256))

In [4]:
enc_model = swin_t_encoder()

In [5]:
c = enc_model(b)

In [6]:
for out in c:
    print(out.shape)

torch.Size([1, 96, 64, 64])
torch.Size([1, 192, 32, 32])
torch.Size([1, 384, 16, 16])
torch.Size([1, 768, 8, 8])


In [7]:
from utils.encoder import Encoder

In [8]:
model = Encoder("swin")

In [9]:
from torchsummary import summary

In [10]:
summary(model, input_size = [(3,256,256), (3,256,256)], batch_size=10)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Unfold-1             [10, 48, 4096]               0
            Linear-2           [10, 64, 64, 96]           4,704
      PatchMerging-3           [10, 64, 64, 96]               0
         LayerNorm-4           [10, 64, 64, 96]             192
            Linear-5          [10, 64, 64, 288]          27,648
            Linear-6           [10, 64, 64, 96]           9,312
   WindowAttention-7           [10, 64, 64, 96]               0
           PreNorm-8           [10, 64, 64, 96]               0
          Residual-9           [10, 64, 64, 96]               0
        LayerNorm-10           [10, 64, 64, 96]             192
           Linear-11          [10, 64, 64, 384]          37,248
             GELU-12          [10, 64, 64, 384]               0
           Linear-13           [10, 64, 64, 96]          36,960
      FeedForward-14           [10, 64,

  total_output += np.prod(summary[layer]["output_shape"])


In [11]:
from prettytable import PrettyTable

table = PrettyTable(["Module", "Requires grad?"])

for name, parameter in model.named_parameters():
    
    # if not parameter.requires_grad: continue
    params = parameter.requires_grad
    table.add_row([name, params])
    
print(table)

+----------------------------------------------------------------+----------------+
|                             Module                             | Requires grad? |
+----------------------------------------------------------------+----------------+
|           model.stage1.patch_partition.linear.weight           |      True      |
|            model.stage1.patch_partition.linear.bias            |      True      |
|     model.stage1.layers.0.0.attention_block.fn.norm.weight     |      True      |
|      model.stage1.layers.0.0.attention_block.fn.norm.bias      |      True      |
|  model.stage1.layers.0.0.attention_block.fn.fn.pos_embedding   |      True      |
|  model.stage1.layers.0.0.attention_block.fn.fn.to_qkv.weight   |      True      |
|  model.stage1.layers.0.0.attention_block.fn.fn.to_out.weight   |      True      |
|   model.stage1.layers.0.0.attention_block.fn.fn.to_out.bias    |      True      |
|        model.stage1.layers.0.0.mlp_block.fn.norm.weight        |      True

Everythig clear from the model side... Its working and all the parameters are appropriately set to require gradiants

In [12]:
from dataloader import train_dataset

In [13]:
from torch.utils.data import DataLoader

In [14]:
tdl = DataLoader(train_dataset, batch_size=10, shuffle=True)

In [15]:
for batch_idx, sample in enumerate(tdl):
    img_A = sample["img_A"].clone()
    img_B = sample["img_B"].clone()
    label = torch.flatten(sample["label"].clone())

    print("img_A.shape = ", img_A.shape)
    print("img_B.shape = ", img_B.shape)
    print("label.shape = ", label.shape)

    print("img_A.requires_grad = ", img_A.requires_grad)
    print("img_B.requires_grad = ", img_B.requires_grad)
    print("label.requires_grad = ", label.requires_grad)
    
    
    break

img_A.shape =  torch.Size([10, 3, 256, 256])
img_B.shape =  torch.Size([10, 3, 256, 256])
label.shape =  torch.Size([10])
img_A.requires_grad =  True
img_B.requires_grad =  True
label.requires_grad =  True


Ok so the require gradiants are good for the inputs...

In [16]:
criterion = torch.nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=0.001)

In [19]:
model.train()
for batch_idx, sample in enumerate(tdl):

    # Move input tensors to the device
    img_A = sample["img_A"].clone()
    img_B = sample["img_B"].clone()
    label = torch.flatten(sample["label"].clone())

    # find the loss and update the model parameters accordingly
    # clear the gradients of all optimized variables
    optimizer.zero_grad()
    # forward pass: compute predicted outputs by passing inputs to the model
    enc1, enc2 = model(img_A, img_B)
    # calculate the batch loss

    loss = criterion(enc1, enc2, label)
    print(type(loss))
    print(loss.grad_fn)
    # print(loss.grad)

    loss.backward()

    optimizer.step()
    break

<class 'torch.Tensor'>
<MeanBackward0 object at 0x7fcca75ed750>
