In [5]:
pip install torch torchvision

Note: you may need to restart the kernel to use updated packages.


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import pandas as pd 

In [7]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

In [8]:
metadata_path = '/kaggle/input/cifar10-python-in-csv/batches.meta' # change this path
metadata = unpickle(metadata_path)

In [9]:
train_data_csv = pd.read_csv('/kaggle/input/cifar10-python-in-csv/train.csv')
test_data_csv = pd.read_csv('/kaggle/input/cifar10-python-in-csv/test.csv')

In [10]:
train_data_csv.head()

Unnamed: 0,pixel_0,pixel_1,pixel_2,pixel_3,pixel_4,pixel_5,pixel_6,pixel_7,pixel_8,pixel_9,...,pixel_3063,pixel_3064,pixel_3065,pixel_3066,pixel_3067,pixel_3068,pixel_3069,pixel_3070,pixel_3071,label
0,59,43,50,68,98,119,139,145,149,149,...,58,65,59,46,57,104,140,84,72,6
1,154,126,105,102,125,155,172,180,142,111,...,42,67,101,122,133,136,139,142,144,9
2,255,253,253,253,253,253,253,253,253,253,...,83,80,69,66,72,79,83,83,84,9
3,28,37,38,42,44,40,40,24,32,43,...,39,59,42,44,48,38,28,37,46,4
4,170,168,177,183,181,177,181,184,189,189,...,88,85,82,83,79,78,82,78,80,1


In [70]:
train_data = []
test_data = []

In [71]:
for i in tqdm(range(train_data_csv.shape[0])):
    
    row_data = np.array(train_data_csv.iloc[i], dtype = np.float64) 
    pixel_data = row_data[:3072]/255
    label = int(row_data[-1])
    image = torch.zeros(3, 32, 32)
    image[0] = torch.tensor(pixel_data[:1024].reshape((32, 32)))
    image[1] = torch.tensor(pixel_data[1024:2048].reshape((32, 32)))
    image[2] = torch.tensor(pixel_data[2048:].reshape((32, 32))) 
    train_data.append((image, label))
    
for i in tqdm(range(test_data_csv.shape[0])):
    
    row_data = np.array(train_data_csv.iloc[i], dtype = np.float64) 
    pixel_data = row_data[:3072]/255
    label = int(row_data[-1])
    image = torch.zeros(3, 32, 32)
    image[0] = torch.tensor(pixel_data[:1024].reshape((32, 32)))
    image[1] = torch.tensor(pixel_data[1024:2048].reshape((32, 32)))
    image[2] = torch.tensor(pixel_data[2048:].reshape((32, 32))) 
    test_data.append((image, label))

100%|██████████| 50000/50000 [00:12<00:00, 4050.51it/s]
100%|██████████| 10000/10000 [00:02<00:00, 4168.01it/s]


In [72]:
def imshow(img):
    """ Function to show an image """
    img = img / 2 + 0.5  # unnormalize
    npimg = img.detach().cpu().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')  # Hide the axes
    plt.show()

