In [3]:
import numpy as np
import pandas as pd 
import os
from jarviscloud import jarviscloud
from torch import nn
import torch
from tqdm import tqdm
import albumentations as A
import cv2
import timm

In [193]:
class Autoencoder(nn.Module):
    def __init__(self, encoder_name, latent_dim):
        super(Autoencoder, self).__init__()
        
        # encoder
        self.encoder = timm.create_model(encoder_name,in_chans=1, pretrained=True,)
        num_embeddings = self.encoder.classifier.in_features
        modules = list(self.encoder.children())[:-1]
        self.encoder = nn.Sequential(*modules)
        
        self.latent_layer = nn.Linear(1024, latent_dim)
        
        # decoder
        self.decoder_input = nn.Linear(latent_dim, 1024)
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),  # 1x1 to 2x2
            nn.ReLU(),
            
            nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),  # 2x2 to 4x4
            nn.ReLU(),
            
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),  # 4x4 to 8x8
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),  # 8x8 to 16x16
            nn.ReLU(),

            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2),  # 8x8 to 16x16
            nn.ReLU(),

            nn.Conv2d(1,1, kernel_size=5, stride=1, padding =0),
        )



        
    def forward(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.latent_layer(x)
        x = self.decoder_input(x)
        x = x.view(x.size(0), -1, 1, 1)
        x = self.decoder(x)
        return x

    def encode(self,x): #expected shape : (BS, C, H, W)
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        x = self.latent_layer(x)
        return x

    def decode(self,x): #expected shape : (BS, latent_space)
        x = self.decoder_input(x)
        x = x.view(x.size(0), -1, 1, 1)
        x = self.decoder(x)
        return x


In [194]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Autoencoder('efficientnet_b0',2).to(device)

In [201]:
all_images = np.load('../data/images_labels/images.npy')
all_labels = np.load('../data/images_labels/labels.npy')

In [196]:
image = torch.tensor(all_labels[:5]).unsqueeze(1).to(device).float()

In [197]:
image.shape

torch.Size([5, 1, 28, 28])

In [198]:
output = model(image)

In [199]:
output.shape

torch.Size([5, 1, 28, 28])

In [205]:
all_images[0]

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,
         18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,   0,   0,
          0,   0],
       [  