In [1]:
import torch
import torch.nn as nn
from dataclasses import dataclass
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter  



In [2]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = 'cpu'
device

'cpu'

In [3]:
@dataclass
class ModelArgs:
    device = 'cpu'
    batch_size = 1
    lr = 0.0002
    img_size = 256
    no_of_channels = 3
    kernel_size = (4,4)
    stride = 2
    dropout = 0.5
    padding = 1
    lr_slope = 0.2
    beta_1 = 0.5
    beta_2 = 0.999

In [4]:
#Transforms for images
transforms = torchvision.transforms.Compose([
    transforms.Resize(size=(ModelArgs.img_size,ModelArgs.img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))

])

In [5]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)  #mean = 0, std = 0.02


In [31]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.main = nn.Sequential(
            nn.Conv2d(ModelArgs.no_of_channels, 64, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(64, 128, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(128, 256, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(256, 512, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(512, 512, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(512, 512, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(512, 512, kernel_size=ModelArgs.kernel_size, stride=1, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),

            nn.Conv2d(512, 512, kernel_size=ModelArgs.kernel_size, stride=1 ,padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope)
        )

    def forward(self, x):
        skip_connection = []
        for layer in self.main:
            x = layer(x)
            if isinstance(layer, nn.LeakyReLU):
                skip_connection.append(x)
        return x, skip_connection


In [8]:
random = torch.randn((ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size), device=ModelArgs.device)
enc = enc.to(ModelArgs.device)
x, skip_connection = enc(random)

In [32]:
from torchinfo import summary

# images = torch.randn(64, 1, 64, 64)
# labels = torch.randint(0, 10, (64,), dtype=torch.long)
enc = Encoder()
summary(model=enc,
        input_size=(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size),
        # input_data=(images.to(ModelArgs.device), labels.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Encoder (Encoder)                        [1, 3, 256, 256]     [1, 512, 2, 2]       --                   True
├─Sequential (main)                      --                   --                   --                   True
│    └─Conv2d (0)                        [1, 3, 256, 256]     [1, 64, 128, 128]    3,136                True
│    └─LeakyReLU (1)                     [1, 64, 128, 128]    [1, 64, 128, 128]    --                   --
│    └─Conv2d (2)                        [1, 64, 128, 128]    [1, 128, 64, 64]     131,200              True
│    └─InstanceNorm2d (3)                [1, 128, 64, 64]     [1, 128, 64, 64]     256                  True
│    └─LeakyReLU (4)                     [1, 128, 64, 64]     [1, 128, 64, 64]     --                   --
│    └─Conv2d (5)                        [1, 128, 64, 64]     [1, 256, 32, 32]     524,544              True
│    └─InstanceNor

In [108]:
ModelArgs.stride = 2

In [19]:
ModelArgs.kernel_size = (3,3)

In [56]:
class Decoder(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.main = nn.Sequential(
            
            # nn.Upsample(size=(3,3)),
            nn.ConvTranspose2d(512, 512, kernel_size=ModelArgs.kernel_size, stride=1, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512*2,affine=True),
            nn.Dropout(p=ModelArgs.dropout),
            nn.ReLU(),
            # print("DOnE"),
            
            # nn.Upsample(size=(4,4)),
            nn.ConvTranspose2d(512*2, 512, kernel_size=ModelArgs.kernel_size, stride=1, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512*2,affine=True),
            nn.Dropout(p=ModelArgs.dropout),
            nn.ReLU(),
               
            # nn.Upsample(size=(8,8)),
            nn.ConvTranspose2d(512*2, 512, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512*2,affine=True),
            nn.Dropout(p=ModelArgs.dropout),
            nn.ReLU(),
            
            # nn.Upsample(size=(16,16)),
            nn.ConvTranspose2d(512*2, 512, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512*2,affine=True),
            # nn.Dropout(p=ModelArgs.dropout),
            nn.ReLU(),
            
            # nn.Upsample(size=(32,32)),
            nn.ConvTranspose2d(512*2, 256, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride,padding=ModelArgs.padding),
            nn.InstanceNorm2d(256*2,affine=True),
            nn.ReLU(),
            
            # nn.Upsample(size=(64,64)),
            nn.ConvTranspose2d(256*2, 128, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride,padding=ModelArgs.padding),
            nn.InstanceNorm2d(128*2,affine=True),
            nn.ReLU(),

            # nn.Upsample(size=(128,128)),
            nn.ConvTranspose2d(128*2, 64, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride,padding=ModelArgs.padding),
            nn.InstanceNorm2d(64*2,affine=True),
            nn.ReLU(),
            
            # nn.Upsample(size=(256,256)),
            # nn.Conv2d(256*2, 128, kernel_size=ModelArgs.kernel_size, padding=ModelArgs.padding),
            # nn.InstanceNorm2d(128,affine=True),
            # nn.ReLU(),
            
            
            nn.ConvTranspose2d(64*2, ModelArgs.no_of_channels, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.Tanh(),
        )
        
    def forward(self, x, skip_connection):
        
        skip_connection = skip_connection[-2::-1]
        count = 0
        for idx, layer in enumerate(self.main):
            
            # print("Original: ", x.shape)
            if (isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d)) and count < len(skip_connection):
                
                # print(f"Before applying layer{layer}", x.shape)
                x = layer(x)
                # print(f"After applying layer{layer}:", x.shape)
                # print("Skip connection:",skip_connection[count].shape)
                x = torch.concat([x, skip_connection[count]], dim=1)
                # print("Final: ", x.shape)
                # print(count)
                count += 1
            else:
                # print(f"Before applying layer{layer}", x.shape)
                x = layer(x)
                # print(f"After applying layer{layer}:", x.shape)
        return x


In [53]:
skip_connection

[tensor([[[[ 6.8836e-01,  3.2033e-01,  3.3346e-01,  ...,  4.5964e-02,
            -1.0953e-01, -5.0371e-02],
           [ 5.0429e-01, -6.7398e-02,  1.4376e-01,  ...,  2.1144e-01,
             4.6284e-01,  9.5048e-01],
           [-9.5324e-02,  6.7478e-01,  1.7943e-01,  ...,  3.3460e-01,
             2.6215e-01,  2.3474e-01],
           ...,
           [ 5.4861e-01,  2.3438e-01,  2.9127e-01,  ..., -4.9607e-02,
            -2.8231e-02,  3.3847e-01],
           [ 2.1475e-01, -1.9573e-01,  3.0462e-01,  ...,  6.6461e-01,
            -1.6278e-01, -8.4918e-02],
           [ 3.6583e-01,  1.9185e-01,  1.4353e+00,  ..., -4.2945e-04,
            -9.7129e-02,  3.6384e-02]],
 
          [[-8.0441e-02,  1.1631e-01,  4.8557e-01,  ...,  5.0810e-01,
            -8.2178e-02,  1.8448e-01],
           [ 4.5781e-01,  7.2728e-01, -5.7750e-02,  ...,  6.5324e-01,
            -6.5259e-02,  7.0425e-01],
           [-1.1518e-01, -2.9324e-01, -2.3471e-02,  ...,  4.6726e-01,
             1.2763e+00,  2.8709e-01],


In [44]:
x.shape

torch.Size([1, 512, 2, 2])

In [37]:
for i in skip_connection[-2::-1]:
    print(i.shape)

torch.Size([1, 512, 3, 3])
torch.Size([1, 512, 4, 4])
torch.Size([1, 512, 8, 8])
torch.Size([1, 512, 16, 16])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 128, 128])


In [57]:
from torchinfo import summary

# images = torch.randn(64, 1, 64, 64)
# labels = torch.ran\dint(0, 10, (64,), dtype=torch.long)
dec = Decoder()
dec = dec.to(ModelArgs.device)
summary(model=dec,
        input_data=(x, skip_connection),
        # input_data=(images.to(ModelArgs.device), labels.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Decoder (Decoder)                        [1, 512, 2, 2]       [1, 3, 256, 256]     --                   True
├─Sequential (main)                      --                   --                   --                   True
│    └─ConvTranspose2d (0)               [1, 512, 2, 2]       [1, 512, 3, 3]       4,194,816            True
│    └─InstanceNorm2d (1)                [1, 1024, 3, 3]      [1, 1024, 3, 3]      2,048                True
│    └─Dropout (2)                       [1, 1024, 3, 3]      [1, 1024, 3, 3]      --                   --
│    └─ReLU (3)                          [1, 1024, 3, 3]      [1, 1024, 3, 3]      --                   --
│    └─ConvTranspose2d (4)               [1, 1024, 3, 3]      [1, 512, 4, 4]       8,389,120            True
│    └─InstanceNorm2d (5)                [1, 1024, 4, 4]      [1, 1024, 4, 4]      2,048                True
│    └─Dropout (6)

In [60]:
class UNet(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x):
        
        x, skip_connection = self.encoder(x)
        x = self.decoder(x, skip_connection)
        return x

In [61]:
#Intializing the Discriminator instance
unet = UNet().to(ModelArgs.device)
#Apply the wieght intilization function layer by layer
unet = unet.apply(weights_init)
#Printing the structure
print(unet)

UNet(
  (encoder): Encoder(
    (main): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.2)
      (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (4): LeakyReLU(negative_slope=0.2)
      (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (7): LeakyReLU(negative_slope=0.2)
      (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (10): LeakyReLU(negative_slope=0.2)
      (11): Conv2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (12): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      

In [64]:

from torchinfo import summary

# images = torch.randn(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size)
# labels = torch.randint(0, 10, (64,), dtype=torch.long)

summary(model=unet,
        input_size=(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size),
        # input_data=(images.to(ModelArgs.device), labels.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])


Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
UNet (UNet)                              [1, 3, 256, 256]     [1, 3, 256, 256]     --                   True
├─Encoder (encoder)                      [1, 3, 256, 256]     [1, 512, 2, 2]       --                   True
│    └─Sequential (main)                 --                   --                   --                   True
│    │    └─Conv2d (0)                   [1, 3, 256, 256]     [1, 64, 128, 128]    3,136                True
│    │    └─LeakyReLU (1)                [1, 64, 128, 128]    [1, 64, 128, 128]    --                   --
│    │    └─Conv2d (2)                   [1, 64, 128, 128]    [1, 128, 64, 64]     131,200              True
│    │    └─InstanceNorm2d (3)           [1, 128, 64, 64]     [1, 128, 64, 64]     256                  True
│    │    └─LeakyReLU (4)                [1, 128, 64, 64]     [1, 128, 64, 64]     --                   --
│    │    └─Conv2d

In [67]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.main = nn.Sequential(
            nn.Conv2d(ModelArgs.no_of_channels*2, 64, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),
                
            nn.Conv2d(64, 128, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(128, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),
               
            nn.Conv2d(128, 256, kernel_size=ModelArgs.kernel_size, stride=ModelArgs.stride, padding=ModelArgs.padding),
            nn.InstanceNorm2d(256, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),
              
            nn.Conv2d(256, 512, kernel_size=ModelArgs.kernel_size, stride=1, padding=ModelArgs.padding),
            nn.InstanceNorm2d(512, affine=True),
            nn.LeakyReLU(negative_slope=ModelArgs.lr_slope),
            
            nn.Conv2d(512, 1, kernel_size=ModelArgs.kernel_size, stride=1, padding=ModelArgs.padding),
            nn.Sigmoid()
        )
        
    def forward(self, x, y):
        res = torch.concat([x, y], dim=1)
        return self.main(res)


In [68]:
#Intializing the Discriminator instance
discriminator = Discriminator().to(ModelArgs.device)
#Apply the wieght intilization function layer by layer
discriminator = discriminator.apply(weights_init)
#Printing the structure
print(discriminator)

Discriminator(
  (main): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (12): Sigmoid()
  )
)


In [69]:
from torchinfo import summary

real_A = torch.randn(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size)
real_B = torch.randn(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size)
# labels = torch.randint(0, 10, (64,), dtype=torch.long)

summary(model=discriminator,
        # input_size=(ModelArgs.batch_size, ModelArgs.no_of_channels, ModelArgs.img_size, ModelArgs.img_size),
        input_data=(real_A.to(ModelArgs.device), real_B.to(ModelArgs.device)),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

Layer (type (var_name))                  Input Shape          Output Shape         Param #              Trainable
Discriminator (Discriminator)            [1, 3, 256, 256]     [1, 1, 30, 30]       --                   True
├─Sequential (main)                      [1, 6, 256, 256]     [1, 1, 30, 30]       --                   True
│    └─Conv2d (0)                        [1, 6, 256, 256]     [1, 64, 128, 128]    6,208                True
│    └─LeakyReLU (1)                     [1, 64, 128, 128]    [1, 64, 128, 128]    --                   --
│    └─Conv2d (2)                        [1, 64, 128, 128]    [1, 128, 64, 64]     131,200              True
│    └─InstanceNorm2d (3)                [1, 128, 64, 64]     [1, 128, 64, 64]     256                  True
│    └─LeakyReLU (4)                     [1, 128, 64, 64]     [1, 128, 64, 64]     --                   --
│    └─Conv2d (5)                        [1, 128, 64, 64]     [1, 256, 32, 32]     524,544              True
│    └─InstanceNor

In [66]:
unet = UNet().to(ModelArgs.device).apply(weights_init)
discriminator = Discriminator().to(ModelArgs.device).apply(weights_init)


epochs = 10000 #30


optimizerC = torch.optim.Adam(params=unet.parameters(), lr=ModelArgs.lr, betas=(ModelArgs.beta_1, ModelArgs.beta_2)) #For discriminator
optimizerG = torch.optim.Adam(params=discriminator.parameters(), lr=ModelArgs.lr,  betas=(ModelArgs.beta_1, ModelArgs.beta_2)) #For generator



real_label = 1
fake_label = 0


loss_g = []
loss_d = []
img_list = []

# Fixed noise for generating the images
fixed_noise = torch.randn((ModelArgs.batch_size, ModelArgs.latent_vector_size, 1, 1), dtype=torch.float32, device=ModelArgs.device)

NameError: name 'Discriminator' is not defined