In [1]:
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from torchinfo import summary

from model.unet import UnetModel


In [2]:
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("Number of GPUs:", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}:")
        print("  Name:", torch.cuda.get_device_name(i))
        print("  Memory Allocated:", round(torch.cuda.memory_allocated(i) / 1024**2, 2), "MB")
        print("  Memory Cached:", round(torch.cuda.memory_reserved(i) / 1024**2, 2), "MB")
        print("  Capability:", torch.cuda.get_device_capability(i))
        print("  Current Device:", torch.cuda.current_device())
else:
    print("No GPU available.")

PyTorch version: 2.7.0+cu128
CUDA available: True
Number of GPUs: 1

GPU 0:
  Name: NVIDIA GeForce RTX 5090
  Memory Allocated: 0.0 MB
  Memory Cached: 0.0 MB
  Capability: (12, 0)
  Current Device: 0


In [3]:
params = {
    'activation': 'relu',
    'optimizer': 'Adam',
    'epochs': 100,
    'nF': 6,
    'learningRate': 5e-4,
    'batch': 32,
    'xX': 101,
    'yY': 101,
    'decayRate': 0.3,
    'patience': 20,
    'scaleFL': 10e4,
    'scaleOP0': 10,
    'scaleOP1': 1,
    'scaleDF': 1,
    'scaleQF': 1,
    'scaleRE': 1,
    'nFilters3D': 128,
    'kernelConv3D': [3,3,3],
    'strideConv3D': [1,1,1],
    'nFilters2D': 128,
    'kernelConv2D': [3,3],
    'strideConv2D': [1,1],
    'data_path': 'data/'
}

model = UnetModel(params)
summary(model, 
        input_size=[(20,2,101,101), (20, 1, 6, 101, 101)])


Layer (type:depth-idx)                   Output Shape              Param #
UnetModel                                [20, 1, 101, 101]         --
├─Sequential: 1-1                        [20, 64, 101, 101]        --
│    └─Conv2d: 2-1                       [20, 64, 101, 101]        1,216
│    └─ReLU: 2-2                         [20, 64, 101, 101]        --
│    └─Dropout: 2-3                      [20, 64, 101, 101]        --
│    └─Conv2d: 2-4                       [20, 64, 101, 101]        36,928
│    └─ReLU: 2-5                         [20, 64, 101, 101]        --
│    └─Dropout: 2-6                      [20, 64, 101, 101]        --
│    └─Conv2d: 2-7                       [20, 64, 101, 101]        36,928
│    └─ReLU: 2-8                         [20, 64, 101, 101]        --
│    └─Dropout: 2-9                      [20, 64, 101, 101]        --
├─Sequential: 1-2                        [20, 64, 6, 101, 101]     --
│    └─Conv3d: 2-10                      [20, 64, 6, 101, 101]     1,792
│