In [1]:
import torch


In [2]:
tensor = torch.sigmoid(torch.rand(15,15))

In [3]:
from datasets import load_dataset, Image

In [4]:
dataset = load_dataset("beans", split="train")

In [5]:
type(dataset[0]['image'])

PIL.JpegImagePlugin.JpegImageFile

In [6]:
from torchvision import transforms

In [7]:
images = dataset[:100]['image']

trans = transforms.ToTensor()

tensor_list = []

for image in images:
    tensor_image = trans(image).unsqueeze(0)
    tensor_list.append(tensor_image)

tensor = torch.cat(tensor_list, dim=0)

print(tensor.shape)

torch.Size([100, 3, 500, 500])


In [8]:
import numpy as np

In [9]:
images = np.array(load_dataset("beans", split="train")['image'])
images.shape

(1034, 500, 500, 3)

In [10]:
from models import modules

In [11]:
import bvae

In [12]:
encoder = modules.Encoder(ch = 128, 
                         out_ch = 3, 
                         num_res_blocks = 2,
                         attn_resolutions = [64,16], 
                         ch_mult = (1,1,2,2,4),
                         in_channels = 3,
                         resolution = 128, 
                         z_channels = 256,
                         double_z = False)

In [13]:
images = dataset[:100]['image']

batch = images[:3]

trans = transforms.ToTensor()

tensor_list = []

for image in batch:
    img = image.copy()
    img = img.resize((128, 128))
    tensor_image = trans(img).unsqueeze(0)
    tensor_list.append(tensor_image)

tensor = torch.cat(tensor_list, dim=0)

In [14]:
print(tensor.shape)

torch.Size([3, 3, 128, 128])


In [15]:
h = encoder.forward(tensor)

print(h.shape)

torch.Size([3, 256, 8, 8])


In [16]:
decoder = modules.Decoder(ch = 128, 
                         out_ch = 3, 
                         num_res_blocks = 2,
                         attn_resolutions = [16], 
                         ch_mult = (1,1,2,2,4),
                         in_channels = 3,
                         resolution = 128, 
                         z_channels = 256)

Working with z of shape (1, 256, 8, 8) = 16384 dimensions.


In [17]:
x = decoder.forward(h)

print(x.shape)

torch.Size([3, 3, 128, 128])


In [18]:
sigma_h = torch.sigmoid(h)
print(sigma_h.shape)
print(sigma_h[0,0])

torch.Size([3, 256, 8, 8])
tensor([[0.4642, 0.5652, 0.4632, 0.3883, 0.5763, 0.6027, 0.5385, 0.5997],
        [0.4801, 0.4085, 0.5135, 0.4570, 0.5150, 0.4254, 0.5556, 0.5512],
        [0.4872, 0.4930, 0.6449, 0.5065, 0.5885, 0.6396, 0.6525, 0.5020],
        [0.4595, 0.5613, 0.5960, 0.4315, 0.5173, 0.5507, 0.5855, 0.5291],
        [0.4381, 0.4467, 0.5577, 0.5134, 0.4582, 0.6124, 0.4951, 0.5992],
        [0.4206, 0.4470, 0.4368, 0.5505, 0.5270, 0.5357, 0.5210, 0.5448],
        [0.4866, 0.4921, 0.5079, 0.5000, 0.4594, 0.5143, 0.5519, 0.5333],
        [0.4826, 0.4969, 0.5068, 0.4753, 0.4501, 0.4871, 0.4826, 0.5082]],
       grad_fn=<SelectBackward0>)


In [19]:
binary = torch.bernoulli(sigma_h)
print(binary.shape)
print(binary[0,0])

torch.Size([3, 256, 8, 8])
tensor([[1., 1., 0., 0., 0., 1., 0., 1.],
        [1., 0., 1., 0., 1., 0., 0., 1.],
        [1., 1., 1., 1., 0., 1., 1., 1.],
        [1., 1., 0., 0., 0., 1., 1., 0.],
        [1., 0., 0., 1., 0., 1., 0., 1.],
        [1., 1., 0., 1., 0., 1., 1., 1.],
        [0., 1., 1., 0., 0., 0., 0., 1.],
        [0., 0., 1., 1., 0., 1., 0., 1.]], grad_fn=<SelectBackward0>)


In [20]:
aux_binary = binary.detach() + sigma_h - sigma_h.detach()
print(aux_binary.shape)
print(aux_binary[0,0])

torch.Size([3, 256, 8, 8])
tensor([[1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000],
        [1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 0.0000, 1.0000],
        [1.0000, 1.0000, 1.0000, 1.0000, 0.0000, 1.0000, 1.0000, 1.0000],
        [1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 1.0000, 1.0000, 0.0000],
        [1.0000, 0.0000, 0.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000],
        [1.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000, 1.0000, 1.0000],
        [0.0000, 1.0000, 1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000],
        [0.0000, 0.0000, 1.0000, 1.0000, 0.0000, 1.0000, 0.0000, 1.0000]],
       grad_fn=<SelectBackward0>)


In [21]:
x = decoder.forward(aux_binary)

print(x.shape)

torch.Size([3, 3, 128, 128])


In [22]:
quantizer = modules.BinaryQuantizer()

In [24]:
binary = quantizer.forward(h)

In [25]:
# Check if a GPU is available
if torch.cuda.is_available():
    # Set the device to the first available GPU
    device = torch.device("cuda")
    print("GPU is available.")
else:
    # If no GPU is available, use the CPU
    device = torch.device("cpu")
    print("GPU is not available, using CPU.")

GPU is available.


In [26]:
model = bvae.BVAEModel(device)

Working with z of shape (1, 256, 8, 8) = 16384 dimensions.




In [1]:
import bvae
bvae.main()

GPU is available.
Working with z of shape (1, 32, 16, 16) = 8192 dimensions.




Epoch  0
Epoch  1
Epoch  2
Epoch  3
Epoch  4
Epoch  5
Epoch  6
Epoch  7
Epoch  8
Epoch  9
