In [98]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from tqdm import tqdm
from torch import optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from PIL import Image
import torchvision
from torch.cuda.amp import GradScaler, autocast
import copy

In [99]:
class DoubleConv(nn.Module):
    def __init__(self,in_ch,out_ch,mid_ch=None,resid=False):
        super(DoubleConv,self).__init__()

        self.res = resid

        if not mid_ch:
            mid_ch = out_ch

        self.dconv = nn.Sequential(nn.Conv2d(in_ch,mid_ch,kernel_size=3,padding=1,bias=False),
                                   nn.GroupNorm(1,mid_ch),
                                   nn.GELU(),
                                   nn.Conv2d(mid_ch,out_ch,kernel_size=3,padding=1,bias=False),
                                   nn.GroupNorm(1,out_ch))

    def forward(self,x):
        if self.res:
            return F.gelu(x + self.dconv(x))
        else:
            return self.dconv(x)

class SelfAttention(nn.Module):
    def __init__(self,ch,heads=2):
        super(SelfAttention,self).__init__()

        self.channels = ch
        #self.size = size

        self.multihead = nn.MultiheadAttention(ch,heads,batch_first=True)
        self.norm = nn.LayerNorm([ch])
        self.ffs = nn.Sequential(nn.LayerNorm([ch]),
                                 nn.Linear(ch,ch),
                                 nn.GELU(),
                                 nn.Linear(ch,ch))

    def forward(self,x):

        #print('inner x',x.size())

        bs,chn,h,w = x.size()
        x = x.reshape(bs,chn,h*w)

        #print("breaKDOWN",x.size())
        #x = x.transpose(1,2)
        x = x.swapaxes(1, 2)
        #print("breakd2",x.size())

        normx = self.norm(x)

        attn,_ = self.multihead(normx,normx,normx)
        attn = attn + x
        #print('attn',attn.size())

        attn = self.ffs(attn) + attn
        #attn = attn.transpose(1,2)
        attn = attn.swapaxes(2, 1)

        attn = attn.reshape(bs,chn,h,w)

        return attn


