Skip to content
This repository has been archived by the owner on Nov 22, 2022. It is now read-only.
/ trans-unet Public archive

An ML model with U-shaped architecture with ResNet50V2 and Vision Transformer based encoders

License

Notifications You must be signed in to change notification settings

Basars/trans-unet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

TransUNet

An ML model with U-shaped architecture with ResNet50V2 and Vision Transformer based encoders

Install

pip install --upgrade git+https://github.com/Basars/trans-unet.git

Usage:

import numpy as np

from transunet import VisionTransformer

# Encoder weights from Google
weights = np.load('R50+ViT-B_16.npz', allow_pickle=True)

model = VisionTransformer(input_shape=(224, 224, 3), 
                          num_classes=1, 
                          w=weights, 
                          encoder_trainable=False)
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.fit(...)

About

An ML model with U-shaped architecture with ResNet50V2 and Vision Transformer based encoders

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages