## Imports

In [1]:
import torch
from torch import nn

## Model Class

In [2]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE,self).__init__()
        
        
        self.common_fc = nn.Sequential(
            nn.Linear(28*28, out_features=196), nn.Tanh(),
            nn.Linear(196, out_features=48), nn.Tanh(),
        )
        
        self.mean_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Tanh(),
            nn.Linear(16, out_features=2), nn.Tanh()
        )
        
        self.log_var_fc = nn.Sequential(
            nn.Linear(48, out_features=16), nn.Tanh(),
            nn.Linear(16, out_features=2), nn.Tanh()
        )
        
        self.decoder_fcs = nn.Sequential(
            nn.Linear(2, out_features=16), nn.Tanh(), 
            nn.Linear(16, out_features=48), nn.Tanh(),
            nn.Linear(48, out_features=196), nn.Tanh(),
            nn.Linear(196, out_features=28*28), nn.Tanh(),
        )
        
    def encode(self,x):
        # B,C,H,W
        out = self.common_fc(torch.flatten(x, start_dim=1))
        mean = self.mean_fc(out)
        log_var = self.log_var_fc(out)
        return mean, log_var

    def sample(self, mean, log_var):
        std = torch.exp(0.5*log_var)
        z = torch.randn_like(std)
        return z * std + mean
    
    def decode(self, z):
        out = self.decoder_fcs(z)
        out = out.reshape(z.shape(0), 1, 28*28)
        return out
        
    
    def forward(self, x):
        #B,C,H,W
        #Encoder
        mean, log_var = self.encode(x)
        #Sampling
        z = self.sample(mean,log_var)
        #Decoder
        out = self.decode(z)
        return mean, log_var, out


## Training Parameters

## Data Creation

In [3]:
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split

def create_data():
    # ---INITIALIZE DATASET (not DataLoader yet)---
    dataset = MNIST(
        root='./data',
        download=True  # Add this to download if needed
    )
    
    TRIM_LEN = int(58_000)  # 60,000 - 58,000 = 2,000
    TRAIN_PORTION = 0.9
    TRAIN_LEN = int((len(dataset) - TRIM_LEN) * TRAIN_PORTION)
    
    # ---SPLIT DATASET---
    train_ds, test_ds, _ = random_split(
        dataset,  # Split the dataset, not the dataloader!
        [TRAIN_LEN, len(dataset) - TRIM_LEN - TRAIN_LEN, TRIM_LEN]
    )
    print(f"train length: {len(train_ds)} test_length: {len(test_ds)}")
    
    # ---CREATE DATALOADERS from the split datasets---
    train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)
    
    return train_loader, test_loader

## Training

In [4]:
def train_vae(model, train_loader):
    # Method 1: One-liner
    single_sample = next(iter(train_loader))
    # ---------Feed Forward---------
    model.forward(single_sample)
    #---------Back Prop---------
    #model loss
    torch.kl_div()


In [5]:
# Initialize the model
device = torch.device('0' if torch.cuda.is_available() else 'cpu')
my_vae = VAE().to(device)
optimizer = torch.optim.Adam(params=my_vae.parameters(),lr=1e-4)

train_loader,test_loader = create_data()
train_vae(model=my_vae,train_loader=train_loader)

train length: 1800 test_length: 200


TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>