In [1]:
import numpy as np
import pandas as pd
from typing import Union, Literal

import keras
import tensorflow
import tensorflow as tf

import torch
from torch import nn
from torch.optim import Adam

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [3]:
!nvidia-smi

Fri Feb  6 10:14:48 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 581.15                 Driver Version: 581.15         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4070 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   39C    P8              1W /   75W |       0MiB /   8188MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [4]:
class DemoDataset:
    def __init__(self, model_type: Literal['tf', 'torch'] = 'tf'):
        self.model_type = model_type

        if self.model_type == 'tf':
            x = np.random.rand(2000, 3)
            y = x.sum(axis=1)
            self.data = tf.data.Dataset.from_tensor_slices((x, y)).batch(200)

        elif self.model_type == 'torch':
            x = torch.randn(2000, 3)
            y = x.sum(axis=1)
            dataset = torch.utils.data.TensorDataset(x, y)
            self.data = torch.utils.data.DataLoader(dataset, batch_size=200)

        else:
            raise ValueError("model_type must be 'tf' or 'torch'")

### Keras

In [5]:
dt = DemoDataset()
demo_data = dt.data

In [6]:
demo_data.take(1)

<_TakeDataset element_spec=(TensorSpec(shape=(None, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None,), dtype=tf.float64, name=None))>

In [7]:
# next(iter(dt.data))

In [8]:
@keras.saving.register_keras_serializable()
class DemoModelKeras(tf.keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.layer_1 = tf.keras.layers.Dense(200, activation='relu')
        self.layer_2 = tf.keras.layers.Dense(100, activation='relu')
        self.layer_3 = tf.keras.layers.Dense(50, activation='relu')
        self.out_layer = tf.keras.layers.Dense(1)

    def call(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        return self.out_layer(x)    

In [9]:
tf_model = DemoModelKeras()

In [10]:
tf_model.compile(optimizer='adam', loss='mse', metrics=['mae'])

In [11]:
with tf.device('/GPU:0'):
    tf_model.fit(demo_data, epochs=10)

Epoch 1/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - loss: 1.8187 - mae: 1.2617
Epoch 2/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - loss: 0.3783 - mae: 0.4889 
Epoch 3/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 0.1315 - mae: 0.3330
Epoch 4/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0355 - mae: 0.1517 
Epoch 5/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0288 - mae: 0.1340 
Epoch 6/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0170 - mae: 0.1085 
Epoch 7/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - loss: 0.0112 - mae: 0.0869
Epoch 8/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss: 0.0078 - mae: 0.0725  
Epoch 9/10
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - loss

In [12]:
### save model
tf_model.save("./models/tf_model.keras", )

In [13]:
tf_loaded = tf.keras.models.load_model("./models/tf_model.keras")




In [14]:
dt = np.random.rand(1, 3)
p = tf_loaded(dt)

print(dt, '\n', p.numpy().item())

[[0.77848137 0.62657433 0.91210867]] 
 2.259352207183838


### Torch

In [15]:
dt = DemoDataset(model_type='torch')
torch_data = dt.data

In [16]:
class DemoModelTorch(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(3, 100, device=device)
        self.layer_2 = nn.Linear(100, 50, device=device)
        self.out_layer = nn.Linear(50, 1, device=device)

    def forward(self, x):
        x = torch.relu(self.layer_1(x))
        x = torch.relu(self.layer_2(x))
        return self.out_layer(x)

In [17]:
torch_model = DemoModelTorch()

In [18]:
torch_model

DemoModelTorch(
  (layer_1): Linear(in_features=3, out_features=100, bias=True)
  (layer_2): Linear(in_features=100, out_features=50, bias=True)
  (out_layer): Linear(in_features=50, out_features=1, bias=True)
)

In [19]:
optimizer = Adam(torch_model.parameters(), lr=1e-3)

In [20]:
criterion = nn.MSELoss()

In [21]:
for epoch in range(10):
    for x, y in torch_data:
        x =  x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        preds = torch_model(x)
        loss = criterion(preds.squeeze(), y)
        loss.backward()
        optimizer.step()

    print(f"Epoch:{epoch}, loss: {loss.item():.3f}")

Epoch:0, loss: 2.202
Epoch:1, loss: 1.385
Epoch:2, loss: 0.626
Epoch:3, loss: 0.134
Epoch:4, loss: 0.055
Epoch:5, loss: 0.049
Epoch:6, loss: 0.014
Epoch:7, loss: 0.010
Epoch:8, loss: 0.006
Epoch:9, loss: 0.004


In [22]:
# save model
torch.save(torch_model, "./models/torch_model.pth")

In [23]:
# load torch model
loaded_torch = torch.load("./models/torch_model.pth", weights_only=False)

In [24]:
loaded_torch.eval()

DemoModelTorch(
  (layer_1): Linear(in_features=3, out_features=100, bias=True)
  (layer_2): Linear(in_features=100, out_features=50, bias=True)
  (out_layer): Linear(in_features=50, out_features=1, bias=True)
)

In [25]:
dt = torch.rand((1, 3), device=device)

with torch.no_grad():
    p = loaded_torch(dt)

    print(dt, '\n', p)

tensor([[0.7060, 0.8543, 0.3414]], device='cuda:0') 
 tensor([[1.8762]], device='cuda:0')
