# Investigating TransUNet

TransUNet is a version of unet that utilizes transformers during the encoding stage.  Details of the model can be read in the paper [3D TransUNet: Advancing Medical Image Segmentation through Vision Transformers](https://arxiv.org/abs/2310.07781) This model caught my attention because of it's use of transformers and the emphasized usecase of medical image segmentation.

The writers were kind enough to publish a git repo with their code, and another good samaritan made a [Tensorflow version](https://github.com/awsaf49/TransUNet-tf).  I'm more comfortable with tensorflow, so I'll try that one.

In [1]:
import tensorflow as tf
from transunet import TransUNet
import pandas as pd
from pprint import pprint


TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [5]:
model = TransUNet(image_size=224, pretrain=True)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5
Downloading data from https://storage.googleapis.com/vit_models/imagenet21k/R50+ViT-B_16.npz


The download details note that the model is using resnet50v2 and R50+ViT-B_16, pretrained on ImageNet.

In [6]:
# I'm not particularly interested in downloading the weights everytime
# I load the model, so I'll save it locally in .keras format.
dst = 'D:/Downloads/models/tunet-pretrained.keras'
model.save(dst)

## Examine model

In [7]:
model.summary()

Model: "TransUNet"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_2 (InputLayer)        [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv1_pad (ZeroPadding2D)   (None, 230, 230, 3)          0         ['input_2[0][0]']             
                                                                                                  
 conv1_conv (Conv2D)         (None, 112, 112, 64)         9472      ['conv1_pad[0][0]']           
                                                                                                  
 pool1_pad (ZeroPadding2D)   (None, 114, 114, 64)         0         ['conv1_conv[0][0]']          
                                                                                          

In [20]:
# Printing out each layer name, number, and whether or not it's trainable.

for i, layer in enumerate(model.layers):
    conf = layer.get_config()
    name = conf.pop('name')

    print(f'''LAYER {i}: {name} -- trainable: {layer.trainable}''')

LAYER 0: input_2 -- trainable: False
LAYER 1: conv1_pad -- trainable: False
LAYER 2: conv1_conv -- trainable: False
LAYER 3: pool1_pad -- trainable: False
LAYER 4: pool1_pool -- trainable: False
LAYER 5: conv2_block1_preact_bn -- trainable: False
LAYER 6: conv2_block1_preact_relu -- trainable: False
LAYER 7: conv2_block1_1_conv -- trainable: False
LAYER 8: conv2_block1_1_bn -- trainable: False
LAYER 9: conv2_block1_1_relu -- trainable: False
LAYER 10: conv2_block1_2_pad -- trainable: False
LAYER 11: conv2_block1_2_conv -- trainable: False
LAYER 12: conv2_block1_2_bn -- trainable: False
LAYER 13: conv2_block1_2_relu -- trainable: False
LAYER 14: conv2_block1_0_conv -- trainable: False
LAYER 15: conv2_block1_3_conv -- trainable: False
LAYER 16: conv2_block1_out -- trainable: False
LAYER 17: conv2_block2_preact_bn -- trainable: False
LAYER 18: conv2_block2_preact_relu -- trainable: False
LAYER 19: conv2_block2_1_conv -- trainable: False
LAYER 20: conv2_block2_1_bn -- trainable: False
LAYE

So we have 164 layers.  The first 144 are frozen, which is great because our dataset is invariably going to be too small and we don't want to overfit, while the final 20 layers are unfrozen.  I think that might still be too much, but we can play with it.

In [22]:
# Speaking of, can tensorflow see my GPU?
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  0


## @#$%!!