In [None]:
!pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl
!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchtext==0.10.0 -f https://download.pytorch.org/whl/cu111/torch_stable.html
!pip install torch torchvision torchaudio pytorch-lightning
!pip install -r requirements.txt

In [None]:
!wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/night2day.tar.gz
!tar -xvf night2day.tar.gz

In [None]:
from PIL import Image
import numpy
import matplotlib.pyplot as plt
import cv2
import glob
import shutil
import os
import torch
import torch.optim as optim
from torch import nn
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from tqdm.notebook import tqdm
import matplotlib
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning import Trainer

In [None]:
shutil.rmtree('/content/night2day') 

In [None]:
class DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=[8,10]):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        dir_list = os.listdir('/content/night2day/train/')
        for i in range(len(dir_list)):
          os.rename(f'/content/night2day/train/'+dir_list[i],f'/content/night2day/train/'+str(i)+".jpg")
        dir_list = os.listdir('/content/night2day/test/')
        for i in range(len(dir_list)):
          os.rename(f'/content/night2day/test/'+dir_list[i],f'/content/night2day/test/'+str(i)+".jpg")
        dir_list = os.listdir('/content/night2day/val/')
        for i in range(len(dir_list)):
          os.rename(f'/content/night2day/val/'+dir_list[i],f'/content/night2day/val/'+str(i)+".jpg")

    def setup(self):
      class data(Dataset):#transform
        def __init__(self, path='/content/night2day/train/'):
          self.filenames = glob.glob(path+'*.jpg')
      
        def __len__(self):
          return len(self.filenames)
  
        def __getitem__(self, idx):
          filename = self.filenames[idx]
          image = cv2.imread(filename)
          image_width = image.shape[1]       
          image_width = image_width // 2
          real = image[:, :image_width, :]
          condition = image[:, image_width:, :]
          transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((256,256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
          ])
          real=transform(real)
          condition=transform(condition)
          return real, condition
      self.train_dataset = data()
      self.val_dataset = data(path='/content/night2day/val/')
      os.mkdir('/content/val')
      os.mkdir('/content/val/night')
      os.mkdir('/content/val/day')
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=8, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=10, shuffle=False) # batch size-? (5 || 10)

In [None]:
class Block_conv(nn.Module):
  def __init__(self,inp,out,pool=True,norm=True):
    super(Block_conv,self).__init__()
    self.conv=nn.Sequential(
            nn.Conv2d(inp, out, 3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2,stride=2) if pool else nn.Identity(),
            nn.BatchNorm2d(out) if norm else nn.Identity()
    )
  def forward(self,x):
     return self.conv(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self,features=[64, 128, 256, 512]):
    super().__init__()
    l=[]
    p=3
    for i in features:
      l.append(Block_conv(p,i))
      p=i
    self.conv=nn.Sequential(*l)#512*16*16
    self.pool=nn.MaxPool2d(16)#512*1*1
    self.linear=nn.Sequential(
        nn.Flatten(),
        nn.Linear(512,256),
        nn.ReLU(inplace=True),
        nn.Dropout(p=0.5, inplace=False),
        nn.Linear(256,64),
        nn.ReLU(inplace=True),
        nn.Linear(64,16),
        nn.ReLU(inplace=True),
        nn.Linear(16,1),
    )
  def forward(self,x):
    x=self.conv(x)
    x=self.pool(x)
    x=self.linear(x)
    return torch.sigmoid(x)

In [None]:
class Generator(nn.Module):
  def __init__(self,features=[64, 128, 256]):
        super().__init__()
        self.enc_conv0 =nn.Sequential(
            Block_conv(3,features[0],False),
            Block_conv(features[0],features[0],False,False)
        )
        self.pool = nn.MaxPool2d(2,stride=2)  # 256 -> 128
        self.upsample =nn.Upsample(scale_factor=2)
        self.enc_conv1 = nn.Sequential(
            Block_conv(features[0],features[1],False),
            Block_conv(features[1],features[1],False,False)
        )
        # self.pool1 =  # 128 -> 64
        self.enc_conv2 =  nn.Sequential(
            Block_conv(features[1],features[2],False),
            Block_conv(features[2],features[2],False,False)
        )
        # self.pool2 =  # 64 -> 16

        # bottleneck
        self.bottleneck_conv = nn.Sequential(
            nn.Conv2d(features[2], features[2], 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[2], features[2], 1),
            nn.ReLU(inplace=True)
        )

        # decoder (upsampling)
        # self.upsample0 =  # 16 -> 64
        self.dec_conv0 = nn.Sequential(
            Block_conv(features[2]*2,features[2],False),
            Block_conv(features[2],features[1],False,False)
        )
        # self.upsample1 =  # 64 -> 128
        self.dec_conv1 = nn.Sequential(
            Block_conv(features[1]*2,features[1],False),
            Block_conv(features[1],features[0],False,False)
        ) 
        # self.upsample2 =   # 128 -> 256
        self.dec_conv2 =nn.Sequential(
            Block_conv(features[0]*2,features[0],False),
            Block_conv(features[0],3,False,False)
        ) 

  def forward(self, x):
        # encoder
        e0 =self.enc_conv0(x)
        e1 =self.enc_conv1(self.pool(e0))
        e2 =self.enc_conv2(self.pool(e1))

        # bottleneck
        b = self.bottleneck_conv(self.pool(e2))

        # decoder
        d0 = self.dec_conv0(torch.cat((e2,self.upsample(b)),dim=1))
        d1 = self.dec_conv1(torch.cat((e1,self.upsample(d0)),dim=1))
        d2 = self.dec_conv2(torch.cat((e0,self.upsample(d1)),dim=1))
        return  torch.tanh(d2)

