In [None]:
import torch
from torch import nn
import torchvision
import torchvision.transforms as T

In [None]:
class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.block1 = ConvBlock(
        in_channels,
        in_channels,
        kernel_size = 3,
        stride =1 ,
        padding = 1, 
        
    )
    self.block2 = ConvBlock(
        in_channels,

        in_channels,
        kernel_size = 3,
        stride =1 ,
        padding = 1, 
        use_act = True
    )

  def forward(self,x):
    out = self.block1(x)
    out = self.block2(out)
    return out+x  

In [None]:

class DeconvulationBlock(nn.Module):
    def __init__(self,  in_channels, shape_1, shape_2):
        super().__init__()
        self.conv = ConvBlock(
            in_channels*2,
            in_channels,
            kernel_size = 3,
            stride =1 ,
            padding = 1


        )
        self.shape_1 = shape_1
        self.shape_2 = shape_2

    def forward(self, x):
      out = torch.nn.functional.interpolate(x,size=(self.shape_1,self.shape_2), mode='bilinear')
      return self.conv(out)



In [None]:
decon = DeconvulationBlock(64, 128, 128)

In [None]:
y = torch.randn(1,3,64,64)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=32, num_blocks=1, image_size1=256,image_size2=256):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=True)
        self.conv1 = ConvBlock(num_channels,num_channels*2,kernel_size=3,stride=1)
        self.conv2 = ConvBlock(num_channels*2,num_channels*4,kernel_size=3,stride=1)
        self.residuals1 = nn.Sequential(*[ResidualBlock(num_channels*4) for _ in range(num_blocks)])
        self.residuals2 = nn.Sequential(*[ResidualBlock(num_channels*4) for _ in range(num_blocks)])
        self.residuals3 = nn.Sequential(*[ResidualBlock(num_channels*4) for _ in range(num_blocks)])
        self.deconconvblock1 = DeconvulationBlock(num_channels*2, num_channels*4, num_channels*4)
        self.deconconvblock2 = DeconvulationBlock(32, image_size1, image_size2 )
        self.final = nn.Conv2d(32, in_channels, kernel_size=9, stride=1,padding =4)

    def forward(self, x):

        input = x
        initial = self.initial(x)
        print(initial.shape)
        x = self.conv1(initial)
        x = self.conv2(x)

        x = self.residuals1(x)
        x = self.residuals2(x)
        x = self.residuals3(x)
        

        x = self.deconconvblock1(x)
        x = self.deconconvblock2(x)
        print(x.shape)

        x = x+initial
        print(x.shape)

        x = self.final(x)

        print(x.shape)

        x = x+input
        

        return torch.tanh(x)


In [None]:
        low_resolution = 24  # 96x96 -> 24x24
        x = torch.randn((1,3,256,256))
        gen =  Generator()
        gen_out = gen(x)
        # disc = Discriminator()
        # disc_out = disc(gen_out)

        print(gen_out.shape)
        # print(disc_out.shape)

torch.Size([1, 32, 256, 256])
torch.Size([1, 32, 256, 256])
torch.Size([1, 32, 256, 256])
torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])


In [None]:
from torch.nn.modules.flatten import Flatten
class Discriminator(nn.Module):
  def __init__(self,  in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
      super().__init__()
      blocks= []
      for idx,feature in enumerate(features):
        blocks.append(
            ConvBlock(
                in_channels,
                feature,
                kernel_size = 3,
                stride =1 +idx%2,
                padding = 1,
                discriminator= True,
                use_act = True,
                use_bn = False if idx ==0 else True,

            )
        )
        in_channels = feature

      self.blocks = nn.Sequential(*blocks)

      self.classifier = nn.Sequential(
          nn.AdaptiveAvgPool2d((6,6)),
          nn.Flatten(),
          nn.Linear(512*6*6, 1024),
          nn.LeakyReLU(0.2, inplace=True),
          nn.Linear(1024,1)
      )


  def forward(self,x):
    x = self.blocks(x)
    return self.classifier(x)

In [None]:
disc = Discriminator()

In [None]:
disc_out= disc(gen_out)

In [None]:
print(disc_out.shape)

torch.Size([1, 1])


In [None]:
from torchvision.models import vgg19
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
DEVICE

'cuda'

In [None]:
class VGGLoss(nn.Module):
  def __init__(self) -> None:
      super().__init__()

      self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)

      self.loss = nn.MSELoss()


      for param in self.vgg.parameters():
        param.requires_grad = False

  def forward(self, input, target):
    vgg_input_features = self.vgg(input)
    vgg_target_features = self.vgg(target)
    return self.loss(vgg_input_features, vgg_input_features)