class Down(nn.Module):
    def __init__(self,in_ch,out_ch,edim=256):
        super(Down,self).__init__()
        self.mconv = nn.Sequential(nn.MaxPool2d(2),
                                   DoubleConv(in_ch,in_ch,resid=True),
                                   DoubleConv(in_ch,out_ch))

        self.embed = nn.Sequential(nn.SiLU(),
                                   nn.Linear(edim, out_ch))

    def forward(self, x, t):
        x = self.mconv(x)

        emb = self.embed(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        #print(x.size(),emb.size())
        out = x + emb
        return out

class Up(nn.Module):
    def __init__(self,in_ch,out_ch,edim=256):
        super(Up,self).__init__()

        self.up = nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
        self.conv = nn.Sequential(DoubleConv(in_ch, in_ch, resid=True),
                                  DoubleConv(in_ch, out_ch, in_ch//2))

        self.embed = nn.Sequential(nn.SiLU(),
                                   nn.Linear(edim, out_ch))

    def forward(self,x,skip,t):
        x = self.up(x)
        x = torch.cat([skip,x],dim=1)
        x = self.conv(x)
        emb = self.embed(t)[:,:,None,None].repeat(1,1,x.shape[-2],x.shape[-1])
        return x + emb


class UNet(nn.Module):
    def __init__(self,cin=3,cout=3,nc=None,time=256):
        super(UNet,self).__init__()

        self.time = time
        self.edim = time

        self.dc = DoubleConv(cin,64)

        self.d1 = Down(64,128,edim=self.edim)
        self.a1 = SelfAttention(ch=128)

        self.d2 = Down(128,256,edim=self.edim)
        self.a2 = SelfAttention(ch=256)

        self.d3 = Down(256,256,edim=self.edim)
        self.a3 = SelfAttention(ch=256)

        self.bn1 = DoubleConv(256,512)
        self.bn2 = DoubleConv(512,512)
        self.bn3 = DoubleConv(512,256)

        self.u1 = Up(512,128,edim=self.edim)
        self.a4 = SelfAttention(ch=128)

        self.u2 = Up(256,64,edim=self.edim)
        self.a5 = SelfAttention(ch=64)

        self.u3 = Up(128,64,edim=self.edim)
        self.a6 = SelfAttention(ch=64)

        self.outc = nn.Conv2d(64,cout,kernel_size=1)

        if nc is not None:
            self.label_emb = nn.Embedding(nc,self.time)

    def posencode(self,t,chn):
        device1 = 'cuda' if torch.cuda.is_available() else 'cpu'
        inv_freq = 1.0/(10000 ** (torch.arange(0,chn,2,device=device1).float()/chn))

        enc_a = torch.sin(t.repeat(1,chn//2)*inv_freq)
        enc_b = torch.cos(t.repeat(1,chn//2)*inv_freq)

        pos = torch.cat([enc_a,enc_b],dim=-1)
        return pos

    def forward(self,x,t,y=None):

        t = t.unsqueeze(-1).type(torch.float)
        t = self.posencode(t,self.time)

        if y is not None:
            t += self.label_emb(y)

        x1 = self.dc(x)
        #print('x1',x1.size())

        x2 = self.d1(x1,t)

        #print('x2',x2.size())
        x2 = self.a1(x2)
        #print('x2 pt 2',x2.size())

        x3 = self.d2(x2,t)
        x3 = self.a2(x3)
        x4 = self.d3(x3,t)
        x4 = self.a3(x4)

        x4 = self.bn1(x4)
        x4 = self.bn2(x4)
        x4 = self.bn3(x4)

        #print(x4.size())
        #print(x3.size())

        x = self.u1(x4,x3,t)
        x = self.a4(x)
        x = self.u2(x,x2,t)
        x = self.a5(x)
        x = self.u3(x,x1,t)
        x = self.a6(x)

        out = self.outc(x)

        return out

In [100]:
class EMA:
    def __init__(self, beta):
        super(EMA, self).__init__()
        self.beta = beta
        self.step = 0

    def reset_parameters(self,ema_mod,mod):
        ema_mod.load_state_dict(mod.state_dict())

    def step_ema(self,ema_mod,mod,start=2000):
        if self.step<start:
            self.reset_parameters(ema_mod, mod)
            self.step += 1
            return
        self.update_model(ema_mod=ema_mod,mod=mod)
        self.step += 1

    def update_model(self,ema_mod,mod):
        for mod_p,ema_p in zip(mod.parameters(),ema_mod.parameters()):
            new_w,old_w =  mod_p.data, ema_p.data
            ema_p.data = self.update_avg(old_w,new_w)

    def update_avg(self,orig,new):
        if orig is None:
          return new
        out = orig * self.beta + (1+self.beta) * new
        return out

In [101]:
import warnings
warnings.filterwarnings("ignore")

class Diffusion:
    def __init__(self, beta_s = 1e-4, beta_e = 0.02, noise_steps = 1000, img_dim = 256, device="cuda") -> None:
        self.beta_s = beta_s
        self.beta_e = beta_e
        self.nsteps = noise_steps
        self.img_size = img_dim
        self.device = device
        #;locals
        self.beta = self.Scheduler().to(self.device)
        self.alpha = 1 - self.beta
        self.alpha_prod = torch.cumprod(self.alpha, dim=0)

    def Scheduler(self):
        return torch.linspace(self.beta_s, self.beta_e, self.nsteps)

    def AddNoise(self, x, t):
        term1 = torch.sqrt(self.alpha_prod[t])[:,None,None,None]
        term2 = torch.sqrt(1. - self.alpha_prod[t])[:,None,None,None]
        eps = torch.randn_like(x)
        res =  term1 * x + term2 * eps
        return res, eps

    def Timesteps(self,n):
        return torch.randint(low=1, high=self.nsteps,size=(n,))

    def Sampler(self, model, n, label):
        #ema_model = model.eval().requires_grad_(False)
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.nsteps)),position=0):
                t = (torch.ones(n)*i).long().to(self.device)
                e_pred = model(x, t, label)
                alpha = self.alpha[t][:,None,None,None]
                alpha_prod = self.alpha_prod[t][:,None,None,None]
                beta = self.beta[t][:,None,None,None]

                if i>1:
                    z = torch.randn_like(x)
                else:
                    z = torch.zeros_like(x)

                final_x = 1/torch.sqrt(alpha) * (x - ((1-alpha) / torch.sqrt(1-alpha_prod))*e_pred) + torch.sqrt(beta) * z

        model.train()

        final_x = (final_x.clamp(-1,1) + 1) / 2
        final_x = (final_x*255).type(torch.uint8)
        return final_x

In [102]:
def plot_imgs(image_set):
    plt.figure(32,32)
    plt.imshow(torch.cat([torch.cat([img for img in image_set.cpu().detach()],dim=-1),],dim=-2).permute(1,2,0).cpu())
    plt.show()

def save_imgs(image_set,save_pth):
    grid_im = torchvision.utils.make_grid(image_set,)
    arr = grid_im.permute(1,2,0).cpu().numpy()
    img = Image.fromarray(arr)
    img.save(save_pth)


def get_data(dpath,batch_size):

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(32),  # args.image_size + 1/4 *args.image_size
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5,),(0.5,))
    ])

    #dataset = torchvision.datasets.ImageFolder(dpath, transform=transforms)
    dataset = torchvision.datasets.CIFAR10(dpath,download=True,transform=transforms)

    #to_filter = list(range(0,len(dataset),10))
    #new_data = torch.utils.data.Subset(dataset,to_filter)

    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle=True)


    #print(len(dataloader))

    return dataloader

