<a href="https://colab.research.google.com/github/Kris57880/DiffusionModel/blob/main/diffusion_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
!pip install diffusers==0.16.1
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
import os
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
import csv

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [16]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
# copy the dataset to your own drive : slow and need to wait drive sync 
# if os.path.exists("/content/drive/My Drive/Colab_Notebooks/Diffusion/Dataset")== False:
#   !unzip "/content/drive/Shareddrives/Stable diffusion/Dataset/GOPRO_128.zip" -d "/content/drive/My Drive/Colab_Notebooks/Diffusion/Dataset"
# os.chdir(r'/content/drive/My Drive/Colab Notebooks/Diffusion')
# !ls

In [18]:
# cont, if need updating dataset run this 
#!unzip "/content/drive/Shareddrives/Stable diffusion/Dataset/GOPRO_128.zip" -d "/content/drive/My Drive/Colab_Notebooks/Diffusion/Dataset"

In [19]:
#unzip to local file <- faster
!unzip "/content/drive/Shareddrives/Stable diffusion/Dataset/GOPRO_128.zip" -d "/content/sample_data/Dataset"
#os.chdir(r'/content/sample_data')
!ls

Archive:  /content/drive/Shareddrives/Stable diffusion/Dataset/GOPRO_128.zip
replace /content/sample_data/Dataset/train/GOPR0372_07_00/blur/000047_0.png? [y]es, [n]o, [A]ll, [N]one, [r]ename: N
drive  sample_data


In [20]:
#check GPU 
!nvidia-smi 

Sun May  7 17:20:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P8     9W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [21]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print('Using device:', device)


Using device: cuda


In [22]:
image_size = 128
batch_size = 32
num_workers = 4
n_epochs = 200
Dataset_dir = '/content/sample_data/Dataset'
model_dir = '/content/drive/MyDrive/Colab Notebooks/Diffusion/saved_model'
result_dir = '/content/drive/MyDrive/Colab Notebooks/Diffusion/result'
log_file = '/content/drive/MyDrive/Colab Notebooks/Diffusion/saved_model/training_log.csv'

In [23]:
preprocess = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])


In [24]:
class GOPRO(Dataset):
    def __init__(self, mode='train', transform=preprocess):
        assert mode == 'train' or mode == 'test'
        self.mode = mode
        # self.args = args
        self.transform = transform
        self.seed_is_set = False
        self.dirs = []
        self.dirs_gt = []
        self.index = 0
        self.d = 0
        if mode == 'train':
            self.ordered = False
        else:
            self.ordered = True

        for d1 in os.listdir(os.path.join(Dataset_dir, mode)):
            for d3 in os.listdir(os.path.join(Dataset_dir, mode, d1, 'blur_gamma')):
                if d3.endswith('.png'):
                    self.dirs.append(os.path.join(
                        Dataset_dir, mode, d1, 'blur_gamma', d3))
                    self.dirs_gt.append(os.path.join(
                        Dataset_dir, mode, d1, 'sharp', d3))

    def set_seed(self, seed):
        if not self.seed_is_set:
            self.seed_is_set = True
            np.random.seed(seed)

    def __len__(self):
        return len(self.dirs)

    def get_seq(self):
        if self.ordered:
            self.cur_dirs = self.dirs[self.d]
            self.cur_dirs_gt = self.dirs_gt[self.d]
            if self.d == len(self.dirs) - 1:
                self.d = 0
            else:
                self.d += 1
        else:
            random_idx = np.random.randint(len(self.dirs))
            self.cur_dirs = self.dirs[random_idx]
            self.cur_dirs_gt = self.dirs_gt[random_idx]

        # image_seq = []
        # image_seq_gt = []
        im = self.transform(Image.open(self.cur_dirs)).reshape(
            (3, image_size, image_size))
        # image_seq.append(im)
        # image_seq = torch.Tensor(np.concatenate(image_seq, axis=0))
        im_gt = self.transform(Image.open(self.cur_dirs_gt)).reshape(
            (3, image_size, image_size))
        # image_seq_gt.append(im_gt)
        # image_seq_gt = torch.Tensor(np.concatenate(image_seq_gt, axis=0))
        return im, im_gt

    def __getitem__(self, index):
        self.set_seed(index)
        seq, gt = self.get_seq()
        return seq, gt


In [25]:
class ClassConditionedUnet(nn.Module):
    def __init__(self):
        super().__init__()
        
        #image embedding  3->1 channel
        # self.img_emb = nn.Conv2d(3,1, kernel_size=(1,1))

        # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
        self.model = UNet2DModel(
            sample_size=image_size,           # the target image resolution
            in_channels=3 + 3,  # Additional input channels for class cond.
            out_channels=3,           # the number of output channels
            layers_per_block=2,       # how many ResNet layers to use per UNet block
            block_out_channels=(32, 64, 128, 256),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",        # a regular ResNet downsampling block
                "AttnDownBlock2D",    # a ResNet downsampling block with spatial self-attention
                # "AttnDownBlock2D",
            ),
            up_block_types=(
                # "AttnUpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",      # a ResNet upsampling block with spatial self-attention
                "UpBlock2D",          # a regular ResNet upsampling block
                "UpBlock2D",          
            ),
        )

    # Our forward method now takes the class labels as an additional argument
    def forward(self, x, t, blur_labels):
        bs, ch, w, h = x.shape
        # (bs, 3 + 1, 256, 256)
        # img_cond = self.img_emb(blur_labels) 
        
         # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
        net_input = torch.cat((x, blur_labels), axis=1)
        # class conditioning in right shape to add as additional input channels
        # class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
        # class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
        # x is shape (bs, 1, 28, 28) and class_cond is now (bs, 4, 28, 28)
        # Feed this to the unet alongside the timestep and return the prediction
        return self.model(net_input, t).sample  # (bs, 3 + 1, 256, 256)



