In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device:',device)

Device: cuda


In [2]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.MNIST(
    root="~/torch_datasets", train=True, transform=transform, download=False
)

train_loader = DataLoader(
    train_dataset, batch_size=256, shuffle=True, num_workers=0, pin_memory=True)


# dVAE Class

In [3]:
class DiscreteVAE(nn.Module):
    def __init__(self, input_dim, latent_dim, num_categories):
        super(DiscreteVAE, self).__init__()
        
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.num_categories = num_categories
        
        # Encoder
        self.encoder_fc1 = nn.Linear(input_dim, 256)
        self.encoder_fc2 = nn.Linear(256, 128)
        self.encoder_fc3 = nn.Linear(128, latent_dim * num_categories) #128 -> 100
        
        # Discrete Latent Space
        self.latent_space = nn.Embedding(num_categories, latent_dim)
        
        # Decoder
        self.decoder_fc1 = nn.Linear(latent_dim, 128) #10 -> 128
        self.decoder_fc2 = nn.Linear(128, 256)
        self.decoder_fc3 = nn.Linear(256, input_dim)
        
    def encode(self, x):
        x = torch.relu(self.encoder_fc1(x))
        x = torch.relu(self.encoder_fc2(x))
        x = self.encoder_fc3(x)
        
        # Split the output into logits and the discrete latent variables
        logits = x[:, :self.latent_dim]
        discrete_latent = x[:, self.latent_dim:]
        
        return logits, discrete_latent
        
    def decode(self, x):
        x = torch.relu(self.decoder_fc1(x))
        x = torch.relu(self.decoder_fc2(x))
        x = self.decoder_fc3(x)
        
        return x
        
    def forward(self, x):
        logits, discrete_latent = self.encode(x)
        
        # Sample from the discrete latent space
        discrete_latent = torch.multinomial(torch.softmax(discrete_latent, dim=-1), 1)
        discrete_latent = self.latent_space(discrete_latent.squeeze())
        
        # Decode the sample
        x = self.decode(discrete_latent)
        
        return x, logits, discrete_latent
    

In [4]:
input_dim = 784
latent_dim = 10
num_categories = 10
model = DiscreteVAE(input_dim, latent_dim, num_categories).to(device)

In [5]:
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


if torch.cuda.is_available():
    model.cuda()

    loss.cuda()

In [6]:
num_epochs = 50
for epoch in range(num_epochs):
    for x,data in enumerate(train_loader):
        inputs = data[0].reshape(-1,28*28).to(device)
        labels = data[1]
        
        # Forward pass
        outputs, logits, discrete_latent = model(inputs)
        loss = criterion(outputs, inputs) + criterion(discrete_latent, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print ('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
    

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument target in method wrapper_nll_loss_forward)