In [103]:
def train(epochs):

    scaler = GradScaler()

    optimizer = torch.optim.Adam(model.parameters(),lr=lr)
    mse_loss = nn.MSELoss()

    ema = EMA(0.995)
    print(ema)

    model_copy = copy.deepcopy(model).eval().requires_grad_(False)
    model_copy.to(device)

    for e in tqdm(range(epochs)):
        for b, (image_set,label_set) in tqdm(enumerate(dataloader)):
            x = image_set.to(device)
            time_s = diff.Timesteps(x.shape[0]).to(device)
            x_t, eps = diff.AddNoise(x,time_s)
            y = label_set.to(device)

            optimizer.zero_grad()

            with torch.autocast(device_type=device,dtype=torch.float16):
                #print("time_s",time_s.size())

                n_out = model(x_t,time_s,y)
                loss = mse_loss(eps,n_out) + 1e-5

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            ema.step_ema(ema_mod=model_copy,mod=model)

            #n_out = model(x_t,time_s,y)
            #loss = mse_loss(eps,n_out)


            #loss.backward()
            #optimizer.step()

        print(f"Epoch:{e+1}, Loss:{loss}")
        if e % 50 == 0 or e==epochs-1:
          torch.save(model, f'/content/drive/MyDrive/csc2231/models/base_model/model/ckpt_{e}')
          torch.save(model_copy, f'/content/drive/MyDrive/csc2231/models/base_model/ema/ckpt_{e}')
        inference(e=e+1, model=model , image_set=image_set, label=y, use_ema = False)
        inference(e=e+1, model=model_copy , image_set=image_set, label=y, use_ema=True)


def inference(e, model, image_set, label, use_ema):
    generated = diff.Sampler(model=model,n=image_set.shape[0],label=label)
    ext = str(e) + '.png'
    if use_ema:
      to_save = '/content/drive/MyDrive/csc2231/outputs/base_model/ema/'+ ext
    else:
      to_save = '/content/drive/MyDrive/csc2231/outputs/base_model/model/'+ ext
    save_imgs(generated, to_save)
    #torch.save()

In [None]:
device = 'cuda'
#if torch.cuda.is_available() else 'cpu'
model = UNet(time=256,nc=10)
model = model.to(device)
diff = Diffusion(noise_steps=1000)

lr = 1e-4
bs = 4

dpath = '/content/drive/MyDrive/csc2231/data'
dataloader = get_data(dpath,bs)
ld = len(dataloader)

train(5)
#torch.save(model, f'/content/drive/MyDrive/csc2231/models/base_model/ckpt_final')

Files already downloaded and verified
<__main__.EMA object at 0x7cf3646df160>


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
2501it [02:15, 18.16it/s][A
2503it [02:15, 18.33it/s][A
2505it [02:15, 18.40it/s][A
2507it [02:15, 18.48it/s][A
2509it [02:15, 18.33it/s][A
2511it [02:15, 18.19it/s][A
2513it [02:15, 17.90it/s][A
2515it [02:16, 17.97it/s][A
2517it [02:16, 17.93it/s][A
2519it [02:16, 18.07it/s][A
2521it [02:16, 18.08it/s][A
2523it [02:16, 18.07it/s][A
2525it [02:16, 18.09it/s][A
2527it [02:16, 17.91it/s][A
2529it [02:16, 18.02it/s][A
2531it [02:16, 18.06it/s][A
2533it [02:17, 18.11it/s][A
2535it [02:17, 18.01it/s][A
2537it [02:17, 17.92it/s][A
2539it [02:17, 17.87it/s][A
2541it [02:17, 17.89it/s][A
2543it [02:17, 18.11it/s][A
2545it [02:17, 17.91it/s][A
2547it [02:17, 17.70it/s][A
2549it [02:17, 17.56it/s][A
2551it [02:18, 17.68it/s][A
2553it [02:18, 17.88it/s][A
2555it [02:18, 17.96it/s][A
2557it [02:18, 18.01it/s][A
2559it [02:18, 18.14it/s][A
2561it [02:18, 18.01it/s][A
2563it [02:18, 17.74it/s][A
2565it 

Epoch:1, Loss:0.039230719208717346


0it [00:00, ?it/s]
  0%|          | 0/5 [12:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 15.77 GiB of which 2.35 GiB is free. Process 66874 has 13.42 GiB memory in use. Of the allocated memory 12.52 GiB is allocated by PyTorch, and 513.69 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)