In [73]:
# Define the Encoder
class Encoder(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.linear_layer = nn.Linear(256 * 16 * 16, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.mu = nn.Linear(1000, 128)
        self.sigma = nn.Linear(1000, 128)

    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 16 * 16)
        x = F.relu(self.bn4(self.linear_layer(x)))
        mu = self.mu(x)
        log_var = self.sigma(x)
        return mu, log_var

class Decoder(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        self.layer3 = nn.Linear(128, 1000)
        self.layer4 = nn.Linear(1000, 256 * 16 * 16)
        self.deconv1 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.deconv3 = nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')

    def forward(self, z):

        z = F.relu(self.layer3(z))
        z = F.relu(self.layer4(z))
        z = z.view(-1, 256, 16, 16)  # Reshape to match the deconv input size
        z = F.relu(self.bn5(self.upsample1(self.deconv1(z))))
        z = F.relu(self.bn6(self.deconv2(z)))
        z = torch.sigmoid((self.deconv3(z)))  # Added another deconv layer and upsampling
        return z

class Discriminator(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.linear_layer = nn.Linear(256 * 16 * 16, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.pen_final = nn.Linear(1000, 10)
        self.bn5 = nn.BatchNorm1d(10)
#         self.final_dis = nn.Linear(10, 10)
        self.final_gen = nn.Linear(10, 1)
        
    def forward(self, x):

        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 16 * 16)
        x = F.relu(self.bn4(self.linear_layer(x)))
        x = F.relu(self.bn5(self.pen_final(x)))
        features = torch.clone(x.detach())
        return F.sigmoid(self.final_gen(x)), features

# VAE_GAN class modification to include discriminator features in the loss function
class VAE_GAN(nn.Module):
    
    def __init__(self, encoder, decoder, discriminator):
        super(VAE_GAN, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z)
        disc_real, features_real = self.discriminator(x)
        disc_recon, features_recon = self.discriminator(recon_x)
        # Sample new data from the prior for discriminator

        z_prior = torch.randn_like(z)
        prior_x = self.decoder(z_prior)
        disc_prior,_= self.discriminator(prior_x.detach())  # Detach to prevent gradients to decoder
        return features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior

In [74]:
class CustomCNN(nn.Module):
    
    def __init__(self):
        
        super(CustomCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.batch_norm1 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.batch_norm2 = nn.BatchNorm2d(128)

        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.conv7 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.batch_norm3 = nn.BatchNorm2d(256)

        self.conv8 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.conv9 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv10 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=1)  # default stride is 2
        self.batch_norm4 = nn.BatchNorm2d(512)

        self.conv11 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.conv13 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=2, stride=1)  # default stride is 2
        self.batch_norm5 = nn.BatchNorm2d(512)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512, 4096)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(4096, 4096)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(4096, 10)
        self.fc4 = nn.Linear(4096, 1)

    def forward(self, x):
        x = F.leaky_relu(self.conv1(x))
        x = F.leaky_relu(self.conv2(x))
        x = self.pool1(x)
        x = self.batch_norm1(x)
        
        x = F.leaky_relu(self.conv3(x))
        x = F.leaky_relu(self.conv4(x))
        x = self.pool2(x)
        x = self.batch_norm2(x)

        x = F.leaky_relu(self.conv5(x))
        x = F.leaky_relu(self.conv6(x))
        x = F.leaky_relu(self.conv7(x))
        x = self.pool3(x)
        x = self.batch_norm3(x)

        x = F.leaky_relu(self.conv8(x))
        x = F.leaky_relu(self.conv9(x))
        x = F.leaky_relu(self.conv10(x))
        x = self.pool4(x)
        x = self.batch_norm4(x)

        x = F.leaky_relu(self.conv11(x))
        x = F.leaky_relu(self.conv12(x))
        x = F.leaky_relu(self.conv13(x))
        x = self.pool5(x)
        x = self.batch_norm5(x)
        
        x = self.flatten(x)
        x = F.leaky_relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.leaky_relu(self.fc2(x))
        x = self.dropout2(x)
        features = torch.clone(x.cpu().detach()).to(device)
        x = self.fc3(x)
        return F.sigmoid(self.fc4(features)), features, x

In [75]:
class Discriminator_multi(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, padding=0, stride=2)
        
        self.linear_layer = nn.Linear(256 * 16 * 16, 1000)
        self.bn4 = nn.BatchNorm1d(1000)
        self.disc_pen = nn.Linear(1000, 10)
        self.pen_final = nn.Linear(1000, 10)
        self.bn5 = nn.BatchNorm1d(10)
        self.bn6 = nn.BatchNorm1d(10)
#         self.final_dis = nn.Linear(10, 10)
        self.final_gen = nn.Linear(10, 1)
        self.final_dis = nn.Linear(10, 10)
        
    def forward(self, x):

        x = F.leaky_relu(self.bn1(self.conv1(x)))
        x = F.leaky_relu(self.bn2(self.conv2(x)))
        x = F.leaky_relu(self.bn3(self.conv3(x)))
        x = self.maxpool3(x)
        x = x.view(-1, 256 * 16 * 16)
        x = F.leaky_relu(self.bn4(self.linear_layer(x)))
        y = F.leaky_relu(self.bn6(self.disc_pen(x)))
        x = F.leaky_relu(self.bn5(self.pen_final(x)))
        y = F.softmax(self.final_dis(y), dim = 1)
        features = torch.clone(x.detach())
        return F.sigmoid(self.final_gen(x)), features, y
    
#     def forward(self, x):

#         x = F.relu((self.conv1(x)))
#         x = F.relu((self.conv2(x)))
#         x = F.relu((self.conv3(x)))
#         x = self.maxpool3(x)
#         x = x.view(-1, 256 * 14 * 14)
#         x = F.relu((self.linear_layer(x)))
#         y = F.relu((self.disc_pen(x)))
#         x = F.relu((self.pen_final(x)))
#         y = F.softmax(self.final_dis(y), dim = 1)
#         features = torch.clone(x.detach())
#         return F.sigmoid(self.final_gen(x)), features, y

In [76]:
def save_checkpoint(epoch, model, optimizer_dis, optimizer_enc, optimizer_dec):
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_dis': optimizer_dis.state_dict(),
        'optimizer_enc': optimizer_enc.state_dict(),
        'optimizer_dec': optimizer_dec.state_dict()
    }
    
    torch.save(checkpoint, "epoch{} checkpoint".format(epoch))

