In [55]:
import numpy as np
from torch.utils.data import Dataset, DataLoader
import qrcode
import Frenel_Holo
import matplotlib.pyplot as plt

In [158]:
class QRDataset(Dataset):
    def __init__(self, n_samples, version = 3, correction = qrcode.constants.ERROR_CORRECT_H, N = 64, ps = 5.5*1e-6, wavelength = 532e-9):
        self.n_samples = n_samples
        self.version = version
        self.correction = correction
        self.N = N
        self.z_0 = N * ps * ps / wavelength
    def __getitem__(self, index):
        np.random.seed(index)
        data = np.random.randint(low = 0, high = 9, size = 10)
        qr = qrcode.QRCode(version = self.version, error_correction = self.correction, box_size = 1, border = 4)
        qr.add_data(data)
        qr.make(fit = True)
        qr = np.array(qr.make_image(fill_color = 'black', back_color = 'white'))*255
        phase_mask = np.random.uniform(0, 2*np.pi, qr.shape)
        obj_field = np.zeros((self.N, self.N), dtype = complex)
        obj_field[0 : qr.shape[0], 0 : qr.shape[1]] = np.sqrt(qr) * np.exp(1j * phase_mask)
        scale  = np.random.choice([1,2,3,4,5,6])
        y, x, y_fren = Frenel_Holo.FullReconstruct(obj_field, scale * self.z_0, plotting = False)

        return y[None], x[None], y_fren[None]
    
    def __len__(self):
        return self.n_samples
    

In [160]:
qrdataset = QRDataset(10)
y, x, y_fren = qrdataset[0]


In [161]:
train_loader = DataLoader(qrdataset, batch_size=10)
for y_batch, x_batch, _ in train_loader:


torch.Size([10, 1, 64, 64]) tensor([[[[ 21.,  27.,   6.,  ...,  36.,  50.,   2.],
          [  6.,  36.,   3.,  ...,  61.,  11.,   1.],
          [ 47.,  10.,  66.,  ...,  18.,   2.,  19.],
          ...,
          [ 21.,  34.,  57.,  ...,  64.,  57.,   4.],
          [  2.,  36.,  37.,  ...,  12.,   6.,   3.],
          [  6.,  41.,  60.,  ...,  39.,  62.,  12.]]],


        [[[ 13.,  23.,  26.,  ...,   4.,  28.,  31.],
          [ 30.,  41.,  65.,  ...,  52.,   5.,  30.],
          [ 62.,  27.,  58.,  ...,   1.,  92.,  21.],
          ...,
          [ 30.,  56.,  36.,  ...,   3.,  51., 109.],
          [ 67., 130.,   5.,  ...,  79.,  91.,  30.],
          [100.,  40.,  64.,  ...,  35.,   3.,  63.]]],


        [[[  2.,  50.,   2.,  ...,  24.,  11.,  27.],
          [  1.,   3.,  53.,  ...,  25.,  26.,  11.],
          [ 96.,  43.,  22.,  ...,   9.,   1.,  29.],
          ...,
          [  7.,  34.,  28.,  ...,  86.,  33.,  14.],
          [ 35.,  12.,  21.,  ...,  12.,  23.,  40.],
 