In [1]:
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
import os
import math

In [73]:

class TFToTorchDataset(Dataset):
    def __init__(self, tf_dataset):
        self.tf_dataset = tf_dataset
        
    def __len__(self):
        return len(self.tf_dataset)

    def __getitem__(self, idx):
        # Convert the index to a TensorFlow slice to efficiently get a single element
        for feature, label in self.tf_dataset.skip(idx).take(1):
            feature = torch.tensor(feature.numpy())
            label = torch.tensor(label.numpy())
            return feature, label
        raise IndexError
    
class TorchToTFDataset:
    def __init__(self, dataloader):
        self.dataloader = dataloader

    def __getitem__(self, idx):
        feature, label = self.dataloader[idx]
        return feature.numpy(), label.numpy()

    def __len__(self):
        return len(self.dataloader)
            
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

In [59]:
# Assuming you have a large dataset
features = tf.random.uniform((5, 16), minval=0, maxval=1)
labels = tf.random.uniform((5,), minval=0, maxval=2, dtype=tf.int32)

In [60]:
# Create TensorFlow dataset
tf_data = tf.data.Dataset.from_tensor_slices((features, labels))

In [61]:
# Print information about the cached dataset and the number of data points
print(f'info data: {tf_data}')
print(f'number of data: {len(tf_data)}')

info data: <_TensorSliceDataset element_spec=(TensorSpec(shape=(16,), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None))>
number of data: 5


In [69]:
for feature, label in tf_data:
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[0.3716613  0.23343992 0.6551763  0.956612   0.88345003 0.81285775
 0.16413713 0.69958603 0.78153515 0.27435637 0.27543426 0.70119333
 0.83819747 0.9796101  0.05149639 0.810192  ], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.92505026 0.12529635 0.22969902 0.7548735  0.58376205 0.83587503
 0.09775198 0.5516207  0.48446012 0.8927151  0.15286112 0.46068025
 0.15758538 0.22550094 0.6657419  0.37408233], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.13539279 0.7585821  0.91278887 0.42594552 0.02400279 0.6480125
 0.00300384 0.7571558  0.74235225 0.6349418  0.12970054 0.42855644
 0.05962956 0.38240898 0.79575336 0.9580705 ], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.3055581  0.99314487 0.2145257  0.632756   0.88665485 0.92397606
 0.57269967 0.15921569 0.1833911  0.28252864 0.9608743  0.7970643
 0.40016353 0.622074   0.93912005 0.11110795], shape=(16,), dtype=float32)
