In [None]:
!pip install pytorch-lightning diffusers huggingface_hub

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-2.0.2-py3-none-any.whl (719 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.0/719.0 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting diffusers
  Downloading diffusers-0.16.1-py3-none-any.whl (934 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m934.9/934.9 kB[0m [31m31.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface_hub
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0
  Downloading lightning_utilities-0.8.0-py3-none-any.whl (20 kB)
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#%cd /content/drive/MyDrive/official/paper/2/data
%cd /content/drive/MyDrive/data

/content/drive/MyDrive/data


In [None]:
#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
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import os

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split, Dataset
from diffusers import AutoencoderKL 
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
from PIL import Image

import pytorch_lightning as pl

random_seed = 42
torch.manual_seed(random_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

lr = 0.0002
BATCH_SIZE = 128
EPOCHS = 10000
AVAIL_GPUS = min(1, torch.cuda.device_count())
NUM_WORKERS=int(os.cpu_count() / 2)

In [None]:
class LQFEncodeDataset(pl.LightningDataModule):
    """Low quality face encodings dataset"""

    def __init__(self, root_dir):
        """
            root_dir (string): Directory with all the images.
        """
        self.root_dir = root_dir
        self.image_list = []
        self.encoding_list = []

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((128, 128)),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

        for image_file in os.listdir(os.path.join(root_dir, 'images')):
          if image_file.endswith('.jpg'):
            encodings_file = os.path.join(root_dir, 'encodings', image_file.replace(".jpg", '.pt'))
            if os.path.isfile(encodings_file):
              self.image_list.append(self.transform(Image.open(os.path.join(root_dir, 'images', image_file)).convert('RGB')))
              self.encoding_list.append(self.latent_transform(torch.load(os.path.join(root_dir, 'encodings', encodings_file))).squeeze())

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        return self.encoding_list[idx], self.image_list[idx]
    
    def latent_transform(self, original_tensor):
        new_size = (1, 4, 50, 50)
        resized_tensor = torch.zeros(new_size, dtype=original_tensor.dtype)
        resized_tensor[:original_tensor.shape[0], :original_tensor.shape[1], :original_tensor.shape[2], :original_tensor.shape[3]] = original_tensor
        return resized_tensor

    

In [None]:
dataset = LQFEncodeDataset('/content/drive/MyDrive/data')
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])



In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

In [None]:
# custom weights initialization called on generator and discriminator
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

In [None]:
def show_images(images):
    fig, ax = plt.subplots(figsize=(128, 128))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        break

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100):
        super(Generator, self).__init__()
        
        self.latent_dim = latent_dim
        self.latent_image = nn.Sequential(nn.Conv2d(4, 32, 3, 2, 1, bias=False),
                      nn.LeakyReLU(0.2, inplace=True),
                      nn.Flatten(),
                      nn.Linear(32*25*25, 3*16*16))
        
        self.latent_noise = nn.Sequential(nn.Linear(self.latent_dim, 16*16*512),
                                  nn.LeakyReLU(0.2, inplace=True))
           
        self.model = nn.Sequential(nn.ConvTranspose2d(515, 64*4, 4, 2, 1, bias=False),
                      nn.BatchNorm2d(64*4, momentum=0.1,  eps=0.8),
                      nn.ReLU(True),
                      nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1,bias=False),
                      nn.BatchNorm2d(64*2, momentum=0.1,  eps=0.8),
                      nn.ReLU(True), 
                      nn.ConvTranspose2d(64*2, 3, 4, 2, 1, bias=False),
                      nn.Tanh())

    def forward(self, inputs):
        noise_vector, latent = inputs
        # print(latent.shape)
        latent_image = self.latent_image(latent)
        latent_image = latent_image.view(-1, 3, 16, 16)
        latent_noise =  self.latent_noise(noise_vector)
        latent_noise = latent_noise.view(-1, 512, 16, 16)
        # print(latent_image.shape, latent_noise.shape)
        concat = torch.cat((latent_image, latent_noise), dim=1)
        image = self.model(concat)
        #print(image.size())
        return image
      


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.latent_image = nn.Sequential(nn.Conv2d(4, 16, 4, 2, 1, bias=False),
                      nn.LeakyReLU(0.2, inplace=True),
                      nn.Flatten(),
                      nn.Linear(16*25*25, 3*128*128))
             
        self.model = nn.Sequential(nn.Conv2d(6, 32, 4, 3, 2, bias=False),
                      nn.LeakyReLU(0.2, inplace=True),
                      nn.Conv2d(32, 32*2, 4, 3, 2, bias=False),
                      nn.BatchNorm2d(32*2, momentum=0.1,  eps=0.8),
                      nn.LeakyReLU(0.2, inplace=True),
                      nn.Flatten(),
                      nn.Dropout(0.4),
                      nn.Linear(14400, 1),
                      nn.Sigmoid()
                     )

    def forward(self, inputs):
        latent, image = inputs
        latent_image = self.latent_image(latent)
        latent_image = latent_image.view(-1, 3, 128,128)
        #print(latent_image.shape)
        concat = torch.cat((latent_image, image), dim=1)
        output = self.model(concat)
        return output

In [None]:
generator = Generator().to(device)
generator.load_state_dict(torch.load('/content/drive/MyDrive/data/cgen_generator_epoch_1300.pth'))
print(generator)

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

Generator(
  (latent_image): Sequential(
    (0): Conv2d(4, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Flatten(start_dim=1, end_dim=-1)
    (3): Linear(in_features=20000, out_features=768, bias=True)
  )
  (latent_noise): Sequential(
    (0): Linear(in_features=100, out_features=131072, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(515, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(128, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): Ta

In [None]:
adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
  gen_loss = adversarial_loss(fake_output, label)
  return gen_loss

def discriminator_loss(output, label):
  disc_loss = adversarial_loss(output, label)
  return disc_loss

opt_g = torch.optim.Adam(generator.parameters(), lr=lr)
opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr)

In [None]:
D_loss_plot, G_loss_plot = [], []
for epoch in range(1301, EPOCHS+1): 
    D_loss_list, G_loss_list = [], []
    for index, (latents, real_images) in enumerate(train_dataloader):
        opt_d.zero_grad()

        #latents of low quality images
        latents = latents.to(device)

        #sample noise data
        noise_vectors = torch.randn(real_images.size(0), 100, device=device)  
        noise_vectors = noise_vectors.to(device)
        
        #real images to calculate loss of generator
        real_images = real_images.to(device)
        #print(real_images.size())
        
        #targets to calculate loss
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
        
        # Train Discriminator with two losses (Fake, real) -> max log(D(x)) + log(1 - D(G(z)))
        D_real_loss = discriminator_loss(discriminator((latents,real_images)), real_target)

        generated_images = generator((noise_vectors, latents))
        output = discriminator((latents, generated_images.detach()))
        D_fake_loss = discriminator_loss(output,  fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)
        D_total_loss.backward()
        opt_d.step()

        # Train generator with real labels -> max log(D(G(z)))
        opt_g.zero_grad()
        G_loss = generator_loss(discriminator((latents, generated_images)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        opt_g.step()


    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), EPOCHS, torch.mean(torch.FloatTensor(D_loss_list)),\
             torch.mean(torch.FloatTensor(G_loss_list))))
    #show_images(generated_images)
    
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
    #torch.save(discriminator.state_dict(), 'training_weights/discriminator_epoch_%d.pth' % (epoch))
    
    if epoch % 100 == 0:
      save_image(generated_images.data[:50], 'images_sample/cgen_sample_%d'%epoch + '.png', nrow=5, normalize=True)
      torch.save(generator.state_dict(), 'training_weights/cgen_generator_epoch_%d.pth' % (epoch))
      #torch.save(discriminator.state_dict(), 'training_weights/cgen_discriminator_epoch_%d.pth' % (epoch))

Epoch: [1301/10000]: D_loss: 0.692, G_loss: 0.684
Epoch: [1302/10000]: D_loss: 0.686, G_loss: 0.712
Epoch: [1303/10000]: D_loss: 0.641, G_loss: 0.840
Epoch: [1304/10000]: D_loss: 0.642, G_loss: 1.036
Epoch: [1305/10000]: D_loss: 0.621, G_loss: 1.062
Epoch: [1306/10000]: D_loss: 0.504, G_loss: 1.275
Epoch: [1307/10000]: D_loss: 0.568, G_loss: 1.234
Epoch: [1308/10000]: D_loss: 0.650, G_loss: 1.184
Epoch: [1309/10000]: D_loss: 0.346, G_loss: 1.741
Epoch: [1310/10000]: D_loss: 0.908, G_loss: 0.970
Epoch: [1311/10000]: D_loss: 0.470, G_loss: 1.449
Epoch: [1312/10000]: D_loss: 0.480, G_loss: 1.323
Epoch: [1313/10000]: D_loss: 0.693, G_loss: 0.999
Epoch: [1314/10000]: D_loss: 0.527, G_loss: 1.174
Epoch: [1315/10000]: D_loss: 0.584, G_loss: 1.214
Epoch: [1316/10000]: D_loss: 0.721, G_loss: 1.119
Epoch: [1317/10000]: D_loss: 0.377, G_loss: 1.619
Epoch: [1318/10000]: D_loss: 0.667, G_loss: 1.153
Epoch: [1319/10000]: D_loss: 0.560, G_loss: 1.346
Epoch: [1320/10000]: D_loss: 0.470, G_loss: 1.401


In [None]:
#train_dataloader[0]

In [None]:
# class SRcGAN(pl.LightningModule):
#   def __init__(self, latent_dim=100, lr = 0.0002):
#     super().__init__()
#     self.save_hyperparameters()

#     self.generator = Generator()
#     self.discriminator = Discriminator()

#     self.validation_z = torch.randn(6, self.hparams.latent_dim)

#     self.adversarial_loss = nn.BCELoss()
    

#   def forward(self, z):
#     return self.generator(z)
  
#   def generator_loss(self, fake_output, label):
#     gen_loss = self.adversarial_loss(fake_output, label)
#     #print(gen_loss)
#     return gen_loss

#   def discriminator_loss(self, output, label):
#     disc_loss = self.adversarial_loss(output, label)
#     return disc_loss
  
#   def training_step(self, batch, batch_idx, optimizer_idx):
#     latents, real_images = batch

#     #sample noise data
#     noise_vector = torch.randn(real_images.size(0), self.hparams.latent_dim, device=device)  
#     noise_vector = noise_vector.to(device)

#     #train the generator
#     if optimizer_idx == 0:
#       self.

  
#   def configure_optimizers(self):
#     lr = self.hparams.lr
#     opt_g = torch.optim.Adam(self.generator.parameters(), lr=lr)
#     opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr)
#     return [opt_g, opt_d], []

#   def plot_imgs(self):
#     z = self.validation_z.type_as(self.generator.lin1.weight)
#     sample_imgs = self(z).cpu()

#     print('Epoch ', self.current_epoch)
#     fig = plt.figure()
#     for i in range(sample_imgs.size(0)):
#       plt.subplot(2, 3, i+1)
#       plt.tight_layout()
#       plt.imshow(sample_imgs.detach() [i, 0, :, :], interpolation='none') 
#       plt.title("Generated Data")
#       plt.xticks([])
#       plt.yticks([])
#       plt.axis('off')
#     plt.show()

#   def on_epoch_end(self):
#     self.plot_imgs()

In [None]:
# from PIL import Image
# from torchvision import transforms as tfms
# to_tensor_tfm = tfms.ToTensor()
# im = Image.open('/content/b.jpg').convert('RGB')
# im = im.resize((1024, 1024))
# conv1 = nn.Conv2d(3, 32, kernel_size=5)
# conv2 = nn.Conv2d(32, 64, kernel_size=5)
# conv3 = nn.Conv2d(64, 128, kernel_size=5)
# conv4 = nn.Conv2d(128, 64, kernel_size=5)
# fc1 = nn.Linear(64*60*60, 512)
# conv2_drop = nn.Dropout2d()
# x = F.relu(F.max_pool2d(conv1(to_tensor_tfm(im)), 2))
# x = F.relu(F.max_pool2d(conv2_drop(conv2(x)), 2))
# x = F.relu(F.max_pool2d(conv3(x), 2))
# x = F.relu(F.max_pool2d(conv2_drop(conv4(x)), 2))
# x = x.view(-1, 64*60*60)
# x = F.relu(fc1(x))
# x.size()

In [None]:
generator = Generator().to(device)
generator.load_state_dict(torch.load('/content/generator_epoch_3900.pth'))
print(generator)


In [None]:
original_tensor = torch.load('/content/0001.pt').to(device)
noise = torch.randn(1, 100, device=device)

In [None]:
img = generator((noise, original_tensor))
img.shape

In [None]:
save_image(img.data[0], 'sample_1' + '.png', normalize=True)

In [None]:
function ConnectButton(){
    console.log("Connect pushed"); 
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click() 
}

var colab = setInterval(ConnectButton,60000);