In [1]:
import torch
import torchvision
from torch import nn
from torch.optim import Adam
import torch.nn.functional as F
from torchvision import transforms 
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import signal
signal.signal(signal.SIGINT, signal.SIG_DFL)

import os
os.environ["TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT"] = "0"

import sys
import numpy as np
import math
import matplotlib.pyplot as plt

from dataload import *

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [3]:
def show_images(datset, num_samples=20, cols=4):
    """ Plots some samples from the dataset """
    plt.figure(figsize=(15,15)) 
    for i, img in enumerate(data):
        if i == num_samples:
            break
        plt.subplot(int(num_samples/cols) + 1, cols, i + 1)
        plt.imshow(img[0])

In [4]:
def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)

def get_index_from_list(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """ 
    Takes an image and a timestep as input and 
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

In [5]:
# Define beta schedule
T = 1000
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

In [6]:
IMG_SIZE = 512
BATCH_SIZE = 1


def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    if len(image.shape) == 4:
        image = image[0, :, :, :] 
    plt.imshow(reverse_transforms(image))

In [7]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

In [8]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [9]:
class SimpleUnet(nn.Module):
    """
    A simplified variant of the Unet architecture.
    """
    def __init__(self):
        super().__init__()
        image_channels = 128+128
        down_channels = (256, 512, 1024)
        up_channels = (1024, 512, 256)
        out_dim = 1
        time_emb_dim = 32

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], 128, out_dim)

    def forward(self, x, y_0, timestep): # Latent_nose + latent_non-standard img    
        # Embedd time
        t = self.time_mlp(timestep)
        # concating conditioned image with the noisey imag        
        x = torch.cat((x, y_0), dim=1) # Noise + non-standard img        
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

In [10]:
def convrelu(in_channels, out_channels, kernel, padding):
  return nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel, padding=padding),
    nn.ReLU(inplace=True),
  )

In [11]:
class ResNetED(nn.Module):
  def __init__(self, n_class=3):
    super().__init__()
    self.base_model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    self.base_layers = list(self.base_model.children())

    self.layer0 = nn.Sequential(*self.base_layers[:3]) # size=(N, 64, x.H/2, x.W/2)
    self.layer0_1x1 = convrelu(64, 64, 1, 0)
    self.layer1 = nn.Sequential(*self.base_layers[3:5]) # size=(N, 64, x.H/4, x.W/4)
    self.layer1_1x1 = convrelu(64, 64, 1, 0)
    self.layer2 = self.base_layers[5]  # size=(N, 128, x.H/8, x.W/8)
    self.layer2_1x1 = convrelu(128, 128, 1, 0)


    self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    self.conv_original_size1 = convrelu(64, 64, 3, 1)
    self.conv_original_size2 = convrelu(128, 64, 3, 1)

    self.conv_last = nn.Conv2d(64, n_class, 1)

  @torch.no_grad()
  def diffusion_noise_2_img(self, x, noise, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * noise / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 


  def forward(self, input, diff_model=None, y_0=None, t=None, device=None):

    noise = None
    noise_pred=None

    layer0 = self.layer0(input)
    layer1 = self.layer1(layer0)
    layer2 = self.layer2(layer1)


    x = self.layer2_1x1(layer2)

    if diff_model :
        #x = lat_diff(x)
        y_lat = self.layer2_1x1(self.layer2(self.layer1(self.layer0(y_0))))
        y_noisy, noise = forward_diffusion_sample(y_lat, t, device)
        #print("My input lat shape ",y_lat.shape, x.shape, t )
        noise_pred = diff_model(y_noisy, x, t)   # this part is for the new model 
        x = self.diffusion_noise_2_img( x, noise_pred, t)

    x = self.upsample(x)
    #x = torch.cat([x, x_original], dim=1)
    x = self.conv_original_size2(x)
    x = self.upsample(x)
    x = self.conv_original_size1(x)
    x = self.upsample(x)
    out = torch.tanh( self.conv_last(x) ), 

    return out, noise, noise_pred

In [12]:
model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

Num params:  59366304


SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=512, bias=True)
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=1024, bias=True)
      (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(1024, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Con

In [13]:
def get_loss(model, x_0, y_0, t):
    y_noisy, noise = forward_diffusion_sample(y_0, t, device)

    noise_pred = model(y_noisy, x_0.to(device), t)   # this part is for the new model  

    return F.l1_loss(noise, noise_pred) #+ F.l1_loss(x_0.to(device), x_noisy-noise_pred)

@torch.no_grad()
def sample_timestep(x, y_0, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, y_0, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image(x_test_0, y_test_0, t, epoch, img_id):
    img_size = IMG_SIZE
    
    img = torch.randn((1, 3, img_size, img_size), device=device) #imgs[0].unsqueeze(0) #
    y_test_0 = y_test_0[0].unsqueeze(0).to(device)  # take a sample from the batch 
    x_test_0 = x_test_0[0].unsqueeze(0).to(device)  # take a sample from the batch 
    #print(img.shape)    
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)


    plt.subplot(1, num_images+2, 1)
    show_tensor_image(y_test_0.detach().cpu())

    plt.subplot(1, num_images+2, 2)
    show_tensor_image(x_test_0.detach().cpu())

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long) # it create a tensor with size 1 with value i
        img = sample_timestep(img, x_test_0, t) # get the dennoised with timestep t=i
        if i % stepsize == 0:
            plt.subplot(1, num_images+2, int(i/stepsize)+3)
            show_tensor_image(img.detach().cpu())

    plt.savefig("results_stanct/validation/"+str(epoch)+'_'+str(img_id)+'.png', format='png', dpi=200)

In [14]:
data_dir=r"C:\Users\dsi224\Documents\PythonFiles\PythonCodesForDiffusionModel\SE\SE"  
kernel_label = {0:"Bl64", 1:"Br40", 2:"B70f", 3:"B31f", 4:"L", 5:"B", 6:"STANDARD", 7:"LUNG"}


global_dataset =  torch.utils.data.DataLoader(\
		GlobalTrainDataset(data_dir, kernel = "all", limit=-1), \
		batch_size=BATCH_SIZE, \
		shuffle=True)

train_dataset =  torch.utils.data.DataLoader(\
		PairTrainDataset(data_dir, kernel_A = "BR40", kernel_B = "BL64", limit=-1), \
		batch_size=BATCH_SIZE, \
		shuffle=True)

test_dataset =  torch.utils.data.DataLoader(\
		PairTrainDataset(data_dir, kernel_A = "BR40", kernel_B = "BL64", limit=5), \
		batch_size=1, \
		shuffle=True)

20
10 10
10 10


In [15]:
train_dir = "."

os.makedirs(train_dir+"/validation/")
os.makedirs(train_dir+"/checkpoint/")
os.makedirs(train_dir+'/checkpoint/latest_global_net.pt')

In [16]:
# TRAINING
train_dir = "."

if not os.path.exists(train_dir):
    os.makedirs(train_dir)
    os.makedirs(train_dir+"validation/")
    os.makedirs(train_dir+"checkpoint/")

	
f_log = open(train_dir+'log.txt', 'w')


###################################### Defining models
#Encoder-decoder model

learning_rate = 1e-4

ED_model = ResNetED().to(device)

#ED_model.load_state_dict(torch.load(train_dir+'/checkpoint/latest_global_net.pt'))


ED_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, ED_model.parameters()), lr=learning_rate)


if True:
	print('In Global model training ...')

	for epoch in range(0, 100):
		for idx, data in enumerate(global_dataset, 0):
			X, label = data  #non-STD, STD, glcm_STD, cls
			X = X.to(device)
			
			out= ED_model(X)[0][0]
			#print(out[0].shape)
			loss =  F.l1_loss(out, X)
		
			# Backpropagation based on the loss
			ED_optimizer.zero_grad()
			loss.backward()
			ED_optimizer.step()
			
			if idx%100==0:
				print(' Global Epoch {} step {}: Loss L1 {}'.format(epoch, idx, loss))
				f_log.write(' Global Epoch {} step {}: Loss L1 {}'.format(epoch, idx, loss))
				f_log.write('\n')			
			#break
			
		if epoch%2==0:
			torch.save(ED_model.state_dict(), train_dir+'/checkpoint/'+str(epoch)+'_global_net.pt')

	
		
		# test the model for model validation	
		for _idx, _data in enumerate(global_dataset, 0):
			xs, label = _data  #non-STD, STD, glcm_STD, cls
			xs = xs.to(device)
			
			xs.requires_grad = False		
			# Feeding a batch of images into the network to obtain the output image, mu, and logVar
			
			with torch.no_grad():	
				out = ED_model(xs)[0][0]
				
			out = out.squeeze().cpu()	
			
			out = torch.cat((out[0],xs.squeeze().cpu()[0]	), axis=1)			
			save_image( (out * 0.5 + 0.5), train_dir+'/validation/'+str(epoch)+kernel_label[label.numpy()[0]]+'.png') #[-1, 1 ==> [0,1]]				
			break # want to test only a single image
		#break	
	torch.save(ED_model.state_dict(), train_dir+'/checkpoint/latest_global_net.pt')	


ED_model.eval()
model.to(device)

optimizer = Adam(model.parameters(), lr=0.001)
epochs = 100 # Try more!

for epoch in range(epochs):
    for step, batch in enumerate(train_dataset):
      optimizer.zero_grad()
      #batch = batch.to(device)#print(b)
      b, c, h, w= batch[0].shape

      t = torch.randint(0, T, (b,), device=device).long()
      #print(batch.shape)    
      #break
		
      out, noise, noise_pred = ED_model( batch[0].to(device), model, batch[1].to(device), t, device)
      loss = F.l1_loss(noise_pred, noise) + F.l1_loss(out[0], batch[1].to(device))
      loss.backward()
      optimizer.step()

      if epoch % 5 == 0 and step == 0:
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
        f_log.write(' Diffusion Epoch {} step {}: Loss L1 {}'.format(epoch, idx, loss))
        f_log.write('\n')
				
        t = torch.randint(0, T, (1,), device=device).long()	
		
        for im_id, test_data in enumerate(test_dataset):
            #test_data = test_data.to(device)#print(b)
            #sample_plot_image(test_data[0], test_data[1], t, epoch, im_id)
            out, noise, noise_pred = ED_model( test_data[0].to(device), model, test_data[1].to(device), t, device)			
            out = out[0].squeeze().cpu()
            out = torch.cat((out, test_data[1].squeeze()	), axis=1)			
            save_image( (out * 0.5 + 0.5), train_dir+'/validation/'+str(epoch)+'lat_diff.png') #[-1, 1 ==> [0,1]]
            break
      #break
    #break
			
print("Training Done!")       
torch.save(model.state_dict(), train_dir+'/checkpoint/latest_diffusion_net.pt')

print("Done!")

In Global model training ...
 Global Epoch 20 step 0: Loss L1 0.7178987860679626
 Global Epoch 21 step 0: Loss L1 0.5748395323753357
 Global Epoch 22 step 0: Loss L1 0.14957760274410248
 Global Epoch 23 step 0: Loss L1 0.09766489267349243
 Global Epoch 24 step 0: Loss L1 0.06287359446287155
 Global Epoch 25 step 0: Loss L1 0.08443065732717514
 Global Epoch 26 step 0: Loss L1 0.040322743356227875
 Global Epoch 27 step 0: Loss L1 0.07599335163831711
 Global Epoch 28 step 0: Loss L1 0.08658187091350555
 Global Epoch 29 step 0: Loss L1 0.06543809175491333
 Global Epoch 30 step 0: Loss L1 0.06653542071580887
 Global Epoch 31 step 0: Loss L1 0.03524135425686836
 Global Epoch 32 step 0: Loss L1 0.07565223425626755
 Global Epoch 33 step 0: Loss L1 0.06490573287010193
 Global Epoch 34 step 0: Loss L1 0.0726899802684784
 Global Epoch 35 step 0: Loss L1 0.07228423655033112
 Global Epoch 36 step 0: Loss L1 0.02661641500890255
 Global Epoch 37 step 0: Loss L1 0.06663317233324051
 Global Epoch 38 st

FileNotFoundError: [Errno 2] No such file or directory: '_checkpoint/latest_global_net.pt'