In [3]:
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms

import os
import time
import argparse
from glob import glob
from copy import deepcopy
import numpy as np
from vq_model import VQ_models
from vq_loss import VQLoss


In [20]:
# from torchsummary import summary

model = VQ_models['VQ-8'](
    codebook_size=16384,
    codebook_embed_dim=8,
    commit_loss_beta=0.25,
    entropy_loss_ratio=0.0,
    dropout_p=0.0,
)

# Load the weights
weights = torch.load('vq_ds8_c2i.pt')
missing_keys, unexpected_keys = model.load_state_dict(weights['model'], strict=False)

# Print missing keys
if missing_keys:
    print(f"Missing keys (weights not loaded): {missing_keys}")

# Print unexpected keys
if unexpected_keys:
    print(f"Unexpected keys (weights in checkpoint but not in model): {unexpected_keys}")

# Identify and print the uninitialized weights
for name, param in model.named_parameters():
    if name in missing_keys:
        param.requires_grad = True  # Enable gradient computation for uninitialized weights
        print(f"Uninitialized weight (requires_grad=True): {name} - shape: {param.shape}")
    else:
        param.requires_grad = False 
        
del weights

Missing keys (weights not loaded): ['encoder.conv_blocks.0.res.0.lora_conv1.lora_A.weight', 'encoder.conv_blocks.0.res.0.lora_conv1.lora_B.weight', 'encoder.conv_blocks.0.res.0.lora_conv2.lora_A.weight', 'encoder.conv_blocks.0.res.0.lora_conv2.lora_B.weight', 'encoder.conv_blocks.0.res.1.lora_conv1.lora_A.weight', 'encoder.conv_blocks.0.res.1.lora_conv1.lora_B.weight', 'encoder.conv_blocks.0.res.1.lora_conv2.lora_A.weight', 'encoder.conv_blocks.0.res.1.lora_conv2.lora_B.weight', 'encoder.conv_blocks.1.res.0.lora_conv1.lora_A.weight', 'encoder.conv_blocks.1.res.0.lora_conv1.lora_B.weight', 'encoder.conv_blocks.1.res.0.lora_conv2.lora_A.weight', 'encoder.conv_blocks.1.res.0.lora_conv2.lora_B.weight', 'encoder.conv_blocks.1.res.0.lora_nin_shortcut.lora_A.weight', 'encoder.conv_blocks.1.res.0.lora_nin_shortcut.lora_B.weight', 'encoder.conv_blocks.1.res.1.lora_conv1.lora_A.weight', 'encoder.conv_blocks.1.res.1.lora_conv1.lora_B.weight', 'encoder.conv_blocks.1.res.1.lora_conv2.lora_A.weight'

In [4]:
class HighLowResDataset(Dataset):
    def __init__(self, high_res_dir, low_res_dir, transform=None):
        """
        Args:
            high_res_dir (str): Directory with high-resolution images.
            low_res_dir (str): Directory with low-resolution images.
            transform (callable, optional): A function/transform to apply to both high and low-resolution images.
        """
        self.high_res_dir = high_res_dir
        self.low_res_dir = low_res_dir
        self.transform = transform
        
        # List of all image files in the high resolution directory
        self.high_res_files = sorted([f for f in os.listdir(high_res_dir) if os.path.isfile(os.path.join(high_res_dir, f))])
        self.low_res_files = sorted([f for f in os.listdir(low_res_dir) if os.path.isfile(os.path.join(low_res_dir, f))])
        
        if len(self.high_res_files) != len(self.low_res_files):
            raise ValueError("The number of high-resolution and low-resolution images must be the same.")
    
    def __len__(self):
        return len(self.high_res_files)
    
    def __getitem__(self, idx):
        high_res_file = self.high_res_files[idx]
        low_res_file = self.low_res_files[idx]
        
        high_res_path = os.path.join(self.high_res_dir, high_res_file)
        low_res_path = os.path.join(self.low_res_dir, low_res_file)
        
        high_res_image = Image.open(high_res_path).convert('RGB')
        low_res_image = Image.open(low_res_path).convert('RGB')
        
        if self.transform:
            high_res_image = self.transform(high_res_image)
            low_res_image = self.transform(low_res_image)
        
        return high_res_image, low_res_image

In [6]:


transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: pil_image.resize((1024, 1024), Image.BICUBIC)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])

