In [23]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm
!jupyter nbextension enable --py widgetsnbextension


Enabling notebook extension jupyter-js-widgets/extension...
      - Validating: [32mOK[0m


In [24]:
class Discriminator(nn.Module):
    def __init__(self,in_channel,out_channel ) -> None:
        super(Discriminator,self).__init__()
        #input: (bs,channel,64,64)
        self.discrimator=nn.Sequential(
            nn.Conv2d(in_channel,out_channel,kernel_size=4,stride=2,padding=1),
            nn.LeakyReLU(0.2),
            self.block(out_channel,out_channel*2,kernel_size=4,stride=2,padding=1),
            self.block(out_channel*2,out_channel*4,kernel_size=4,stride=2,padding=1),
            self.block(out_channel*4,out_channel*8,kernel_size=4,stride=2,padding=1),
            nn.Conv2d(out_channel*8,1,kernel_size=4,stride=2,padding=0),
            # nn.Sigmoid() # no for wgan there ios sigmoid ativation
        )
    def block(self,in_channel,out_channel,kernel_size,stride,padding):
        return nn.Sequential(
            nn.Conv2d(in_channel,out_channel,kernel_size,stride,padding,bias=False))
    
    def forward(self,x):
        return self.discrimator(x)

In [25]:
d=Discriminator(3,1)
x=torch.randn(1,3,64,64)
d(x).shape

torch.Size([1, 1, 1, 1])

In [26]:
class Generator(nn.Module):
    def __init__(self,noise_dim,in_channels,out_channels):
        super(Generator, self).__init__()
        self.generator=nn.Sequential(
            self.block(noise_dim,out_channels*16,kernel_size=4,stride=2,padding=0),
            self.block(out_channels*16,out_channels*8,kernel_size=4,stride=2,padding=1),
            self.block(out_channels*8,out_channels*4,kernel_size=4,stride=2,padding=1),
            self.block(out_channels*4,out_channels*2,kernel_size=4,stride=2,padding=1),
            nn.ConvTranspose2d(out_channels*2,in_channels,4,2,1),
            nn.Tanh(), #[-1,1]
        )
    def block(self,in_channel,out_channel,kernel_size,stride,padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channel,out_channel,kernel_size,stride,padding),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(),
        )
    def forward(self,x):
        return self.generator(x)

In [27]:
def initalize_weight(model):
    for m in model.modules():
        if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data,0.0,0.02)

In [28]:
def test():
    noise_dim=100
    b,in_channel,h,w=4,1,64,64
    x=torch.randn((b,in_channel,h,w))
    print(x.shape)
    disc=Discriminator(in_channel,8)
    initalize_weight(disc)
    assert disc(x).shape == (b,1,1,1)
    dis_model=disc(x)
    print('disc model',dis_model.shape)
    gen=Generator(noise_dim,in_channel,8)
    y=torch.randn((b,noise_dim,1,1))
    print(y.shape)
    gen_model=gen(y)
    assert gen_model.shape== (b,in_channel,h,w)
    print(gen_model.shape)
    
test()

torch.Size([4, 1, 64, 64])
disc model torch.Size([4, 1, 1, 1])
torch.Size([4, 100, 1, 1])
torch.Size([4, 1, 64, 64])


In [29]:
# Hyperparameters etc.
CUDA=torch.cuda.is_available()
# CUDA=False
device = "cuda:7" if CUDA else "cpu"
lr = 5e-5
noise_dim = 100
batch_size = 64
in_channel=3 #1
Image_size=64
num_epochs=10
out_channel_gen=64
out_channel_disc=64
critic_iteration=10
weight_clip=0.01

