In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from generator import Generator
from discriminator import Discriminator
import configurations

INFO:albumentations.check_version:A new version of Albumentations is available: 1.4.8 (you have 1.4.7). Upgrade using: pip install --upgrade albumentations
  return torch._C._cuda_getDeviceCount() > 0


In [2]:
# Data augmentations and transformations
both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

class Satellite2Map_Data(Dataset):
    def __init__(self,root):
        self.root = root
        list_files = os.listdir(self.root)
        #### Removing '.ipynb_checkpoints' from the list
        #list_files.remove('.ipynb_checkpoints')
        self.n_samples = list_files
        
            
    
    def __len__(self):
        return len(self.n_samples)
    
    def __getitem__(self,idx):
        try:
            if torch.is_tensor(idx):
                idx = idx.tolist()
            image_name = self.n_samples[idx]
            #print(self.n_samples)
            image_path = os.path.join(self.root,image_name)
            image = np.asarray(Image.open(image_path).convert('RGB'))
            height, width,_ = image.shape
            width_cutoff = width // 2
            satellite_image = image[:, :width_cutoff,:]
            map_image = image[:, width_cutoff:,:]

            augmentations = both_transform(image=satellite_image, image0=map_image)
            input_image = augmentations["image"]
            target_image = augmentations["image0"]

            satellite_image = transform_only_input(image=input_image)["image"]
            map_image = transform_only_mask(image=target_image)["image"]

            return (satellite_image, map_image)
        except:
            if torch.is_tensor(idx):
                idx = idx.tolist()
            image_name = self.n_samples[idx]
            #print(self.n_samples)
            image_path = os.path.join(self.root,image_name)
            print(image_path)
            pass
    
    
            
if __name__=="__main__":
    dataset = Satellite2Map_Data("facades/facades/train")
    loader = DataLoader(dataset, batch_size=5)
    for x,y in loader:
        print("X Shape :-",x.shape)
        print("Y Shape :-",y.shape)
        save_image(x,"satellite.png")
        save_image(y,"map.png")
        break            

X Shape :- torch.Size([5, 3, 256, 256])
Y Shape :- torch.Size([5, 3, 256, 256])


In [3]:
# Saving examples and checkpoints
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(configurations.DEVICE), y.to(configurations.DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=configurations.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
        
disc_loss = []
gen_loss = []

In [4]:
def train(gen, disc, train_loader, optim_gen, optim_disc, l1_loss, bce_loss):
    loop = tqdm(train_loader)
    for idx, (x,y) in enumerate(loop):
        x = x.cuda()
        y = y.cuda()
        
        # Train Discriminator
        
        # Generator produces a fake image(of domain y) from input domain x
        '''y_fake = gen(x)
        # Passing real images to discriminator
        d_real = disc(x,y)
        # Calculating the bce loss and classifying all real images as 1
        d_real_loss = bce_loss(d_real, torch.ones_like(d_real))
        # Feeding generator's fake images to the discriminator
        d_fake = disc(x,y_fake.detach())
        # Calculating the bce loss and classifying all fake images as 0
        d_fake_loss = bce_loss(d_fake, torch.zeros_like(d_fake))
        d_loss = (d_real_loss + d_fake_loss) / 2
        
        disc.zero_grad()
        disc_loss.append(d_loss.item())
        d_loss.backward()
        optim_disc.step()'''
        
        # Train Discriminator (Critic)
        for _ in range(configurations.CRITIC_ITERATIONS):
            y_fake = gen(x)
            d_real = disc(x, y)
            d_fake = disc(x, y_fake.detach())
            d_loss = -(torch.mean(d_real) - torch.mean(d_fake))

            disc.zero_grad()
            disc_loss.append(d_loss.item())
            d_loss.backward()
            optim_disc.step()

            # Weight clipping
            for p in disc.parameters():
                p.data.clamp_(-configurations.CLIP_VALUE, configurations.CLIP_VALUE)
        
        # Train Generator
        # Pass the generated image to discriminator
        d_fake = disc(x, y_fake)
        # We want the discriminator to classify them as real
        #g_fake_loss = bce_loss(d_fake, torch.ones_like(d_fake))
        g_fake_loss = -torch.mean(d_fake)
        l1_loss_term = l1_loss(y_fake,y) * configurations.L1_LAMBDA
        g_loss = g_fake_loss + l1_loss_term
        optim_gen.zero_grad()
        gen_loss.append(g_loss.item())
        g_loss.backward()
        optim_gen.step()
        
        if idx % 10 == 0:
            loop.set_postfix(
                d_real=torch.sigmoid(d_real).mean().item(),
                d_fake=torch.sigmoid(d_fake).mean().item(),
            )

In [5]:
def main():
    disc = Discriminator(in_channels=3).cuda()
    gen = Generator(in_channels=3).cuda()
    optim_disc = torch.optim.Adam(disc.parameters(),lr=configurations.LEARNING_RATE,betas=(configurations.BETA1,0.999))
    optim_gen = torch.optim.Adam(gen.parameters(),lr=configurations.LEARNING_RATE,betas=(configurations.BETA1,0.999))
    bce_loss = nn.BCEWithLogitsLoss()
    lsgan_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()
    if configurations.LOAD_MODEL:
        load_checkpoint(
            configurations.CHECKPOINT_GEN,gen,optim_gen,configurations.LEARNING_RATE
        )
        load_checkpoint(
            configurations.CHECKPOINT_DISC,disc,optim_disc,configurations.LEARNING_RATE
        )
        
    train_dataset = Satellite2Map_Data(root=configurations.TRAIN_DIR)
    train_loader = DataLoader(train_dataset,batch_size=configurations.BATCH_SIZE,
                              shuffle=True,num_workers=configurations.NUM_WORKERS,pin_memory=True)
    
    val_dataset = Satellite2Map_Data(root=configurations.VAL_DIR)
    val_loader = DataLoader(val_dataset,batch_size=1,
                        shuffle=True,num_workers=configurations.NUM_WORKERS,pin_memory=True)
    
    for epoch in range(configurations.NUM_EPOCHS):
        train(
            gen, disc, train_loader, optim_gen, optim_disc, l1_loss, bce_loss
        )
        
        if configurations.SAVE_MODEL and epoch%50==0:
            print("Epoch: ",epoch)
            save_checkpoint(gen, optim_gen, filename=configurations.CHECKPOINT_GEN)
            save_checkpoint(gen, optim_disc, filename=configurations.CHECKPOINT_DISC)
        if epoch%2==0:
            save_some_examples(gen, val_loader, epoch, folder="eval")

In [None]:
main()

  0%|          | 0/25 [00:00<?, ?it/s]

In [17]:
print(configurations.NUM_EPOCHS)

1000


In [30]:
torch.cuda.is_available()

False