In [12]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension
from torch.utils.tensorboard import SummaryWriter
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
import matplotlib.pyplot as plt
import numpy as np

Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


# discriminator

In [13]:
class Discriminator(nn.Module):
    def __init__(self,in_channels,features=[64,128,256,512],) -> None:
        super(Discriminator,self).__init__()
        # in channeel *2 because we concate the label and image channel(3+3=6)
        
        self.con1=nn.Conv2d(in_channels*2,features[0],kernel_size=4,
                            stride=2,padding=1,padding_mode="reflect")
        self.con2=nn.Conv2d(features[0],features[1],kernel_size=4,stride=2,bias=False,
                            padding=1,padding_mode="reflect")
        self.con3=nn.Conv2d(features[1],features[2],kernel_size=4,stride=2,bias=False,
                            padding=1,padding_mode="reflect")
        self.con4=nn.Conv2d(features[2],features[3],kernel_size=4,stride=1,bias=False,
                            padding=1,padding_mode="reflect")
        self.last_conv=nn.Conv2d(features[3],1,kernel_size=4,
                                 stride=1,padding=1,padding_mode='reflect')
        
        self.leaky=nn.LeakyReLU(0.2)
        self.batchNorm2=nn.BatchNorm2d(features[1])
        self.batchNorm3=nn.BatchNorm2d(features[2])
        self.batchNorm4=nn.BatchNorm2d(features[3])
        
    def forward(self,image,label):
        # print(image.shape) # torch.Size([1, 3, 256, 256])
        # print(label.shape) #torch.Size([1, 3, 256, 256])
        out=torch.cat([image,label],dim=1)
        # print(out.shape)# torch.Size([1, 6, 256, 256])
        
        out=self.con1(out)
        # print(out.shape)#torch.Size([1, 64, 128, 128])
        out=self.leaky(out)
        
        out=self.con2(out)
        # print(out.shape)#torch.Size([1, 128, 64, 64])
        out=self.batchNorm2(out)
        out=self.leaky(out)
        
        out=self.con3(out)
        # print(out.shape)#torch.Size([1, 256, 32, 32])
        out=self.batchNorm3(out)
        out=self.leaky(out)
        
        out=self.con4(out)
        # print(out.shape)#torch.Size([1, 512, 31, 31])
        out=self.batchNorm4(out)
        out=self.leaky(out)
        
        out=self.last_conv(out)
        # print(out.shape)#torch.Size([1, 1, 30, 30])
        
        return out

# Generator

In [14]:
x=torch.randn((1,3,256,256))
y=torch.randn((1,3,256,256))
model=Discriminator(in_channels=3)
out=model(x,y)

