In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as Tfs
import matplotlib.pyplot as plt
import torchvision.models as models
import numpy as np
from PIL import Image

In [None]:
import wandb
wandb.init(project='vae-faces')

In [None]:
from dataset.facesDataset import DataSet
from model.VAE import VAE
from ELBOLOSS import ELBOLoss

In [None]:
ds = DataSet()
img, label = ds[-14]
plt.imshow(img.permute(1,2,0))

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
def print_gpu_mem():
    print(f'{(torch.cuda.memory_allocated(device) / (1024**3)):.2f}', '/',
         f'{(torch.cuda.get_device_properties(device).total_memory/1024**3):.2f}', 'GBs')

In [None]:
def print_train_update(xb, preds):
    orig = xb[0,:,:,:].view(xb.shape[1], xb.shape[2], xb.shape[3])
    guess = preds[0,:,:,:].view(xb.shape[1], xb.shape[2], xb.shape[3])
    for i in range(1, xb.shape[0]):
      orig = torch.cat((orig, xb[i,:,:,:].view(xb.shape[1], xb.shape[2], xb.shape[3])), dim=1)
      guess = torch.cat((guess, preds[i,:,:,:].view(xb.shape[1], xb.shape[2], xb.shape[3])), dim=1)
    ##orig.shape/guess.shape = 3,b*height, width
    chart = torch.cat((orig,guess), dim=2)
    plt.imshow(chart.permute(1,2,0).to('cpu'))
    plt.show()
    return chart

In [None]:
def unroll_batch(batch):
    rows = []
    num_rows = int(batch.shape[0]**(0.5))
    num_cols = int(batch.shape[0] / num_rows)
    for i in range(num_rows):
        for j in range(num_cols):
            if j== 0:
                rows.append(batch[(i*num_cols),:,:,:].view(batch.shape[1],batch.shape[2],batch.shape[3]))
            else:
                rows[i] = torch.cat((rows[i], batch[(i*num_cols + j),:,:,:].view(batch.shape[1],batch.shape[2],batch.shape[3])), dim=2)

    for i in range(1, len(rows)):
        rows[0] = torch.cat((rows[0], rows[i]), dim=1)
    return rows[0] 
        

In [None]:
unroll_batch(torch.ones(25,3,12,12)).shape

In [None]:
def train(model, lr, epochs, batch_size, tds,valds, beta):
    model.to(device)
    tdl = torch.utils.data.DataLoader(tds, batch_size=batch_size, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(valds, batch_size=batch_size, shuffle=False)
    loss_fn = ELBOLoss(beta=beta)
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.1, patience=3, verbose=True)
    for epoch in range(epochs):
        losses = []
        val_losses = []
        for xb, _ in tdl:
            model.train()
            xb = xb.to(device)
            preds, z, logvar, mean, std = model(xb)
            
            loss = loss_fn(xb, z,logvar,mean,std,preds)
            optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 0.1)
            optim.step()
            #sched.step()
            del(xb); del(preds); del(z); del(logvar); del(mean)
            torch.cuda.empty_cache()
            losses.append(loss.item())
            
        with torch.no_grad():
            for idx, (vxb, _) in enumerate(vdl):
                model.eval()
                vxb = vxb.to(device)
                vpreds, vz, vlogvar, vmean, vstd = model(vxb)
                vloss = loss_fn(vxb, vz,vlogvar,vmean,vstd,vpreds)

                if idx == len(vdl) - 1:
                  ##print reconstruction validation example
                  chart = print_train_update(vxb[0:5,:,:,:], vpreds[0:5,:,:,:])
                  ##log the reconstruction examples
                  rimg = Image.fromarray(np.array((chart.to('cpu').permute(1,2,0)*255), dtype=np.uint8))
                  wandb.log({'recon_images':wandb.Image(rimg)})
                  ##create new images
                  generated = model.decoder(torch.randn(25,model.z_d).to(device))
                  generated = unroll_batch(generated)
                  gimg = Image.fromarray(np.array((generated.to('cpu').permute(1,2,0)*255), dtype=np.uint8))
                  wandb.log({'generated_images':wandb.Image(gimg)})

                del(vxb); del(vpreds); del(vz); del(vlogvar); del(vmean)
                val_losses.append(vloss.item())
        sched.step(np.mean(val_losses))
        print('Epoch', epoch+1, 'Loss', f'{np.mean(losses):.5f}')
        print('VLoss', np.mean(val_losses))
        wandb.log({"train_loss_512Latent": np.mean(losses), "val_loss_512Latent": np.mean(val_losses)})
        if (epoch + 1 <= 5): 
            print_gpu_mem()

In [None]:
tds, valds = torch.utils.data.random_split(ds, [len(ds)-200, 200])

In [None]:
model = VAE(z_d=512)

In [None]:
model.z_d

In [None]:
##hyperparams
epochs = 1000
batch_size = 1024
#lr = 0.001 - batchsize=1024, pretrainedresnet,sgd
lr = 0.0001
beta = 0.1

In [None]:
print_gpu_mem()

In [None]:
train(model, lr, epochs, batch_size, tds, valds, beta)

In [None]:
img = np.ones((512,512,3), dtype=np.uint8) * 255
img = Image.fromarray(img)
plt.imshow(img)

In [None]:
asdf1= torch.ones(1,3,32,32, dtype=torch.float32)
asdf2 = torch.zeros(1,3,32,32, dtype=torch.float32)

asdf3 = torch.cat((asdf1,asdf2), dim=0)
plt.imshow(asdf3.view(3,64,32).permute(1,2,0))

In [None]:
os.listdir()

In [None]:
img = Image.fromarray(np.ones((512,512,3), dtype=np.uint8) * 255)
img

In [None]:
np.array(torch.tensor([2.3]), dtype=np.uint8)