In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


INIT

In [2]:
import os
import torch

# Specify the path you want to change to
new_directory = "/content/drive/MyDrive/Stable_diff"

# Change the current working directory
os.chdir(new_directory)

# Confirm the directory has changed
print("Current Directory:", os.getcwd())

DEVICE = "cpu"

ALLOW_CUDA = True
ALLOW_MPS = True

if torch.cuda.is_available() and ALLOW_CUDA:
    DEVICE = "cuda"
elif (torch.backends.mps.is_available()) and ALLOW_MPS:
    DEVICE = "mps"
print(f"Using device: {DEVICE}")

Current Directory: /content/drive/MyDrive/Stable_diff
Using device: cuda


In [3]:
def load_model(model, save_dir, device):
    """
    Load model weights from the specified directory.
    Args:
        models (dict): Dictionary of model instances.
        save_dir (str): Directory containing saved model weights.
        device (str): Device to load models onto.
    """
    for model_name, model in models.items():
        model_path = os.path.join(save_dir, f"{model_name}.pt")
        if os.path.exists(model_path):
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Loaded {model_name} from {model_path}")
        else:
            print(f"No saved weights found for {model_name} at {model_path}")
    return models

In [4]:
import random

from encoder import VAE_Encoder
from decoder import VAE_Decoder
from diffusion import Diffusion
from torch import nn, optim
from transformers import CLIPTokenizer
import torchvision.transforms as transforms
from vae_model import VAE

import attention
import ddpm
import pipeline
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset

import model_loader
from pipeline import train_model