In [15]:
class Generator(nn.Module):
    def __init__(self,in_channels=3,features=64) -> None:
        super(Generator,self).__init__()
        self.d1=nn.Conv2d(in_channels,features,kernel_size=4,
                             stride=2,padding=1,padding_mode="reflect")
        
        self.d2=nn.Conv2d(features,features*2,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d2=nn.BatchNorm2d(features*2)
        
        self.d3=nn.Conv2d(features*2,features*4,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d3=nn.BatchNorm2d(features*4)
        
        self.d4=nn.Conv2d(features*4,features*8,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d4=nn.BatchNorm2d(features*8)
        
        self.d5=nn.Conv2d(features*8,features*8,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d5=nn.BatchNorm2d(features*8)
        
        self.d6=nn.Conv2d(features*8,features*8,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d6=nn.BatchNorm2d(features*8)
        
        self.d7=nn.Conv2d(features*8,features*8,4,2,1,padding_mode='reflect',bias=False)
        self.batch_norm_d7=nn.BatchNorm2d(features*8)
        
        self.u1=nn.ConvTranspose2d(features*8,features*8,4,2,1,bias=False)
        self.batch_norm_u1=nn.BatchNorm2d(features*8)
        
        self.u2=nn.ConvTranspose2d(features*8*2,features*8,4,2,1,bias=False)
        self.batch_norm_u2=nn.BatchNorm2d(features*8)
        
        self.u3=nn.ConvTranspose2d(features*8*2,features*8,4,2,1,bias=False)
        self.batch_norm_u3=nn.BatchNorm2d(features*8)
        
        self.u4=nn.ConvTranspose2d(features*8*2,features*8,4,2,1,bias=False)
        self.batch_norm_u4=nn.BatchNorm2d(features*8)
        
        self.u5=nn.ConvTranspose2d(features*8*2,features*4,4,2,1,bias=False)
        self.batch_norm_u5=nn.BatchNorm2d(features*4)
        
        self.u6=nn.ConvTranspose2d(features*4*2,features*2,4,2,1,bias=False)
        self.batch_norm_u6=nn.BatchNorm2d(features*2)
        
        self.u7=nn.ConvTranspose2d(features*2*2,features,4,2,1,bias=False)
        self.batch_norm_u7=nn.BatchNorm2d(features)
        
        self.final_layer=nn.ConvTranspose2d(features*2,in_channels,4,2,1)
        
        self.bottleneck=nn.Conv2d(features*8,features*8,4,2,1,bias=False)
        self.leaky=nn.LeakyReLU(0.2)
        self.dropout=nn.Dropout(0.5)
        self.relu=nn.ReLU()
        self.tanh=nn.Tanh()
        
    def forward(self,image):
        # print(image.shape) #torch.Size([1, 3, 256, 256])
        d1=self.d1(image)
        d1=self.leaky(d1)
        # print(d1.shape) # torch.Size([1, 64, 128, 128])
        d2=self.d2(d1)
        d2=self.batch_norm_d2(d2)
        d2=self.leaky(d2)
        # print(d2.shape)#torch.Size([1, 128, 64, 64])
        d3=self.d3(d2)
        d3=self.batch_norm_d3(d3)
        d3=self.leaky(d3)
        # print(d3.shape)#torch.Size([1, 256, 32, 32])
        
        d4=self.d4(d3)
        d4=self.batch_norm_d4(d4)
        d4=self.leaky(d4)
        # print(d4.shape)#torch.Size([1, 512, 16, 16])
        
        d5=self.d5(d4)
        d5=self.batch_norm_d5(d5)
        d5=self.leaky(d5)
        # print(d5.shape)#torch.Size([1, 512, 8, 8])
        
        d6=self.d6(d5)
        d6=self.batch_norm_d6(d6)
        d6=self.leaky(d6)
        # print(d6.shape)#torch.Size([1, 512, 4, 4])
        
        d7=self.d7(d6)
        d7=self.batch_norm_d7(d7)
        d7=self.leaky(d7)
        # print(d7.shape)#torch.Size([1, 512, 2, 2])
        
        bn=self.bottleneck(d7)
        bn=self.relu(bn)
        # print(bn.shape) #torch.Size([1, 512, 1, 1])
        
        u1=self.u1(bn)
        u1=self.batch_norm_u1(u1)
        u1=self.relu(u1)
        u1=self.dropout(u1)
        # print(u1.shape) #torch.Size([1, 512, 2, 2])
        
        u2=self.u2(torch.cat([u1,d7],1))
        u2=self.batch_norm_u2(u2)
        u2=self.relu(u2)
        u2=self.dropout(u2)
        # print(u2.shape) #torch.Size([1, 512, 4, 4])
        
        u3=self.u3(torch.cat([u2,d6],1))
        u3=self.batch_norm_u3(u3)
        u3=self.relu(u3)
        u3=self.dropout(u3)
        # print(u3.shape) #torch.Size([1, 512, 8, 8])
        
        u4=self.u4(torch.cat([u3,d5],1))
        u4=self.batch_norm_u4(u4)
        u4=self.relu(u4)
        # print(u4.shape) #torch.Size([1, 512, 16, 16])
        
        u5=self.u5(torch.cat([u4,d4],1))
        u5=self.batch_norm_u5(u5)
        u5=self.relu(u5)
        # print(u5.shape) #torch.Size([1, 256, 32, 32])
        
        u6=self.u6(torch.cat([u5,d3],1))
        u6=self.batch_norm_u6(u6)
        u6=self.relu(u6)
        # print(u6.shape) #torch.Size([1, 128, 64, 64])
        
        u7=self.u7(torch.cat([u6,d2],1))
        u7=self.batch_norm_u7(u7)
        u7=self.relu(u7)
        # print(u7.shape) #torch.Size([1, 64, 128, 128])
        
        out=self.final_layer(torch.cat([u7,d1],1))
        out=self.tanh(out)
        # print(out.shape) #torch.Size([1, 3, 256, 256])
        
        return out

In [16]:
# x = torch.randn((1, 3, 256, 256))
# model = Generator(in_channels=3, features=64)
# preds = model(x)
# print(preds.shape)


# Dataset preparation

In [17]:
tranforms_for_both=A.Compose(
    [A.Resize(width=256,height=256)],
    additional_targets={"image0":"image"},) 

tranform_input=A.Compose([
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(p=0.2),
    A.Normalize(mean=[0.5 for _ in range(3)],
                std=[0.5 for _ in range(3)],
                max_pixel_value=255.0,),
    ToTensorV2()
])
tranform_mask=A.Compose([
    A.Normalize(mean=[0.5 for _ in range(3)],
                std=[0.5 for _ in range(3)],
                max_pixel_value=255.0,),
    ToTensorV2(),
])

In [18]:
# both_transform = A.Compose(
#     [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
# )

# transform_only_input = A.Compose(
#     [
#         A.HorizontalFlip(p=0.5),
#         A.ColorJitter(p=0.2),
#         A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
#         ToTensorV2(),
#     ]
# )

# transform_only_mask = A.Compose(
#     [
#         A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
#         ToTensorV2(),
#     ]
# )

In [19]:
class pix2pixDataset(nn.Module):
    def __init__(self,path) -> None:
        super(pix2pixDataset,self).__init__()
        self.path=path 
        self.files_list=os.listdir(self.path)
        
    def __len__(self):
        return len(self.files_list)
    
    def __getitem__(self,idx):
        img_file=self.files_list[idx]
        image=np.array(Image.open(os.path.join(self.path,img_file)))
        H,W,C=image.shape
        
        input_image=image[:,:W//2,:]
        target_image=image[:,W//2:,:]
        both_augumentation=tranforms_for_both(image=input_image,image0=target_image)
        input_image=both_augumentation["image"]
        target_image=both_augumentation["image0"]
        
        input_image=tranform_input(image=input_image)["image"]
        target_image=tranform_mask(image=target_image)["image"]
        
        # augmentations = both_transform(image=input_image, image0=target_image)
        # input_image = augmentations["image"]
        # target_image = augmentations["image0"]

        # input_image = transform_only_input(image=input_image)["image"]
        # target_image = transform_only_mask(image=target_image)["image"]
        
        return input_image,target_image

In [20]:
# d=pix2pixDataset("/mnt/disk1/Gulshan/GAN/pix2pix/edges2shoes/val")
# loader = DataLoader(d, batch_size=2)
# x,y=next(iter(loader))
# print(x.shape,y.shape)
# x,y=np.array(x),np.array(y)
# x=x.transpose(0,2,3,1)
# y=y.transpose(0,2,3,1)
# plt.axis('off')
# plt.imshow(x[1,...])
# plt.show()
# plt.axis('off')
# plt.imshow(y[1,...])
# plt.show()

# Model intialazing

In [23]:
device="cuda:7" if torch.cuda.is_available() else "cpu"
lr=1e-5
bs=16
num_workers=2
num_epochs=30
L1_LAMBDA = 100
LAMBDA_GP = 10
Image_size=256
disc=Discriminator(3).to(device)
gen=Generator(3).to(device)
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
criterion_BCE=nn.BCEWithLogitsLoss()
criterion_L1=nn.L1Loss()
train_dataset=pix2pixDataset("/mnt/disk1/Gulshan/GAN/pix2pix/edges2shoes/train")
train_loader=DataLoader(train_dataset,batch_size=bs,shuffle=True,num_workers=num_workers,drop_last=True)
val_dataset=pix2pixDataset("/mnt/disk1/Gulshan/GAN/pix2pix/edges2shoes/val")
val_loader=DataLoader(val_dataset,batch_size=bs,shuffle=False,num_workers=num_workers,drop_last=True)
gen_scaler=torch.cuda.amp.GradScaler() # for less v ram usage and train faster
disc_scaler=torch.cuda.amp.GradScaler()
writer_real = SummaryWriter(f"logs_pix2pix_{lr}/real")
writer_fake = SummaryWriter(f"logs_pix2pix_{lr}/fake")
len(train_loader)*bs


10992

# training

In [24]:
def train():
    gen.train()
    disc.train()
    for epoch in tqdm(range(num_epochs),total=num_epochs):
        total_disc=0
        total_gen=0
        for batch_idx,(image,label) in enumerate(tqdm(train_loader)):
            image=image.to(device)
            label=label.to(device)
            # training dicriminator
            with torch.cuda.amp.autocast():
                fake=gen(image)
                d_real=disc(image,label)
                d_real_loss = criterion_BCE(d_real, torch.ones_like(d_real))
                d_fake=disc(image,fake.detach()) # to retain graph or we can do reatin graph in  loss.backward
                d_fake_loss=criterion_BCE(d_fake,torch.zeros_like(d_fake))
                d_loss=(d_fake_loss+d_real_loss)/2
            opt_disc.zero_grad()
            disc_scaler.scale(d_loss).backward()
            disc_scaler.step(opt_disc)
            disc_scaler.update()
            # training generater
            with torch.cuda.amp.autocast():
                d_fake=disc(image,label)
                g_fake_loss=criterion_BCE(d_fake,torch.ones_like(d_fake))
                l1_loss=criterion_L1(fake,label)*L1_LAMBDA
                g_loss=l1_loss+g_fake_loss
            opt_gen.zero_grad()
            gen_scaler.scale(g_loss).backward()
            gen_scaler.step(opt_gen)
            gen_scaler.update()
            
            total_disc+=image.shape[0]
            total_gen+=image.shape[0]
            # print(total_gen)
            if (len(train_loader)*bs)==total_gen:
                print(len(train_loader),total_gen)
                print(
                f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} \
                  Loss D: {d_loss:.4f}, loss G: {g_loss:.4f}"
            )   
                with torch.no_grad():
                    gen.eval()
                    image,label=next(iter(val_loader))
                    image,label=image.to(device),label.to(device)
                    fake = gen(image)
                    # take out (up to) 32 examples
                    img_grid_real = torchvision.utils.make_grid(image[:8], normalize=True)
                    img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
                    img_grid_label = torchvision.utils.make_grid(label[:8], normalize=True)

                    # writer_real.add_image("Real", img_grid_real, global_step=step)
                    writer_real.add_image("Fake", img_grid_label, global_step=epoch+1)
                    writer_real.add_image("Real", img_grid_real, global_step=epoch+1)
                    writer_fake.add_image("Fake", img_grid_fake, global_step=epoch+1)
                    writer_fake.add_scalar("Fake loss",g_loss/total_gen,global_step=epoch+1)
                    writer_real.add_scalar("discrimitor loss",d_loss/total_disc,global_step=epoch+1)
                gen.train()
train()

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [1/30] Batch 686/687                   Loss D: 0.0786, loss G: 23.6000


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [2/30] Batch 686/687                   Loss D: 0.0404, loss G: 17.9223


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [3/30] Batch 686/687                   Loss D: 0.0152, loss G: 19.7028


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [4/30] Batch 686/687                   Loss D: 0.0119, loss G: 17.6959


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [5/30] Batch 686/687                   Loss D: 0.0061, loss G: 15.8374


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [6/30] Batch 686/687                   Loss D: 0.0055, loss G: 13.8820


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [7/30] Batch 686/687                   Loss D: 0.0076, loss G: 16.6029


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [8/30] Batch 686/687                   Loss D: 0.0026, loss G: 17.2765


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [9/30] Batch 686/687                   Loss D: 0.0036, loss G: 12.8073


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [10/30] Batch 686/687                   Loss D: 0.0011, loss G: 13.7405


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [11/30] Batch 686/687                   Loss D: 0.0048, loss G: 14.7356


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [12/30] Batch 686/687                   Loss D: 0.0123, loss G: 14.5038


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [13/30] Batch 686/687                   Loss D: 0.0045, loss G: 12.4705


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [14/30] Batch 686/687                   Loss D: 0.0009, loss G: 15.6292


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [15/30] Batch 686/687                   Loss D: 0.0034, loss G: 10.7269


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [16/30] Batch 686/687                   Loss D: 0.0013, loss G: 14.7663


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [17/30] Batch 686/687                   Loss D: 0.0020, loss G: 11.5292


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [18/30] Batch 686/687                   Loss D: 0.0003, loss G: 12.1059


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [19/30] Batch 686/687                   Loss D: 0.0005, loss G: 13.0784


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [20/30] Batch 686/687                   Loss D: 0.0003, loss G: 10.3417


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [21/30] Batch 686/687                   Loss D: 0.0037, loss G: 11.6807


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [22/30] Batch 686/687                   Loss D: 0.0216, loss G: 10.1379


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [23/30] Batch 686/687                   Loss D: 0.0005, loss G: 10.9732


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [24/30] Batch 686/687                   Loss D: 0.0003, loss G: 11.5566


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [25/30] Batch 686/687                   Loss D: 0.0001, loss G: 11.5251


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [26/30] Batch 686/687                   Loss D: 0.0013, loss G: 10.5489


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [27/30] Batch 686/687                   Loss D: 0.0007, loss G: 10.3884


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [28/30] Batch 686/687                   Loss D: 0.0005, loss G: 9.8253


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [29/30] Batch 686/687                   Loss D: 0.0007, loss G: 12.8724


  0%|          | 0/687 [00:00<?, ?it/s]

687 10992
Epoch [30/30] Batch 686/687                   Loss D: 0.0004, loss G: 12.3877
