In [1]:
import numpy as np
import tensorflow as tf

In [2]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [3]:
import numpy as np
train_rgb, train_labels = np.load('./training_data/images_rgb.npy'), np.load('./training_data/image_labels.npy')
train_vessels, train_labels_v = np.load('./training_data/vessels_seg.npy'), np.load('./training_data/vessels_seg.npy')
train_vessels = train_vessels.reshape(3150,224,224,1)
train_tda = np.load('./training_data/tda_vr.npy')
training_final = [train_rgb, train_vessels, train_tda]

In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Dense, Concatenate
from tensorflow.keras.models import Model
from swin_transformer_final import SwinTransformerModel, CFGS

def create_swin_encoder(input_shape, swin_model_type, pretrained=True, layer_img='rgb'):
    # Load the Swin Transformer model configuration
    swin_config = CFGS[swin_model_type]
    swin_transformer = SwinTransformerModel(include_top=False, **swin_config)
    
    # Load pretrained weights if specified
    if pretrained:
        model_name = swin_model_type
        url = f'https://github.com/rishigami/Swin-Transformer-TF/releases/download/v0.1-tf-swin-weights/{model_name}.tgz'
        pretrained_weights = tf.keras.utils.get_file(fname=model_name, origin=url, untar=True)
        if tf.io.gfile.isdir(pretrained_weights):
            pretrained_weights = f'{pretrained_weights}/{model_name}.ckpt'
        swin_transformer.load_weights(pretrained_weights)
    
    # Define the input layer
    input_layer = Input(shape=input_shape)

    # Forward pass up to the global average pooling
    x = swin_transformer.patch_embedding(input_layer)
    x = swin_transformer.basic_layers(x)
    x = swin_transformer.normalization_layer(x)
    x = swin_transformer.global_average_pooling(x)

    # If the model does not output a 1000-dimensional vector, use a Dense layer to project it
    if swin_transformer.n_features != 1000:
        x = Dense(1000)(x)

    # Create a model
    return Model(input_layer, x, name=f'{swin_model_type}_encoder_{layer_img}')


# Define input shapes
input_shape_rgb = (224, 224, 3)
input_shape_vsi = (224, 224, 3)
input_shape_tda = (224, 224, 3)

# Create Swin Transformer encoders for each input type
swin_encoder_rgb = create_swin_encoder(input_shape_rgb, 'swin_small_224', layer_img='rgb')
swin_encoder_vsi = create_swin_encoder(input_shape_vsi, 'swin_small_224', layer_img='vsi')
swin_encoder_tda = create_swin_encoder(input_shape_tda, 'swin_small_224', layer_img='tda')

# Inputs for RGB, VSI, and TDA images
input_rgb = Input(shape=input_shape_rgb, name='input_rgb')
input_vsi = Input(shape=(224,224,1), name='input_vsi')
input_tda = Input(shape=input_shape_tda, name='input_tda')

input_after_conv_vsi = Conv2D(3, kernel_size=(3,3), padding='same', activation='relu')(input_vsi)

# Process each input through its respective Swin Transformer encoder
features_rgb = swin_encoder_rgb(input_rgb)
features_vsi = swin_encoder_vsi(input_after_conv_vsi)
features_tda = swin_encoder_tda(input_tda)

# Concatenate the output features from all Swin Transformer encoders
concatenated_features = Concatenate()([features_rgb, features_vsi, features_tda])

# Classifier head
classifier_output = Dense(2, activation='softmax', name='classifier')(concatenated_features)

2024-04-28 16:43:08.688222: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-04-28 16:43:08.688255: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-04-28 16:43:08.688267: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-04-28 16:43:08.688294: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-04-28 16:43:08.688310: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [5]:
for layer in swin_encoder_rgb.layers[:-1]:
    layer.trainable = False

# Freeze the Swin Transformer layers for VSI encoder
for layer in swin_encoder_vsi.layers[:-1]:
    layer.trainable = False

for layer in swin_encoder_tda.layers[:-1]:
    layer.trainable = False

In [6]:
# Check the trainable status of the layers in swin_encoder_rgb
print("Trainable status of layers in swin_encoder_rgb:")
for layer in swin_encoder_rgb.layers:
    print(layer.name, layer.trainable)

# Check the trainable status of the layers in swin_encoder_vsi
print("\nTrainable status of layers in swin_encoder_vsi:")
for layer in swin_encoder_vsi.layers:
    print(layer.name, layer.trainable)

print("\nTrainable status of layers in swin_encoder_tda:")
for layer in swin_encoder_tda.layers:
    print(layer.name, layer.trainable)

Trainable status of layers in swin_encoder_rgb:
input_1 False
patch_embed False
sequential_4 False
norm False
global_average_pooling1d False
dense True

Trainable status of layers in swin_encoder_vsi:
input_2 False
patch_embed False
sequential_9 False
norm False
global_average_pooling1d_1 False
dense_1 True

Trainable status of layers in swin_encoder_tda:
input_3 False
patch_embed False
sequential_14 False
norm False
global_average_pooling1d_2 False
dense_2 True


In [7]:
# Create the combined Model
combined_model = Model(inputs=[input_rgb, input_vsi, input_tda], outputs=classifier_output) # , input_tda

# Compile the combined model
combined_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Summary of the model
combined_model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_vsi (InputLayer)      [(None, 224, 224, 1)]        0         []                            
                                                                                                  
 input_rgb (InputLayer)      [(None, 224, 224, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 224, 224, 3)          30        ['input_vsi[0][0]']           
                                                                                                  
 input_tda (InputLayer)      [(None, 224, 224, 3)]        0         []                            
                                                                                              

In [None]:
from tqdm import tqdm
epochs = 10
n_batches = 49
batch_size = 64
for epoch in range(epochs):
    loss = 0
    acc = 0
    print(f"Epoch {epoch+1}/{epochs}")
    
    for i in tqdm(range(n_batches+1)):
        if i==n_batches:
            start_idx = i*batch_size
            end_idx = 315
        else:
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
        
        # Extract batches for each input
        batch_input1 = train_rgb[start_idx:end_idx]
        batch_input2 = train_vessels[start_idx:end_idx]
        batch_input3 = train_tda[start_idx:end_idx]
        
        # Extract labels for the current batch
        batch_labels = train_labels[start_idx:end_idx]
        
        # Training the model on the batch
        out = combined_model.train_on_batch([batch_input1, batch_input2, batch_input3], batch_labels)
        loss += out[0]
        acc += out[1]
    print(f"Loss: {round((loss/n_batches),4)}, Accuracy: {round((acc/n_batches),2)}")

Epoch 1/10


  0%|                                                    | 0/50 [00:00<?, ?it/s]2024-04-28 16:44:52.027535: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
100%|███████████████████████████████████████████| 50/50 [21:25<00:00, 25.71s/it]


Loss: 5.6424, Accuracy: 0.51
Epoch 2/10


  4%|█▊                                          | 2/50 [00:53<21:27, 26.83s/it]

In [12]:
combined_model.evaluate(training_final, train_labels)



[1.32634437084198, 0.6920635104179382]

In [13]:
y_pred_s = combined_model.predict(training_final)



In [14]:
np.save('y_pred_all_channels_s.npy', y_pred_s)