In [4]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import os
import numpy as np
from torchvision.utils import save_image

os.chdir('./')

device = 'cuda' if torch.cuda.is_available() else 'cpu'


BATCH_SIZE=64
EPOCHS=50

cond_shape=10
intermediate_dim=400
z_dim=20

torch.manual_seed(42)

train_data = datasets.MNIST(root="../data/", train=True, download=True, transform=transforms.ToTensor(),target_transform=None)
test_data = datasets.MNIST(root="../data/", train=True, download=False, transform=transforms.ToTensor(),target_transform=None)

train_dataloader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE,  drop_last=True, shuffle=True)
test_dataloader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE,  drop_last=True, shuffle=False)


train_features_batch, train_labels_batch = next(iter(train_dataloader))
train_features_batch.shape, train_labels_batch.shape

(torch.Size([64, 1, 28, 28]), torch.Size([64]))

In [5]:
flat_img=torch.flatten(train_data[0][0])
flat_shape=list(flat_img.shape)
flat_shape[0]



784

In [6]:
class CVAE(nn.Module):
    def __init__(self):
        super(CVAE, self).__init__()

        self.enc_in = nn.Linear(in_features=flat_shape[0]+cond_shape, out_features=intermediate_dim)
        self.mu = nn.Linear(in_features=intermediate_dim, out_features=z_dim)
        self.logvar = nn.Linear(in_features=intermediate_dim, out_features=z_dim)
        self.dec_in = nn.Linear(in_features=z_dim+cond_shape, out_features=intermediate_dim)
        self.dec_out = nn.Linear(in_features=intermediate_dim, out_features=flat_shape[0]+cond_shape)

    def encoder(self, inp):
        x, y = inp
        y = F.one_hot(y, cond_shape)
        cat = torch.cat((x, y), dim=1)
        # print(cat.shape)
        enc_in_ret = F.relu(self.enc_in(cat))
        # print(enc_in_ret.shape)
        mu_out = F.relu(self.mu(enc_in_ret))
        # print(mu_out.shape)
        logvar_out = F.relu(self.logvar(enc_in_ret))
        # print(logvar_out.shape)
        return mu_out, logvar_out

    def reparametarize(self, mu, logvar):
        sd = torch.exp(0.5*logvar)
        eps = torch.randn_like(sd)
        return mu + eps*sd

    def decoder(self, inp):
        z, y = inp
        y = F.one_hot(y, cond_shape)
        cat = torch.cat((z,y), dim=1)
        dec_in_ret = F.relu(self.dec_in(cat))
        dec_out = torch.sigmoid(self.dec_out(dec_in_ret))
        return dec_out

    def forward(self, inp):
        x, y = inp
        x = x.to(device)
        y = y.to(device)
        mu, logvar = self.encoder((x.view(-1, flat_shape[0]), y))
        z = self.reparametarize(mu, logvar)
        recon = self.decoder((z, y))
        return recon, mu, logvar

def loss_fn(recon, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon, x, reduction='sum')        
    KLD = -0.5*torch.sum(1+logvar-mu.pow(2)-logvar.exp())
    return BCE+KLD

model = CVAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [9]:

def train(epoch):
    
    model.train()
    train_loss = 0


    for batch, (X, y) in enumerate(train_dataloader):
        X = X.to(device) #[64, 1, 28, 28]
        
        
        
        # 1. Forward pass
        recon_batch, mu, logvar = model((X, y))
        
        flat_data = X.view(-1, flat_shape[0]).to(device)                            
        y_onehot = F.one_hot(y, cond_shape).to(device)
        inp = torch.cat((flat_data, y_onehot), 1)
        
        # 2. Calculate loss
        loss = loss_fn(recon_batch, inp, mu, logvar)
        train_loss += loss.item()

        # 3. Zero grad
        optimizer.zero_grad()

        # 4. Backprop
        loss.backward()

        # 5. Step
        optimizer.step()

        if batch % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,
                batch * len(X),
                len(train_dataloader.dataset),
                100. * batch / len(train_dataloader),
                loss.item() / len(X)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_dataloader.dataset)))

In [10]:
def test(epoch):
    #Sets the module in evaluation mode
    model.eval()
    test_loss = 0

    with torch.inference_mode():
        for i, (X, y) in enumerate(test_dataloader):
            X = X.to(device)

            # 1. Forward pass
            recon_batch, mu, logvar = model((X, y))
            
            flat_data = X.view(-1, flat_shape[0]).to(device)
            y_onehot = F.one_hot(y, cond_shape).to(device)
            inp = torch.cat((flat_data, y_onehot), 1)

            # 2. Loss
            test_loss += loss_fn(recon_batch, inp, mu, logvar).item()

            # 3. Save images
            if epoch%5==0 and i == 0:
                n = min(X.size(0), 8)
                recon_image = recon_batch[:, 0:recon_batch.shape[1]-10]
                print(recon_image.shape)
                recon_image = recon_image.view(BATCH_SIZE, 1, 28,28)
                print('---',recon_image.shape)
                comparison = torch.cat([X[:n],
                                      recon_image.view(BATCH_SIZE, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_dataloader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch)

    # Generate random digits every n epochs
    with torch.inference_mode():
        if epoch%5==0:
            sample = torch.randn(64, 20).to(device)
        
            c = np.zeros(shape=(sample.shape[0],))
            rand = np.random.randint(0, 10)
            print(f"Random number: {rand}")
            c[:] = rand
            c = torch.FloatTensor(c)
            c = c.to(torch.int64)
            c = c.to(device)
            sample = model.decoder((sample, c)).cpu()
            
            generated_image = sample[:, 0:sample.shape[1]-10]
            
            
            save_image(generated_image.view(64, 1, 28, 28),
                    'results/sample_' + str(epoch) + '.png')


====> Epoch: 1 Average loss: 170.5982
====> Test set loss: 150.8345
====> Epoch: 2 Average loss: 147.1803
====> Test set loss: 144.3100
====> Epoch: 3 Average loss: 142.7552
====> Test set loss: 141.3957
====> Epoch: 4 Average loss: 140.6234
====> Test set loss: 139.7719
====> Epoch: 5 Average loss: 139.2734
torch.Size([64, 784])
--- torch.Size([64, 1, 28, 28])
====> Test set loss: 138.5953
Random number: 0
====> Epoch: 6 Average loss: 138.3960
====> Test set loss: 137.7825
====> Epoch: 7 Average loss: 137.7340
====> Test set loss: 137.3624
====> Epoch: 8 Average loss: 137.0769
====> Test set loss: 136.9515
====> Epoch: 9 Average loss: 136.7154
====> Test set loss: 136.1959
====> Epoch: 10 Average loss: 136.3334
torch.Size([64, 784])
--- torch.Size([64, 1, 28, 28])
====> Test set loss: 135.9288
Random number: 5
====> Epoch: 11 Average loss: 135.9160
====> Test set loss: 135.6228
====> Epoch: 12 Average loss: 135.5357
====> Test set loss: 135.1991
====> Epoch: 13 Average loss: 135.3905


In [2]:
sample = torch.randn(64, 20).to(device)
        
c = np.zeros(shape=(sample.shape[0],))
rand = 9
print(f"Random number: {rand}")
c[:] = rand
c = torch.FloatTensor(c)
c = c.to(torch.int64)
c = c.to(device)
sample = model.decoder((sample, c)).cpu()
generated_image = sample[:, 0:sample.shape[1]-10]
            
            
save_image(generated_image.view(64, 1, 28, 28),
        'results/generated_' + str(rand) + '.png')


Random number: 9


NameError: name 'model' is not defined