In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#データセットのダウンロード
!pip install gdown

import gdown
file_id = "1-EyM2kIj24P6DtT-swZP8DLyBAPU1PkU"
url = f"https://drive.google.com/uc?id={file_id}"
output = "dataset.zip"
gdown.download(url, output, quiet=False)
!unzip "dataset.zip"





Downloading...
From: https://drive.google.com/uc?id=1-EyM2kIj24P6DtT-swZP8DLyBAPU1PkU
To: /content/dataset.zip
1.22GB [00:32, 37.5MB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: dataset/006461.jpg      
  inflating: dataset/005303.jpg      
  inflating: dataset/010169.jpg      
  inflating: dataset/009002.jpg      
  inflating: dataset/013056.jpg      
  inflating: dataset/004399.jpg      
  inflating: dataset/016813.jpg      
  inflating: dataset/004069.jpg      
  inflating: dataset/003705.jpg      
  inflating: dataset/004570.jpg      
  inflating: dataset/002959.jpg      
  inflating: dataset/006237.jpg      
  inflating: dataset/008262.jpg      
  inflating: dataset/004517.jpg      
  inflating: dataset/010624.jpg      
  inflating: dataset/011795.jpg      
  inflating: dataset/007328.jpg      
  inflating: dataset/012614.jpg      
  inflating: dataset/005327.jpg      
  inflating: dataset/012366.jpg      
  inflating: dataset/011284.jpg      
  inflating: dataset/014139.jpg      
  inflating: dataset/006140.jpg      
  inflating: dataset/004258.jpg      
  inflating: dataset/00

In [1]:
#import libraries
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import numpy as np
print(torch.__version__)
torch.manual_seed(0)


1.9.0+cu102


<torch._C.Generator at 0x7f7786cfb7d0>

In [2]:
#util for training
from IPython.display import Image,display_png
from PIL import Image
import math

def combine_images(generated_images):
    total = generated_images.shape[0]
    cols = int(math.sqrt(total))
    rows = math.ceil(float(total)/cols)
    width, height = generated_images.shape[1:3]
    combined_image = np.zeros((width*cols, height*rows,3),
                              dtype=generated_images.dtype)
    #coreturn combined_image

    for index, image in enumerate(generated_images):
        i = index % cols
        j = int(index/cols)
        combined_image[width*i:width*(i+1), height*j:height*(j+1),0:3] = image[:,:,0:3]
    return combined_image

def show_image(result,name):
    generated_image = result.to('cpu').detach().numpy().copy()
    generated_image = generated_image * 127.5 + 127.5
    generated_image = np.where(generated_image < 0, 0, generated_image)
    generated_image = np.where(generated_image > 255, 255, generated_image)
    generated_image=np.transpose(generated_image, (0, 2, 3, 1))
    generated_image = combine_images(generated_image)
    generated_image = generated_image.astype(np.uint8)
    image = Image.fromarray(generated_image).save('/content/drive/MyDrive/StyleGAN2/generated/' + name + '.png')
    #display_png(image)

In [3]:
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
import glob

class CustomImageDataset(Dataset):
    def __init__(self, img_dir):
        self.img_dir = img_dir
        self.len = len(glob.glob(self.img_dir + "*"))

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, str(idx).zfill(6) + ".jpg")
        
        image = torch.nn.functional.interpolate((read_image(img_path).unsqueeze(0) - 127.5) / 127.5, size=(512,512),mode='bilinear').squeeze(0)

        return image

In [4]:

# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
LReLU_alpha = 0.2

#https://github.com/yuuho/stylegans-pytorch/blob/master/network/stylegan2.py　を参考に
from torch.nn import functional as F

class equalized_linear(nn.Module):
    def __init__(self,in_features,out_features,lr = 1):
        super(equalized_linear, self).__init__()
        
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        torch.nn.init.normal_(self.weight.data, mean=0.0, std=1/lr)
        self.bias = nn.Parameter(torch.zeros(out_features))
        self.weight_scaler = lr / (in_features ** 0.5)
        self.lr = lr

    def forward(self,x,gain = np.sqrt(2)):  
        return F.linear(x, self.weight * self.weight_scaler * gain, self.bias * self.lr)
        

