In [None]:
import torch 
from utils import weights_init
import model

import torch.nn as nn
import torchvision

import pickle
import csv
from tqdm import tqdm

def train_wgan(train_loader, 
                noise_dim=100, 
                batch_size=4, 
                device=torch.device("cpu"), 
                lr=0.0002, 
                betas=(0.5, 0.999), 
                epochs=1,
                clipping_thresh = 0.01,
                name="wgan"):

    netC = model.CriticCG1().to(device)
    netG = model.Generator().to(device)
    netC.apply(weights_init)
    netG.apply(weights_init)

    fixed_noise = torch.randn(batch_size, noise_dim, 1, 1, device=device)
    # instead of real and fake labels, we use 1 and -1
    one = torch.FloatTensor([1]).to(device)
    minus_one = torch.FloatTensor([-1]).to(device)
    # real_label = 1
    # fake_label = 0

    optimizerC = torch.optim.Adam(netC.parameters(), lr=lr, betas=betas)
    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=betas)

    C_losses = []
    G_losses = []                               
    img_list = []
    iters = 0
    gen_iterations = 0   
    loss_G_flag = 100000
    loss_D_flag = 100000

    for epoch in range(epochs):
        print(len(train_loader))
        data_iterator = iter(train_loader)
        i = 0
        n_critic = 5
        
        while i < len(train_loader):
            print(i)
            for param in netC.parameters():
                param.requires_grad = True
            
            if gen_iterations < 25 or gen_iterations % 500 == 0:
                n_critic = 100
            
            print("entering loop: ", n_critic, gen_iterations)
            # while ic < n_critic and i < len(train_loader):
            for _ in tqdm(range(n_critic)):
                if i >= len(train_loader):
                    break

                for p in netC.parameters():
                    p.data.clamp_(-clipping_thresh, clipping_thresh)
                
                netC.zero_grad()
                real_images = data_iterator.next().to(device)
                
                b_size = real_images.size(0)

                critic_for_real = netC(real_images)
                critic_for_real.backward(minus_one)

                noise = torch.randn(b_size, noise_dim, 1, 1, device=device)
                fake = netG(noise)
                critic_for_fake = netC(fake)
                critic_for_fake.backward(one)

                wasserstein_distance = critic_for_real - critic_for_fake
                critic_loss = critic_for_fake - critic_for_real

                optimizerC.step()
                i += 1

            
            # now we come to generator, for which we don't need to update critic
            # so we freeze the critic parameters
            for p in netC.parameters():
                p.requires_grad = False
            # we need to update generator
            print("entering generator")
            netG.zero_grad()
            noise = torch.randn(b_size, noise_dim, 1, 1, device=device)
            fake = netG(noise)
            critic_for_fake = netC(fake)
            gen_loss = critic_for_fake
            critic_for_fake.backward(minus_one)

            # update generator
            optimizerG.step()
            gen_iterations += 1

            if i % 100 == 0:
                save_vals = [epoch, epochs, i, len(train_loader), 
                             critic_loss.data[0].item(), gen_loss.data[0].item(),
                             wasserstein_distance.data[0].item()]

                logger = open(f"./solutions/{name}.csv", "a", newline="")
                with logger:
                    write = csv.writer(logger)
                    write.writerow(save_vals)

                if gen_loss.data[0].item() < loss_G_flag:
                    loss_G_flag = gen_loss.data[0].item()
                    torch.save(netG.state_dict(), "./solutions/gen_" + name + ".pt")
                if critic_loss.data[0].item() < loss_D_flag:
                    loss_D_flag = critic_loss.data[0].item()
                    torch.save(netC.state_dict(), "./solutions/critic_" + name + ".pt")

                # running_D_loss = 0.0
                # running_G_loss = 0.0

            C_losses.append(critic_loss.data[0].item())
            G_losses.append(gen_loss.data[0].item())
            if (iters % 500 == 0) or ((epoch == epochs-1) and (i == len(train_loader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(torchvision.utils.make_grid(fake, padding=2, normalize=True))
            iters += 1

        
    with open(f"./solutions/{name}_C_loss.pkl", "wb") as logger:
        pickle.dump(C_losses, logger)
    logger.close()
    with open(f"./solutions/{name}_G_loss.pkl", "wb") as logger:
        pickle.dump(G_losses, logger)
    logger.close()
    with open(f"./solutions/{name}_img_list.pkl", "wb") as logger:
        pickle.dump(img_list, logger)
    logger.close()

from dataloader import MyDataset
# from train import train_gan

from torch.utils.data import DataLoader
import torch

data_path = '/home/akshita/Documents/data/pizzas'


train_dataset = MyDataset(data_path)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
train_wgan(train_loader, epochs=1, name="wgan", device=device)

In [None]:
import torchsummary
from model import Generator, CriticCG1
import torch

# netG = Generator()
netD = CriticCG1()
# torchsummary.summary(netG, (100, 1, 1))
# torchsummary.summary(netD, (3, 64, 64))
# sample = torch.randn(1, 3, 64, 64)
# out = netD(sample)

In [None]:
import pickle
import matplotlib.pyplot as plt
# import torchvision.transforms.functional as tvtF
# import imageio   

import numpy as np
import cv2
import os

# # Load the data
# with open('./solutions/gan_D_loss.pkl', 'rb') as f:
#     d_data = pickle.load(f)

# with open('./solutions/gan_G_loss.pkl', 'rb') as f:
#     g_data = pickle.load(f)

# plt.plot(d_data, label='Discriminator')
# plt.plot(g_data, label='Generator')
# plt.legend()



with open('./solutions/fake_data.pkl', 'rb') as f:
    fake_images = pickle.load(f)
    

In [None]:
# with open('./solutions/gan_img_list.pkl', 'rb') as f:
#     img_data = pickle.load(f)
# images = []           
# for imgobj in img_data:  
#     img = tvtF.to_pil_image(imgobj)  
#     images.append(img) 
# imageio.mimsave("./solutions/generation_animation.gif", images, fps=5)

In [None]:
from dataloader import MyDataset
import model


import torch
from pytorch_fid.fid_score import calculate_activation_statistics, calculate_frechet_distance
from pytorch_fid.inception import InceptionV3

# Load the data
fake_path = '/home/akshita/Documents/data/fake_pizzas'
eval_path = '/home/akshita/Documents/data/pizzas'
val_dataset = MyDataset(data_path=eval_path, split='eval')
eval_image_files = val_dataset.file_list
len(eval_image_files)

# device = torch.device('cuda' if )
# dims = 2048
# block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
# model = InceptionV3([block_idx]).to(device)

# if not os.path.exists(mac_path):
    # os.makedirs(mac_path)
# len(val_dataset)

# Load the model
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# print(device)
# generator = model.Generator()
# generator.load_state_dict(torch.load('./solutions/gen_gan.pt', map_location=device))
# generator.eval()
# for i in range(1000):
#     noise = torch.randn(1, 100, 1, 1, device=device)
#     fake = generator(noise)
#     fake = fake.detach().cpu().numpy()
#     # print(fake.shape)
#     fake = np.transpose(fake, (0, 2, 3, 1))
#     # print(fake.shape)
#     fake = ((fake + 1) * 127.5).astype(np.uint8)
#     fake = fake[0]
#     fake = np.expand_dims(fake, axis=0)
#     if i == 0:
#         fake_data = fake
#     else:
        # fake_data = np.concatenate((fake_data, fake), axis=0)

# with open('./solutions/fake_data.pkl', 'wb') as f:
#     pickle.dump(fake_data, f)


# for i, img in enumerate(fake_images):
#     cv2.imwrite(os.path.join(fake_path, f"{i}.jpg"), img)



In [None]:
# for i, img in enumerate(fake_images):
#     plt.figure()
#     plt.imshow(img)
#     plt.axis('off')
#     plt.savefig(os.path.join(fake_path, f"{i}.jpg"), bbox_inches='tight', pad_inches=0)