# Define a function to load a checkpoint
def load_checkpoint(model, optimizer_dis, optimizer_enc, optimizer_dec, filename):
    
    checkpoint = torch.load(filename)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer_dis.load_state_dict(checkpoint['optimizer_dis'])
    optimizer_enc.load_state_dict(checkpoint['optimizer_enc'])
    optimizer_dec.load_state_dict(checkpoint['optimizer_dec'])
    epoch = checkpoint['epoch']
    return epoch

In [77]:
def train(model, data_loader, optimizer_enc, optimizer_dec, optimizer_dis, epochs=5):
    
    try:
        device = next(model.parameters()).device  # Get the device from the model's parameters
        torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection
        for p in model.parameters():
            p.register_hook(lambda grad: torch.clamp(grad, -300, 300))

        for epoch in range(epochs):
            for batch_idx, (imgs, _) in enumerate(data_loader):
                imgs = imgs.to(device)

                # Zero the parameter gradients
#                 optimizer_enc.zero_grad()
#                 optimizer_dec.zero_grad()
                optimizer_dis.zero_grad()

                # Forward pass through VAE/GAN
    #             print(model(imgs))
                features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior = model(imgs)

#                 mse_loss = nn.MSELoss(reduction = 'sum')
                # Calculate losses
                # Reconstruction loss and KL divergence for VAE
                l_dis_like = mse_loss(features_real, features_recon)
                l_prior = -0.1 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

    #             print(disc_real)
                l_gan = torch.sum(torch.log(disc_real) + torch.log(1 - disc_recon) + torch.log(1 - disc_prior))
                
    #             l_gan = torch.sum(-torch.log(disc_real) + torch.log(disc_prior))
    #             enc_loss = l_dis_like + l_prior 
    #             dec_loss = l_dis_like - l_gan
                disc_loss = -l_gan

    #             print(l_gan.item())
                disc_loss.backward()
    #             enc_loss.backward(retain_graph = True)
    #             dec_loss.backward()

    #             optimizer_dis.step()
    #             optimizer_enc.step()


    #             print(max_norm)
    # Calculate the total norm of all gradients