In [None]:
vgg_loss = VGGLoss()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [None]:
from PIL import Image
import os
from torch.utils.data import Dataset
import numpy as np

In [None]:
class DenoiseDataset(Dataset):
    def __init__(self, train_noise, train_real, transform=None):
        self.train_noise = train_noise
        self.train_real = train_real
        self.transform = transform

        self.train_noise_images = os.listdir(train_noise)
        self.train_real_images = os.listdir(train_real)
        self.length_dataset = max(len(self.train_noise_images), len(self.train_real_images)) # 1000, 1500
        self.train_noise_len = len(self.train_noise_images)
        self.train_real_len = len(self.train_real_images)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        train_noise_img = self.train_noise_images[index % self.train_noise_len]
        train_real_img = self.train_real_images[index % self.train_real_len]

        train_noise_path = os.path.join(self.train_noise, train_noise_img)
        train_real_path = os.path.join(self.train_real, train_real_img)

        train_noise_img = np.array(Image.open(train_noise_path).convert("RGB"))
        train_real_img = np.array(Image.open(train_real_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image=train_noise_img, image0=train_real_img)
            train_noise_img = augmentations["image"]
            train_real_img = augmentations["image0"]

        return train_noise_img, train_real_img

In [None]:
from albumentations.pytorch import ToTensorV2
import albumentations as A

In [None]:

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image

In [None]:
def test():
    dataset = DenoiseDataset(
      train_noise = '/content/drive/MyDrive/train/train', train_real='/content/drive/MyDrive/train_cleaned/train_cleaned', transform=
transforms
    )
    loader = DataLoader(dataset, batch_size=1, num_workers=4)

    for low_res, high_res in loader:
        print(low_res.shape)
        print(high_res.shape)

In [None]:
os.listdir('/content/drive/MyDrive/train')

['train']

In [None]:
from torchvision.utils import save_image

In [None]:
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake.detach() * (1 - alpha)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty



In [None]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


In [None]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr



In [None]:
def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
        image = Image.open("/content/drive/MyDrive/test/" + file)
        with torch.no_grad():
            upscaled_img = gen(
                test_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
    gen.train()

In [None]:
test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

In [None]:
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "/content/drive/MyDrive/gen.pth.tar"
CHECKPOINT_DISC = "/content/drive/MyDrive/disc.pth.tar"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
BATCH_SIZE = 16
NUM_WORKERS = 4
HIGH_RES = 96
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

In [None]:
from tqdm import tqdm
from torch import optim


In [None]:
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)
        
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())
        disc_loss_real = bce(
            disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
        )
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        disc_fake = disc(fake)
        #l2_loss = mse(fake, high_res)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            plot_examples("/content/drive/MyDrive/test", gen)



In [None]:
def main():
   dataset = DenoiseDataset(
      train_noise = '/content/drive/MyDrive/train/train', train_real='/content/drive/MyDrive/train_cleaned/train_cleaned', transform=
transforms
    )
   loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=True,
        pin_memory=True,
        num_workers=8,
    )
   gen = Generator(in_channels=3).to(DEVICE)
   disc = Discriminator(in_channels=3).to(DEVICE)
   opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
   opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
   mse = nn.MSELoss()
   bce = nn.BCEWithLogitsLoss()
   vgg_loss = VGGLoss()

  #  if LOAD_MODEL:
  #       load_checkpoint(
  #           CHECKPOINT_GEN,
  #           gen,
  #           opt_gen,
  #           LEARNING_RATE,
  #       )
  #       load_checkpoint(
  #          CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
  #       )

   for epoch in range(NUM_EPOCHS):
        train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)

        if SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

In [None]:
if __name__ == "__main__":
  main()

    

  0%|          | 0/144 [00:00<?, ?it/s]

torch.Size([1, 32, 256, 256])
torch.Size([1, 32, 256, 256])
torch.Size([1, 32, 256, 256])
torch.Size([1, 3, 256, 256])


  0%|          | 0/144 [00:02<?, ?it/s]


ValueError: ignored

In [None]:
import tarfile
import os.path

def make_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))

In [None]:
make_tarfile( 'gen.pth', '/content/drive/MyDrive')

In [None]:
make_tarfile( 'disc.pth', '/content/drive/MyDrive')

In [None]:
torch.cuda.empty_cache()

In [None]:
import gc
del variables
gc.collect()

NameError: ignored

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)

