In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


# Imports

In [2]:
# Import necessary packages
import os 
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensor
from torch.types import Number
import sys
from torch.utils.data import DataLoader
import torch.optim as optim
from tqdm import tqdm
from torchvision.utils import save_image
import albumentations as A

# Constant Variables

In [14]:
GROUND_PATH = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/DataSet/testA"
SEGMENT_PATH = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/DataSet/testB"

GROUND_PATH_TEST = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/DataSet/testA"
SEGMENT_PATH_TEST = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/DataSet/testB"


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
WEIGHTS_GEN_G = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/GEN_G"
WEIGHTS_GEN_S = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/GEN_S"
WEIGHTS_DISC_G = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/DISC_G"
WEIGHTS_DISC_S = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/DISC_S"
SAVE_IMAGE = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Images"


transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensor(),
     ],
    additional_targets={"image": "image",
                        'image1': 'image',},
)

transforms2 = A.Compose(
    [
        A.Resize(width=256, height=256),
        ToTensor(),
     ],
)
BATCH_SIZE = 1
NUM_WORKERS = 4
NUM_EPOCHS = 1000

In [None]:
NUM_EPOCHS = 1000

# Functions

In [15]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    # optimizer.load_state_dict(checkpoint["optimizer"])

    # for param_group in optimizer.param_groups:
    #     param_group["lr"] = lr


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)



# DataSet

In [16]:

class CityscapesDataSet(Dataset):
    def __init__(self, SegmentationPath, GroundImagePath, transform=None):
        self.GroundImagePath = GroundImagePath
        self.SegmentationPath = SegmentationPath
        self.transform = transform

        self.GroundImages = os.listdir(GroundImagePath)
        self.SegmentImages = os.listdir(SegmentationPath)
        self.length_dataset = max(len(self.GroundImages), len(self.SegmentImages)) 
        self.Gr_len = len(self.GroundImages)
        self.Sg_len = len(self.SegmentImages)

    def __len__(self):
        return self.length_dataset

    def __getitem__(self, index):
        Ground = self.GroundImages[index % self.Gr_len]
        Segment = self.SegmentImages[index % self.Sg_len]


        Ground = np.array(Image.open(os.path.join(self.GroundImagePath, Ground)).convert("RGB"))
        Segment = np.array(Image.open(os.path.join(self.SegmentationPath, Segment)).convert("RGB"))

        if self.transform:
            Ground = self.transform(image=Ground)
            Segment = self.transform(image=Segment)
            Ground = Ground['image']
            Segment = Segment['image']

        return Segment, Ground

# Generator

In [17]:
class ConvBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, padding, down=True, use_act=True, **kwargs):
        super().__init__()
        if down:
          self.conv1 = nn.Conv2d(in_channel, out_channel, padding_mode="reflect", kernel_size=kernel_size, padding=padding, **kwargs)
        else:
          self.conv1 = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=kernel_size, padding=padding, **kwargs)
        self.IN = nn.InstanceNorm2d(out_channel)
        if use_act:
          self.acv = nn.ReLU(inplace=True)
        else:
          self.acv = nn.Identity()
    def forward(self, x):
      x = self.conv1(x)
      x = self.IN(x)
      return  self.acv(x)

