In [1]:
import numpy as np
import argparse
import tqdm
import torch, torchvision
import torch.nn.functional as F
from diffusers import DDIMScheduler, DiTTransformer2DModel, AutoencoderKL
from matplotlib import pyplot as plt
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms, datasets
import os
from torchvision.datasets import ImageFolder
from typing import Any, Callable, cast, Optional, Union

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")

In [2]:
from google.colab import drive
drive.mount('/content/drive/')
#!unzip -q "/content/drive/MyDrive/DS_MVTec.zip" -d "/content/drive/MyDrive/DefectSpectrum"
!ls /content/drive/MyDrive/DefectSpectrum

Mounted at /content/drive/
DS-MVTec
hazelnut_best_samples_119Epoch_param_eff.png
samples_step_18_epoch_109.png
samples_step_36_epoch_19.png
samples_step_36_epoch_9.png
screw_best_samples_29epoch_fullfinetune.png


In [3]:
def show_images(im_batch):
    #Unnormalize from [-1, 1] to [0,1]
    im_batch = im_batch*0.5 + 0.5
    grid = torchvision.utils.make_grid(im_batch)
    grid = grid.detach().cpu().permute(1,2,0) * 255
    grid_im = Image.fromarray(np.array(grid).astype(np.uint8))
    return grid_im
print(device)

cuda