class modulated_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros',mod = True,demod = True,style_dimension = 256):
        super(modulated_conv2d, self).__init__()
        self.padding, self.stride = padding, stride
        lr = 1
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))
        torch.nn.init.normal_(self.weight.data, mean=0.0, std=1.0 / lr)
        #self.weight = nn.Parameter(torch.zeros(out_channels, in_channels, kernel_size, kernel_size))
        self.bias = nn.Parameter(torch.zeros(1,out_channels,1,1))
        self.weight_scaler = lr / (in_channels * kernel_size*kernel_size)**0.5
        self.mod = mod
        self.demod = demod
        if mod:
            self.affine = equalized_linear(style_dimension, in_channels)
      

    def forward(self,x,style=None,shape = None,web = False,style_2 = None,gain = np.sqrt(2)):
        oC, iC, kH, kW = self.weight.shape
        if not web:#shape is None:
            N, iC, H, W = x.shape
        else:
            N, iC, H, W = shape
        
        if not self.mod:
            weight = self.weight.view(1,oC,iC,kH,kW) * self.weight_scaler
            if not web:
                weight = weight.expand(N,oC,iC,kH,kW)
            
            x = F.conv2d(x.view(1,N*iC,H,W),gain * weight.reshape(N*oC,iC,kH,kW),
                    padding=self.padding, stride=self.stride, groups=N)
            if web:
                return x.view(N,oC,H,W) + self.bias
            
            return x.view(N,oC,x.shape[2],x.shape[3]) + self.bias
        
        affined_style = self.affine(style,gain=1) + 1

        if web:
            affined_style_2 = self.affine(style_2,gain=np.sqrt(2)) + 1
            modulated_weight = self.weight.view(1,oC,iC,kH,kW) * self.weight_scaler
            #modulated_weight = modulated_weight.repeat(N,1,1,1,1)
            x = x * affined_style.view(N,iC,1,1)
            x = F.conv2d(x.view(1,N*iC,H,W), gain * modulated_weight.view(N*oC,iC,kH,kW),
                    padding=self.padding, stride=self.stride, groups=N).view(N,oC,H,W)
            modulated_weight = modulated_weight * affined_style_2.view(N,1,iC,1,1)
            demod_norm = 1
            if self.demod:
                demod_norm = 1 / torch.sqrt((modulated_weight * modulated_weight).sum([2,3,4])  + 1e-8)
                out = x * demod_norm.view(N, oC, 1, 1) + self.bias
                return out
            out = x + self.bias

        else:
            modulated_weight = self.weight_scaler *self.weight.view(1,oC,iC,kH,kW) * affined_style.view(N,1,iC,1,1) 

            demod_norm = 1 / torch.sqrt((modulated_weight * modulated_weight).sum([2,3,4]) + 1e-8) # (N, oC)
            demodulated_weight = modulated_weight
            if self.demod:
                demodulated_weight = demodulated_weight * demod_norm.view(N, oC, 1, 1, 1) 
           
            out = F.conv2d(x.view(1,N*iC,H,W), gain * demodulated_weight.view(N*oC,iC,kH,kW),
                    padding=self.padding, stride=self.stride, groups=N).view(N,oC,H,W) + self.bias
            
        return out

def alternative_Upsample(image,input_size):
    
    batches, channels, h, w = input_size

    x = image.view(batches, channels, h * w, 1)
    x = torch.cat((x,x),3)
    x = x.view(batches, channels, h, w * 2)
    x = torch.cat((x,x),3)
    x = x.view(batches, channels, h * 2, w * 2)

    return x


class block(nn.Module):
    def __init__(self,resolution,in_channels, mid_channels, out_channels,style_dimension = 512):
        super(block,self).__init__()
        self.conv_1 = modulated_conv2d(in_channels,mid_channels,kernel_size=3,stride=1,padding=1,style_dimension = style_dimension)
        self.const_noise_1 = torch.randn((1,1,resolution,resolution),requires_grad=False, device = device)
        self.noise_scalar_1 = nn.Parameter(torch.zeros(1))

        self.conv_2 = modulated_conv2d(mid_channels,out_channels,kernel_size=3,stride=1,padding=1,style_dimension = style_dimension)
        self.const_noise_2 = torch.randn((1,1,resolution,resolution),requires_grad=False, device = device)
        self.noise_scalar_2 = nn.Parameter(torch.zeros(1))

        self.skip = modulated_conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0,mod = False)
        self.resolution = resolution
    def forward(self,x,style=None,web = False,style_2 = None):
        t = x
        x = self.conv_1(x,style,shape = (1,-1,self.resolution,self.resolution),style_2 = style_2,web = web)
        x = F.leaky_relu(x,LReLU_alpha)
        x = x + self.const_noise_1 * self.noise_scalar_1

        x = self.conv_2(x,style,shape = (1,-1,self.resolution,self.resolution),style_2 = style_2,web = web)
        x = F.leaky_relu(x,LReLU_alpha)
        x = x + self.const_noise_2 * self.noise_scalar_2

        x = (x + self.skip(t,shape = (1,-1,self.resolution,self.resolution), web = web, gain = 1)) * (1 / math.sqrt(2))
        return x

# Define model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        dimensions = [128,128,128,128,64,64,64,64,64,64,64
                     ,32,32,16,16, 8, 8, 4, 4]
        #[256,256,256,256,128,128,128,128,128,128,128
        #             ,64,64,32,32, 16, 16, 8, 8] 
        self.learning_const = nn.Parameter(torch.randn(1,dimensions[1],4,4))#torch.ones((1,dimensions[1],4,4),requires_grad=True, device = device)
        self.mapping_network = self.generate_mapping_network(dimensions[0],8)

        
        self.block_4 = block(4,dimensions[1],dimensions[2],dimensions[3],dimensions[0])
        self.to_rgb_4 = modulated_conv2d(dimensions[3],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_8 = block(8,dimensions[3],dimensions[4],dimensions[5],dimensions[0])
        self.to_rgb_8 = modulated_conv2d(dimensions[5],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_16 = block(16,dimensions[5],dimensions[6],dimensions[7],dimensions[0])
        self.to_rgb_16 = modulated_conv2d(dimensions[7],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_32 = block(32,dimensions[7],dimensions[8],dimensions[9],dimensions[0])
        self.to_rgb_32 = modulated_conv2d(dimensions[9],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_64 = block(64,dimensions[9],dimensions[10],dimensions[11],dimensions[0])
        self.to_rgb_64 = modulated_conv2d(dimensions[11],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_128 = block(128,dimensions[11],dimensions[12],dimensions[13],dimensions[0])
        self.to_rgb_128 = modulated_conv2d(dimensions[13],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_256 = block(256,dimensions[13],dimensions[14],dimensions[15],dimensions[0])
        self.to_rgb_256 = modulated_conv2d(dimensions[15],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])

        self.block_512 = block(512,dimensions[15],dimensions[16],dimensions[17],dimensions[0])
        self.to_rgb_512 = modulated_conv2d(dimensions[17],3,kernel_size=1,stride=1,padding=0,demod = False,style_dimension = dimensions[0])


    
    def generate_mapping_network(self,dimension = 512,number_of_layer = 8):
       mapping_network = nn.Sequential()
       for i in range(number_of_layer):
           mapping_network.add_module('mapping_fc{0}'.format(i), equalized_linear(dimension,dimension,lr = 0.01))
           mapping_network.add_module('mapping_lrelu{0}'.format(i), nn.LeakyReLU(LReLU_alpha))
       return mapping_network

    def forward(self, z, stage = 1 ,alpha = 0, batches = 1,web = False):
        style = self.mapping_network(z)
        style_2 = None
        if web:
            style_2 = self.mapping_network(z)
            x = self.learning_const
        else:
            x = self.learning_const.repeat(int(batches),1,1,1)
        for i in range(stage):
            if i != 0:
                if web:
                    x = alternative_Upsample(x,(1,-1,2 ** (i + 1),2 ** (i + 1)))
                else:
                    x = F.interpolate(x,scale_factor=2, mode='nearest')

            x = getattr(self, 'block_{0}'.format(2 ** (i + 2)))(x,style=style,web = web,style_2 = style_2)
            if i == 0:
                if web:
                    x_out = getattr(self, 'to_rgb_{0}'.format(2 ** (i + 2)))(x,style = style,shape = (1,-1,int(2 ** (i + 2)),int(2 ** (i + 2))),web = True,style_2 = style_2,gain = np.sqrt(2))
                else:
                    x_out = getattr(self, 'to_rgb_{0}'.format(2 ** (i + 2)))(x,style = style,web = False,style_2 = style_2,gain = np.sqrt(2)) 
            else:
                if web:
                    x_out_2 = getattr(self, 'to_rgb_{0}'.format(2 ** (i + 2)))(x,style = style,shape = (1,-1,int(2 ** (i + 2)),int(2 ** (i + 2))),web = True,style_2 = style_2,gain = np.sqrt(2))
                else:
                    x_out_2 = getattr(self, 'to_rgb_{0}'.format(2 ** (i + 2)))(x,style = style, web = False,style_2 = style_2,gain = np.sqrt(2)) 
                x_out_2 = F.leaky_relu(x_out_2,LReLU_alpha)
                if web:
                    x_out = alternative_Upsample(x_out,(1,-1,2 ** (i + 1),2 ** (i + 1))) + x_out_2
                else:
                    x_out = F.interpolate(x_out,scale_factor=2, mode='nearest') + x_out_2
            #x_out = getattr(self, 'to_rgb_{0}'.format(2 ** (stage + 1)))(x,web = False,gain = 1) 
                    
        return x_out

Using cuda device


In [5]:
class d_block(nn.Module):
    def __init__(self,in_channels, mid_channels, out_channels):
        super(d_block,self).__init__()
        self.conv_1 = modulated_conv2d(in_channels,mid_channels,kernel_size=3,stride=1,padding=1,mod = False)
        self.conv_2 = modulated_conv2d(mid_channels,out_channels,kernel_size=3,stride=1,padding=1,mod = False)
        self.skip =  modulated_conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0,mod = False)

    def forward(self,input_x):
        x = self.conv_1(input_x)
        x = F.leaky_relu(x,LReLU_alpha)
        x = self.conv_2(x)
        x = F.leaky_relu(x,LReLU_alpha)
        x = (x + self.skip(input_x,gain = 1)) * (1 / np.sqrt(2))
        return x
  
# Define model
from torch.autograd import Variable
from torch import autograd
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        dimensions = [128,128,128,128,64,64,64,64,64,64,64
                     ,32,32,16,16, 8, 8, 4, 4] #[256,256,256,128,128,128,64,64,64,32,32,32]

        self.from_rgb_4 = modulated_conv2d(3,dimensions[2],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_4 = d_block(dimensions[2],dimensions[1],dimensions[0])

        self.from_rgb_8 = modulated_conv2d(3,dimensions[4],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_8 = d_block(dimensions[4],dimensions[3],dimensions[2])

        self.from_rgb_16 = modulated_conv2d(3,dimensions[6],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_16 = d_block(dimensions[6],dimensions[5],dimensions[4])

        self.from_rgb_32 = modulated_conv2d(3,dimensions[8],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_32 = d_block(dimensions[8],dimensions[7],dimensions[6])

        self.from_rgb_64 = modulated_conv2d(3,dimensions[10],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_64 = d_block(dimensions[10],dimensions[9],dimensions[8])

        self.from_rgb_128 = modulated_conv2d(3,dimensions[12],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_128 = d_block(dimensions[12],dimensions[11],dimensions[10])

        self.from_rgb_256 = modulated_conv2d(3,dimensions[14],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_256 = d_block(dimensions[14],dimensions[13],dimensions[12])

        self.from_rgb_512 = modulated_conv2d(3,dimensions[16],kernel_size=1,stride=1,padding=0,mod = False)
        self.block_512 = d_block(dimensions[16],dimensions[15],dimensions[14])

        self.final_conv_1 = modulated_conv2d(dimensions[0] + 1,dimensions[0],kernel_size=3,stride=1,padding=1,mod = False)
        self.final_conv_2 = modulated_conv2d(dimensions[0],dimensions[0],kernel_size=4,stride=1,padding=0,mod = False)
        self.linear_1 = equalized_linear(dimensions[0],dimensions[0])
        self.linear_2 = equalized_linear(dimensions[0],1)

    #https://github.com/Zeleni9/pytorch-wgan/blob/master/models/wgan_gradient_penalty.py を改変
    def calculate_gradient_penalty(self, real_images, fake_images,batch_size,stage,alpha):
        #eta = torch.FloatTensor(batch_size,1,1,1).uniform_(0,1).to(device)
        #eta = eta.expand(batch_size, real_images.size(1), real_images.size(2), real_images.size(3))

        interpolated = real_images#eta * real_images + ((1 - eta) * fake_images)

        # define it to calculate gradient
        interpolated = Variable(interpolated, requires_grad=True)
        # calculate probability of interpolated examples
        prob_interpolated = self(interpolated,stage,alpha,batch_size).sum()
        # calculate gradients of probabilities with respect to examples
        gradients = autograd.grad(outputs=prob_interpolated, inputs=interpolated,
                              grad_outputs=torch.ones(
                                  prob_interpolated.size()).to(device),
                              create_graph=True, retain_graph=True)[0]
                              
        grad_penalty = ((gradients** 2).sum(axis=[1,2,3])).mean() #(gradients.norm(2, dim=1) ** 2).mean()
        #grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
        return grad_penalty
    
    def forward(self, image, stage = 1 ,alpha = 0, batches = 1):
        for i in range(stage):
            image = F.interpolate(image,size=int(2 ** (stage + 1 - i)), mode='bilinear')
            
            if i == 0:
                x =  F.leaky_relu(getattr(self, 'from_rgb_{0}'.format(2 ** (stage + 1 - i)))(image,gain=math.sqrt(2)),LReLU_alpha)
            #else:
            #    x = x + F.leaky_relu(getattr(self, 'from_rgb_{0}'.format(2 ** (stage + 1 - i)))(image,gain=1),LReLU_alpha)
            
            x = getattr(self, 'block_{0}'.format(2 ** (stage + 1 - i)))(x)
            if i != stage - 1:
                x = F.interpolate(x,scale_factor=0.5, mode='bilinear')
        
        minibatch_std = torch.std(x , dim=(0,1))
        x = torch.cat((x,minibatch_std.broadcast_to(batches,1,4,4)),dim = 1)
        x = self.final_conv_1(x)
        x = F.leaky_relu(x,LReLU_alpha)
        x = self.final_conv_2(x)
        x = F.leaky_relu(x,LReLU_alpha)
        x = torch.flatten(x,start_dim = 1)
        x = self.linear_1(x)
        x = F.leaky_relu(x,LReLU_alpha)
        return  self.linear_2(x,gain = 1)

In [6]:
import copy

class trainer():
    def __init__(self,learning_rate = 0.002, stage = 4,BATCH_SIZE = 64, n_critic = 1,gp_lambda = 10,z_dimension = 256,gs_beta = 0.999):
        self.stage = stage
        self.BATCH_SIZE = BATCH_SIZE
        self.n_critic = n_critic
        self.gp_lambda = gp_lambda
        self.g = Generator().to(device)
        self.gs = copy.deepcopy(self.g)
        self.d = Discriminator().to(device)
        self.G_optimizer = torch.optim.Adam(self.g.parameters(), lr=learning_rate, betas=(0, 0.99))
        self.D_optimizer = torch.optim.Adam(self.d.parameters(), lr=learning_rate, betas=(0, 0.99))
        self.z_dimension = z_dimension
        self.gs_beta = gs_beta
        self.noise = torch.randn((self.BATCH_SIZE,self.z_dimension),device = device)
    def train(self):
        torch.backends.cudnn.benchmark = True
        
        training_data = datasets.CIFAR10(
            root="data",
            train=True,
            download=True,
            transform=ToTensor(),
        )
        '''
        training_data = CustomImageDataset(
                    img_dir="/content/dataset/",
                )
        '''
        train_dataloader = DataLoader(training_data, batch_size=self.BATCH_SIZE, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
        tick = 0
        scaler_D = torch.cuda.amp.GradScaler()
        scaler_G = torch.cuda.amp.GradScaler()
        while True:
            for X in train_dataloader:
                X = X[0]
                X = X * 2 - 1
                for _ in range(self.n_critic):   
                    #self.D_optimizer.zero_grad()
                    for param in self.d.parameters():
                        param.grad = None
                    with torch.cuda.amp.autocast():               
                        z = torch.randn((self.BATCH_SIZE,self.z_dimension),device = device)
                        generated = self.g(z, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)

                        X = X.to(device)
                        X = torch.nn.functional.interpolate(X,size=generated.size(2), mode='bilinear')
                  
                        y_real = self.d(X, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)
                        y_fake = self.d(generated.data, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)
                        #d_loss = torch.mean(y_fake) - torch.mean(y_real)
                        d_loss = torch.mean(F.softplus(y_fake) + F.softplus(-y_real))#- torch.mean(torch.log(torch.sigmoid(y_real))) - torch.mean(torch.log(1 - torch.sigmoid(y_fake)))#
                        pure_loss = d_loss.data
                        d_loss = d_loss + (self.d.calculate_gradient_penalty(X,generated,self.BATCH_SIZE,self.stage,alpha = 0) * self.gp_lambda)

                    '''#normal step
                    self.D_optimizer.zero_grad()
                    d_loss.backward()
                    self.D_optimizer.step()
                    '''
                    
                    scaler_D.scale(d_loss).backward()
                    scaler_D.step(self.D_optimizer)
                    scaler_D.update()
                   
                for param in self.g.parameters():
                        param.grad = None
                with torch.cuda.amp.autocast():
                    z = torch.randn((self.BATCH_SIZE,self.z_dimension),device = device)
                    generated = self.g(z, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)
                    y_fake = self.d(generated, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)
                    #g_loss = -torch.mean(y_fake)
                    g_loss = torch.mean(F.softplus(-y_fake)) #torch.mean(torch.log(1 - torch.sigmoid(y_fake)))#
                
                '''#normal step
                g_loss.backward()
                self.G_optimizer.step()
                '''
                scaler_G.scale(g_loss).backward()
                scaler_G.step(self.G_optimizer)
                scaler_G.update()

                for gparam, gsparam in zip(self.g.parameters(), self.gs.parameters()):
                    gsparam.data = (1 - self.gs_beta) * gsparam.data + self.gs_beta * gparam.data

                if tick % 100 == 0:
                    print('stage:{},g_loss:{}, d_loss:{}, Pure discriminator Loss:{}'.format(self.stage,g_loss,d_loss,pure_loss))
                    generated = self.gs(self.noise, stage = self.stage ,alpha = 0, batches = self.BATCH_SIZE)
                    generated = F.interpolate(generated,size=512, mode='nearest')
                    show_image(generated,str(tick))
                    torch.save(self, '/content/drive/MyDrive/StyleGAN2/model128-cifar3.pth')
                tick += 1

In [None]:
try:
    train = torch.load('/content/drive/MyDrive/StyleGAN2/model128-cifar3.pth')
    print('loaded')
except:
    train = trainer(learning_rate = 0.001, stage = 2,BATCH_SIZE = 64, n_critic = 1,gp_lambda = 5,z_dimension = 128,gs_beta = 0.9)
    print('load failed')

train.stage = 4
#stage:resolution
#1:4, 2:8, 3:16, 4:32, 5:64, 6:128, 7:256, 8:512
train.train()

loaded
Files already downloaded and verified


  "See the documentation of nn.Upsample for details.".format(mode)
  "The default behavior for interpolate/upsample with float scale_factor changed "


stage:4,g_loss:0.7195033431053162, d_loss:1.2388925552368164, Pure discriminator Loss:1.125885248184204
stage:4,g_loss:0.6785848140716553, d_loss:1.2959833145141602, Pure discriminator Loss:1.2105457782745361
stage:4,g_loss:0.6845486164093018, d_loss:1.2713526487350464, Pure discriminator Loss:1.19166898727417
stage:4,g_loss:0.5029839277267456, d_loss:1.3515263795852661, Pure discriminator Loss:1.273514747619629
stage:4,g_loss:0.8518887758255005, d_loss:1.3078516721725464, Pure discriminator Loss:1.193267822265625
stage:4,g_loss:0.6076936721801758, d_loss:1.3535698652267456, Pure discriminator Loss:1.293604850769043
stage:4,g_loss:0.8822863101959229, d_loss:1.3156284093856812, Pure discriminator Loss:1.256404161453247
stage:4,g_loss:0.7782460451126099, d_loss:1.309186577796936, Pure discriminator Loss:1.2575647830963135
stage:4,g_loss:0.9905093908309937, d_loss:1.2829245328903198, Pure discriminator Loss:1.1828457117080688
stage:4,g_loss:0.9138436317443848, d_loss:1.3327162265777588, P

In [None]:
#make ONNX model
try:
    train = torch.load('/content/drive/MyDrive/StyleGAN2/model128-cifar3.pth')
except:
    train = trainer()
    print('load failed')

dummy_input = torch.randn(1, train.z_dimension, device=device)
train.gs.learning_const.requires_grad = False
for i in range(5):  
    getattr(train.gs, 'block_{0}'.format(2 ** (i + 2))).const_noise_1.requires_grad = False
    getattr(train.gs, 'block_{0}'.format(2 ** (i + 2))).const_noise_2.requires_grad = False
stage = torch.tensor(4, dtype=torch.int)
alpha = torch.tensor(0, dtype=torch.int)
batches = torch.tensor(1, dtype=torch.int)
web = torch.tensor(True, dtype=torch.bool)
torch.onnx.export(train.gs, (dummy_input,stage,alpha,batches,web), 'generator.onnx', opset_version= 9)

from google.colab import files
files.download('generator.onnx')


In [None]:

#testing
train = torch.load('/content/drive/MyDrive/StyleGAN2/modessl.pth')
#train = trainer()
z = torch.zeros(1, train.z_dimension, device=device)
generated= train.g(z,stage = 1,web=True)
#show_image(generated)
print(generated)