ML model weights and trainable codes for Basars
pip install tensorflow requests pandas opencv-python
pip install git+https://github.com/Basars/trans-unet.git
pip install git+https://github.com/Basars/basars-addons.git
Make sure you've installed Python >= 3.8
.
usage: python -m basars.train
[-h] --type {stairs,proj}
[--num-classes NUM_CLASSES] [--epochs EPOCHS]
[--batch_size BATCH_SIZE] [--buffer_size BUFFER_SIZE]
[--multiprocessing-workers MULTIPROCESSING_WORKERS] [--cache-dataset CACHE_DATASET]
Polyp Segmentation and Phase Classification from Endoscopic Images
optional arguments:
-h, --help show this help message and exit
--type {stairs,proj} The type of transformer model. Default value is 'stairs'
--num-classes NUM_CLASSES
Number of classes to be classified. Default value is 5
--epochs EPOCHS Epochs that how many times the model would be trained. Default value is 1290
--batch_size BATCH_SIZE
The batch size. Default value is 64
--buffer_size BUFFER_SIZE
The buffer size for shuffling datasets. Default value is 1024
--multiprocessing-workers MULTIPROCESSING_WORKERS
Number of workers for prefetching datasets. Default value is 64
--cache-dataset CACHE_DATASET
True to cache datasets on memory otherwise don't. Default value is True
python -m basars.train --type proj --epochs 1290
Refer the repository: final-experiments
You can find out the weights in Releases.
stairs
model have conv3x3 (256, 128, 64, 32, 16) → conv1x1 (5)
upsamples
proj
model have conv3x3 (256, 128, 64, 16) → conv1x1 (5, 5)
upsamples
stairs
model:
model = Sequential(name='ViT-stairs', layers=[
VisionTransformer(input_shape=(224, 224, 3),
upsample_channels=(256, 128, 64, 32),
output_kernel_size=3, num_classes=16),
Conv2D(5, kernel_size=(1, 1), padding='same', activation='softmax', use_bias=False)
])
model.load_weights('basars-cls5-stairs.h5')
proj
model:
Sequential(name='ViT-proj', layers=[
VisionTransformer(input_shape=(224, 224, 3), num_classes=5),
Conv2D(5, kernel_size=(1, 1), padding='same', activation='softmax', use_bias=False)
])
model.load_weights('basars-cls5-proj.h5')