In [1]:
pip install torch torchvision

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

In [3]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define the transformation to include normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.130,), ( 0.308,))  # Normalize with mean=0.5 and std=0.5
])

# Data loading
train_dataset = datasets.FashionMNIST(root = './data', train=True, download=True, transform=transform)
data_loader = DataLoader(train_dataset, batch_size = 128, shuffle=True)

In [4]:
def imshow(img):
    """ Function to show an image """
    img = img / 2 + 0.5  # unnormalize
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')  # Hide the axes
    plt.show()

In [5]:
# Define the Encoder
class Encoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.linear_layer = nn.Linear(256 * 14 * 14, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.mu = nn.Linear(1000, 128)
        self.sigma = nn.Linear(1000, 128)

    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 14 * 14)
        x = F.relu(self.bn4(self.linear_layer(x)))
        mu = self.mu(x)
        log_var = self.sigma(x)
        return mu, log_var

class Decoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.layer3 = nn.Linear(128, 1000)
        self.layer4 = nn.Linear(1000, 256 * 14 * 14)
        self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv3 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, z):

        z = F.relu(self.layer3(z))
        z = F.relu(self.layer4(z))
        z = z.view(-1, 256, 14, 14)  # Reshape to match the deconv input size
        z = F.relu(self.bn5(self.upsample1(self.deconv1(z))))
        z = F.relu(self.bn6(self.deconv2(z)))
        z = torch.sigmoid((self.deconv3(z)))  # Added another deconv layer and upsampling
        return z

class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.linear_layer = nn.Linear(256 * 14 * 14, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.pen_final = nn.Linear(1000, 10)
        self.bn5 = nn.BatchNorm1d(10)
#         self.final_dis = nn.Linear(10, 10)
        self.final_gen = nn.Linear(10, 1)
        
    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 14 * 14)
        x = F.relu(self.bn4(self.linear_layer(x)))
        x = F.relu(self.bn5(self.pen_final(x)))
        features = torch.clone(x.detach())
        return F.sigmoid(self.final_gen(x)), features

# VAE_GAN class modification to include discriminator features in the loss function
class VAE_GAN(nn.Module):
    
    def __init__(self, encoder, decoder, discriminator):
        super(VAE_GAN, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z)
        disc_real, features_real = self.discriminator(x)
        disc_recon, features_recon = self.discriminator(recon_x)
        # Sample new data from the prior for discriminator

        z_prior = torch.randn_like(z)
        prior_x = self.decoder(z_prior)
        disc_prior,_= self.discriminator(prior_x.detach())  # Detach to prevent gradients to decoder
        return features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior

