In [1]:
import tensorflow as tf  
from tensorflow.keras import layers, models, optimizers, losses, metrics  
import numpy as np  
import os  
import albumentations as A  
from skimage.transform import resize  
from skimage.io import imread, imsave  
from tqdm import tqdm  
  
class DoubleConv(tf.keras.Model):  
    def __init__(self, in_channels, out_channels):  
        super(DoubleConv, self).__init__()  
        self.conv = tf.keras.Sequential([  
            layers.Conv2D(out_channels, 4, padding='same', use_bias=False),  
            layers.BatchNormalization(),  
            layers.ReLU(),  
            layers.Conv2D(out_channels, 4, padding='same', use_bias=False),  
            layers.BatchNormalization(),  
            layers.ReLU(),  
        ])  
  
    def call(self, x):  
        return self.conv(x)  
  
class UNET(tf.keras.Model):  
    def __init__(self, in_channels=4, out_channels=1, features=[64, 128, 256, 512]):  
        super(UNET, self).__init__()  
        self.downs = [DoubleConv(in_channels, features[0])]  
        self.downs += [  
            DoubleConv(features[i], features[i + 1]) for i in range(len(features) - 1)  
        ]  
          
        self.ups = [  
            DoubleConv(features[i], features[i - 1]) for i in range(len(features) - 1, 0, -1)  
        ] + [DoubleConv(features[0] * 2, features[0])]  
          
        self.pool = layers.MaxPool2D(pool_size=2, strides=2)  
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)  
        self.upconvs = [  
            layers.Conv2DTranspose(features[i], kernel_size=2, strides=2) for i in range(len(features) - 1, -1, -1)  
        ]  
          
        self.final_conv = layers.Conv2D(out_channels, kernel_size=1)  
  
    def call(self, x):  
        skip_connections = []  
          
        for down in self.downs:  
            x = down(x)  
            skip_connections.append(x)  
            x = self.pool(x)  
              
        x = self.bottleneck(x)  
        skip_connections = reversed(skip_connections[:-1])  
          
        for idx in range(len(self.ups)):  
            x = self.upconvs[idx](x)  
            skip_connection = skip_connections[idx]  
            if x.shape[1] != skip_connection.shape[1]:  
                x = tf.image.resize(x, (skip_connection.shape[1], skip_connection.shape[2]))  
            x = tf.concat([skip_connection, x], axis=-1)  
            x = self.ups[idx](x)  
  
        return self.final_conv(x)  
  
# This is a simplified version and must be replaced with your actual data pipeline  
def load_image(file_path, target_size):  
    img = imread(file_path)  
    img = resize(img, target_size)  
    img = np.expand_dims(img, axis=0)  # Add batch dimension  
    return img  
  
def train_fn(train_dataset, model, optimizer, loss_fn, epochs=1):  
    for epoch in range(epochs):  
        print(f"Epoch {epoch+1}/{epochs}")  
        for step, (images, masks) in enumerate(train_dataset):  
            with tf.GradientTape() as tape:  
                predictions = model(images, training=True)  
                loss = loss_fn(masks, predictions)  
            gradients = tape.gradient(loss, model.trainable_variables)  
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))  
  
            if step % 10 == 0:  
                print(f"Step {step}, Loss: {loss.numpy()}")  
  
# This is a very simplified training loop and data loader.  
# You'll want to replace `load_image` with your actual data loading logic.  
  
# Model Configuration and Initialization  
model = UNET()  
optimizer = optimizers.Adam()  
loss_fn = losses.BinaryCrossentropy(from_logits=True)  
  
# Dummy data loader (Replace with your actual dataset)  
# Assuming you have a function that loads your dataset images and masks.  
train_dataset = tf.data.Dataset.from_tensor_slices((np.random.rand(10, 256, 256, 4), np.random.rand(10, 256, 256, 1)))  
  
# Train Model  
train_fn(train_dataset, model, optimizer, loss_fn)  
  
print("Training complete")  

2024-02-23 16:22:58.063766: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-02-23 16:22:58.063784: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 24.00 GB
2024-02-23 16:22:58.063787: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 8.00 GB
2024-02-23 16:22:58.063813: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-02-23 16:22:58.063825: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


Epoch 1/1


ValueError: Exception encountered when calling layer 'sequential' (type Sequential).

Input 0 of layer "conv2d" is incompatible with the layer: expected min_ndim=4, found ndim=3. Full shape received: (256, 256, 4)

Call arguments received by layer 'sequential' (type Sequential):
  • inputs=tf.Tensor(shape=(256, 256, 4), dtype=float32)
  • training=True
  • mask=None