In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os 
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import matplotlib.pyplot as plt
%matplotlib inline
import torch


# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

os.makedirs('./data',exist_ok=True)
path2data = './data'
dev = 'cuda' if torch.cuda.is_available() else 'cpu'
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:

h, w = 64, 64
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5) 
transform= transforms.Compose([
    transforms.Resize((h,w)),
    transforms.CenterCrop((h,w)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)]
)


In [None]:
from torchvision import datasets
train_ds = datasets.STL10(path2data,download=True,transform=transform)

In [None]:
for x,_ in train_ds:
    print(x.shape , x.min(),x.max())
    break
image = to_pil_image(0.5*x+0.5)#std mean
plt.imshow(image)

In [None]:
train_dl = torch.utils.data.DataLoader(train_ds,batch_size=32,shuffle=True)
for x,y in train_dl:
    print(y)
    break

In [None]:
import torch.nn.functional as tf
import torch.nn as nn
class Generator(nn.Module):
    def __init__(self,params):
        super(Generator,self).__init__()
        nz = params["nz"]
        ngf = params["ngf"]
        noc = params["noc"]
        self.dconv1 = nn.ConvTranspose2d(nz,ngf*8,kernel_size=4,stride=1,padding=0,bias=False) # ngf*8 fitlers
        self.bn1    = nn.BatchNorm2d(ngf*8) #number of features 
        
        self.dconv2 = nn.ConvTranspose2d(ngf*8,ngf*4,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn2    = nn.BatchNorm2d(ngf*4) #number of features 
        
        self.dconv3 = nn.ConvTranspose2d(ngf*4,ngf*2,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn3    = nn.BatchNorm2d(ngf*2) #number of features 
        
        self.dconv4 = nn.ConvTranspose2d(ngf*2,ngf,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn4    = nn.BatchNorm2d(ngf) #number of features 
        
        self.dconv5 = nn.ConvTranspose2d(ngf,noc,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn5    = nn.BatchNorm2d(noc) #number of features 
    def forward(self,x):
        x = tf.relu(self.bn1(self.dconv1(x)))
        x = tf.relu(self.bn2(self.dconv2(x)))
        x = tf.relu(self.bn3(self.dconv3(x)))
        x = tf.relu(self.bn4(self.dconv4(x)))
        x = tf.relu(self.bn5(self.dconv5(x)))
        out = torch.tanh(x)
        return out
params = {'nz' : 100,
          'ngf': 64,
          'noc':3}
model_gen = Generator(params).to(device=dev)
with torch.no_grad():
    y= model_gen(torch.zeros(1,100,1,1, device=dev))
    print(y.shape)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,params):
        super(Discriminator,self).__init__()
        nic= params["nic"]
        ndf = params["ndf"]
        self.conv1 = nn.Conv2d(nic,ndf,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn1    = nn.BatchNorm2d(ndf) #number of features 
        
        self.conv2 = nn.Conv2d(ndf,ndf*2,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn2    = nn.BatchNorm2d(ndf*2) #number of features 
        
        self.conv3 = nn.Conv2d(ndf*2,ndf*4,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn3    = nn.BatchNorm2d(ndf*4) #number of features 
        
        self.conv4 = nn.Conv2d(ndf*4,ndf*8,kernel_size=4,stride=2,padding=1,bias=False) # ngf*8 fitlers
        self.bn4    = nn.BatchNorm2d(ndf*8) #number of features 
        
        self.conv5 = nn.Conv2d(ndf*8,1,kernel_size=4,stride=1,padding=0,bias=False) # ngf*8 fitlers
    def forward(self,x):
        x = tf.leaky_relu(self.bn1(self.conv1(x)),0.2,inplace=True)
        x = tf.leaky_relu(self.bn2(self.conv2(x)),0.2,inplace=True)
        x = tf.leaky_relu(self.bn3(self.conv3(x)),0.2,inplace=True)
        x = tf.leaky_relu(self.bn4(self.conv4(x)),0.2,inplace=True)
        x = self.conv5(x)
        out = torch.sigmoid(x)
        return out.view(-1)
params_dis = { "nic": 3,"ndf": 64}
model_dis = Discriminator(params_dis)
model_dis.to(dev) 
print(model_dis)
with torch.no_grad():
    y= model_dis(torch.zeros(1,3,h,w, device=dev))
print(y.shape)

In [None]:
def initialize_weights(model):
    classname = model.__class__.__name__ 
    if classname.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)#init conv and dconv mean 0 std 0.02
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(model.weight.data, 1.0, 0.02)
        nn.init.constant_(model.bias.data, 0)

model_dis.apply(initialize_weights)
model_gen.apply(initialize_weights)

In [None]:
loss_func = nn.BCELoss()
opt_gen = torch.optim.Adam(model_gen.parameters(),lr=2e-4,betas=(0.5,0.9999))
opt_dis = torch.optim.Adam(model_dis.parameters(),lr=2e-4,betas=(0.5,0.9999))

In [None]:
real_label = 1
fake_label = 0
nz = params["nz"]
loss_history={"gen": [], "dis": []}
num_epoch = 100
for epoch in range(num_epoch):
    for xb,yb in train_dl:
        bz = xb.shape[0]
        xb = xb.to(device=dev)
        yb = torch.full((bz,),real_label,device=dev,dtype=torch.float32)
        model_dis.zero_grad()#same as loss_func zero_grad
        out_dis = model_dis(xb)
        loss_r = loss_func(out_dis,yb)
        loss_r.backward()
        
        noise = torch.randn(bz, nz, 1, 1, device=dev)
        out_gen = model_gen(noise)
        out_dis = model_dis(out_gen.detach())
        yb.fill_(fake_label)
        loss_f = loss_func(out_dis,yb)
        loss_f.backward()
        loss_dis = loss_r + loss_f
        opt_dis.step()
        
        model_gen.zero_grad() 
        yb.fill_(real_label)
        out_dis = model_dis(out_gen) 
        loss_gen = loss_func(out_dis, yb)
        loss_gen.backward() 
        opt_gen.step()
        loss_history["gen"].append(loss_gen.item())
        loss_history["dis"].append(loss_dis.item()) 
    print(epoch)


import os
path2models = "./models/"
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, "weights_gen.pt")
path2weights_dis = os.path.join(path2models, "weights_dis.pt")
torch.save(model_gen.state_dict(), path2weights_gen)
torch.save(model_dis.state_dict(), path2weights_dis)


In [None]:
with torch.no_grad():
    fixed_noise = torch.randn(16, nz, 1, 1, device=dev) 
    img_fake = model_gen(noise)[0].detach().cpu()
image = to_pil_image(img_fake*0.5+0.5)
plt.imshow(image)

In [None]:
'h'