tf.Tensor(1, s

In [63]:
# Convert to PyTorch Dataset
torch_data = TFToTorchDataset(tf_data)

In [64]:
# Print information about the cached dataset and the number of data points
print(f'info data: {torch_data}')
print(f'number of data: {len(torch_data)}')

info data: <__main__.TFToTorchDataset object at 0x000002124EE52520>
number of data: 5


In [67]:
for feature, label in torch_data:
    print(feature)
    print(label)
    print('='.center(50, '='))

tensor([0.3717, 0.2334, 0.6552, 0.9566, 0.8835, 0.8129, 0.1641, 0.6996, 0.7815,
        0.2744, 0.2754, 0.7012, 0.8382, 0.9796, 0.0515, 0.8102])
tensor(1, dtype=torch.int32)
tensor([0.9251, 0.1253, 0.2297, 0.7549, 0.5838, 0.8359, 0.0978, 0.5516, 0.4845,
        0.8927, 0.1529, 0.4607, 0.1576, 0.2255, 0.6657, 0.3741])
tensor(1, dtype=torch.int32)
tensor([0.1354, 0.7586, 0.9128, 0.4259, 0.0240, 0.6480, 0.0030, 0.7572, 0.7424,
        0.6349, 0.1297, 0.4286, 0.0596, 0.3824, 0.7958, 0.9581])
tensor(1, dtype=torch.int32)
tensor([0.3056, 0.9931, 0.2145, 0.6328, 0.8867, 0.9240, 0.5727, 0.1592, 0.1834,
        0.2825, 0.9609, 0.7971, 0.4002, 0.6221, 0.9391, 0.1111])
tensor(1, dtype=torch.int32)
tensor([0.6788, 0.3782, 0.2245, 0.7588, 0.6238, 0.0346, 0.5016, 0.9449, 0.7905,
        0.6445, 0.0668, 0.5278, 0.3445, 0.2238, 0.1198, 0.4317])
tensor(0, dtype=torch.int32)


In [77]:
# Wrap the PyTorch DataLoader in the wrapper class
tf_data_generator = TorchToTFDataset(torch_data)

In [88]:
# Define output signature dynamically using a sample batch
feature, label = torch_data[0]
feature_shape, feature_dtype = feature.shape, feature.dtype
label_shape, label_dtype = label.shape, label.dtype

feature_shape, label_shape, feature_dtype, label_dtype

(torch.Size([16]), torch.Size([]), torch.float32, torch.int32)

In [89]:
# Create a TensorFlow dataset from the generator
output_signature = (
    tf.TensorSpec(shape=(feature_shape), dtype=tf.float32),
    tf.TensorSpec(shape=(label_shape), dtype=tf.int32)
)

tf_data_restore = tf.data.Dataset.from_generator(
    tf_data_generator,
    output_signature=output_signature
)

In [95]:
# Print information about the cached dataset and the number of data points
print(f'info data: {tf_data_restore}')
# print(f'number of data: {len(tf_data_restore)}')

info data: <_FlatMapDataset element_spec=(TensorSpec(shape=(16,), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None))>


In [91]:
for feature, label in tf_data_restore:
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[0.3716613  0.23343992 0.6551763  0.956612   0.88345003 0.81285775
 0.16413713 0.69958603 0.78153515 0.27435637 0.27543426 0.70119333
 0.83819747 0.9796101  0.05149639 0.810192  ], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.92505026 0.12529635 0.22969902 0.7548735  0.58376205 0.83587503
 0.09775198 0.5516207  0.48446012 0.8927151  0.15286112 0.46068025
 0.15758538 0.22550094 0.6657419  0.37408233], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.13539279 0.7585821  0.91278887 0.42594552 0.02400279 0.6480125
 0.00300384 0.7571558  0.74235225 0.6349418  0.12970054 0.42855644
 0.05962956 0.38240898 0.79575336 0.9580705 ], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.3055581  0.99314487 0.2145257  0.632756   0.88665485 0.92397606
 0.57269967 0.15921569 0.1833911  0.28252864 0.9608743  0.7970643
 0.40016353 0.622074   0.93912005 0.11110795], shape=(16,), dtype=float32)
tf.Tensor(1, s

In [97]:
tf_data_batced = tf_data_restore.batch(2)

In [98]:
for feature, label in tf_data_batced:
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[[0.3716613  0.23343992 0.6551763  0.956612   0.88345003 0.81285775
  0.16413713 0.69958603 0.78153515 0.27435637 0.27543426 0.70119333
  0.83819747 0.9796101  0.05149639 0.810192  ]
 [0.92505026 0.12529635 0.22969902 0.7548735  0.58376205 0.83587503
  0.09775198 0.5516207  0.48446012 0.8927151  0.15286112 0.46068025
  0.15758538 0.22550094 0.6657419  0.37408233]], shape=(2, 16), dtype=float32)
tf.Tensor([1 1], shape=(2,), dtype=int32)
tf.Tensor(
[[0.13539279 0.7585821  0.91278887 0.42594552 0.02400279 0.6480125
  0.00300384 0.7571558  0.74235225 0.6349418  0.12970054 0.42855644
  0.05962956 0.38240898 0.79575336 0.9580705 ]
 [0.3055581  0.99314487 0.2145257  0.632756   0.88665485 0.92397606
  0.57269967 0.15921569 0.1833911  0.28252864 0.9608743  0.7970643
  0.40016353 0.622074   0.93912005 0.11110795]], shape=(2, 16), dtype=float32)
tf.Tensor([1 1], shape=(2,), dtype=int32)
tf.Tensor(
[[0.67880285 0.37823284 0.22450113 0.75876796 0.62375057 0.03460956
  0.50160885 0.944927