In [18]:
class ResBlock(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.Conv1 = ConvBlock(channels, channels, kernel_size=3, padding=1)
    self.Conv2 = ConvBlock(channels, channels, kernel_size=3, padding=1, use_act=False)
  
  def forward(self, x):
    x2 = self.Conv1(x)
    x2 = self.Conv2(x2)
    return x + x2

In [19]:

class Generator(nn.Module):
  def __init__(self, img_channels, num_features=64, num_residual=9):
    super().__init__()
    self.conv1 = nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
    self.acv1 = nn.ReLU(inplace=True)

    self.DownBlock = nn.ModuleList([
        nn.Conv2d(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
        nn.Conv2d(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
    ])

    self.residual_blocks = nn.Sequential(
        *[ResBlock(num_features*4) for _ in range(num_residual)]
    )

    self.up_blocks = nn.ModuleList([
        ConvBlock(num_features*4, num_features*2, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
        ConvBlock(num_features*2, num_features*1, kernel_size=3, stride=2, padding=1, output_padding=1, down=False),
    ])

    self.last_layer = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

  def forward(self, x):
    x = self.conv1(x)
    x = self.acv1(x)
    for layer in self.DownBlock:
      x = layer(x)

    for layer in self.up_blocks:
      x = layer(x)
    return torch.tanh(self.last_layer(x))
    


# Discriminator

In [20]:
class DiscBlock(nn.Module):
  def __init__(self, in_channel, out_channel, stride):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, out_channel, 4, stride, 1, bias=True, padding_mode="reflect")
    self.IN1 = nn.InstanceNorm2d(out_channel)
    self.act1 = nn.LeakyReLU(0.2)
  def forward(self, x):
    x = self.conv1(x)
    x = self.IN1(x)
    x = self.act1(x)
    return x

In [21]:
class Discriminator(nn.Module):
  def __init__(self, in_channel, num_features=[64, 128, 256, 512]):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channel, num_features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect")
    self.act1 = nn.LeakyReLU(0.2)

    layers = []
    in_channel = num_features[0]
    
    for counter in range(1,len(num_features)):

      layers.append(DiscBlock(in_channel, num_features[counter], stride=1 if counter==(len(num_features)-1) else 2))
      in_channel = num_features[counter]
    layers.append(nn.Conv2d(in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"))
    self.model = nn.Sequential(*layers)


  def forward(self, x):
    x = self.conv1(x)
    x = self.act1(x)
    x = self.model(x)
    return torch.sigmoid(x)

   

# Train

In [22]:
def TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler):
  H_reals = 0
  H_fakes = 0

  loop_prop = tqdm(loader, leave=True)
  index2 = 0
  for index, (segm, grd) in enumerate(loop_prop):
    grd = grd.to(DEVICE)
    segm = segm.to(DEVICE)
    
    
    with torch.cuda.amp.autocast():
        fake_ground = Gen_Ground(segm)
        D_H_real = Disc_Ground(grd)
        D_H_fake = Disc_Ground(fake_ground.detach())
        H_reals += D_H_real.mean().item()
        H_fakes += D_H_fake.mean().item()
        D_H_real_loss = MSE(D_H_real, torch.ones_like(D_H_real))
        D_H_fake_loss = MSE(D_H_fake, torch.zeros_like(D_H_fake))
        D_H_loss = D_H_real_loss + D_H_fake_loss
       
        fake_seg = Gen_Seg(grd)
        D_Z_real = Disc_Seg(segm)
        D_Z_fake = Disc_Seg(fake_seg.detach())
        D_Z_real_loss = MSE(D_Z_real, torch.ones_like(D_Z_real))
        D_Z_fake_loss = MSE(D_Z_fake, torch.zeros_like(D_Z_fake))
        D_Z_loss = D_Z_real_loss + D_Z_fake_loss

        
        D_loss = (D_H_loss + D_Z_loss)/2

    optim_Disc.zero_grad()
    Disc_scaler.scale(D_loss).backward()
    Disc_scaler.step(optim_Disc)
    Disc_scaler.update()

    # Train Generators H and Z
    with torch.cuda.amp.autocast():
        # adversarial loss for both generators
        D_H_fake = Disc_Ground(fake_ground)
        D_Z_fake = Disc_Seg(fake_seg)
        loss_G_H = MSE(D_H_fake, torch.ones_like(D_H_fake))
        loss_G_Z = MSE(D_Z_fake, torch.ones_like(D_Z_fake))

        # cycle loss
        cycle_segm = Gen_Seg(fake_ground)
        cycle_ground = Gen_Ground(fake_seg)
        cycle_segm_loss = L1(segm, cycle_segm)
        cycle_ground_loss = L1(grd, cycle_ground)

        # identity loss (remove these for efficiency if you set lambda_identity=0)
        identity_segm = Gen_Seg(segm)
        identity_ground = Gen_Ground(grd)
        identity_segm_loss = L1(segm, identity_segm)
        identity_ground_loss = L1(grd, identity_ground)

        # add all togethor
        G_loss = (
            loss_G_Z
            + loss_G_H
            + cycle_segm_loss * LAMBDA_CYCLE
            + cycle_ground_loss * LAMBDA_CYCLE
            + identity_ground_loss * LAMBDA_IDENTITY
            + identity_segm_loss * LAMBDA_IDENTITY
        )

    optim_Gen.zero_grad()
    Gen_scaler.scale(G_loss).backward()
    Gen_scaler.step(optim_Gen)
    Gen_scaler.update()

    if index2 % 200 == 0:
        save_image(fake_ground, SAVE_IMAGE + f"/Ground/{index2}.png")
        save_image(fake_seg, SAVE_IMAGE + f"/Segment/{index2}.png")
    index2 += 1
    loop_prop.set_postfix(H_real=H_reals/(index+1), H_fake=H_fakes/(index+1))



# Main Function

In [23]:

LOAD_MODEL=True
SAVE_MODEL=True


Disc_Ground = Discriminator(in_channel=3).to(DEVICE)
Disc_Seg = Discriminator(in_channel=3).to(DEVICE)
Gen_Ground = Generator(img_channels=3, num_residual=9).to(DEVICE)
Gen_Seg = Generator(img_channels=3, num_residual=9).to(DEVICE)





optim_Disc = optim.Adam(list(Disc_Ground.parameters()) + list(Disc_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)
optim_Gen = optim.Adam(list(Disc_Ground.parameters()) + list(Gen_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)

L1 = nn.L1Loss()
MSE = nn.MSELoss()
if LOAD_MODEL:

  Disc_Ground.load_state_dict(torch.load(WEIGHTS_DISC_G))
  Disc_Seg.load_state_dict(torch.load(WEIGHTS_DISC_S))
  Gen_Ground.load_state_dict(torch.load(WEIGHTS_GEN_G))
  Gen_Seg.load_state_dict(torch.load(WEIGHTS_GEN_S))
    # load_checkpoint(
    #     WEIGHTS_GEN_G, Gen_Ground, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_GEN_S, Gen_Seg, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_G, Disc_Ground, optim_Disc, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_S, Disc_Seg, optim_Disc, LEARNING_RATE,
    # )
dataset = CityscapesDataSet(SEGMENT_PATH, GROUND_PATH, transforms2)
test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


test_loader = DataLoader(
  test_dataset,
  batch_size=1,
  shuffle=False,
  pin_memory=True,
)

Gen_scaler = torch.cuda.amp.GradScaler()
Disc_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler)


    if SAVE_MODEL:
      torch.save(Disc_Ground.state_dict(), WEIGHTS_DISC_G)
      torch.save(Disc_Seg.state_dict(), WEIGHTS_DISC_S)

      torch.save(Gen_Ground.state_dict(), WEIGHTS_GEN_G)
      torch.save(Gen_Seg.state_dict(), WEIGHTS_GEN_S)



  cpuset_checked))
100%|██████████| 500/500 [00:52<00:00,  9.57it/s, H_fake=0.532, H_real=0.617]
100%|██████████| 500/500 [00:53<00:00,  9.35it/s, H_fake=0.548, H_real=0.64]
100%|██████████| 500/500 [00:54<00:00,  9.23it/s, H_fake=0.557, H_real=0.65]
100%|██████████| 500/500 [00:54<00:00,  9.17it/s, H_fake=0.565, H_real=0.659]
100%|██████████| 500/500 [00:54<00:00,  9.11it/s, H_fake=0.571, H_real=0.667]
100%|██████████| 500/500 [00:55<00:00,  9.07it/s, H_fake=0.577, H_real=0.674]
100%|██████████| 500/500 [00:55<00:00,  9.03it/s, H_fake=0.583, H_real=0.682]
100%|██████████| 500/500 [00:55<00:00,  9.05it/s, H_fake=0.588, H_real=0.69]
100%|██████████| 500/500 [00:54<00:00,  9.16it/s, H_fake=0.592, H_real=0.697]
100%|██████████| 500/500 [00:54<00:00,  9.22it/s, H_fake=0.596, H_real=0.705]
100%|██████████| 500/500 [00:54<00:00,  9.14it/s, H_fake=0.6, H_real=0.709]
100%|██████████| 500/500 [00:54<00:00,  9.20it/s, H_fake=0.606, H_real=0.714]
100%|██████████| 500/500 [00:54<00:00,  9.20it/s, 

In [None]:
WEIGHTS_GEN_G = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/GEN_G"
WEIGHTS_GEN_S = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/GEN_S"
WEIGHTS_DISC_G = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/DISC_G"
WEIGHTS_DISC_S = "/content/drive/MyDrive/Courses/York Courses/Neural Networks and Deep Learning/Project/Weights/DISC_S"




In [24]:

LOAD_MODEL=True
SAVE_MODEL=True


Disc_Ground = Discriminator(in_channel=3).to(DEVICE)
Disc_Seg = Discriminator(in_channel=3).to(DEVICE)
Gen_Ground = Generator(img_channels=3, num_residual=9).to(DEVICE)
Gen_Seg = Generator(img_channels=3, num_residual=9).to(DEVICE)





optim_Disc = optim.Adam(list(Disc_Ground.parameters()) + list(Disc_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)
optim_Gen = optim.Adam(list(Disc_Ground.parameters()) + list(Gen_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)

L1 = nn.L1Loss()
MSE = nn.MSELoss()
if LOAD_MODEL:

  Disc_Ground.load_state_dict(torch.load(WEIGHTS_DISC_G))
  Disc_Seg.load_state_dict(torch.load(WEIGHTS_DISC_S))
  Gen_Ground.load_state_dict(torch.load(WEIGHTS_GEN_G))
  Gen_Seg.load_state_dict(torch.load(WEIGHTS_GEN_S))
    # load_checkpoint(
    #     WEIGHTS_GEN_G, Gen_Ground, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_GEN_S, Gen_Seg, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_G, Disc_Ground, optim_Disc, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_S, Disc_Seg, optim_Disc, LEARNING_RATE,
    # )
dataset = CityscapesDataSet(SEGMENT_PATH, GROUND_PATH, transforms2)
test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


test_loader = DataLoader(
  test_dataset,
  batch_size=1,
  shuffle=False,
  pin_memory=True,
)

Gen_scaler = torch.cuda.amp.GradScaler()
Disc_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler)


    if SAVE_MODEL:
      torch.save(Disc_Ground.state_dict(), WEIGHTS_DISC_G)
      torch.save(Disc_Seg.state_dict(), WEIGHTS_DISC_S)

      torch.save(Gen_Ground.state_dict(), WEIGHTS_GEN_G)
      torch.save(Gen_Seg.state_dict(), WEIGHTS_GEN_S)




  cpuset_checked))
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.752, H_real=0.831]
100%|██████████| 500/500 [00:53<00:00,  9.29it/s, H_fake=0.754, H_real=0.832]
100%|██████████| 500/500 [00:53<00:00,  9.27it/s, H_fake=0.755, H_real=0.833]
100%|██████████| 500/500 [00:53<00:00,  9.29it/s, H_fake=0.756, H_real=0.833]
100%|██████████| 500/500 [00:54<00:00,  9.25it/s, H_fake=0.756, H_real=0.833]
100%|██████████| 500/500 [00:54<00:00,  9.21it/s, H_fake=0.757, H_real=0.833]
100%|██████████| 500/500 [00:54<00:00,  9.24it/s, H_fake=0.757, H_real=0.833]
100%|██████████| 500/500 [00:54<00:00,  9.21it/s, H_fake=0.758, H_real=0.834]
100%|██████████| 500/500 [00:54<00:00,  9.19it/s, H_fake=0.759, H_real=0.834]
100%|██████████| 500/500 [00:53<00:00,  9.29it/s, H_fake=0.759, H_real=0.834]
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.76, H_real=0.834]
100%|██████████| 500/500 [00:54<00:00,  9.22it/s, H_fake=0.761, H_real=0.834]
100%|██████████| 500/500 [00:54<00:00,  9.25it

In [25]:

LOAD_MODEL=True
SAVE_MODEL=True


Disc_Ground = Discriminator(in_channel=3).to(DEVICE)
Disc_Seg = Discriminator(in_channel=3).to(DEVICE)
Gen_Ground = Generator(img_channels=3, num_residual=9).to(DEVICE)
Gen_Seg = Generator(img_channels=3, num_residual=9).to(DEVICE)





optim_Disc = optim.Adam(list(Disc_Ground.parameters()) + list(Disc_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)
optim_Gen = optim.Adam(list(Disc_Ground.parameters()) + list(Gen_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)

L1 = nn.L1Loss()
MSE = nn.MSELoss()
if LOAD_MODEL:

  Disc_Ground.load_state_dict(torch.load(WEIGHTS_DISC_G))
  Disc_Seg.load_state_dict(torch.load(WEIGHTS_DISC_S))
  Gen_Ground.load_state_dict(torch.load(WEIGHTS_GEN_G))
  Gen_Seg.load_state_dict(torch.load(WEIGHTS_GEN_S))
    # load_checkpoint(
    #     WEIGHTS_GEN_G, Gen_Ground, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_GEN_S, Gen_Seg, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_G, Disc_Ground, optim_Disc, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_S, Disc_Seg, optim_Disc, LEARNING_RATE,
    # )
dataset = CityscapesDataSet(SEGMENT_PATH, GROUND_PATH, transforms2)
test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


test_loader = DataLoader(
  test_dataset,
  batch_size=1,
  shuffle=False,
  pin_memory=True,
)

Gen_scaler = torch.cuda.amp.GradScaler()
Disc_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler)


    if SAVE_MODEL:
      torch.save(Disc_Ground.state_dict(), WEIGHTS_DISC_G)
      torch.save(Disc_Seg.state_dict(), WEIGHTS_DISC_S)

      torch.save(Gen_Ground.state_dict(), WEIGHTS_GEN_G)
      torch.save(Gen_Seg.state_dict(), WEIGHTS_GEN_S)




  cpuset_checked))
100%|██████████| 500/500 [00:53<00:00,  9.32it/s, H_fake=0.783, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.32it/s, H_fake=0.783, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.784, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.784, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.31it/s, H_fake=0.785, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.31it/s, H_fake=0.785, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.785, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.34it/s, H_fake=0.785, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.36it/s, H_fake=0.785, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.37it/s, H_fake=0.786, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.36it/s, H_fake=0.786, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.37it/s, H_fake=0.786, H_real=0.847]
100%|██████████| 500/500 [00:53<00:00,  9.38i

In [None]:

LOAD_MODEL=True
SAVE_MODEL=True


Disc_Ground = Discriminator(in_channel=3).to(DEVICE)
Disc_Seg = Discriminator(in_channel=3).to(DEVICE)
Gen_Ground = Generator(img_channels=3, num_residual=9).to(DEVICE)
Gen_Seg = Generator(img_channels=3, num_residual=9).to(DEVICE)





optim_Disc = optim.Adam(list(Disc_Ground.parameters()) + list(Disc_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)
optim_Gen = optim.Adam(list(Disc_Ground.parameters()) + list(Gen_Seg.parameters()), lr=LEARNING_RATE, betas=(0.5, 0.999),)

L1 = nn.L1Loss()
MSE = nn.MSELoss()
if LOAD_MODEL:

  Disc_Ground.load_state_dict(torch.load(WEIGHTS_DISC_G))
  Disc_Seg.load_state_dict(torch.load(WEIGHTS_DISC_S))
  Gen_Ground.load_state_dict(torch.load(WEIGHTS_GEN_G))
  Gen_Seg.load_state_dict(torch.load(WEIGHTS_GEN_S))
    # load_checkpoint(
    #     WEIGHTS_GEN_G, Gen_Ground, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_GEN_S, Gen_Seg, optim_Gen, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_G, Disc_Ground, optim_Disc, LEARNING_RATE,
    # )
    # load_checkpoint(
    #     WEIGHTS_DISC_S, Disc_Seg, optim_Disc, LEARNING_RATE,
    # )
dataset = CityscapesDataSet(SEGMENT_PATH, GROUND_PATH, transforms2)
test_dataset = CityscapesDataSet(SEGMENT_PATH_TEST, GROUND_PATH_TEST, transforms2)

loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


test_loader = DataLoader(
  test_dataset,
  batch_size=1,
  shuffle=False,
  pin_memory=True,
)

Gen_scaler = torch.cuda.amp.GradScaler()
Disc_scaler = torch.cuda.amp.GradScaler()

for epoch in range(NUM_EPOCHS):
    TrainCicleGAN(Disc_Ground, Gen_Seg, Disc_Seg, Gen_Ground, loader, optim_Disc, optim_Gen, L1, MSE, Disc_scaler, Gen_scaler)


    if SAVE_MODEL:
      torch.save(Disc_Ground.state_dict(), WEIGHTS_DISC_G)
      torch.save(Disc_Seg.state_dict(), WEIGHTS_DISC_S)

      torch.save(Gen_Ground.state_dict(), WEIGHTS_GEN_G)
      torch.save(Gen_Seg.state_dict(), WEIGHTS_GEN_S)




  cpuset_checked))
100%|██████████| 500/500 [00:53<00:00,  9.32it/s, H_fake=0.795, H_real=0.846]
100%|██████████| 500/500 [00:53<00:00,  9.30it/s, H_fake=0.795, H_real=0.846]
100%|██████████| 500/500 [00:54<00:00,  9.22it/s, H_fake=0.795, H_real=0.845]
100%|██████████| 500/500 [00:54<00:00,  9.22it/s, H_fake=0.796, H_real=0.845]
100%|██████████| 500/500 [00:54<00:00,  9.17it/s, H_fake=0.796, H_real=0.845]
100%|██████████| 500/500 [00:54<00:00,  9.12it/s, H_fake=0.796, H_real=0.844]
100%|██████████| 500/500 [00:54<00:00,  9.13it/s, H_fake=0.797, H_real=0.844]
100%|██████████| 500/500 [00:54<00:00,  9.14it/s, H_fake=0.797, H_real=0.844]
100%|██████████| 500/500 [00:54<00:00,  9.11it/s, H_fake=0.797, H_real=0.844]
100%|██████████| 500/500 [00:55<00:00,  9.08it/s, H_fake=0.797, H_real=0.845]
100%|██████████| 500/500 [00:54<00:00,  9.16it/s, H_fake=0.797, H_real=0.845]
100%|██████████| 500/500 [00:55<00:00,  9.07it/s, H_fake=0.797, H_real=0.845]
100%|██████████| 500/500 [00:55<00:00,  9.09i