#                 total_norm = 0
#                 for name, param in model.discriminator.named_parameters():
#                     if param.grad is not None:
#                         total_norm += torch.norm(param.grad).item() ** 2
#                 total_norm = total_norm ** 0.5  # Take the square root to get the total norm

#                 # Clip gradients if the total norm exceeds a threshold
#                 max_norm = 100  # Adjust as needed
#                 if total_norm > max_norm:
#                     for name, param in model.discriminator.named_parameters():
#                         if param.grad is not None:
#                             param.grad *= max_norm / total_norm

                optimizer_dis.step()

                if(batch_idx % 50) == 0:             
                    total_norm = 0
                    for name, param in model.discriminator.named_parameters():
                        if param.grad is not None:
                            total_norm += torch.norm(param.grad).item() ** 2
                    total_norm = total_norm ** 0.5  # Take the square root to get the total norm
                    print(total_norm)

                if batch_idx % 1000 == 0:

                    print("epoch:{}".format(epoch))
                    print("disc_loss:", disc_loss.item())
                    print('Original Images')
                    imshow(torchvision.utils.make_grid(imgs[:10]))

                    # Pass images through the model to get the reconstructions
                    with torch.no_grad():

                        imgs = imgs.to(device)
                        mu, log_var = model.encoder(imgs)
                        noise = torch.rand_like(log_var)
                        z = mu + noise * torch.exp(0.5 * log_var)
                        recon_images = model.decoder(z)
                        recon_images = recon_images.cpu()

                    print('Reconstructed Images')
                    imshow(torchvision.utils.make_grid(recon_images[:10]))

            if epoch%10 == 0:
                torch.save(model.discriminator.state_dict(), 'disc.pt')
    
    except:
        return