In [None]:
class GAN(pl.LightningModule):
    def __init__(self,disc_N,disc_D,gen_D,gen_N):
            super().__init__()
            self.disc_N=disc_N
            self.disc_D=disc_D
            self.gen_D=gen_D
            self.gen_N=gen_N
            self.L1 = nn.L1Loss()
            self.mse = nn.MSELoss()
            self.g_loss_glob=1e9
            self.automatic_optimization = False
  
      # Настраиваются параметры обучения
    def training_step(self, batch, batch_idx):
            x, y = batch
            opt_disc, opt_gen=self.optimizers()
            night,day=batch
            fake_night = self.gen_N(day)
            D_N_real = self.disc_N(night)
            D_N_fake = self.disc_N(fake_night.detach())
            D_N_real_loss = self.mse(D_N_real, torch.ones_like(D_N_real))
            D_N_fake_loss = self.mse(D_N_fake, torch.zeros_like(D_N_fake))
            D_N_loss = D_N_real_loss + D_N_fake_loss

            fake_day = self.gen_D(night)
            D_D_real = self.disc_D(day)
            D_D_fake = self.disc_D(fake_day.detach())
            D_D_real_loss = self.mse(D_D_real, torch.ones_like(D_D_real))
            D_D_fake_loss = self.mse(D_D_fake, torch.zeros_like(D_D_fake))
            D_D_loss = D_D_real_loss + D_D_fake_loss
            D_loss = (D_N_loss + D_D_loss)/2

            opt_disc.zero_grad()
            D_loss.backward()
            opt_disc.step()

            D_D_fake = self.disc_D(fake_day)
            D_N_fake = self.disc_N(fake_night)
            loss_G_D = self.mse(D_D_fake, torch.ones_like(D_D_fake))
            loss_G_N = self.mse(D_N_fake, torch.ones_like(D_N_fake))

            cycle_night = self.gen_N(fake_day)
            cycle_day = self.gen_D(fake_night)
            cycle_night_loss = self.L1(night, cycle_night)
            cycle_day_loss = self.L1(day, cycle_day)

            G_loss = (
              loss_G_D
              + loss_G_N
              + cycle_night_loss * 10
              + cycle_day_loss * 10
            )
            opt_gen.zero_grad()
            G_loss.backward()
            opt_gen.step()

            return self.log_dict({"G_loss": G_loss, "D_loss": D_loss}, prog_bar=True)

    def validation_step(self, batch, batch_idx):
            night,day=batch

            fake_night = self.gen_N(day)
            fake_day = self.gen_D(night)
            
            place_n='/content/val/night/'+str(self.current_epoch)
            place_d='/content/val/day/'+str(self.current_epoch)
            try:
              os.mkdir(place_n)
              os.mkdir(place_d)
            except Exception:
              pass
            # fake_day_np = ((torch.permute(fake_day[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            # fake_night_np=((torch.permute(fake_night[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            # tr_day=((torch.permute(day[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            # tr_night=((torch.permute(night[0],(1,2,0)).to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            # Image.fromarray(tr_day).save(place_d+'/tr.jpg')
            # Image.fromarray(tr_night).save(place_n+'/tr.jpg')
            # Image.fromarray(fake_day_np).save(place_d+'/l.jpg')
            # Image.fromarray(fake_night_np).save(place_n+'/l.jpg')
            fake_day_np = ((fake_day.to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            fake_night_np=((fake_night.to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            tr_day=((day.to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            tr_night=((night.to('cpu').numpy()*0.5+0.5)*255).astype(np.uint8)
            for i in range(fake_day.shape[0]):
              Image.fromarray(tr_day[i].transpose(1,2,0)).save(place_d+'/tr'+str(i)+'.jpg')
              Image.fromarray(tr_night[i].transpose(1,2,0)).save(place_n+'/tr'+str(i)+'.jpg')
              Image.fromarray(fake_day_np[i].transpose(1,2,0)).save(place_d+'/l'+str(i)+'.jpg')
              Image.fromarray(fake_night_np[i].transpose(1,2,0)).save(place_n+'/l'+str(i)+'.jpg')
            D_D_fake = self.disc_D(fake_day)
            D_N_fake = self.disc_N(fake_night)
            loss_G_D = self.mse(D_D_fake, torch.ones_like(D_D_fake))
            loss_G_N = self.mse(D_N_fake, torch.ones_like(D_N_fake))
            cycle_night = self.gen_N(fake_day)
            cycle_day = self.gen_D(fake_night)
            cycle_night_loss = self.L1(night, cycle_night)
            cycle_day_loss = self.L1(day, cycle_day)

            G_loss = (
              loss_G_D
              + loss_G_N
              + cycle_night_loss * 10
              + cycle_day_loss * 10
            )
            if self.g_loss_glob>G_loss:
                torch.save(gen_N.state_dict(), "./gen_N_best.pth")
                torch.save(gen_D.state_dict(), "./gen_D_best.pth")
                self.g_loss_glob=G_loss
            torch.save(gen_N.state_dict(), "./gen_N.pth")
            torch.save(gen_D.state_dict(), "./gen_D.pth")    
            return self.log_dict({"G_loss": G_loss}, prog_bar=True)

    def configure_optimizers(self):
        opt_gen = torch.optim.Adam(list(self.gen_N.parameters()) + list(self.gen_D.parameters()), lr=1e-4)
        opt_disc = torch.optim.Adam(list(self.disc_N.parameters()) + list(self.disc_D.parameters()), lr=1e-4)
        return opt_disc, opt_gen

In [None]:
disc_N=Discriminator()
disc_D=Discriminator()
gen_D=Generator()
gen_N=Generator()
dm = DataModule()
dm.prepare_data()
dm.setup()

In [None]:
disc_N.load_state_dict(torch.load('/content/gen_N_best.pth'))
disc_D.load_state_dict(torch.load('/content/gen_D_best.pth'))

In [None]:
model=GAN(disc_N,disc_D,gen_D,gen_N)
trainer = pl.Trainer(tpu_cores=8, precision=16)
trainer.fit(model,dm)

  f"You passed `Trainer(accelerator='tpu', precision=16)` but {self.amp_type.value} AMP"
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
IPU available: False, using: 0 IPUs
  f"DataModule.{name} has already been called, so it will not be called again. "
Missing logger folder: /content/lightning_logs

  | Name   | Type          | Params
-----------------------------------------
0 | disc_N | Discriminator | 1.7 M 
1 | disc_D | Discriminator | 1.7 M 
2 | gen_D  | Generator     | 3.2 M 
3 | gen_N  | Generator     | 3.2 M 
4 | L1     | L1Loss        | 0     
5 | mse    | MSELoss       | 0     
-----------------------------------------
9.8 M     Trainable params
0         Non-trainable params
9.8 M     Total params
39.198    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]