In [23]:
dataset = HighLowResDataset(high_res_dir='archive/lol_dataset/train/high', low_res_dir='archive/lol_dataset/train/low', transform=transform)
loader = DataLoader(
        dataset,
        batch_size=2,
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )

In [22]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )
    
print_trainable_parameters(model)

trainable params: 93184 || all params: 70203019 || trainable%: 0.13


In [None]:
from PIL import Image
import random
import torch
import time
from logger import create_logger

# Ensure the scaler is set properly for mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=True)
scaler_disc = torch.cuda.amp.GradScaler(enabled=True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'


params_to_update = [param for param in model.parameters() if param.requires_grad]
vq_loss = VQLoss(
    disc_start=20000,
    disc_weight=0.5,
    disc_type='patchgan',
    disc_loss='hinge',
    gen_adv_loss='hinge',
    image_size=256,
    perceptual_weight=1.0,
    reconstruction_weight=1.0,
    reconstruction_loss='l2',
    codebook_weight=1.0,
).to(device)

model.to(device)
optimizer = torch.optim.Adam(params_to_update, lr=1e-4, betas=(0.9, 0.95))
optimizer_disc = torch.optim.Adam(vq_loss.discriminator.parameters(), lr=1e-4, betas=(0.9, 0.95))
checkpoint_dir = 'checkpoints'
experiment_dir = 'experiments'
running_loss = 0
log_steps = 0
start_time = time.time()
model.train()
vq_loss.train()
train_steps = 0
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
from tqdm import tqdm

for epoch in range(100):
    with tqdm(total=len(loader), desc=f"Epoch {epoch+1}/{100}", unit="batch") as pbar:
        for i, (high_res_image, low_res_image) in enumerate(loader):
            high_res_image = high_res_image.to(device)
            low_res_image = low_res_image.to(device)

            # Zero gradients
            optimizer.zero_grad()
            optimizer_disc.zero_grad()

            # Forward pass
            recons_imgs, codebook_loss = model(high_res_image)
            loss_gen = vq_loss(
                codebook_loss, low_res_image, recons_imgs, 
                optimizer_idx=0, global_step=train_steps + 1, 
                last_layer=model.decoder.last_layer, logger=logger, log_every=100
            )
            
            # Backward pass for generator
            scaler.scale(loss_gen).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(params_to_update, 1.0)
            scaler.step(optimizer)
            scaler.update()

            # Backward pass for discriminator
            loss_disc = vq_loss(
                codebook_loss, low_res_image, recons_imgs, 
                optimizer_idx=1, global_step=train_steps + 1, 
                logger=logger, log_every=100
            )
            scaler_disc.scale(loss_disc).backward()
            scaler_disc.unscale_(optimizer_disc)
            torch.nn.utils.clip_grad_norm_(vq_loss.discriminator.parameters(), 1.0)
            scaler_disc.step(optimizer_disc)
            scaler_disc.update()

            # Log loss values:
            running_loss += loss_gen.item() + loss_disc.item()
            log_steps += 1
            train_steps += 1

            # Update tqdm bar with the current batch information
            pbar.set_postfix(
                {"Loss Gen": f"{loss_gen.item():.4f}", "Loss Disc": f"{loss_disc.item():.4f}"}
            )
            pbar.update(1)

            if train_steps % 100 == 0:
                # Measure training speed:
                end_time = time.time()
                steps_per_sec = log_steps / (end_time - start_time)
                avg_loss = running_loss / log_steps
                logger.info(f"(step={train_steps:07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}")
                running_loss = 0
                log_steps = 0
                start_time = time.time()

            # Save checkpoint:
            if train_steps % 5000 == 0 and train_steps > 0:
                model_weight = model.state_dict()
                checkpoint = {
                    "model": model_weight,
                    "optimizer": optimizer.state_dict(),
                    "discriminator": vq_loss.discriminator.state_dict(),
                    "optimizer_disc": optimizer_disc.state_dict(),
                    "steps": train_steps,
                }
                checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
                torch.save(checkpoint, checkpoint_path)
                logger.info(f"Saved checkpoint to {checkpoint_path}")

            


[[34m2024-11-06 20:08:21[0m] Experiment directory created at experiments


loaded pretrained LPIPS loss from /media/mlr_lab/325C37DE7879ABF2/LowLIGHTQuantGAN/cache/vgg.pth


Epoch 1/100:   0%|          | 0/243 [00:00<?, ?batch/s]


RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 23.64 GiB total capacity; 13.10 GiB already allocated; 175.81 MiB free; 13.20 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

: 