In [18]:
import torch 
import torch.nn as nn
from torchinfo import summary
from torchvision.datasets import MNIST
from torchvision import transforms
from tqdm.notebook import tqdm

In [2]:
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=0.1307,std=0.3081)])
dataset=MNIST(root='./root',download=True,train=True,transform=transform)
testset=MNIST(root='./root',download=True,train=False,transform=transform)

In [3]:
dataset.data.shape

torch.Size([60000, 28, 28])

In [4]:
dataset.data[0].shape

torch.Size([28, 28])

In [5]:
dataset.data = dataset.data.unsqueeze(1)

In [6]:
class Encoder(nn.Module):
    def __init__(self,ldim):
        super().__init__() 
        self.layers = nn.Sequential(
            nn.Conv2d(1, 32, 2, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, 2, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2),
            nn.ReLU(),
            nn.Flatten(1,-1),
            nn.Linear(1152,ldim)
        )
    def forward(self,x):
        x = x.float()
        return self.layers(x)

In [7]:
class Decoder(nn.Module):
    def __init__(self,ldim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(ldim,1152),
            nn.Unflatten(1,(128,3,3)),
            nn.ConvTranspose2d(128,64,3,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64,32,2,stride=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32,1,2,stride=2),
            nn.ReLU()
        )
        
    def forward(self,x):
        x = x.float()
        return self.layers(x)

In [19]:
class AutoEncoder(nn.Module):
    def __init__(self,ldim):
        super().__init__()
        self.encoder = Encoder(ldim)
        self.decoder = Decoder(ldim)
        
    def forward(self,x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x.float()

In [15]:
real = AutoEncoder(32)

In [16]:
real.loss(dataset.data)

tensor(7258.1738, grad_fn=<MseLossBackward0>)

In [20]:
# Latent Dimensions=32
model1 = AutoEncoder(32)
X_train = dataset.data
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model1.parameters(),lr=0.01)
epochs=50
for epoch in tqdm(range(epochs)):
    model1.train()
    X_pred = model1(X_train)
    loss = loss_fn(X_pred,X_train)
    print(f"Loss: {loss}")
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()    

  0%|          | 0/50 [00:00<?, ?it/s]

Loss: 7282.7470703125


RuntimeError: Found dtype Byte but expected Float