In [6]:
class Discriminator_multi(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.linear_layer = nn.Linear(256 * 14 * 14, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.disc_pen = nn.Linear(1000, 10)
        self.pen_final = nn.Linear(1000, 10)
        self.bn5 = nn.BatchNorm1d(10)
        self.bn6 = nn.BatchNorm1d(10)
#         self.final_dis = nn.Linear(10, 10)
        self.final_gen = nn.Linear(10, 1)
        self.final_dis = nn.Linear(10, 10)
        
#     def forward(self, x):

#         x = F.relu(self.bn1(self.conv1(x)))
#         x = F.relu(self.bn2(self.conv2(x)))
#         x = F.relu(self.bn3(self.conv3(x)))
#         x = self.maxpool3(x)
#         x = x.view(-1, 256 * 14 * 14)
#         x = F.relu(self.bn4(self.linear_layer(x)))
#         y = F.relu(self.bn6(self.disc_pen(x)))
#         x = F.relu(self.bn5(self.pen_final(x)))
#         y = F.softmax(self.final_dis(y), dim = 1)
#         features = torch.clone(x.detach())
#         return F.sigmoid(self.final_gen(x)), features, y
    
    def forward(self, x):

        x = F.relu((self.conv1(x)))
        x = F.relu((self.conv2(x)))
        x = F.relu((self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 14 * 14)
        x = F.relu((self.linear_layer(x)))
        y = F.relu((self.disc_pen(x)))
        x = F.relu((self.pen_final(x)))
        y = F.softmax(self.final_dis(y), dim = 1)
        features = torch.clone(x.detach())
        return F.sigmoid(self.final_gen(x)), features, y

In [7]:
def save_checkpoint(epoch, model, optimizer_dis, optimizer_enc, optimizer_dec):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_dis': optimizer_dis.state_dict(),
        'optimizer_enc': optimizer_enc.state_dict(),
        'optimizer_dec': optimizer_dec.state_dict()
    }
    
    torch.save(checkpoint, "epoch{} checkpoint".format(epoch))

# Define a function to load a checkpoint
def load_checkpoint(model, optimizer_dis, optimizer_enc, optimizer_dec, filename):
    
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer_dis.load_state_dict(checkpoint['optimizer_dis'])
    optimizer_enc.load_state_dict(checkpoint['optimizer_enc'])
    optimizer_dec.load_state_dict(checkpoint['optimizer_dec'])
    epoch = checkpoint['epoch']
    return epoch

In [8]:
def train(model, data_loader, optimizer_enc, optimizer_dec, optimizer_dis, epochs=5):
    
    try:
        device = next(model.parameters()).device  # Get the device from the model's parameters
        torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection
        for p in model.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -300, 300))

        for epoch in range(epochs):
            for batch_idx, (imgs, _) in enumerate(data_loader):
                imgs = imgs.to(device)

                # Zero the parameter gradients
                optimizer_enc.zero_grad()
                optimizer_dec.zero_grad()
                optimizer_dis.zero_grad()

                # Forward pass through VAE/GAN
    #             print(model(imgs))
                features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior = model(imgs)

                mse_loss = nn.MSELoss(reduction = 'sum')
                # Calculate losses
                # Reconstruction loss and KL divergence for VAE
                l_dis_like = mse_loss(features_real, features_recon)
                l_prior = -5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    #             print(disc_real)
                l_gan = torch.sum(torch.log(disc_real) + torch.log(1 - disc_recon) + torch.log(1 - disc_prior))
                
    #             l_gan = torch.sum(-torch.log(disc_real) + torch.log(disc_prior))
    #             enc_loss = l_dis_like + l_prior 
    #             dec_loss = l_dis_like - l_gan
                disc_loss = -l_gan

    #             print(l_gan.item())
                disc_loss.backward()
    #             enc_loss.backward(retain_graph = True)
    #             dec_loss.backward()

    #             optimizer_dis.step()
    #             optimizer_enc.step()


    #             print(max_norm)
    # Calculate the total norm of all gradients
#                 total_norm = 0
#                 for name, param in model.discriminator.named_parameters():
#                     if param.grad is not None:
#                         total_norm += torch.norm(param.grad).item() ** 2
#                 total_norm = total_norm ** 0.5  # Take the square root to get the total norm

#                 # Clip gradients if the total norm exceeds a threshold
#                 max_norm = 100  # Adjust as needed
#                 if total_norm > max_norm:
#                     for name, param in model.discriminator.named_parameters():
#                         if param.grad is not None:
#                             param.grad *= max_norm / total_norm

                optimizer_dis.step()

                if(batch_idx % 50) == 0:             
                    total_norm = 0
                    for name, param in model.discriminator.named_parameters():
                        if param.grad is not None:
                            total_norm += torch.norm(param.grad).item() ** 2
                    total_norm = total_norm ** 0.5  # Take the square root to get the total norm
                    print(total_norm)

                if batch_idx % 1000 == 0:

                    print("epoch:{}".format(epoch))
                    print("disc_loss:", disc_loss.item())
                    print('Original Images')
                    imshow(torchvision.utils.make_grid(imgs[:10]))

                    # Pass images through the model to get the reconstructions
                    with torch.no_grad():

                        imgs = imgs.to(device)
                        mu, log_var = model.encoder(imgs)
                        noise = torch.rand_like(log_var)
                        z = mu + noise * torch.exp(0.5 * log_var)
                        recon_images = model.decoder(z)
                        recon_images = recon_images.cpu()

                    print('Reconstructed Images')
                    imshow(torchvision.utils.make_grid(recon_images[:10]))

            if epoch%10 == 0:
                torch.save(model.discriminator.state_dict(), 'disc.pt')
    
    except:
        return

In [9]:
class VAE_GAN_multi(nn.Module):
    
    def __init__(self, encoder, decoder, discriminator):
        super(VAE_GAN_multi, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z)
        disc_real, features_real, class_real = self.discriminator(x)
        disc_recon, features_recon, _ = self.discriminator(recon_x)
        # Sample new data from the prior for discriminator

        z_prior = torch.randn_like(z)
        prior_x = self.decoder(z_prior)
        disc_prior, _, _= self.discriminator(prior_x.detach())  # Detach to prevent gradients to decoder
        return features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real

