In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from tqdm import tqdm  

from ldm.models.autoencoder import VQModel
from taming.modules.losses.vqperceptual import VQLPIPSWithDiscriminator


config = {
    'base_learning_rate': 1.0e-6,
    'params': {
        'embed_dim': 3,
        'n_embed': 8192,
        'monitor': 'val/rec_loss',
        'ddconfig': {
            'double_z': False,
            'z_channels': 3,
            'resolution': 256,
            'in_channels': 3,
            'out_ch': 3,
            'ch': 128,
            'ch_mult': [1, 2, 4],
            'num_res_blocks': 2,
            'attn_resolutions': [],
            'dropout': 0.0
        },
        'lossconfig': {
            'target': 'taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator',
            'params': {
                'disc_conditional': False,
                'disc_in_channels': 3,
                'disc_start': 0,
                'disc_weight': 0.75,
                'codebook_weight': 1.0
            }
        }
    }
}



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"GPU Name: {device_name}")
else:
    print("No GPU available, using CPU.")

GPU Name: NVIDIA GeForce RTX 4090


In [None]:

transform2 = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),  
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
])


class RealPalmDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        
        for subfolder2 in os.listdir(root_dir):
            subfolder2_path = os.path.join(root_dir, subfolder2)
            if os.path.isdir(subfolder2_path):
                for filenameB in os.listdir(subfolder2_path):
                    image_path = os.path.join(subfolder2_path, filenameB)
                    if os.path.isfile(image_path):
                        self.image_paths.append(image_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  
        if self.transform:
            image = self.transform(image)
        return image


real_image_folder = './RealPalm'


dataset_real_palm_B = RealPalmDataset(real_image_folder, transform=transform2)

train_loader = DataLoader(dataset_real_palm_B, 
                          batch_size=6, 
                          shuffle=True, 
                          num_workers=8, 
                          pin_memory=True)

In [None]:
mean = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)
std = torch.tensor([0.5, 0.5, 0.5]).view(1, 3, 1, 1).to(device)


save_dir1 = os.path.join('./latent-diffusion-final', 'input_fine')
save_dir2 = os.path.join('./latent-diffusion-final', 'output_fine')

os.makedirs(save_dir1, exist_ok=True)
os.makedirs(save_dir2, exist_ok=True)

In [None]:

model = VQModel(**config['params'])

checkpoint = torch.load('VQf4model.ckpt', map_location=device)

model.load_state_dict(checkpoint['state_dict'])


model.to(device)


optimizer = torch.optim.AdamW(model.parameters(), lr=config['base_learning_rate'])


def train(model, dataloader, optimizer, num_epochs):
    model.train()
    global_step = 0
    for epoch in tqdm(range(num_epochs), desc='Epochs'):
        for images in dataloader:
            images = images.to(device)

            optimizer.zero_grad()
            
            reconstructions, codebook_loss = model(images)
            
            loss = torch.mean(torch.abs(images - reconstructions)) + codebook_loss.mean()
            
            
            loss.backward()
            optimizer.step()
            
            global_step += 1
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
        
    
        torch.save(model.state_dict(), 'vqmodel_checkpoint.ckpt')
        print("Model parameters saved to vqmodel_checkpoint.ckpt")
        
        
        if epoch in [4, 9]:  
            for i, (input_i, output_i) in enumerate(zip(images, reconstructions)):
                
                input_i = input_i * std + mean
                output_i = output_i * std + mean

                input = transforms.ToPILImage()(input_i.cpu().squeeze())
                output = transforms.ToPILImage()(output_i.cpu().squeeze())

                
                name_B1 = f'input{epoch+1}_pair{i+1}.png'
                name_B2 = f'output{epoch+1}_pair{i+1}.png'

                save_path1 = os.path.join(save_dir1, name_B1)
                save_path2 = os.path.join(save_dir2, name_B2)

                input.save(save_path1)
                output.save(save_path2)


making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 3, 64, 64) = 12288 dimensions.
making attention of type 'vanilla' with 512 in_channels


  self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
  checkpoint = torch.load('VQf4model.ckpt', map_location=device)


loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth
VQLPIPSWithDiscriminator running with hinge loss.


In [None]:

model.to(device)


num_epochs = 5
train(model, train_loader, optimizer, num_epochs)

Epochs:  20%|██        | 1/5 [49:44<3:18:56, 2984.25s/it]

Epoch [1/5], Loss: 0.0140


Epochs:  40%|████      | 2/5 [1:39:30<2:29:16, 2985.36s/it]

Epoch [2/5], Loss: 0.0126


Epochs:  60%|██████    | 3/5 [2:29:15<1:39:30, 2985.28s/it]

Epoch [3/5], Loss: 0.0112