In [78]:
class VAE_GAN_multi(nn.Module):
    
    def __init__(self, encoder, decoder, discriminator):
        super(VAE_GAN_multi, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.discriminator = discriminator

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decoder(z)
        disc_real, features_real, class_real = self.discriminator(x)
        disc_recon, features_recon, _ = self.discriminator(recon_x)
        # Sample new data from the prior for discriminator

        z_prior = torch.randn_like(z)
        prior_x = self.decoder(z_prior)
        disc_prior, _, _= self.discriminator(prior_x.detach())  # Detach to prevent gradients to decoder
        return features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real

In [79]:
def train_multi(model, data_loader, optimizer_dis, optimizer_enc, optimizer_dec,  epochs=5):
    
    device = next(model.parameters()).device  # Get the device from the model's parameters
    torch.autograd.set_detect_anomaly(True)  # Enable anomaly detection
#         for p in model.parameters():
#             p.register_hook(lambda grad: torch.clamp(grad, -300, 300))

    for epoch in range(epochs):
        for batch_idx, (imgs, labels) in enumerate(data_loader):
            
            imgs = imgs.to(device)
            labels_one_hot = F.one_hot(labels, num_classes=10).float().to(device)

            # Zero the parameter gradients
            optimizer_dis.zero_grad()
            optimizer_enc.zero_grad()
            optimizer_dec.zero_grad()


            # Forward pass through VAE/GAN
#             print(model(imgs))
            features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real = model(imgs)

            cross_entropy_loss = nn.CrossEntropyLoss()
            # Calculate losses
            # Reconstruction loss and KL divergence for VAE
# #             l_dis_like = torch.mean(torch.sum((features_real - features_recon)**2, dim = 1))
# #             l_prior = -0.01 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

#             print(disc_real)
            l_gan = torch.mean(torch.log(disc_real) + torch.log(1 - disc_recon) + torch.log(1-disc_prior))
            l_class = cross_entropy_loss(class_real, labels_one_hot) 

#             l_gan = torch.sum(-torch.log(disc_real) + torch.log(disc_prior))
#             enc_loss = l_dis_like + l_prior 
#             dec_loss = l_dis_like + l_gan
            disc_loss = 0.7 * -l_gan + 0.3 * l_class
            
            try:
#                 enc_loss.backward(retain_graph = True)
#                 dec_loss.backward(retain_graph = True)
                disc_loss.backward()
            
    
#                 optimizer_enc.step()
#                 optimizer_dec.step()
                optimizer_dis.step()
                
            except: 
                
                print('disc_real:', torch.mean(disc_real))
                print('disc_recon:', torch.mean(disc_recon))
                print('disc_prior:', torch.mean(disc_prior))
                print('true_class_prob:', torch.exp(-l_class))
                print('\n')
                
                epsilon = 1e-4
                for param in model.discriminator.parameters():
                    if param.requires_grad:
                        noise = torch.randn_like(param) * epsilon
                        param.data += noise

# Calculate the total norm of all gradients

            if batch_idx % 500 == 0: 
#                 print(l_gan.item())
                print('disc_real:', torch.mean(disc_real))
                print('disc_recon:', torch.mean( disc_recon))
                print('disc_prior:',torch.mean( disc_prior))
                print('true_class_prob:', torch.exp(-l_class))
                print('\n')
                
#             try:    
#                 disc_loss.backward()
                
#             except: 
                
#                 print('disc_real:', torch.mean(disc_real))
#                 print('disc_recon:', torch.mean(disc_recon))
#                 print('disc_prior:',torch.mean(disc_prior))
#                 print('true_class_prob:', torch.exp(-l_class))
#                 print("-l_gan, l_class:", -l_gan.item(), " ", l_class.item())
#                 print('\n')
            
#             total_norm = 0
#             for name, param in model.discriminator.named_parameters():
#                 if param.grad is not None:
#                     total_norm += torch.norm(param.grad).item() ** 2
#             total_norm = total_norm ** 0.5  # Take the square root to get the total norm

#             # Clip gradients if the total norm exceeds a threshold
#             max_norm = 10 * (0.99)**epoch # Adjust as needed
#             if total_norm > max_norm:
#                 for name, param in model.discriminator.named_parameters():
#                     if param.grad is not None:
#                         param.grad *= max_norm / total_norm
                        
#             if batch_idx % 50 == 0: 
#                 total_norm = 0
#                 for name, param in model.discriminator.named_parameters():
#                     if param.grad is not None:
#                         total_norm += torch.norm(param.grad).item() ** 2
                        
#                 total_norm = total_norm ** 0.5  # Take the square root to get the total norm
#                 print('grad_norm:', total_norm)
                
            
            imgs = imgs.to('cpu')
        
            
            if batch_idx % 1000 == 0:

                print("epoch:{}".format(epoch))
                print("-l_gan, l_class:", -l_gan.item(), " ", l_class.item())
                print('Original Images')
                imshow(torchvision.utils.make_grid(imgs[:10]))

                # Pass images through the model to get the reconstructions
                with torch.no_grad():

                    imgs = imgs.to(device)
                    mu, log_var = model.encoder(imgs)
                    noise = torch.rand_like(log_var)
                    z = mu + noise * torch.exp(0.5 * log_var)
                    recon_images = model.decoder(z)
                    recon_images = recon_images.cpu()
                    imgs = imgs.to('cpu')

                print('Reconstructed Images')
                imshow(torchvision.utils.make_grid(recon_images[:10]))

        if epoch%10 == 0:
            torch.save(model.discriminator.state_dict(), 'disc_multi.pt')

In [80]:
# Initialize the models and optimizers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
encoder = Encoder().to(device)
decoder = Decoder().to(device)
# discriminator = Discriminator().to(device)

In [81]:
# vae_gan = VAE_GAN(encoder, decoder, discriminator).to(device)
optimizer_enc = Adam(encoder.parameters(), lr = 3e-5)
optimizer_dec = Adam(decoder.parameters(), lr = 3e-5)
# optimizer_dis = Adam(discriminator.parameters(), lr=6e-5)

In [82]:
discriminator_multi = Discriminator_multi().to(device)
vae_gan_multi = VAE_GAN_multi(encoder, decoder, discriminator_multi).to(device)
optimizer_dis = Adam(discriminator_multi.parameters(), lr = 3e-5)
data_loader = DataLoader(train_data, batch_size = 4, shuffle = True)

In [None]:
vae_gan_multi.encoder.eval()
vae_gan_multi.decoder.eval()
vae_gan_multi.discriminator.train()
train_multi(vae_gan_multi, data_loader, optimizer_dis, None, None,  epochs = 100)

In [None]:
vae_gan_multi.eval()
# vae_gan_multi.discriminator.train()
# vae_gan_multi.to('cpu')
for imgs, labels in test_loader:
    imgs = imgs.to(device)
    features_real, features_recon, mu, log_var, disc_real, disc_recon, disc_prior, class_real = vae_gan_multi(imgs)
    print(torch.mean(disc_real), torch.mean(disc_recon), torch.mean(disc_prior))
    imgs = imgs.to('cpu')
    print('\n')

In [None]:
# for imgs, _ in test_loader: 
#     with torch.no_grad():

#         imgs = imgs.to(device)
#         mu, log_var = encoder(imgs)
#         noise = torch.rand_like(log_var)
#         z = mu + noise * torch.exp(0.5 * log_var)
#         recon_images = decoder(z)
#         recon_images = recon_images.cpu()
#         imgs = imgs.to('cpu')

#     print('Reconstructed Images')
#     imshow(torchvision.utils.make_grid(recon_images[:10]))

In [84]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

# Assuming `train_dataset` is a PyTorch dataset where each item is a tuple (image, label)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Stack all images into a single tensor
all_images = torch.stack([train_data[i][0].squeeze().flatten() for i in range(len(train_data))]).to(device)
num_items = all_images.size(0)

closest_indices = []

# Compute pairwise distances using broadcasting
with torch.no_grad():
    # Compute the L2 distance between each pair of ""images
    distances = torch.cdist(all_images, all_images, p=2)  # Shape: (num_items, num_items)

In [85]:
with torch.no_grad():   
    for i in tqdm(range(num_items)):
        # Exclude the distance to itself by setting it to a high value
        distances[i, i] = float('inf')
        # Get the indices of the 10 smallest distances
        _, indices = torch.topk(distances[i], 10, largest=False)
        indices = indices.to('cpu')
        closest_indices.append(indices.cpu().tolist())

all_images = all_images.to('cpu')  # Move the tensor back to CPU if needed
distances = distances.to('cpu') 

100%|██████████| 50000/50000 [00:07<00:00, 6372.24it/s]


In [88]:
import numpy as np
from scipy.linalg import lstsq

def approximate_target_matrix(target, matrices):
    # Flatten the target matrix
    target_flat = target.flatten()
    
    # Stack flattened matrices as columns
    matrix_flat_stack = np.stack([matrix.flatten() for matrix in matrices], axis=1)
    
    # Solve the least squares problem
    coefficients, _, _, _ = lstsq(matrix_flat_stack, target_flat)
    
    return coefficients

coefficients = []
for i in tqdm(range(num_items)):
    
    present = train_data[i][0].squeeze().numpy()
    indices = closest_indices[i] 
    consts = []
    
    for index in indices:
        consts.append(train_data[index][0].squeeze().numpy()) 
        
    coeffs = approximate_target_matrix(present, consts)
    coefficients.append(coeffs)

100%|██████████| 50000/50000 [00:46<00:00, 1086.14it/s]


In [95]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class CustomImageDataset(Dataset):
    
    def __init__(self, train_dataset, closest_indices, coefficients):
        """
        Args:
            images (list of tensors): List of image tensors.
            closest_indices (list of lists): List where each element is a list of indices of the closest images.
            coefficients (list of lists): List where each element is a list of coefficients for the closest images.
        """
        self.dataset = train_dataset
        self.closest_indices = closest_indices
        self.coefficients = coefficients

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        
        item = self.dataset[idx]
        closest_images = [self.dataset[i][0] for i in self.closest_indices[idx]]
        coeffs = self.coefficients[idx]

        closest_images_tensor = torch.stack(closest_images)
        coeffs_tensor = torch.tensor(coeffs, dtype=torch.float)
        
        image = item[0]
        label = item[1]
        
        return image, label, closest_images_tensor, coeffs_tensor

In [99]:
BATCH_SIZE = 24
new_dataset = CustomImageDataset(train_data, closest_indices, coefficients)
new_dataloader = DataLoader(new_dataset, shuffle = True, batch_size = BATCH_SIZE)

In [None]:
# import gc
# gc.collect()

In [None]:
# import torch
# import gc

# def get_cuda_tensors():
#     cuda_tensors = []
#     for obj in gc.get_objects():
#         try:
#             if torch.is_tensor(obj) and obj.is_cuda:
#                 cuda_tensors.append(obj)
#         except:
#             pass
#     return cuda_tensors

# # Usage example
# cuda_tensors = get_cuda_tensors()
# for tensor in cuda_tensors:
#     print(tensor.size(), tensor.device)

In [None]:
epochs = 30
criterion = nn.MSELoss(reduction = 'sum')
for epoch in range(epochs):
    
    total_loss = 0
    count = 0
    for image, _ , closest_images_tensor, coeffs_tensor in new_dataloader:
        
        optimizer_enc.zero_grad()
        optimizer_dec.zero_grad()
        
        closest_images_tensor = closest_images_tensor.to(device)
        coeffs_tensor = coeffs_tensor.to(device)
        image = image.to(device)
        
        mu, log_var = encoder(image)
        noise = torch.rand_like(log_var)
        my_z = mu + noise * torch.exp(0.5 * log_var) #[24, 10]
        recons = decoder(my_z)
        
        # closest_images_tensor of shape [24, 10, 3, 32, 32]
        batch_size = closest_images_tensor.shape[0]
        dost_sums = []
        
        for i in range(batch_size):
            sub_batch = closest_images_tensor[i, :, : , : , :] # [1, 10, 3, 32, 32]
            sub_batch = sub_batch.squeeze(0) # [10, 32, 32, 32] 
            mu, log_var = encoder(sub_batch) 
            noise = torch.rand_like(log_var)
            z = mu + noise * torch.exp(0.5 * log_var) # [128, 10]
            dost_recons_sum = z.T @ coeffs_tensor[i] # [128, 1]
            dost_sums.append(dost_recons_sum.unsqueeze(1).T) 
        
        dost_recons_sum = torch.stack(dost_sums).squeeze() #[batch_size, 128]
        
        loss1 = criterion(image, recons)/BATCH_SIZE
        loss2 = torch.mean(torch.sum((my_z - dost_recons_sum)**2, axis = 1))
        loss = loss1 + 0.1*loss2 - 5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        loss.backward()
        optimizer_enc.step()
        optimizer_dec.step()
        total_loss += loss
        count +=1
        
        closest_images_tensor = closest_images_tensor.to('cpu')
        coeffs_tensor = coeffs_tensor.to('cpu')
        image = image.to('cpu')
        
    print('Original Images')
    imshow(torchvision.utils.make_grid(image[:10]))

    # Pass images through the model to get the reconstructions
    with torch.no_grad():

        image = image.to(device)
        mu, log_var = encoder(image)
        noise = torch.rand_like(log_var)
        z = mu + noise * torch.exp(0.5 * log_var)
        recon_images = decoder(z)
        recon_images = recon_images.cpu()
        image = image.to(device)

    print('Reconstructed Images')
    imshow(torchvision.utils.make_grid(recon_images[:10]))
    print("The epoch is {} and the loss is {}".format(epoch, total_loss/count))

In [None]:
# torch.save(encoder.state_dict(), 'encoder.pt')
# torch.save(decoder.state_dict(), 'decoder.pt')