In [30]:
transform=transforms.Compose([
    transforms.Resize((Image_size,Image_size)),
    # transforms.Resize(Image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(in_channel)],
        [0.5 for _ in range(in_channel)],  
    )
])
# dataset=datasets.MNIST(root="dataset/",train=True,transform=transform,download=True)
dataset=datasets.ImageFolder(root="/mnt/disk1/Gulshan/dataset/Dog_data",transform=transform)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
generator=Generator(noise_dim,in_channel,out_channel_gen).to(device)
discrmintor_critic=Discriminator(in_channel,out_channel_disc).to(device)

In [31]:
initalize_weight(generator)
initalize_weight(discrmintor_critic)
optimizer_gen=optim.RMSprop(generator.parameters(),lr=lr,)
optimizer_disc_critic=optim.RMSprop(discrmintor_critic.parameters(),lr=lr,)
fixed_noise=torch.randn(32,noise_dim,1,1).to(device)
writer_real=SummaryWriter("logs/real")
writer_fake=SummaryWriter("logs/fake")
step=0

In [32]:
generator.train()
discrmintor_critic.train()
for epoch in tqdm(range(num_epochs),total=num_epochs):
    
    running_loss_gen=0
    for batch_idx, (real, _) in tqdm(enumerate(loader)):
        real=real.to(device)
        for _ in range(critic_iteration):
            noise=torch.randn((batch_size,noise_dim,1,1)).to(device)
            fake=generator(noise)
            ##loss for discrimi
            disc_critic_real=discrmintor_critic(real).reshape(-1)
            disc_critic_fake=discrmintor_critic(fake).reshape(-1)
            
            loss_dics_critic=-(torch.mean(disc_critic_real)-torch.mean(disc_critic_fake))
            
            optimizer_disc_critic.zero_grad()
            loss_dics_critic.backward(retain_graph=True)
            optimizer_disc_critic.step()
        ## losss for generater
        
        gen_out=discrmintor_critic(fake).reshape(-1)
        gen_loss=-(torch.mean(gen_out))
        optimizer_gen.zero_grad()
        gen_loss.backward()
        optimizer_gen.step()
        
        running_loss_gen+=gen_loss
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {loss_dics_critic:.4f}, loss G: {gen_loss:.4f},"
            )

            with torch.no_grad():
                fake = generator(fixed_noise)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                writer_fake.add_scalar("Fake loss",running_loss_gen,global_step=step)
                step += 1
        
        

  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Epoch [1/10] Batch 0/133                       Loss D: -675.6105, loss G: 33.9164,


0it [00:00, ?it/s]

Epoch [2/10] Batch 0/133                       Loss D: -481830464.0000, loss G: -77085088.0000,


0it [00:00, ?it/s]

Epoch [3/10] Batch 0/133                       Loss D: -10502194176.0000, loss G: 1048001088.0000,


0it [00:00, ?it/s]

Epoch [4/10] Batch 0/133                       Loss D: -43280515072.0000, loss G: -16130102.0000,


0it [00:00, ?it/s]

Epoch [5/10] Batch 0/133                       Loss D: -49071243264.0000, loss G: -57459306496.0000,


0it [00:00, ?it/s]

Epoch [6/10] Batch 0/133                       Loss D: -80632135680.0000, loss G: -235837030400.0000,


In [None]:
! tensorboard --logdir=/mnt/disk1/Gulshan/GAN/WGAN/logs

TensorFlow installation not found - running with reduced feature set.

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.14.0 at http://localhost:6006/ (Press CTRL+C to quit)
E0507 16:16:48.917974 139790653249280 _internal.py:96] Error on request:
Traceback (most recent call last):
  File "/mnt/disk1/conda/envs/gulshan/lib/python3.8/site-packages/werkzeug/serving.py", line 362, in run_wsgi
    execute(self.server.app)
  File "/mnt/disk1/conda/envs/gulshan/lib/python3.8/site-packages/werkzeug/serving.py", line 323, in execute
    application_iter = app(environ, start_response)
  File "/mnt/disk1/conda/envs/gulshan/lib/python3.8/site-packages/tensorboard/backend/application.py", line 528, in __call__
    return self._app(environ, start_response