In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import time

# Check GPU availability
!nvidia-smi

# Create a simple neural network model
model = models.Sequential([
    layers.Flatten(input_shape=(28, 28)),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Load a dummy dataset (MNIST)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Configure TensorFlow to allocate GPU memory on an as-needed basis
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Set memory growth for each GPU
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)

        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

# Training loop with GPU VRAM tracking
num_epochs = 5
check_interval = 1
total_vram_usage = 0  # To accumulate total VRAM usage
num_checks = 0  # To count the number of checks

for epoch in range(num_epochs):
    # Train the model
    start_time = time.time()

    # Track VRAM usage during training
    with tf.device('/GPU:0'):
        model.fit(x_train, y_train, epochs=1, verbose=2)

    end_time = time.time()

    # Check GPU usage and extract VRAM usage
    gpu_info = !nvidia-smi
    vram_usage_line = next((line for line in gpu_info if 'MiB /' in line), None)

    if vram_usage_line:
        # Extract numerical parts and convert to integers
        vram_used_str = vram_usage_line.split()[8]
        vram_used = int(vram_used_str[:-3])  # Remove 'MiB' and convert to int

        # Accumulate VRAM usage
        total_vram_usage += vram_used
        num_checks += 1

        # Print information
        print(f"Epoch {epoch + 1}/{num_epochs} - VRAM Usage: {vram_used} MiB")

        # Optional: Add a delay to avoid overwhelming the GPU monitoring tool
        time.sleep(5)

# Calculate and print average VRAM usage
average_vram_usage = total_vram_usage / num_checks
print(f"Average VRAM Usage: {average_vram_usage} MiB")


Sun Jan 28 10:35:20 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   52C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    