In [16]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import math

In [17]:
data_transforms = [
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(), 
        transforms.Lambda(lambda t: (t * 2) - 1) 
    ]
data_transform = transforms.Compose(data_transforms)

train = torchvision.datasets.CIFAR10(root=".", download=True,transform=data_transform, train=True)

test = torchvision.datasets.CIFAR10(root=".", download=True,transform=data_transform, train=False)

data =  torch.utils.data.ConcatDataset([train, test])
data_loader=DataLoader(data,batch_size=128,shuffle=True)


Files already downloaded and verified
Files already downloaded and verified


In [18]:
T=300
betas=torch.linspace(0.0001,0.02,T)
alphas=1-betas

alphas_cumprod=torch.cumprod(alphas,axis=0)


sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)


In [20]:
def get_index_from_list(vals, t, x_shape):
    """
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def forward_diffusion_sample(x_0, t, device="cpu"):
    """
    Takes an image and a timestep as input and
    returns the noisy version of it
    """
    noise = torch.randn_like(x_0)
    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)

In [21]:

class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, up=False):
        super().__init__()
        self.time_mlp =  nn.Linear(time_emb_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()

    def forward(self, x, t, ):
        h = self.bnorm1(self.relu(self.conv1(x)))
        time_emb = self.relu(self.time_mlp(t))
        time_emb = time_emb[(..., ) + (None, ) * 2]
        h = h + time_emb
        h = self.bnorm2(self.relu(self.conv2(h)))
        return self.transform(h)


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class SimpleUnet(nn.Module):
    
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3
        time_emb_dim = 32

        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )

        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])

        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

    def forward(self, x, timestep):
        t = self.time_mlp(timestep)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)
        return self.output(x)


In [22]:
model=SimpleUnet()

optimizer=optim.Adam(model.parameters(), lr=0.001)
epochs=10
loss_epoch=[]

for epoch in range(epochs):
    loss_each=[]

    for step,batch in enumerate(data_loader):
        optimizer.zero_grad()

        t = torch.randint(0, T, (128,)).long()

        x_noisy,noise =forward_diffusion_sample(batch[0],t)
        noise_pred = model(x_noisy,t)

        loss=nn.functional.l1_loss(noise,noise_pred)
        loss.backward()
        optimizer.step()
        loss_each.append(loss.item())
        print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
    
    loss_epoch.append(sum(loss_each)/len(loss_each))


Epoch 0 | step 000 Loss: 0.8123559951782227 
Epoch 0 | step 001 Loss: 0.7556984424591064 
Epoch 0 | step 002 Loss: 0.713679850101471 
Epoch 0 | step 003 Loss: 0.6853588223457336 
Epoch 0 | step 004 Loss: 0.6506916880607605 
Epoch 0 | step 005 Loss: 0.6194464564323425 
Epoch 0 | step 006 Loss: 0.5812672972679138 
Epoch 0 | step 007 Loss: 0.546823263168335 
Epoch 0 | step 008 Loss: 0.5183539390563965 
Epoch 0 | step 009 Loss: 0.4960297644138336 
Epoch 0 | step 010 Loss: 0.46571481227874756 
Epoch 0 | step 011 Loss: 0.4388415813446045 
Epoch 0 | step 012 Loss: 0.42291462421417236 
Epoch 0 | step 013 Loss: 0.4054476022720337 
Epoch 0 | step 014 Loss: 0.3996020555496216 
Epoch 0 | step 015 Loss: 0.38702520728111267 
Epoch 0 | step 016 Loss: 0.3871452808380127 
Epoch 0 | step 017 Loss: 0.3791097104549408 
Epoch 0 | step 018 Loss: 0.3561975955963135 
Epoch 0 | step 019 Loss: 0.36527374386787415 
Epoch 0 | step 020 Loss: 0.36259666085243225 
Epoch 0 | step 021 Loss: 0.3482036292552948 
Epoch 0

KeyboardInterrupt: 

In [33]:
import matplotlib.pyplot as plt
import os

# Create the directory if it doesn't exist
save_dir = "./images/"
os.makedirs(save_dir, exist_ok=True)

# Set the model to evaluation mode
model.eval()

# Generate and save 10 new images
with torch.no_grad():
    for i in range(10):
        # Randomly sample a timestep
        t_sample = torch.randint(0, T, (1,)).long()

        # Generate a noisy image
        x_noisy_sample, _ = forward_diffusion_sample(batch[0], t_sample)

        # Generate a denoised image using the trained model
        denoised_image = model(x_noisy_sample, t_sample)

        # Convert the images to numpy arrays
        x_noisy_sample_np = x_noisy_sample[0].cpu().numpy().transpose(1, 2, 0)
        denoised_image_np = denoised_image[0].cpu().numpy().transpose(1, 2, 0)

        # Plot and save the images
        plt.figure(figsize=(8, 4))

        plt.subplot(1, 2, 1)
        plt.imshow((x_noisy_sample_np + 1) / 2)  # Rescale to [0, 1]
        plt.title("Noisy Image")
        plt.axis("off")

        plt.subplot(1, 2, 2)
        plt.imshow((denoised_image_np + 1) / 2)  # Rescale to [0, 1]
        plt.title("Denoised Image")
        plt.axis("off")

        # Save the images
        plt.savefig(os.path.join(save_dir, f"generated_image_{i+1}.png"))
        plt.close()

# Set the model back to training mode
model.train()


Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i

SimpleUnet(
  (time_mlp): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=32, out_features=32, bias=True)
    (2): ReLU()
  )
  (conv0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (downs): ModuleList(
    (0): Block(
      (time_mlp): Linear(in_features=32, out_features=128, bias=True)
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transform): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bnorm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (1): Block(
      (time_mlp): Linear(in_features=32, out_features=256, bias=True)
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (transfor

In [None]:
import os
os.makedirs('images', exist_ok=True)
with torch.no_grad():
    for i in range(10):
        noise = torch.randn(1, nz, 1, 1)
        fake = netG(noise).detach().cpu()
        plt.figure(figsize=(1,1))
        plt.axis("off")
        plt.savefig('images/fake_temp_image_%d.png' % i)
        plt.imshow(np.transpose(fake[0],(1,2,0)))
        plt.savefig('images/fake_image_%d.png' % i)
        plt.close()

In [None]:
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data

from torchvision.models.inception import inception_v3

import numpy as np
from scipy.stats import entropy

def inception_score(imgs, cuda=False, batch_size=32, resize=False, splits=1):
    """Computes the inception score of the generated images imgs

    imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
    cuda -- whether or not to run on GPU
    batch_size -- batch size for feeding into Inception v3
    splits -- number of splits
    """
    N = len(imgs)
    # print(imgs)

    assert batch_size > 0
    assert N > batch_size

    # Set up dtype
    if cuda:
        dtype = torch.cuda.FloatTensor
    else:
        if torch.cuda.is_available():
            print("WARNING: You have a CUDA device, so you should probably set cuda=True")
        dtype = torch.FloatTensor

    # Set up dataloader
    dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)

    # Load inception model
    inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
    inception_model.eval();
    up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
    
    def get_pred(x):
        if resize:
            x = up(x)
        x = inception_model(x)
        return F.softmax(x).data.cpu().numpy()

    # Get predictions
    preds = np.zeros((N, 1000))

    for i, batch in enumerate(dataloader, 0):
        batch = batch.type(dtype)
        batchv = Variable(batch)
        batch_size_i = batch.size()[0]

        preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)

    # Now compute the mean kl-div
    split_scores = []

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = np.mean(part, axis=0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx, py))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

# if _name_ == '_main_':
#     class IgnoreLabelDataset(torch.utils.data.Dataset):
#         def _init_(self, orig):
#             self.orig = orig

#         def _getitem_(self, index):
#             return self.orig[index][0]

#         def _len_(self):
#             return len(self.orig)

    # import torchvision.datasets as dset
    # import torchvision.transforms as transforms

    # cifar = dset.CIFAR10(root='data/', download=True,
    #                          transform=transforms.Compose([
    #                              transforms.Resize(32),
    #                              transforms.ToTensor(),
    #                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    #                          ])
    # )

    # IgnoreLabelDataset(cifar)

    # print ("Calculating Inception Score...")
    # print (inception_score(IgnoreLabelDataset(cifar), cuda=False, batch_size=32, resize=True, splits=10))

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

def apply_transform_to_image(img_paths, transform):
    result=[]
    for img_path in img_paths:
        pil_image = Image.open(img_path)
        img_rgb = pil_image.convert('RGB')
        transformed_image = transform(img_rgb)
        result.append(transformed_image)
    return torch.stack(result)

# Example usage
image_paths = ['./images/fake_image_0.png', './images/fake_image_1.png', './images/fake_image_2.png', './images/fake_image_3.png','./images/fake_image_4.png', './images/fake_image_5.png', './images/fake_image_6.png', './images/fake_image_7.png', './images/fake_image_8.png', './images/fake_image_9.png']

# Define the transformation pipeline
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
transformed_images = apply_transform_to_image(image_paths,transform)

a,b=inception_score(transformed_images, cuda=False, batch_size=2, resize=True, splits=10)
print(a,b)