In [1]:
from torch.utils.data import Dataset, DataLoader
from torchmetrics import Accuracy

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import torch
import os
import math

In [24]:
class TFToTorchDataset(Dataset):
    def __init__(self, tf_dataset):
        self.tf_dataset = tf_dataset
        
    def __len__(self):
        return len(self.tf_dataset)

    def __getitem__(self, idx):
        # Convert the index to a TensorFlow slice to efficiently get a single element
        for feature, label in self.tf_dataset.skip(idx).take(1):
            feature = torch.tensor(feature.numpy())
            label = torch.tensor(label.numpy())
            return feature, label
        raise IndexError
    
class TorchToTFDataset:
    def __init__(self, dataloader):
        self.dataloader = dataloader

    def __getitem__(self, idx):
        feature, label = self.dataloader[idx]
        return feature.numpy(), label.numpy()

    def __len__(self):
        return len(self.dataloader)
            
    def __call__(self):
        for i in range(self.__len__()):
            yield self.__getitem__(i)

In [25]:
# Assuming you have a large dataset
features = tf.random.uniform((5, 16), minval=0, maxval=1)
labels = tf.random.uniform((5,), minval=0, maxval=2, dtype=tf.int32)

In [26]:
# Create TensorFlow dataset
tf_data = tf.data.Dataset.from_tensor_slices((features, labels))

In [27]:
# Print information about the cached dataset and the number of data points
print(f'info data: {tf_data}')
print(f'number of data: {len(tf_data)}')

info data: <_TensorSliceDataset element_spec=(TensorSpec(shape=(16,), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None))>
number of data: 5


In [28]:
for feature, label in tf_data.take(1):
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[0.7101072  0.5463505  0.760671   0.13194537 0.89819336 0.6222327
 0.00361001 0.34648955 0.9688585  0.7658889  0.22182727 0.5738659
 0.11347568 0.3120637  0.5544977  0.29776478], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)


In [29]:
# Convert to PyTorch Dataset
torch_data = TFToTorchDataset(tf_data)

In [30]:
# Print information about the cached dataset and the number of data points
print(f'info data: {torch_data}')
print(f'number of data: {len(torch_data)}')

info data: <__main__.TFToTorchDataset object at 0x00000274569664F0>
number of data: 5


In [31]:
feature, label = torch_data[0]
print(feature)
print(label)
print('='.center(50, '='))

# for feature, label in torch_data:
#     print(feature)
#     print(label)
#     print('='.center(50, '='))

tensor([0.7101, 0.5464, 0.7607, 0.1319, 0.8982, 0.6222, 0.0036, 0.3465, 0.9689,
        0.7659, 0.2218, 0.5739, 0.1135, 0.3121, 0.5545, 0.2978])
tensor(1, dtype=torch.int32)


In [32]:
# Wrap the PyTorch DataLoader in the wrapper class
tf_data_generator = TorchToTFDataset(torch_data)

In [33]:
# Define output signature dynamically using a sample batch
feature, label = torch_data[0]
feature_shape, feature_dtype = feature.shape, feature.dtype
label_shape, label_dtype = label.shape, label.dtype

feature_shape, label_shape, feature_dtype, label_dtype

(torch.Size([16]), torch.Size([]), torch.float32, torch.int32)

In [34]:
# Create a TensorFlow dataset from the generator
output_signature = (
    tf.TensorSpec(shape=(feature_shape), dtype=tf.float32),
    tf.TensorSpec(shape=(label_shape), dtype=tf.int32)
)

tf_data_restore = tf.data.Dataset.from_generator(
    tf_data_generator,
    output_signature=output_signature
)

In [44]:
tf_data_restore.__class__ = type(tf_data_restore.__class__.__name__, (tf_data_restore.__class__,), {'__len__': lambda self: len(torch_data)})

In [42]:
# Print information about the cached dataset and the number of data points
print(f'info data: {tf_data_restore}')
print(f'number of data: {len(tf_data_restore)}')

info data: <_FlatMapDataset element_spec=(TensorSpec(shape=(16,), dtype=tf.float32, name=None), TensorSpec(shape=(), dtype=tf.int32, name=None))>
number of data: 5


In [37]:
for feature, label in tf_data_restore:
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[0.7101072  0.5463505  0.760671   0.13194537 0.89819336 0.6222327
 0.00361001 0.34648955 0.9688585  0.7658889  0.22182727 0.5738659
 0.11347568 0.3120637  0.5544977  0.29776478], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.37180305 0.6986232  0.62538207 0.740319   0.43544388 0.91177964
 0.797052   0.6493373  0.01715887 0.7396107  0.63975847 0.59314907
 0.8558531  0.04375792 0.2940253  0.10130537], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.20648324 0.5060085  0.23901904 0.16011941 0.65756    0.04323947
 0.24448466 0.01703465 0.8452778  0.09872723 0.28330922 0.5499612
 0.33233202 0.5842719  0.44610524 0.12213457], shape=(16,), dtype=float32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(
[0.5216601  0.0522666  0.11583102 0.6274766  0.61375666 0.05369842
 0.8095026  0.11971641 0.40681577 0.48861003 0.6336185  0.12428892
 0.7975936  0.84827363 0.20046985 0.84192336], shape=(16,), dtype=float32)
tf.Tensor(0, sh

In [38]:
tf_data_batced = tf_data_restore.batch(2)

In [39]:
for feature, label in tf_data_batced:
    print(feature)
    print(label)
    print('='.center(50, '='))

tf.Tensor(
[[0.7101072  0.5463505  0.760671   0.13194537 0.89819336 0.6222327
  0.00361001 0.34648955 0.9688585  0.7658889  0.22182727 0.5738659
  0.11347568 0.3120637  0.5544977  0.29776478]
 [0.37180305 0.6986232  0.62538207 0.740319   0.43544388 0.91177964
  0.797052   0.6493373  0.01715887 0.7396107  0.63975847 0.59314907
  0.8558531  0.04375792 0.2940253  0.10130537]], shape=(2, 16), dtype=float32)
tf.Tensor([1 1], shape=(2,), dtype=int32)
tf.Tensor(
[[0.20648324 0.5060085  0.23901904 0.16011941 0.65756    0.04323947
  0.24448466 0.01703465 0.8452778  0.09872723 0.28330922 0.5499612
  0.33233202 0.5842719  0.44610524 0.12213457]
 [0.5216601  0.0522666  0.11583102 0.6274766  0.61375666 0.05369842
  0.8095026  0.11971641 0.40681577 0.48861003 0.6336185  0.12428892
  0.7975936  0.84827363 0.20046985 0.84192336]], shape=(2, 16), dtype=float32)
tf.Tensor([1 0], shape=(2,), dtype=int32)
tf.Tensor(
[[0.01708877 0.01407421 0.7729361  0.34017742 0.97231257 0.66388774
  0.8256575  0.2713456