In [1]:
import os
import shutil
from tqdm import tqdm
import numpy as np
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from pytorch_lightning import seed_everything
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from utils import *
from ccvae import CCVAE

from torch.utils.tensorboard import SummaryWriter

  warn(
2025-06-01 19:51:27.787631: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
configs = {   
    "model_name" : "CCVAE",
    "exp" : "1",  
    "latent_dim" : 64,
    "batch_size" : 64,
    "num_epochs" : 50,
    "lr" : 1e-3,
    "scheduler" : "ReduceLROnPlateau",
    "use_scheduler" : True,
    "lambda_kld" : 1e-4,
    }


In [3]:
dataset_root = './data/AFHQ/'

transform = transforms.Compose([transforms.Resize((64,64)),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5]*3 , [0.5]*3)])

BS = configs["batch_size"]

train_dataset = datasets.ImageFolder(root= dataset_root+'train', transform= transform )
test_dataset = datasets.ImageFolder(root= dataset_root+'test', transform= transform )

# print(train_dataset.classes)  
print(train_dataset.class_to_idx)  

train_loader = DataLoader(dataset= train_dataset, 
                          batch_size= BS, 
                          shuffle= True, 
                          drop_last= True )

test_loader = DataLoader(dataset= test_dataset, 
                          batch_size= BS, 
                          shuffle= False, 
                          drop_last= True )

{'cat': 0, 'dog': 1, 'wild': 2}


In [4]:

seed_everything(42)

model = CCVAE(latent_dim=configs["latent_dim"]).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=configs["lr"], weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=7, factor=0.5, verbose=True)
model_name = configs["model_name"]+configs["exp"]+f"_KLD_{configs['lambda_kld']}"
savepath, writer = makedires(configs)

train_loss, val_loss, loss_iters, val_loss_recons, val_loss_kld = train_model(
        model=model, 
        model_name=model_name,
        optimizer=optimizer,
        scheduler=scheduler if configs["use_scheduler"] else None, 
        criterion=vae_loss_function,
        lambda_kld=configs["lambda_kld"],
        train_loader=train_loader, 
        valid_loader=test_loader, 
        num_epochs=configs["num_epochs"],
        save_frequency= 10, 
        savepath=savepath,
        writer=writer,
        constrained=True
    )

epoch = configs["num_epochs"]
lambda_kld = configs["lambda_kld"]

save_model(model, model_name, optimizer, epoch = epoch,lambda_kld = lambda_kld, stats = configs )
save_config(configs)


Seed set to 42


Epoch 1 Iter 224: loss 0.17872. : 100%|██████████| 224/224 [01:41<00:00,  2.21it/s]


    Train loss: 0.2611
    Valid loss: 0.58914
       Valid loss recons: 0.58914
       Valid loss KL-D:   0.02078


Epoch 2 Iter 224: loss 0.20266. : 100%|██████████| 224/224 [01:40<00:00,  2.23it/s]
Epoch 3 Iter 224: loss 0.18525. : 100%|██████████| 224/224 [01:39<00:00,  2.24it/s]


    Train loss: 0.1849
    Valid loss: 0.18793
       Valid loss recons: 0.18171
       Valid loss KL-D:   62.14641


Epoch 4 Iter 224: loss 0.17528. : 100%|██████████| 224/224 [01:40<00:00,  2.24it/s]
Epoch 5 Iter 224: loss 0.18472. : 100%|██████████| 224/224 [01:40<00:00,  2.24it/s]


    Train loss: 0.18102
    Valid loss: 0.1829
       Valid loss recons: 0.17619
       Valid loss KL-D:   67.06333


Epoch 6 Iter 81: loss 0.19313. :  36%|███▌      | 81/224 [00:36<01:04,  2.23it/s]


KeyboardInterrupt: 

### Images vs recons

In [None]:
# img_vs_recons(model, test_loader, device)

### Inference

In [None]:
# plot_recons(inference(configs,model))


### Latent space  visualization

In [None]:
# vis_latent(test_loader, model, test_dataset)

### Interpolation

In [None]:
# plot_reconstructed(model, xrange=(-50, 50), yrange=(-50, 50), N=15)