In [10]:
def train_multi(model, data_loader, optimizer_dis, epochs=5):
    
    try:
        device = next(model.parameters()).device  # Get the device from the model's parameters
        torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection
    #         for p in model.parameters():
    #             p.register_hook(lambda grad: torch.clamp(grad, -300, 300))

        for epoch in range(epochs):
            for batch_idx, (imgs, labels) in enumerate(data_loader):
                imgs = imgs.to(device)
                labels_one_hot = F.one_hot(labels, num_classes=10).float().to(device)

                # Zero the parameter gradients
                optimizer_dis.zero_grad()

                # Forward pass through VAE/GAN
    #             print(model(imgs))
                features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real = model(imgs)

#                 mse_loss = nn.MSELoss(reduction = 'sum')
                cross_entropy_loss = nn.CrossEntropyLoss()
                # Calculate losses
                # Reconstruction loss and KL divergence for VAE
#                 l_dis_like = mse_loss(features_real, features_recon)
#                 l_prior = -5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    #             print(disc_real)
                l_gan = torch.sum(torch.log(disc_real) + torch.log(1 - disc_recon) + torch.log(1 - disc_prior))
                l_class = cross_entropy_loss(class_real, labels_one_hot)

    #             l_gan = torch.sum(-torch.log(disc_real) + torch.log(disc_prior))
    #             enc_loss = l_dis_like + l_prior 
    #             dec_loss = l_dis_like - l_gan
                disc_loss = -l_gan + l_class 

#                 print(l_gan.item())
#                 print('disc_real:', disc_real)
#                 print('disc_recon:',  disc_recon)
#                 print('disc_prior:', disc_prior)
#                 print('true_class_prob:', torch.exp(-l_class))
#                 print('\n')
                disc_loss.backward()

    #             enc_loss.backward(retain_graph = True)
    #             dec_loss.backward()

    #             optimizer_dis.step()
    #             optimizer_enc.step()


    #             print(max_norm)
    # Calculate the total norm of all gradients
    #                 total_norm = 0
    #                 for name, param in model.discriminator.named_parameters():
    #                     if param.grad is not None:
    #                         total_norm += torch.norm(param.grad).item() ** 2
    #                 total_norm = total_norm ** 0.5  # Take the square root to get the total norm

    #                 # Clip gradients if the total norm exceeds a threshold
    #                 max_norm = 100  # Adjust as needed
    #                 if total_norm > max_norm:
    #                     for name, param in model.discriminator.named_parameters():
    #                         if param.grad is not None:
    #                             param.grad *= max_norm / total_norm

                optimizer_dis.step()
                if batch_idx % 1000 == 0:
                
                    print("epoch:{}".format(epoch))
                    print("disc_loss:", disc_loss.item())
                    print('Original Images')
                    imshow(torchvision.utils.make_grid(imgs[:10]))

                    # Pass images through the model to get the reconstructions
                    with torch.no_grad():

                        imgs = imgs.to(device)
                        mu, log_var = model.encoder(imgs)
                        noise = torch.rand_like(log_var)
                        z = mu + noise * torch.exp(0.5 * log_var)
                        recon_images = model.decoder(z)
                        recon_images = recon_images.cpu()

                    print('Reconstructed Images')
                    imshow(torchvision.utils.make_grid(recon_images[:10]))

            if epoch%10 == 0:
                torch.save(model.discriminator.state_dict(), 'disc_multi.pt')
                
    except:
        return

In [11]:
# Initialize the models and optimizers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder().to(device)
decoder = Decoder().to(device)
# discriminator = Discriminator().to(device)

In [12]:
encoder.load_state_dict(torch.load('/kaggle/input/encoder_normalized/pytorch/v1/1/encoder.pt'))
decoder.load_state_dict(torch.load('/kaggle/input/decoder_normalized/pytorch/v1/1/decoder.pt'))
# discriminator.load_state_dict(torch.load('/kaggle/input/discriminator/pytorch/v1/1/disc.pt'))

<All keys matched successfully>

In [13]:
# vae_gan = VAE_GAN(encoder, decoder, discriminator).to(device)
optimizer_enc = Adam(encoder.parameters(), lr=3e-5)
optimizer_dec = Adam(decoder.parameters(), lr=3e-5)
# optimizer_dis = Adam(discriminator.parameters(), lr=6e-5)

