In [None]:
from PIL import Image
import numpy
import matplotlib.pyplot as plt
import cv2
import glob
import shutil
import os
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from tqdm.notebook import tqdm
import matplotlib
import numpy as np

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/night2day.tar.gz
!tar -xvf night2day.tar.gz

In [None]:
dir_list = os.listdir('/content/night2day/train/')
for i in range(len(dir_list)):
  os.rename(f'/content/night2day/train/'+dir_list[i],f'/content/night2day/train/'+str(i)+".jpg")
dir_list = os.listdir('/content/night2day/test/')
for i in range(len(dir_list)):
  os.rename(f'/content/night2day/test/'+dir_list[i],f'/content/night2day/test/'+str(i)+".jpg")
dir_list = os.listdir('/content/night2day/val/')
for i in range(len(dir_list)):
  os.rename(f'/content/night2day/val/'+dir_list[i],f'/content/night2day/val/'+str(i)+".jpg")

In [None]:
class data(Dataset):#transform
   def __init__(self, path='/content/night2day/train/'):
       self.filenames = glob.glob(path+'*.jpg')
      
   def __len__(self):
       return len(self.filenames)
  
   def __getitem__(self, idx):
       filename = self.filenames[idx]
       image = cv2.imread(filename)
       image_width = image.shape[1]       
       image_width = image_width // 2
       real = image[:, :image_width, :]
       condition = image[:, image_width:, :]
       transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((256,256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
       ])
       real=transform(real)
       condition=transform(condition)
       return real, condition

In [None]:
train_dataset = data()
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)#transform

val_dataset = data(path='/content/night2day/val/')
val_loader = DataLoader(val_dataset, batch_size=10, shuffle=True)

In [None]:
next(iter(train_loader))[0][0].shape