In [12]:
def has_file_allowed_extension(filename: str, extensions):
    """Checks if a file is an allowed extension. Helper function for the dataset loader
    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)
    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))

class ImageFolderSelectSubset(ImageFolder):
    """A custom data loader.
    It selects specified number of samples from each class.
    If a class does not have those samples then all available samples are selected.
    Args:
        max_samples_per_class (int) : Maximum number of samples to be collected per class
        For each class, the number of samples will be min(available_samples, max_samples_per_class)
    """
    def __init__(
        self,
        root,
        extensions = IMG_EXTENSIONS,
        transform = None,
        target_transform = None,
        is_valid_file = None,
        allow_empty = False,
        max_samples_per_class = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        self.root = root
        self.transform = transform
        classes, class_to_idx = self.find_classes(self.root)
        #
        self.max_samples_per_class = max_samples_per_class
        samples = self.make_dataset(
            self.root,
            class_to_idx=class_to_idx,
            extensions=extensions,
            is_valid_file=is_valid_file,
            allow_empty=allow_empty,
            max_samples_per_class = self.max_samples_per_class
        )
        self.extensions = extensions
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
    directory,
    class_to_idx = None,
    extensions = None,
    is_valid_file = None,
    allow_empty = False,
    max_samples_per_class = None,
    ) :
        """Generates a list of samples of a form (path_to_sample, class).
        See :class:`DatasetFolder` for details.
        We override this method to select only max_samples_per_class samples for each class.
        Returns : list[tuple[str, int]]
        """
        directory = os.path.expanduser(directory)

        if class_to_idx is None:
            _, class_to_idx = self.find_classes(directory)
        elif not class_to_idx:
            raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

        both_none = extensions is None and is_valid_file is None
        both_something = extensions is not None and is_valid_file is not None
        if both_none or both_something:
            raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

        if extensions is not None:

            def is_valid_file(x: str) -> bool:
                return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

        is_valid_file = cast(Callable[[str], bool], is_valid_file)

        instances = []
        available_classes = set()
        #samples_per_class = 1
        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(directory, target_class)
            if not os.path.isdir(target_dir):
                continue
            #Keeps track of the number of samples (file_paths) collected for current class.
            samples_per_class = 0
            for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(root, fname)
                    if is_valid_file(path):
                        item = path, class_index
                        instances.append(item)

                        if target_class not in available_classes:
                            available_classes.add(target_class)
                        ###Check if max_samples has been collected for this class.
                        samples_per_class +=1
                        if (max_samples_per_class is not None
                        and samples_per_class >= max_samples_per_class):
                            #If yes, then stop collecting samples. Move to next class.
                            break

        empty_classes = set(class_to_idx.keys()) - available_classes
        if empty_classes and not allow_empty:
            msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
            if extensions is not None:
                msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
            raise FileNotFoundError(msg)

        return instances


In [20]:
def dataloader(image_size, batch_size, part_type, max_samples_per_type):
    """ Data loader for defectspectrum
    Args :
        image_size (int) : The input image size
        batch_size (int) : Batch size
        part_type (str) : The product or part type from DefectSpectrum
        max_samples_per_type : Maximum no of images to be loaded per defect type of the part
    Retruns :
        dataset (torchvision.datasets.ImageFolder or ImageFolderSelectSubset) : The dataset object
        dataloader (torch.utils.DataLoader) : A dataloader for the dataset
     """
    preprocess = transforms.Compose(
        [
            #transforms.CenterCrop(800),
            transforms.Resize((image_size,image_size)),
            #transforms.ColorJitter(brightness=(0.5,1.5),contrast=(3),saturation=(0.3,1.5),hue=(-0.1,0.1)),
            #transforms.RandomRotation([45,120], interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            #transforms.RandomEqualize(),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]), # From [0,1] to [-1,1] Normalization
        ]
    )

    #Direct access from hf
    #dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")
            #DefectSpectrum/Defect_Spectrum")
    #In case the dataset is locally stored
    dataset = ImageFolderSelectSubset(
            os.path.join('/content/drive/MyDrive/DefectSpectrum/DS-MVTec',
            part_type, 'image'),
            transform=preprocess,
            max_samples_per_class = max_samples_per_type
            )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    return dataset, dataloader


In [21]:
def sample_diffusion(diffusion_model, vae_model, vae_latent_size, noise_scheduler, num_classes, num_samples_per_class):
    """
    Sample from a (trained) latent diffusion model.
    :param diffusion_model : The pretrained latent diffusion model
    :param vae_model : The pretrained VAE model to map images to latent space
    :param vae_latent_size : Size of the VAE latent - also the input size of diffusion model
    :param noise_scheduler : The noise scheduler used by the diffusion model
    :param num_classes : Number of different class labels for conditional generation
    :param num_samples_per_class : The no of samples to be generated for each class

    :return A grid image (PIL) of the generated samples
    """
    labels = [l for l in range(num_classes) for _ in range(num_samples_per_class)]
    labels = torch.tensor(labels, device = device).long()
    labels = labels + 783  #for Imagenet adjustment
    #################################################
    """
    Closest class to "screw" from Imagent - 783
    Closest class to "hazelnut" - 990
    """
    #################################################
    z = torch.randn(
            num_samples_per_class,
            vae_model.config.latent_channels,
            vae_latent_size,
            vae_latent_size,
            device=device
            )
    #Replicate the noise vectors for all class labels
    z = z.repeat(num_classes, 1,1,1)

    #Sampling loop with the DDIMSampler
    for i, t in enumerate(noise_scheduler.timesteps):
        model_input = noise_scheduler.scale_model_input(z,t)
        timestep = t[None].to(device)
        #timestep = timestep.long()
        with torch.no_grad():
            #with torch.autocast(device_type='cpu'):
            noise_pred = diffusion_model(model_input,timestep,labels,return_dict=False)[0]
        mean_noise_pred = noise_pred[:,:4,:,:]
        z = noise_scheduler.step(mean_noise_pred, t, z).prev_sample
        #print('One dnoising step completed')

    #print('Sampling done. Got latents of shape {}'.format(z.shape))
    #Make a grid with num_samples_per_class images in each row, one row per class
    #print(vae_model.config.scaling_factor)
    x = vae_model.decode(z/vae_model.config.scaling_factor).sample
    #print('Finally decoded input of shape ', x.shape)
    grid = torchvision.utils.make_grid(x,nrow=num_classes)
    im = grid.permute(1,2,0).cpu().clip(-1,1)*0.5 + 0.5
    im = Image.fromarray(np.array(im*255).astype(np.uint8))
    return im


In [22]:
def mark_only_biases_as_trainable(model: torch.nn.Module, is_bitfit=False):
    if is_bitfit:
        #Original BitFit only tunes biases. We also tune norm layers and label embedding
        trainable_names = ['bias', 'norm', 'y_embed']
    else:
        trainable_names = ["bias","norm","gamma","y_embed"]

    for par_name, par_tensor in model.named_parameters():
        par_tensor.requires_grad = any([kw in par_name for kw in trainable_names])
    return model

In [23]:
def finetune(
        full_fine_tune=False,
        image_size=256,
        #noise_scheduler_model = 'google/ddpm-celebahq-256',
        part_type='hazelnut',
        max_img_per_defect=None,
        num_epochs=1,
        lr=1e-5,
        batch_size=4,
        grad_acc_steps=2,
        wandb_project='defectDiT_finetune',
        ckpt_every=100,
        ckpt_dir='/content/drive/MyDrive/DefectSpectrum',
        log_samples_every=10,
        patience=5
        ):
    """
    Fine tune a latent diffusion DiT model with a defect images of a particular part type from DefectSpectrum data.
    :param full_fine_tune : To do full fine tuning if True, otherwise parameter efficient fine tuning (default False)
    :param image_size : The input image size (default 256x256)
    :param part_type : The selected part type from DefectSpectrum (e.g, zipper, pill,.. Default hazelnut)
    :param max_img_per_defect : Maximum no of images to be loaded per defect type of the part (default None)
    :param num_epochs : Number of epochs (default 1)
    :param lr : Learning rate (default 1e-5, we don't use LR decay)
    :param batch_size : Batch size (default = 4 to make it memory efficient)
    :param grad_acc_steps : Number steps for which gradient is accumulated before updating wts - to account for small batch
    :param wandb_project : The id of wandb project where intermediate models and outputs are logged (not used currently)
    :param ckpt_every : Frequency of checkpointing the model (no of epochs)
    :param ckpt_dir : currently the models are saved locally, so this parameter specifies the directory
    :param log_samples_every : Frequency of saving generating samples (no of epochs)

    :return A list of average loss values per epoch
    """

    #Initialize the wandb project to log the samples and checkpoints during training
    #TODO Enable this if it works in colab, otherwise save locally
    #wandb.init(wandb_project, config=locals())

    #Define the dataloader
    db, dl = dataloader(image_size=image_size,batch_size=batch_size,part_type=part_type,max_samples_per_type=max_img_per_defect)
    classes = db.classes
    num_classes = len(db.classes)
    #total_num_classes = 1000 + num_classes #Add the new classes to the 1000 from ImageNet
    print("Prepared the dataloader with size", len(dl))

    #A fast scheduler to trade-off fidelity with sampling speed
    scheduler=DDIMScheduler.from_pretrained('google/ddpm-celebahq-256')
    scheduler.set_timesteps(num_inference_steps=200)

    #The base DiT model. This is a latent diffusion model - i.e. it operates in the latent space of a VAE
    # We add our new class labels to the 1000 of ImageNet, changing the attribute num_embeds_ada_norm
    diffusion_model = DiTTransformer2DModel.from_pretrained(
            'facebook/DiT-XL-2-256', subfolder='transformer',
            #num_embeds_ada_norm=num_classes,
            dropout = 0.2,
            #low_cpu_mem_usage=False, ignore_mismatched_sizes=True
            ).to(device)
    #Parameter efficient tuning - only tune biases, norm and label embeddings
    diffusion_model.train()
    if (not full_fine_tune):
      diffusion_model = mark_only_biases_as_trainable(diffusion_model, is_bitfit=True)

    #Check if the trainable params are set correctly
    #print("Non-Trainable Params")
    #for name, param in diffusion_model.named_parameters():
        #if param.requires_grad:
            #print(name)
    # The VAE
    vae_model = AutoencoderKL.from_pretrained(
            'facebook/DiT-XL-2-256', subfolder='vae').to(device)
    #TODO How to get this value from the model (is it vae.config.norm_num_groups)?
    vae_latent_size=image_size//8


    #Use a very small learning rate, since we have a very small dataset and a small batch size (gradients may be noisy)
    optimizer = torch.optim.AdamW(diffusion_model.parameters(),lr=lr, weight_decay=0.0)
    #Loss history, for posterior analysis or debugging
    losses=[]

    #Fine tuning loop
    min_loss_epoch = 0
    min_loss_till_now = 1e6
    for epoch in range(num_epochs):
        print("Training epoch", epoch)
        for step, batch in tqdm.tqdm(enumerate(dl), total=len(dl)):
            train_images = batch[0].to(device)
            train_labels = batch[1].to(device)
            train_labels = train_labels + 783
            """Closest Imagent classes to our chosen product classes
            Cloasest class for "screw" - 783
            Closest class to "hazelnut" - 990

            """
            #One hot
            #train_labels = F.one_hot(train_labels, num_classes=total_num_classes)

            #Get the latent representations of images from the VAE model - Diffusion will operate in this space
            #We are not finetuning the VAE model (its generic, learns reps for any image). so, no_grad()
            with torch.no_grad():
                image_latents = vae_model.encode(train_images).latent_dist.sample()
                image_latents = image_latents * vae_model.config.scaling_factor
            #print("Got vae latents with shape ", image_latents.shape)

            #Standard Gaussian noise to be added to each clean image
            noise = torch.randn(image_latents.shape).to(device)
            #Sample a timestep t uniformly for each real image in the batch
            timesteps = torch.randint(0, scheduler.config.num_train_timesteps,
                                      (image_latents.shape[0],),
                                      device=image_latents.device).long()
            #Add the noise, scaled with appropriate variance for that timestep (according to scheduler), to corresponding images
            noisy_latents = scheduler.add_noise(image_latents, noise, timesteps)

            #Predict the added noise from the VAE latent vectors of clean images
            #The DiT model outputs two tensors - the predicted noise  and the diagonal covariance matrix
            #Both of shape (patchXpatchXchannel), stacked together. We separate the mean predicted noise
            with torch.autocast(device_type='cuda'):
                ##Mixed precision training for memory efficiency
                noise_pred = diffusion_model(
                    noisy_latents,
                    timesteps,train_labels,
                    return_dict=False)
                noise_pred = noise_pred[0]
                #DiT output is of shape pX2Cxwxh - one C for mean noise and the other for std
                mean_noise_pred = noise_pred[:,:4,:,:]
                #Gradient descent on the error between the true added noise and predicted noise
                loss = F.mse_loss(mean_noise_pred, noise)

            losses.append(loss.item())
            loss.backward(loss)

            #Accumulate gradients for some steps - because small batch size might mean very small gradients
            if (step % (grad_acc_steps) == 0):
                optimizer.step()
                optimizer.zero_grad()

        #Save some sample generations every 'log_samples_every' epochs
        #Generate a batch of 8 images per class and save as a grid
        if (epoch+1)%log_samples_every == 0:
            num_samples_per_class=2
            diffusion_model.eval()
            im = sample_diffusion(
                    diffusion_model, vae_model,
                    vae_latent_size, scheduler,
                    num_classes, num_samples_per_class
                    )
            diffusion_model.train()
            save_path = f"{ckpt_dir}/samples_step_{step}_epoch_{epoch}.png"
            im.save(save_path)
            #wandb.log({'Sample generations': wandb.Image(im)})
        # Save a checkpoint every 'ckpt_every' epochs
        if (epoch+1)%ckpt_every == 0:
              checkpoint_path=f"{ckpt_dir}/checkpoint_step_{step+1}"
              checkpoint = {
                        "model" : diffusion_model.state_dict(),
                        "optimizer" : optimizer.state_dict()
                      }
              torch.save(checkpoint, checkpoint_path)
        avg_loss = sum(losses[-len(dl):])/len(dl)
        print(f"Epoch {epoch}, Average Loss {avg_loss}")
        """
        #Maintain the epoch number with min avg loss. And if the loss doesnt improve for
        #''patience'' number epochs, checkpoint the best model till now, save samples and quit
        if(avg_loss < min_loss_till_now):
          print("New Best ")
          min_loss_loss_till_now = avg_loss
          min_loss_epoch = epoch
          num_samples_per_class=2
          im = sample_diffusion(
                  diffusion_model, vae_model,
                  vae_latent_size, scheduler,
                  num_classes, num_samples_per_class
                  )
          save_path = f"{ckpt_dir}/best_samples_epoch_{epoch}.png"
          im.save(save_path)
        if (min_loss_epoch < epoch - patience):
          break
          """
    return losses


In [24]:
torch.manual_seed(42)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
losses = finetune(
                full_fine_tune=False,
                image_size=128,
                part_type='screw',
                max_img_per_defect=50,
                num_epochs=100,
                lr=1e-5,
                batch_size=2,
                grad_acc_steps=2,
                ckpt_every=100,
                log_samples_every=10)

Prepared the dataloader with size 80


An error occurred while trying to fetch facebook/DiT-XL-2-256: facebook/DiT-XL-2-256 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
An error occurred while trying to fetch facebook/DiT-XL-2-256: facebook/DiT-XL-2-256 does not appear to have a file named diffusion_pytorch_model.safetensors.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.


Training epoch 0


  0%|          | 0/80 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 32.00 MiB. GPU 0 has a total capacity of 22.16 GiB of which 9.38 MiB is free. Process 4652 has 22.15 GiB memory in use. Of the allocated memory 21.71 GiB is allocated by PyTorch, and 206.10 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)