In [1]:
import torch
import torch.nn as nn
import numpy as np
import torchsummary

import tml

class AEBase(nn.Module):
    
    def __init__(self, input_shape=(3,84,84)):
        super().__init__()
        self.input_shape = input_shape
        self.encoder = self._get_encoder()
        self.decoder = self._get_decoder()
        self.latent_shape = tuple(self.encoder(torch.zeros(2,*input_shape)).shape)[1:]
        
    def forward(self, x):
        z = self.encoder(x)
        y = self.decoder(z)
        return y

    def _get_encoder(self):
        raise NotImplementedError()

    def _get_decoder(self):
        raise NotImplementedError()
   
    
class AEAlex(AEBase):
    
    def __init__(self, input_shape=(3,84,84), latent_shape=(2*2*128,), dropout=0.5):
        self.dropout = dropout
        self.latent_shape = latent_shape
        super().__init__(input_shape=input_shape)
       
        
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=1),  
            tml.View(2*2*128, -1),
            nn.Linear(2*2*128, self.latent_shape[0]), 
            nn.Dropout(self.dropout) if self.dropout > 0 else nn.Identity(),
        )
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.Linear(self.latent_shape[0], 2*2*128),
            tml.View(-1, (128,2,2)),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
class AEGDN(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), tml.GDN(16),
            nn.Conv2d(16,32,kernel_size=5,stride=2), tml.GDN(32),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  tml.GDN(64),
            nn.Conv2d(64,128,kernel_size=5,stride=2), tml.GDN(128),
            nn.Conv2d(128,128,kernel_size=5,stride=2),tml.GDN(128),
            nn.Conv2d(128,128,kernel_size=5,stride=1), 
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), tml.GDN(128),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), tml.GDN(128),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2),  tml.GDN(64),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  tml.GDN(32),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2),  tml.GDN(16),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
      
class AEConv(AEBase):
    
    def _get_encoder(self):
        return nn.Sequential(
            nn.Conv2d(3,16,kernel_size=6,stride=1), nn.LeakyReLU(),
            nn.Conv2d(16,32,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(32,64,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.Conv2d(64,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.Conv2d(128,128,kernel_size=5,stride=1), 
        ) 
    
    def _get_decoder(self):
        return nn.Sequential(
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=1), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,128,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(128,64,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(64,32,kernel_size=6,stride=1),  nn.LeakyReLU(),
            nn.ConvTranspose2d(32,16,kernel_size=5,stride=2), nn.LeakyReLU(),
            nn.ConvTranspose2d(16,3,kernel_size=6, stride=1),
        )
    
model = AEConv()
#model = AEAlex()
#model = AEGDN()


torchsummary.summary(model, input_size=(3,84,84), device="cpu")
print(model.latent_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 79, 79]           1,744
         LeakyReLU-2           [-1, 16, 79, 79]               0
            Conv2d-3           [-1, 32, 38, 38]          12,832
         LeakyReLU-4           [-1, 32, 38, 38]               0
            Conv2d-5           [-1, 64, 33, 33]          73,792
         LeakyReLU-6           [-1, 64, 33, 33]               0
            Conv2d-7          [-1, 128, 15, 15]         204,928
         LeakyReLU-8          [-1, 128, 15, 15]               0
            Conv2d-9            [-1, 128, 6, 6]         409,728
        LeakyReLU-10            [-1, 128, 6, 6]               0
           Conv2d-11            [-1, 128, 2, 2]         409,728
  ConvTranspose2d-12            [-1, 128, 6, 6]         409,728
        LeakyReLU-13            [-1, 128, 6, 6]               0
  ConvTranspose2d-14          [-1, 128,

In [21]:
import importlib  
import torch
wob = importlib.import_module("world-of-bugs-experiments")
path = "/home/ben/Documents/repos/world-of-bugs-experiments/dataset/"

dm = wob.data.WOBDataModule(path, batch_size=256, train_mode=["state"], shuffle_buffer_size=10, train_files="NORMAL-TRAIN/ep-0000/*.tar", in_memory=False, force=False, num_workers=1)
dm.prepare_data()
loader = dm.train_dataloader()

x = torch.cat([x[0] for x in loader])

print(x.shape)

import jnu as J
import torchvision

img = torchvision.utils.make_grid(x[::100][:2], nrows=2) * 2

J.image(img, scale=10)


Found 1 training files.
Found 1 validation files.
Found 119 test files.
torch.Size([4960, 3, 84, 84])


NameError: name 'J' is not defined

In [37]:
import importlib  
import torch
wob = importlib.import_module("world-of-bugs-experiments")
path = "/home/ben/Documents/repos/world-of-bugs-experiments/dataset/"

dm = wob.data.WOBDataModule(path, batch_size=256, train_mode=["state"], shuffle_buffer_size=10, train_files="TEST/TextureMissing/ep-0000/*.tar", in_memory=False, force=False, num_workers=1)
dm.prepare_data()
loader = dm.train_dataloader()

x = torch.cat([x[0] for x in loader])




Found 1 training files.
Found 1 validation files.
Found 119 test files.
torch.Size([4960, 3, 84, 84])


HBox(children=(Canvas(height=880, width=1740),), layout=Layout(align_items='center', display='flex', flex_flow…

<jnu.image._image.Image at 0x7f216462fb80>

In [50]:
print(x.shape)

import jnu as J
import torchvision



img = torch.clip(torchvision.utils.make_grid(x[500::200][:2], nrows=2) * 2,0,1)
print(img.shape)

J.image(img, scale=10)

torch.Size([4960, 3, 84, 84])
torch.Size([3, 88, 174])


HBox(children=(Canvas(height=880, width=1740),), layout=Layout(align_items='center', display='flex', flex_flow…

<jnu.image._image.Image at 0x7f2167f2a760>