torch.Size([3, 256, 256])

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.imshow(((torch.permute(next(iter(train_loader))[0][0], (1, 2, 0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8))

In [None]:
loader={ 'train':train_loader,'validate':val_loader}

In [None]:
class Block_conv(nn.Module):
  def __init__(self,inp,out,pool=True,norm=True):
    super(Block_conv,self).__init__()
    self.conv=nn.Sequential(
            nn.Conv2d(inp, out, 3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,stride=2) if pool else nn.Identity(),
            nn.BatchNorm2d(out) if norm else nn.Identity()
    )
  def forward(self,x):
     return self.conv(x)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
class Discriminator(nn.Module):
  def __init__(self,features=[64, 128, 256, 512]):
    super().__init__()
    l=[]
    p=3
    for i in features:
      l.append(Block_conv(p,i))
      p=i
    self.conv=nn.Sequential(*l)#512*16*16
    self.pool=nn.MaxPool2d(16)#512*1*1
    self.linear=nn.Sequential(
        nn.Flatten(),
        nn.Linear(512,256),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5, inplace=False),
        nn.Linear(256,64),
        nn.ReLU(inplace=True),
        nn.Linear(64,16),
        nn.ReLU(inplace=True),
        nn.Linear(16,1),
    )
  def forward(self,x):
    x=x.to(device)
    x=self.conv(x)
    x=self.pool(x)
    x=self.linear(x)
    return torch.sigmoid(x)

In [None]:
class Generator(nn.Module):
  def __init__(self,features=[64, 128, 256]):
        super().__init__()
        self.enc_conv0 =nn.Sequential(
            Block_conv(3,features[0],False),
            Block_conv(features[0],features[0],False,False)
        )
        self.pool = nn.MaxPool2d(2,stride=2) 
        self.upsample =nn.Upsample(scale_factor=2)
        self.enc_conv1 = nn.Sequential(
            Block_conv(features[0],features[1],False),
            Block_conv(features[1],features[1],False,False)
        )

        self.enc_conv2 =  nn.Sequential(
            Block_conv(features[1],features[2],False),
            Block_conv(features[2],features[2],False,False)
        )

        # bottleneck
        self.bottleneck_conv = nn.Sequential(
            nn.Conv2d(features[2], features[2], 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[2], features[2], 1),
            nn.ReLU(inplace=True)
        )


        self.dec_conv0 = nn.Sequential(
            Block_conv(features[2]*2,features[2],False),
            Block_conv(features[2],features[1],False,False)
        )


        self.dec_conv1 = nn.Sequential(
            Block_conv(features[1]*2,features[1],False),
            Block_conv(features[1],features[0],False,False)
        ) 

        self.dec_conv2 =nn.Sequential(
            Block_conv(features[0]*2,features[0],False),
            Block_conv(features[0],3,False,False)
        ) 

  def forward(self, x):
        x=x.to(device)
        # encoder
        e0 =self.enc_conv0(x)
        e1 =self.enc_conv1(self.pool(e0))
        e2 =self.enc_conv2(self.pool(e1))

        # bottleneck
        b = self.bottleneck_conv(self.pool(e2))

        # decoder
        d0 = self.dec_conv0(torch.cat((e2,self.upsample(b)),dim=1))
        d1 = self.dec_conv1(torch.cat((e1,self.upsample(d0)),dim=1))
        d2 = self.dec_conv2(torch.cat((e0,self.upsample(d1)),dim=1))
        return  torch.tanh(d2)

In [None]:
os.mkdir('/content/val')
os.mkdir('/content/val/night')
os.mkdir('/content/val/day')

In [None]:
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
L1 = nn.L1Loss()
mse = nn.MSELoss()

In [None]:
disc_N=Discriminator().to(device)
disc_D=Discriminator().to(device)
gen_D=Generator().to(device)
gen_N=Generator().to(device)

In [None]:
opt_disc = optim.Adam(
        list(disc_N.parameters()) + list(disc_D.parameters()),
        lr=1e-4,
)
opt_gen = optim.Adam(
        list(gen_N.parameters()) + list(gen_D.parameters()),
        lr=1e-4,
)

In [None]:
max_epochs = 7
for epoch in range(max_epochs):
  for k, dataloader in loader.items():
    for night, day in tqdm(dataloader, leave=False, desc=f"{k} iter:"):
      night=night.to(device)
      day=day.to(device)
      if k == "train":
        with torch.cuda.amp.autocast():
          fake_night = gen_N(day)
          D_N_real = disc_N(night)
          D_N_fake = disc_N(fake_night.detach())
          D_N_real_loss = mse(D_N_real, torch.ones_like(D_N_real))
          D_N_fake_loss = mse(D_N_fake, torch.zeros_like(D_N_fake))
          D_N_loss = D_N_real_loss + D_N_fake_loss

          fake_day = gen_D(night)
          D_D_real = disc_D(day)
          D_D_fake = disc_D(fake_day.detach())
          D_D_real_loss = mse(D_D_real, torch.ones_like(D_D_real))
          D_D_fake_loss = mse(D_D_fake, torch.zeros_like(D_D_fake))
          D_D_loss = D_D_real_loss + D_D_fake_loss

          D_loss = (D_N_loss + D_D_loss)/2
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        with torch.cuda.amp.autocast():
          D_D_fake = disc_D(fake_day)
          D_N_fake = disc_N(fake_night)
          loss_G_D = mse(D_D_fake, torch.ones_like(D_D_fake))
          loss_G_N = mse(D_N_fake, torch.ones_like(D_N_fake))

          cycle_night = gen_N(fake_day)
          cycle_day = gen_D(fake_night)
          cycle_night_loss = L1(night, cycle_night)
          cycle_day_loss = L1(day, cycle_day)

          G_loss = (
            loss_G_D
            + loss_G_N
            + cycle_night_loss * 10
            + cycle_day_loss * 10
          )
          opt_gen.zero_grad()
          g_scaler.scale(G_loss).backward()
          g_scaler.step(opt_gen)
          g_scaler.update()
      else:
        gen_D.eval() 
        gen_N.eval()
        with torch.no_grad():
          with torch.cuda.amp.autocast():
            fake_day = ((torch.permute(gen_D(night)[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            fake_night=((torch.permute(gen_N(day)[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            tr_day=((torch.permute(day[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            tr_night=((torch.permute(night[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            Image.fromarray(tr_day).save('/content/val/day/'+'tr'+str(epoch)+'.jpg')
            Image.fromarray(tr_night).save('/content/val/night/'+'tr'+str(epoch)+'.jpg')
            Image.fromarray(fake_day).save('/content/val/day/'+str(epoch)+'.jpg')
            Image.fromarray(fake_night).save('/content/val/night/'+str(epoch)+'.jpg')
      day.to('cpu')
      night.to('cpu')
      torch.cuda.empty_cache()
    if k == "train":
      print(f"Epoch: {epoch+1}")
    print(f"Loader: {k}")

train iter::   0%|          | 0/2228 [00:00<?, ?it/s]

KeyboardInterrupt: ignored