In [14]:
# train(vae_gan, data_loader, optimizer_enc, optimizer_dec, optimizer_dis, epochs = 50)

In [15]:
# torch.save(vae_gan.discriminator.state_dict(), 'disc.pt')

In [16]:
# from IPython.display import FileLink 
# %cd /kaggle/working 
# FileLink('disc.pt')

In [17]:
# vae_gan.to('cpu')
# vae_gan.eval()
# out, _ = vae_gan.discriminator(train_dataset[8][0].unsqueeze(0))
# out

In [18]:
# vals = []
# vae_gan.to('cpu')
# vae_gan.eval()
# for el in tqdm(train_dataset):
#     out, _ = vae_gan.discriminator(el[0].unsqueeze(0))
#     vals.append(out.item())    

In [19]:
# vals = np.array(vals)
# np.mean(vals), np.std(vals)

In [20]:
# encoder = Encoder().to(device)
# decoder = Decoder().to(device)
# optimizer_enc = Adam(encoder.parameters(), lr=3e-5)
# optimizer_dec = Adam(decoder.parameters(), lr=3e-5)

In [21]:
# epochs = 30
# for epoch in range(epochs):
    
#     total_loss = 0
#     count = 0
#     for img, _ in data_loader:
        
#         optimizer_enc.zero_grad()
#         optimizer_dec.zero_grad()
        
#         img = img.to(device)
#         mu, log_var = encoder(img)
#         noise = torch.rand_like(log_var)
#         z = mu + noise * torch.exp(0.5 * log_var)
#         recons = decoder(z)
        
#         criterion = nn.MSELoss(reduction = 'sum')
#         loss = criterion(img, recons) + -5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
#         loss.backward()
#         optimizer_enc.step()
#         optimizer_dec.step()
#         total_loss += loss
#         count +=1
        
#     print('Original Images')
#     imshow(torchvision.utils.make_grid(img[:10]))

#     # Pass images through the model to get the reconstructions
#     with torch.no_grad():

#         img = img.to(device)
#         mu, log_var = encoder(img)
#         noise = torch.rand_like(log_var)
#         z = mu + noise * torch.exp(0.5 * log_var)
#         recon_images = decoder(z)
#         recon_images = recon_images.cpu()

#     print('Reconstructed Images')
#     imshow(torchvision.utils.make_grid(recon_images[:10]))
#     print("The epoch is {} and the loss is {}".format(epoch, total_loss/count))

In [22]:
# torch.save(encoder.state_dict(), 'encoder.pt')
# torch.save(decoder.state_dict(), 'decoder.pt')

In [23]:
# from IPython.display import FileLink
# %cd /kaggle/working
# FileLink('encoder.pt')

In [24]:
# from IPython.display import FileLink
# %cd /kaggle/working
# FileLink('decoder.pt')

In [25]:
discriminator_multi = Discriminator_multi().to(device)
vae_gan_multi = VAE_GAN_multi(encoder, decoder, discriminator_multi).to(device)
optimizer_dis = Adam(discriminator_multi.parameters(), lr = 6e-5)

In [None]:
# discriminator_multi.load_state_dict(torch.load('/kaggle/input/discriminator_multi_normalized/pytorch/v1/1/disc_multi (1).pt'))

In [None]:
vae_gan_multi.encoder.eval()
vae_gan_multi.decoder.eval()
# vae_gan_multi.discriminator.eval()
train_multi(vae_gan_multi, data_loader, optimizer_dis, epochs = 60)

In [None]:
torch.save(vae_gan_multi.discriminator.state_dict(), 'disc_multi.pt')

In [None]:
vals = []
vae_gan_multi.to('cpu')
vae_gan_multi.eval()
for i in tqdm(range(len(train_dataset))):
    out, _, _ = vae_gan_multi.discriminator(train_dataset[i][0].unsqueeze(0))
    vals.append(out.item())   

In [None]:
vae_gan_multi.eval()
# vae_gan_multi.discriminator.train()
vae_gan_multi.to('cpu')
for imgs, labels in data_loader:
    features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real = vae_gan_multi(imgs)
    print(disc_real, disc_recon, disc_prior)

In [None]:
vals = np.array(vals)
np.mean(vals), np.std(vals)