In [3]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import sys, os
sys.path.append(os.path.dirname('../ml/transformers/.'))
from BaseDecodeHead import BaseDecodeHead
from SegFormerHead import SegFormerHead
from MixTransformer import MixVisionTransformer, MixVisionTransformer_short
from OutputResizeHead import OutputResizeHead

sys.path.append(os.path.dirname('../ml/.'))
from utils import get_total_params

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
DEVICE = 'cuda'
BATCH_SIZE = 16
IMG_SIZE = 3*(32,)

In [8]:
encoder_cfg = {
    'img_size' : IMG_SIZE,
    'patch_sizes' : [5, 3],
    'patch_strides' : [1, 2],
    'in_channels' : 1,
    'embedding_dims' : [32, 64],
    'num_heads' : [4, 4],
    'mlp_ratios' : [2, 2],
    'depths' : [4, 4],
    'sr_ratios' : [8, 2],
    'qkv_bias' : True,
    'qk_scale' : None,
}

encoder = MixVisionTransformer_short(encoder_cfg)

decoder = SegFormerHead(embedding_dim = 128,
                        in_channels=[32, 64],
                        num_classes=8,
                        dropout_ratio=0.1,
                        conv_cfg=None,
                        norm_cfg=None,
                        act_cfg=dict(type='ReLU'),
                        in_index=[0, 1],
                        decoder_params=None,
                        align_corners=False)

auxiliary_head = OutputResizeHead(in_channels=8, out_channels=1, out_size=IMG_SIZE, mode='upsample')

In [9]:
model = nn.Sequential(
    encoder,
    decoder,
    auxiliary_head                   
)

print("total_params:", get_total_params(model))

x = torch.randn(1, 1, 64, 64, 64)
out = model(x)
print(out.shape)

total_params: 2525473
torch.Size([1, 1, 32, 32, 32])


In [12]:
from torch.optim import Adam
from torch.nn import MSELoss

model = model.to(DEVICE)
loss_fn = MSELoss()
optim = Adam(model.parameters())

for i in range(10):
    input = torch.randn(BATCH_SIZE, 1, *IMG_SIZE).to(DEVICE)
    GT = torch.randn(BATCH_SIZE, 1, *IMG_SIZE).to(DEVICE)
    output = model.forward(input)   
    
    loss = loss_fn(GT, output)
    optim.zero_grad()
    loss.backward()
        
    optim.step()
    loss_val = loss.item()