<a href="https://colab.research.google.com/github/alif-munim/computer-vision/blob/main/imagen_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install imagen-pytorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting imagen-pytorch
  Downloading imagen_pytorch-1.21.5-py3-none-any.whl (62 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 KB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.16.0-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.7/199.7 KB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
Collecting ema-pytorch>=0.0.3
  Downloading ema_pytorch-0.1.4-py3-none-any.whl (4.2 kB)
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
Collecting kornia
  Downloading kornia-0.6.9-py2.py3-none-any.whl (569 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m569.1/569.1 KB[0

In [None]:
import os
import time
from PIL import Image

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, datasets
from imagen_pytorch import Unet, Imagen, ImagenTrainer

Downloading (…)lve/main/config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

In [None]:
class MnistCond(Dataset):
    def __init__(self, train=True) -> None:
        super().__init__()
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(32),
        ])
        self.mnist = datasets.MNIST(root="data", train=train, download=True, transform=self.transform)
    
    def __len__(self):
        return len(self.mnist)
    
    def __getitem__(self, i):
        img, label = self.mnist[i]
        img = img.repeat(3, 1, 1)
        hot_label = torch.zeros(10)
        hot_label[label] = 1
        return img, hot_label.unsqueeze(0)

In [None]:
def delay2str(t):
    t = int(t)
    secs = t%60
    mins = (t//60)%60
    hours = (t//3600)%24
    days = t//86400
    string = f"{secs}s"
    if mins:
        string = f"{mins}m {string}"
    if hours:
        string = f"{hours}h {string}"
    if days:
        string = f"{days}d {string}"
    return string

In [None]:
experiment_path = os.path.join("experiments", "conditional_mnist_diffusion")

images_path = os.path.join(experiment_path, "images")
os.makedirs(images_path, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Generate one-hot embedding for each digit
emb_test = torch.nn.functional.one_hot(torch.arange(10)).float()[:,None,:]

# Define model
unet = Unet(
    dim = 128, 
    dim_mults = (1, 2, 4),
    num_resnet_blocks = 3, 
    layer_attns = (False, True, True), # type: ignore
    layer_cross_attns = (False, True, True), # type: ignore
    max_text_len = 1, # maximum number of embeddings per image
)

imagen = Imagen(
    unets = unet,
    image_sizes = 32,
    text_embed_dim = 10, # dimension of one-hot embeddings
)

trainer = ImagenTrainer(
    imagen = imagen,
).to(device)

# If you want to resume training from a checkpoint
# trainer.load(path_to_checkpoint.pt)

# Define dataset
trainer.add_train_dataset(MnistCond(train=True),  batch_size = 4)
trainer.add_valid_dataset(MnistCond(train=False), batch_size = 4)

In [None]:
# Trainning variables
start_time = time.time()
avg_loss = 1.0
w_avg = 0.99
target_loss = 0.005

# Train
print(f"Started training with target loss of {target_loss}")
while avg_loss > target_loss: # Should converge in < 5000 steps

    loss = trainer.train_step(unet_number = 1)
    avg_loss = w_avg * avg_loss + (1 - w_avg) * loss

    print(f'Step: {trainer.steps.item():<6} | Loss: {loss:<6.4f} Avg Loss: {avg_loss:<6.4f} | {delay2str(time.time() - start_time):<10}', end='\r') # type: ignore

    if trainer.steps % 500 == 0: # type: ignore
        # Calculate validation loss
        valid_loss = np.mean([trainer.valid_step(unet_number = 1) for _ in range(10)])
        print(f'Step: {trainer.steps.item():<6} | Loss: {loss:<6.4f} Avg Loss: {avg_loss:<6.4f} | {delay2str(time.time() - start_time):<10} | Valid Loss: {valid_loss:<8.4f}') # type: ignore

        # Generate one image per class
        images = trainer.sample(batch_size = 10, return_pil_images = True, text_embeds=emb_test, cond_scale=3.) # returns List[Image]
        images = np.concatenate([np.array(img) for img in images], axis=1)
        Image.fromarray(images).save(os.path.join(images_path, f"sample-{str(trainer.steps.item()).zfill(10)}.png")) # type: ignore

# Final validation loss
valid_loss = np.mean([trainer.valid_step(unet_number = 1) for _ in range(10)])
print(f'Step: {trainer.steps.item():<6} | Loss: {loss:<6.4f} Avg Loss: {avg_loss:<6.4f} | {delay2str(time.time() - start_time):<10} | Valid Loss: {valid_loss:<8.4f}') # type: ignore

# Generate images
images = trainer.sample(batch_size = 10, return_pil_images = True, text_embeds=emb_test, cond_scale=3.) # returns List[Image]
images = np.concatenate([np.array(img) for img in images], axis=1)
Image.fromarray(images).save(os.path.join(experiment_path, f"final_sample.png")) # type: ignore

# Save model
trainer.save(os.path.join(experiment_path, f"trained_mnist.pt")) # type: ignore

print("Done!")

Started training with target loss of 0.005
Step: 500    | Loss: 0.0124 Avg Loss: 0.0229 | 2m 0s      | Valid Loss: 0.0128  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 1000   | Loss: 0.0120 Avg Loss: 0.0106 | 6m 10s     | Valid Loss: 0.0099  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 1500   | Loss: 0.0101 Avg Loss: 0.0091 | 10m 19s    | Valid Loss: 0.0079  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 2000   | Loss: 0.0069 Avg Loss: 0.0085 | 14m 17s    | Valid Loss: 0.0084  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 2500   | Loss: 0.0080 Avg Loss: 0.0079 | 18m 15s    | Valid Loss: 0.0076  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 3000   | Loss: 0.0089 Avg Loss: 0.0072 | 22m 14s    | Valid Loss: 0.0067  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]

Step: 3500   | Loss: 0.0080 Avg Loss: 0.0071 | 26m 10s    | Valid Loss: 0.0071  


0it [00:00, ?it/s]

sampling loop time step:   0%|          | 0/1000 [00:00<?, ?it/s]



KeyboardInterrupt: ignored

In [None]:
from google.colab import drive
drive.mount('/content/gdrive', force_remount=True)

save_path = '/content/gdrive/My Drive/Research'

Mounted at /content/gdrive


In [None]:
trainer.save('mnist_3500.ckpt')

checkpoint saved to mnist_3500.ckpt


In [None]:
! pip install numba

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import gc
from numba import cuda 

trainer = None
imagen = None
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()

device = cuda.get_current_device()
device.reset()

In [None]:
! nvidia-smi

Wed Feb  8 18:51:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   61C    P8    11W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces