In [32]:
from diffusers import UNet2DModel
from matplotlib import pyplot as plt
from PIL import Image

from torch.utils.data import TensorDataset, DataLoader, random_split
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import numpy as np

import torchvision
from torchvision import transforms

import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

In [33]:
torch.cuda.device_count()

2

In [34]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [35]:
def ddp_setup(rank: int, world_size: int):
  """
  Args:
      rank: Unique identifier of each process
     world_size: Total number of processes
  """
  os.environ["MASTER_ADDR"] = "localhost"
  os.environ["MASTER_PORT"] = "12355"
  torch.cuda.set_device(rank)
  init_process_group(backend="nccl", rank=rank, world_size=world_size)

## CIFAR Dataset

In [36]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

#Load the data and choose one class, then create the data loader.
cifar_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False)

dogs = torch.Tensor(cifar_dataset.data[np.array(cifar_dataset.targets) == 5])

dataset = TensorDataset(dogs)

## Forward pass

In [37]:
beta_min = 0.0001
beta_max = 0.02

num_timesteps = 1000

beta_samples = torch.linspace(beta_min, beta_max, steps=num_timesteps, device=device)

alphas = torch.cumprod(1 - beta_samples, dim=0)

In [38]:
def forward_pass(images, timesteps):
  images = images.to(device)
  timesteps = timesteps.to(device)
  alpha_timesteps = alphas[timesteps].to(device)
  sqrt_alphas = torch.sqrt(alpha_timesteps) # to scale the image

  sqrt_one_minus_alphas = torch.sqrt(1 - alpha_timesteps) # std of noise

  # sample our noise
  noise = torch.randn_like(images) 

  return noise * sqrt_one_minus_alphas[:, None, None, None] + images * sqrt_alphas[:, None, None, None], noise



## Main

In [39]:
def train(model,data_loader,criterion,optimizer, epochs, rank):

  for epoch in tqdm(range(epochs)):
    data_loader.sampler.set_epochs(epoch)
    for batch_idx, inputs in enumerate(data_loader):
      model.train()
      inputs = inputs.to(rank)
      
      # sample some random timesteps
      timesteps = torch.randint(low=0, high=num_timesteps, size=[len(inputs)], device=rank)
      peturbed, noise = forward_pass(inputs, timesteps)

      y = model(peturbed, timesteps.float()).sample

      optimizer.zero_grad()
      loss = criterion(y, noise)
      loss.backward()
      optimizer.step()
  
      tqdm.write(f"Batch {batch_idx + 1}, Loss {loss.item()}", end= "\r")


In [40]:
def main(rank, world_size):
  ddp_setup(rank, world_size)
  
  epochs = 100
  
  model = UNet2DModel(sample_size=32, block_out_channels=[128, 128, 256, 512]).to(rank)
  model = DDP(nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[rank])
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
  
  dataloader = DataLoader(dataset, 32, pin_memory=True, shuffle=False, sampler=DistributedSampler(dataset))
  
  train(model, dataloader, nn.MSELoss(), optimizer, epochs, rank)
  
  destroy_process_group()

In [43]:
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size), nprocs=world_size)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.10.2/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.10.2/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'main' on <module '__main__' (built-in)>
  File "/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.10.2/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/cvmfs/soft.computecanada.ca/easybuild/software/2020/avx2/Core/python/3.10.2/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
Attri

ProcessExitedException: process 1 terminated with exit code 1