In [5]:
from tensorflow.keras.models import Model
import tensorflow as tf
from model import lstm_vit
from helpers import f1
import warnings
import numpy as np
import matplotlib.pyplot as plt
from data_loader import data_generator
warnings.filterwarnings("ignore")

if tf.config.list_physical_devices('GPU'):
    print("GPU Details:")
    for gpu in tf.config.list_physical_devices('GPU'):
        print(f"  {gpu}")
        details = tf.config.experimental.get_device_details(gpu)
        print(f"  Device details: {details}")
else:
    print("No GPU found!")

GPU Details:
  PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')
  Device details: {'compute_capability': (8, 9), 'device_name': 'NVIDIA RTX 6000 Ada Generation'}


In [6]:
model = lstm_vit()

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[f1])
model.summary()

Model: "LSTM_ViT_ConvLSTM"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_3 (InputLayer)        [(None, 15, 256, 256, 1   0         
                             )]                                  
                                                                 
 temporal_encoding (TimeDis  (None, 15, 32, 32, 128)   490112    
 tributed)                                                       
                                                                 
 bidirectional_1 (Bidirecti  (None, 15, 32, 32, 128)   885248    
 onal)                                                           
                                                                 
 feature_projection (TimeDi  (None, 15, 32, 32, 128)   16512     
 stributed)                                                      
                                                                 
 time_distributed_4 (TimeDi  (None, 15, 64, 64, 1

In [12]:
train_img_path = 'data_v3_processed/train/images/images'
train_masks_path = 'data_v3_processed/train/masks/masks'

val_img_path = 'data_v3_processed/val/images/images'
val_masks_path = 'data_v3_processed/val/masks/masks'

train_gen = data_generator(train_img_path, train_masks_path, 8)
val_gen = data_generator(val_img_path, val_masks_path, 8, train=False, sequence_length=13)

In [13]:
history = model.fit(train_gen,
                    validation_data=val_gen,
                    epochs=1,
                    batch_size=8,
                    steps_per_epoch=12000,
                    validation_steps=13)

Building sequence list and filtering empty sequences...
Total valid sequences: 220
Skipped empty sequences: 270
Reverse time order: False
Total valid sequences: 2
Skipped empty sequences: 0
Reverse time order: False


In [14]:
model.save("lstm_vit_model.h5")