In [3]:
import torch
from torch import nn
import numpy as np
torch.cuda.empty_cache()
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device='cpu'

import gc
#del y,prt
gc.collect()

torch.no_grad()

<torch.autograd.grad_mode.no_grad at 0x1be0145da60>

In [54]:
class s_view(nn.Module):
    def forward(self,x):
        if len(x.shape)==4:
            self.i_shape=x.shape
            out=x.view(x.shape[0],-1)
        elif len(x.shape)==2:
            out=x.view(self.i_shape)
        return out

class s_conv(nn.Module):
    def __init__(self,repr_size_in,repr_size_out):
        super(s_conv, self).__init__()
        self.Conv=nn.Conv2d(repr_size_in,repr_size_out,kernel_size=3,stride=2,padding=1)
        self.act=nn.ReLU()
    def forward(self,x):
        return self.act(self.Conv(x))
    
class s_deconv(nn.Module):
    def __init__(self,repr_size_in,repr_size_out):
        super(s_deconv, self).__init__()
        self.Conv=nn.ConvTranspose2d(repr_size_in,repr_size_out,kernel_size=2,stride=2)
        self.act=nn.ReLU()
    def forward(self,x):
        return self.act(self.Conv(x))

In [55]:
class b_encoder_conv(nn.Module):
    def __init__(self,image_channels=3,repr_sizes=[32,64,128,256]):
        super(b_encoder_conv, self).__init__()
        self.repr_sizes=[3]+repr_sizes
        
        self.im_layers=nn.ModuleList(
            [
                s_conv(repr_in,repr_out)
                for repr_in,repr_out in zip(
                    self.repr_sizes[:-1],
                    self.repr_sizes[1:]
                )
            ]
        )
    def forward(self,x):
        for l in self.im_layers:
            x=l(x)
        return x
    
class b_decoder_conv(nn.Module):
    def __init__(self,image_channels=3,repr_sizes=[32,64,128,256]):
        super(b_decoder_conv,self).__init__()
        self.repr_sizes=[3]+repr_sizes
        self.repr_sizes.reverse()
        
        self.im_layers=nn.ModuleList(
            [
                s_deconv(repr_in,repr_out)
                for repr_in,repr_out in zip(
                    self.repr_sizes[:-1],
                    self.repr_sizes[1:]
                )
            ]
        )
    def forward(self,x):
        for l in self.im_layers:
            x=l(x)
        return x
    
class NeuralNet(nn.Module):
    def __init__(self,input_size,output_size,layer_sizes=[300,150,50]):
        super(NeuralNet,self).__init__()
        self.layer_sizes=[input_size]+layer_sizes+[output_size]
        self.layers=nn.ModuleList(
            [
                nn.Sequential(nn.Linear(in_size,out_size),nn.ReLU())
                for in_size,out_size in zip(
                    self.layer_sizes[:-1],
                    self.layer_sizes[1:],
                )
            ]
        )
    def forward(self,x):
        for l in self.layers:
            x=l(x)
        return x

In [16]:
NNet=NeuralNet(500,20)

In [17]:
NNet

NeuralNet(
  (layers): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=500, out_features=300, bias=True)
      (1): ReLU()
    )
    (1): Sequential(
      (0): Linear(in_features=300, out_features=150, bias=True)
      (1): ReLU()
    )
    (2): Sequential(
      (0): Linear(in_features=150, out_features=50, bias=True)
      (1): ReLU()
    )
    (3): Sequential(
      (0): Linear(in_features=50, out_features=20, bias=True)
      (1): ReLU()
    )
  )
)

In [65]:
class b_encodeco(nn.Module):
    def __init__(self,
                 image_dim=4096,
                 image_channels=3,
                 repr_sizes=[32,64,128,256],
                 layer_sizes=[300,150,50],
                 latent_space_size=20
                ):
        super(b_encodeco,self).__init__()
        self.layer_sizes=layer_sizes
        self.NN_input=int(image_dim/(2**(len(repr_sizes))))
        self.latent_space_size=latent_space_size
        
        self.encoder_conv=b_encoder_conv(image_channels=image_channels,repr_sizes=repr_sizes)
        
        self.encoder_NN=NeuralNet(self.NN_input,self.latent_space_size,layer_sizes=self.layer_sizes)
        
        self.flatten=s_view()
        
        self.decoder_NN=NeuralNet(self.latent_space_size,self.NN_input,layer_sizes=self.layer_sizes[::-1])
        
        self.decoder_conv=b_decoder_conv(image_channels=image_channels,repr_sizes=repr_sizes)
        
    
    def forward(self,x):
        x=self.encoder_conv(x)
        print(x.shape)
        x=self.flatten(x)
        print(x.shape)
        #FCNN
        
        x=self.flatten(x)
        x=self.decoder_conv(x)
        
        return x

In [66]:
enc=b_encoder_conv().to(device)
ed=b_encodeco().to(device)
de=b_decoder_conv().to(device)

In [52]:
pr=np.random.randint(0,10,(2,3,4096,4096))
prt=torch.Tensor(pr)

In [64]:
y=enc(prt.to(device))

In [67]:
y=ed(prt.to(device))

torch.Size([2, 256, 256, 256])
torch.Size([2, 16777216])


In [33]:
y.shape

torch.Size([2, 256, 256, 256])

In [34]:
y=de(y)

torch.Size([2, 128, 512, 512])
torch.Size([2, 64, 1024, 1024])
torch.Size([2, 32, 2048, 2048])
torch.Size([2, 3, 4096, 4096])


In [8]:
ed

b_encodeco(
  (encoder): b_encoder(
    (im_layers): ModuleList(
      (0): s_conv(
        (Conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (act): ReLU()
      )
      (1): s_conv(
        (Conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (act): ReLU()
      )
      (2): s_conv(
        (Conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (act): ReLU()
      )
      (3): s_conv(
        (Conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (act): ReLU()
      )
    )
  )
  (flatten): s_view()
  (decoder): b_decoder(
    (im_layers): ModuleList(
      (0): s_deconv(
        (Conv): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
        (act): ReLU()
      )
      (1): s_deconv(
        (Conv): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
        (act): ReLU()
      )
      (2): s_deconv(
        (Conv): ConvTranspose2d(64, 32, kernel_s

In [68]:
256**3

16777216