from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        # Convert to grayscale ('L' mode)
        image = Image.open(image_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image

HEIGHT, WIDTH = 256, 256

transform = transforms.Compose([
    transforms.Resize((HEIGHT, WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] for 1 channel
])

pwd = os.getcwd()
dataset = ImageDataset(image_dir=f'/content/drive/MyDrive/Stable_diff/complete_tdata/expanded_dataset', transform=transform)

batch_size = 2
learning_rate = 1e-4
momentum = 0.9
num_epochs = 25
n_timesteps = 1000  # num diffusion steps

dataset_size = len(dataset) // 3
train_size = int(0.6*dataset_size)
val_size = dataset_size - train_size

indices = random.sample(range(len(dataset)), dataset_size)
small_dataset = Subset(dataset, indices)
dataloader = DataLoader(small_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

indices = list(range(dataset_size))
random.shuffle(indices)

train_sampler = SubsetRandomSampler(indices[:train_size])
val_sampler = SubsetRandomSampler(indices[train_size:])

train_dataloader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=2)
val_dataloader = DataLoader(dataset, batch_size=batch_size, sampler=val_sampler, num_workers=2)


vae_path = '/content/drive/MyDrive/Stable_diff/train_vae/models/best_model-epoch:6-loss:0.348914.pth'
# models = load_models(models, models_path_dem, DEVICE)

VAE = VAE().to(DEVICE)

VAE.load_state_dict(torch.load(vae_path, map_location=DEVICE))

encoder = VAE_Encoder()
decoder = VAE_Decoder()

encoder.load_state_dict(VAE.encoder.state_dict())
decoder.load_state_dict(VAE.decoder.state_dict())

# model_file = "/content/drive/MyDrive/Stable_diff/models/v1-5-pruned-emaonly.ckpt"
# models = model_loader.preload_models_from_standard_weights(model_file, DEVICE)

diffusion = Diffusion()

# diff_path = '/content/drive/MyDrive/Stable_diff/models/sd_models/best_diff_model-epoch:1-loss:0.010581.pth'
diff_path = '/content/drive/MyDrive/Stable_diff/complete_tdata/overfitt_debug/best_diff_model-epoch:400-loss:5.783495_NG.pth'
# diffusion.load_state_dict(torch.load(diff_path, map_location=DEVICE))

models = {
    'encoder': encoder,
    'decoder': decoder,
    'diffusion': diffusion
}

for model in models.values():
    model.to(DEVICE)

params = []
for model in models.values():
    params += list(model.parameters())

optimizer = optim.SGD(
    [
        {'params': encoder.parameters()},
        {'params': decoder.parameters()},
        {'params': diffusion.parameters()}
    ],
    lr=learning_rate,
    momentum=momentum
)

criterion = torch.nn.MSELoss(reduction='mean')

train_model(
    models=models,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    optimizer=optimizer,
    criterion=criterion,
    n_epochs=num_epochs,
    num_timesteps=n_timesteps,
    device=DEVICE,
    accumulation_steps=6,
    mini_patience=3,
    full_patience=10,
)

TypeError: VAE_Decoder.__init__() missing 3 required positional arguments: 'stg1_res', 'stg2_res', and 'stg3_res'

In [None]:
import model_loader
import pipeline
from PIL import Image
from torchvision.transforms.functional import to_pil_image
from pathlib import Path
import torch
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from vae_model import VAE
from diffusion import Diffusion
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import torchvision.transforms as transforms
import random


class ImageDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_filenames = sorted(os.listdir(image_dir))
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_filenames[idx])
        # Convert to grayscale ('L' mode)
        image = Image.open(image_path).convert('L')
        if self.transform:
            image = self.transform(image)
        return image

HEIGHT, WIDTH = 256, 256

transform = transforms.Compose([
    transforms.Resize((HEIGHT, WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1] for 1 channel
])


device = "cpu"

ALLOW_CUDA = True
ALLOW_MPS = True

if torch.cuda.is_available() and ALLOW_CUDA:
    device = "cuda"

print(f"Using device: {device}")


# diff_file = "/content/drive/MyDrive/Stable_diff/models/sd_models/best_diff_model-epoch:7-loss:0.001274.pth"
diff_file = "/content/drive/MyDrive/Stable_diff/models/sd_models/best_diff_model-epoch:11-loss:0.084444.pth"

# models = model_loader.preload_models_from_standard_weights(model_file, device)

diffusion = Diffusion().to(device)
diffusion.load_state_dict(torch.load(diff_file, map_location=device))

vae_path = '/content/drive/MyDrive/Stable_diff/models/best_model-epoch:5-loss:0.003076.pth'
# diffusion = load_model(diffusion, diff_file, DEVICE)
##############################
VAE = VAE().to(device)

VAE.load_state_dict(torch.load(vae_path, map_location=device))

encoder = VAE_Encoder()
decoder = VAE_Decoder()

enc_path = '/content/drive/MyDrive/Stable_diff/models/sd_models/best_enc_model-epoch:1-loss:0.132742.pth'
dec_path = '/content/drive/MyDrive/Stable_diff/models/sd_models/best_dec_model-epoch:1-loss:0.132742.pth'

encoder.load_state_dict(VAE.encoder.state_dict())
decoder.load_state_dict(VAE.decoder.state_dict())

# encoder.load_state_dict(torch.load(enc_path, map_location=device))
# decoder.load_state_dict(torch.load(dec_path, map_location=device))

models = {
    'encoder': encoder,
    'decoder': decoder,
    'diffusion': diffusion
}
##################################
# ## IMAGE TO IMAGE

# Comment to disable image to image
dataset = ImageDataset(image_dir=f'/content/drive/MyDrive/Stable_diff/complete_tdata/expanded_dataset', transform=transform)
# image_path = "/content/drive/MyDrive/Stable_diff/complete_tdata/overfitt_debug/wop/image_071_tile_41.png"
# input_image = Image.open(image_path)

random_index = random.randint(0, len(dataset) - 1)
input_image = dataset[random_index]

# Higher values means more noise will be added to the input image, so the result will further from the input image.
# Lower values means less noise is added to the input image, so output will be closer to the input image.
strength = 1.0

input_image = None

## SAMPLER

sampler = "ddpm"
num_inference_steps = 500
seed = 42

output_image = pipeline.generate(
    input_image=input_image,
    strength=strength,
    sampler_name=sampler,
    n_inference_steps=num_inference_steps,
    seed=seed,
    models=models,
    device=device,
    idle_device="cpu"
)

# Combine the input image and the output image into a single image.
# output_image = output_image.astype(np.uint8)
# output_image = np.squeeze(output_image)
# Image.fromarray(output_image)


import matplotlib.pyplot as plt

# 1) Convert input_image (a (1,H,W) in [-1,1]) to [0..255] for display
input_image_np = input_image.cpu().numpy()  # shape (1,H,W)
# scale from [-1..1] -> [0..255]
input_image_np = np.clip((input_image_np + 1.0) * 127.5, 0, 255).astype(np.uint8)
# remove channel dimension
input_image_np = np.squeeze(input_image_np, axis=0)  # shape (H,W)

# 2) Output image is already uint8 [0..255], but ensure shape is (H,W)
output_image = output_image.astype(np.uint8)
output_image = np.squeeze(output_image)

# 3) Display side by side with matplotlib
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(input_image_np, cmap='gray')
plt.title("Input Image")
plt.axis("off")

plt.subplot(1, 2, 2)
plt.imshow(output_image, cmap='gray')
plt.title("Generated Image")
plt.axis("off")

plt.tight_layout()
plt.show()

Using device: cuda


  diffusion.load_state_dict(torch.load(diff_file, map_location=device))
  VAE.load_state_dict(torch.load(vae_path, map_location=device))
100%|██████████| 500/500 [00:23<00:00, 21.31it/s]


RuntimeError: Given groups=1, weight of size [256, 544, 1, 1], expected input[1, 32, 32, 32] to have 544 channels, but got 32 channels instead