In [None]:
import dataset
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt
import yuGANoh_with_fc_and_disc
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
import zipfile

In [None]:
'''with zipfile.ZipFile('card.zip', 'r') as zip_ref:
    zip_ref.extractall('data')'''

In [None]:
#initialize transform
transform = transforms.Compose([
    transforms.Resize((428,321)),
    transforms.ToTensor()
])

In [None]:
#initialize dataloader
root_dir = 'data/card'
batch_size = 45
ygoDset = dataset.ygoCards(root_dir=root_dir,transform = transform)
ygoLoader = DataLoader(ygoDset, batch_size=batch_size, shuffle=True, num_workers=4)

In [None]:
#preview data
batch = next(iter(ygoLoader))
trans = transforms.ToPILImage()
plt.imshow(trans(batch[0]))

In [None]:
#network params
latent_size = 100
num_gan_features = 64
num_disc_features = num_gan_features
num_hidden_features = 256
lr = 1e-4
beta1 = 0.5
num_epochs = 100
similarity_features = 50
lower_bound = 0.8
bound = 1- lower_bound

In [None]:
#initialize network
device = torch.device('cuda:0')
gen = yuGANoh_with_fc_and_disc.Generator(latent_size,num_gan_features).to(device)
disc = yuGANoh_with_fc_and_disc.Discriminator(num_disc_features,num_hidden_features,similarity_features,batch_size).to(device)
gen.apply(yuGANoh_with_fc_and_disc.init_weights)
print(gen)
disc.apply(yuGANoh_with_fc_and_disc.init_weights)
print(disc)

In [None]:
#initialize loss and optimizer
criterion = nn.BCELoss()
noise = torch.randn(1,latent_size,1,1).to(device)
optimizerD = optim.Adam(disc.parameters(),lr=2*lr,betas=(beta1,0.999), weight_decay = .0001)
optimizerG = optim.Adam(gen.parameters(),lr=lr,betas=(beta1,0.999), weight_decay = .0001)

In [None]:
def feature_loss(real_features,fake_features):
    temp = torch.mean(real_features,axis=0) - torch.mean(fake_features,axis=0)
    return torch.sum(temp*temp)

In [None]:
#gen.load_state_dict(torch.load('gen22.pt'))
#disc.load_state_dict(torch.load('disc22.pt'))

In [None]:
images = []
gen_losses = []
disc_losses = []

print("Begin training")
for epoch in range(num_epochs):
    if epoch > 0 and epoch%25 == 0:
        lr*=2
        optimizerD = optim.Adam(disc.parameters(),lr=lr,betas=(beta1,0.999), weight_decay = .0001)
    for i,img in enumerate(ygoLoader,0):
        #zero discriminator gradient
        disc.zero_grad()
        
        
        #run real image through discriminator
        img = img.to(device)
        #img = normalize_tensors(img) + 0.1*torch.randn_like(img).to(device)
        if epoch < 10:
            img = img + 0.1*torch.randn_like(img).to(device)
        b_size = img.shape[0]
        features_real, output = disc(img)
        output = output.squeeze()
        #real_label = 1
        #label = torch.ones_like(output).to(device)
        label = 1 - bound*torch.rand_like(output)
        disc_loss_real = criterion(output,label)
        disc_loss_real.backward(retain_graph = True)
        
        
        #now run a fake batch through generator.
        inp = torch.randn(b_size,latent_size,1,1).to(device)
        gen_out = gen(inp)
        features_fake, fake_out = disc(gen_out)
        fake_out = fake_out.squeeze()
        #fake_label = 0
        label_fake = label*0
        disc_loss_fake = criterion(fake_out,label_fake)
        disc_loss_fake.backward(retain_graph = True)
        total_loss = disc_loss_real + disc_loss_fake       
        optimizerD.step()
        
        #train the generator
        #we already have a generator pass with gen_out=gen(inp)
        gen.zero_grad()
        model_loss = feature_loss(features_real,features_fake)
        model_loss.backward(retain_graph = True)   
        optimizerG.step()
        '''
        #train the generator again
        #we already have a generator pass with gen_out=gen(inp)
        gen.zero_grad()
        features_out_real , model_out = disc(gen_out)
        model_out = model_out.squeeze()
        model_loss = feature_loss(features_out_real,features_fake)
        model_loss.backward()   
        optimizerG.step()'''
        
        #check_individual_norm(gen)
        #check_individual_norm(disc)
        
        if i%50 == 0:
            disc_losses.append(total_loss.item())
            gen_losses.append(model_loss.item())
            print('Epoch: '+str(epoch) + ' iter: ' + str(i) + ' lossG: ' + str(gen_losses[-1]) + ' lossD: ' + str(disc_losses[-1]))
            out_img = gen(noise)
            #out_img = undo_normalize(out_img)
            trans = transforms.ToPILImage()
            plt.imshow(trans(out_img[0].cpu()))
            if gen_losses[-1] == 0.0:
                break
            if disc_losses[-1] == 0.0:
                break
            print('discriminator gradient L2 Norm')
            #check_norm(disc)
            print('generator gradient L2 Norm')
            #check_norm(gen)

In [None]:
def check_norm(model):
    total_norm = 0
    for p in model.parameters():
        if p.grad is None:
            continue
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    total_norm = total_norm ** (1. / 2)
    print(total_norm)
    
def normalize_tensors(inp):
    return (inp-0.5)*2

def undo_normalize(inp):
    return inp/2 + 0.5

def check_individual_norm(model):
    for p in model.parameters():
        if p.grad is None:
            print(str(p.name)+' is None')
            continue
        print(str(p.name) +': '+  str(p.grad.data.norm(2).item()**2))

In [None]:
plt.figure(figsize=(15,15))
for i in range(9):
    noise = torch.randn(1,latent_size,1,1).to(device)
    out_img = gen(noise)
    #out_img = undo_normalize(out_img)
    trans = transforms.ToPILImage()
    plt.subplot(3,3,i+1)
    plt.imshow(trans(out_img[0].cpu()))

In [None]:
plt.plot(disc_losses)
plt.plot(gen_losses)
plt.legend(['disc_losses','gen_losses'])
plt.show()

In [None]:
torch.save(gen.state_dict(), 'gen22_new.pt')
torch.save(disc.state_dict(), 'disc22_new.pt')

In [None]:
torch.save(disc_losses,'disc22_loss_weight_new-dec_2.pt')

In [None]:
torch.save(gen_losses,'gen22_loss_weight_new-dec_2.pt')

In [None]:
E = torch.randn(3*2,3,4)
A = torch.randn(4,3*2)
A.transpose(0,1)

In [None]:
A = torch.Tensor([[[1,1],[2,2]],[[3,3],[4,4]],[[2,1],[2,4]],[[11,3],[4,3]]])
A = A.transpose(0,2)
print(A.shape)
A = A.repeat(A.shape[2],1,1,1)
print(A.shape)
B = A.transpose(0,-1)
print(B.shape)
C = torch.exp(-torch.sum(torch.abs(B-A),dim=2))
print(C.shape)
D = torch.sum(C,dim = 2)
print(D.shape)
print(D[0])
print(D[1])

In [None]:
A = torch.randn(3,2)
print(A.shape)
print(A)
torch.sum(A,dim=1)