In [26]:
train_dataset = GOPRO(mode='train')
test_dataset = GOPRO(mode='test')
train_dataloader = DataLoader(train_dataset, batch_size=batch_size,
          num_workers=num_workers, shuffle=True, drop_last=True, pin_memory=True)

print(f'train data num: {len(train_dataset)} test data num: {len(test_dataset)}')
print(f'image size: {image_size} batch size: {batch_size} num workers: {num_workers}')

noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_schedule='scaled_linear')

train data num: 8412 test data num: 4444
image size: 128 batch size: 32 num workers: 4




In [36]:


def load_ckp(checkpoint_fpath, model, optimizer):
    checkpoint = torch.load(checkpoint_fpath)
    # print(checkpoint)
    # model.load_state_dict(checkpoint['state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer'])
    # valid_loss_min = checkpoint['valid_loss_min']
    # return model, optimizer, checkpoint['epoch'], valid_loss_min.item()

latest_model = max([os.path.join(model_dir,f) for f in os.listdir(model_dir) if f.startswith('epoch_')], key=os.path.getctime)
net = ClassConditionedUnet().to(device)
opt = torch.optim.Adam(net.parameters(), lr=1e-3)
start_epoch = 0 

if latest_model is not None :
  net.load_state_dict(torch.load(latest_model))
  start_epoch = int(latest_model[:-3].split('/')[-1].split('_')[1])
  print(f'resume from epoch {start_epoch}')



resume from epoch 10


In [37]:
# create the log file if it doesn't exist
if not os.path.exists(log_file):
    with open(log_file, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['Epoch', 'Loss'])


In [None]:

# Our network
print(net)
# Our loss finction
loss_fn = nn.MSELoss()
# The optimizer
print('number of params: ', sum([p.numel() for p in net.parameters()]))

# Keeping a record of the losses for later viewing
epoch_losses = []

net.train()


# The training loop
for epoch in range(start_epoch+1, n_epochs):
    iter_losses = []
    for blur_img, gt in tqdm(train_dataloader):
        # Get some data and prepare the corrupted version
        blur_img = blur_img.to(device) * 2 - 1  # Normalize the data to [-1, 1]
        gt = gt.to(device) * 2 - 1  # Normalize the data to [-1, 1]
        noise = torch.randn_like(gt)
        timesteps = torch.randint(0, 999, (gt.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(gt, noise, timesteps)

        # Get the model prediction
        # Note that we pass in the blur_img as condition
        pred = net(noisy_x, timesteps, blur_img)

        # Calculate the loss
        loss = loss_fn(pred, noise)  # How close is the output to the noise

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store the loss for later
        iter_losses.append(loss.item())

    # Print our the average of the last 100 loss values to get an idea of progress:
    avg_loss = sum(iter_losses[-100:])/100
    print(f'Finished epoch {epoch}. Loss values: {avg_loss:05f}')
    with open(log_file, 'a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([epoch, avg_loss])
    torch.save(net.state_dict(), model_dir+f'/epoch_{epoch}.pt')

torch.save(net.state_dict(), model_dir+f'/final.pt')
# View the loss curve
plt.plot(epoch_losses)
plt.title('Loss curve')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.savefig(result_dir+'/loss_curve.png')
plt.show()

ClassConditionedUnet(
  (model): UNet2DModel(
    (conv_in): Conv2d(6, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (time_proj): Timesteps()
    (time_embedding): TimestepEmbedding(
      (linear_1): Linear(in_features=32, out_features=128, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=128, out_features=128, bias=True)
    )
    (down_blocks): ModuleList(
      (0): DownBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 32, eps=1e-05, affine=True)
            (conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (time_emb_proj): Linear(in_features=128, out_features=32, bias=True)
            (norm2): GroupNorm(32, 32, eps=1e-05, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): 



  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 11. Loss values: 0.048928


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 12. Loss values: 0.048114


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 13. Loss values: 0.048222


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 14. Loss values: 0.048296


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 15. Loss values: 0.047545


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 16. Loss values: 0.045758


  0%|          | 0/262 [00:00<?, ?it/s]

Finished epoch 17. Loss values: 0.043763


  0%|          | 0/262 [00:00<?, ?it/s]

For long time training, preventing from disconnection, follow this instruction 
https://blog.csdn.net/jinniulema/article/details/128994223 