In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as Tfs
import matplotlib.pyplot as plt
import torchvision.models as models
import numpy as np
from PIL import Image

In [None]:
#import wandb
#wandb.init(project='ganstuff')

In [None]:
from models.DeepGan import DeepGAN

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
transforms = Tfs.Compose([
    Tfs.ToTensor(),
    Tfs.Normalize((0.5,0.5,0.5), (0.5, 0.5, 0.5))
])
tds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms)
valds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms)

In [None]:
classes = ['plane','car','bird','cat','deer','dog','frog','horse','boat','truck']

In [None]:
idx = -141
plt.imshow(0.5*(tds[idx][0].permute(1,2,0)) + 0.5), tds[idx][1]

In [None]:
def print_gpu_mem():
    print(f'{(torch.cuda.memory_allocated(device) / (1024**3)):.2f}', '/',
         f'{(torch.cuda.get_device_properties(device).total_memory/1024**3):.2f}', 'GBs')

In [None]:
def unroll_batch(batch, num_rows, num_cols):
    rows = []
    num_rows = num_rows
    num_cols = num_cols
    for i in range(num_rows):
        for j in range(num_cols):
            if j== 0:
                rows.append(batch[(i*num_cols),:,:,:].view(batch.shape[1],batch.shape[2],batch.shape[3]))
            else:
                rows[i] = torch.cat((rows[i], batch[(i*num_cols + j),:,:,:].view(batch.shape[1],batch.shape[2],batch.shape[3])), dim=2)

    for i in range(1, len(rows)):
        rows[0] = torch.cat((rows[0], rows[i]), dim=1)
    return rows[0] 
        

In [None]:
num_classes = 2

In [None]:
def train(model, Dlr, Glr, epochs, batch_size, tds, valds):
    model.to(device).train()
    tdl = torch.utils.data.DataLoader(tds, batch_size=batch_size, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(valds, batch_size=batch_size, shuffle=False, drop_last=True)

    ##make optimizers for both D and G
    D_optim = torch.optim.Adam(model.D.parameters(), lr=Dlr)
    G_optim = torch.optim.Adam(model.G.parameters(), lr=Glr)
    
    train_dloss = []
    val_dloss = []
    train_gloss = []
    val_gloss = []

    train_dacc = []
    val_dacc = []
    train_gacc = []
    val_gacc = []
    for epoch in range(epochs):
        train_dloss_batch = []
        val_dloss_batch = []
        train_gloss_batch = []
        val_gloss_batch = []

        train_dacc_batch = []
        val_dacc_batch = []
        train_gacc_batch = []
        val_gacc_batch = []
        model.train()
        for xb, labels in tdl:
            xb = xb.to(device)
            labels = labels.to(device)
            D_optim.zero_grad()
            d_loss, d_acc = model.D_trainstep(xb, labels, batch_size, D_optim)
            G_optim.zero_grad()
            g_loss, g_acc, gen_imgs = model.G_trainstep(batch_size, G_optim, num_samples=1)
            D_optim.zero_grad()
            d_loss, d_acc = model.D_trainstep(xb, labels, batch_size, D_optim)

            train_dloss_batch.append(d_loss.item())
            train_gloss_batch.append(g_loss.item())
            train_dacc_batch.append(d_acc)
            train_gacc_batch.append(g_acc)
            
            del(xb); del(labels); del(gen_imgs)
            torch.cuda.empty_cache()

        model.eval()
        with torch.no_grad():
            for idx, (vxb, vlabels) in enumerate(vdl):
                vxb = vxb.to(device)
                vlabels = vlabels.to(device)
                vd_loss, vd_acc = model.D_valstep(vxb, vlabels, batch_size)

                val_dloss_batch.append(vd_loss.item())
                val_dacc_batch.append(vd_acc)

                if idx == len(vdl) - 1:
                  ##print reconstruction validation example
                  ##create new images
                  generated = model.G(torch.randn(num_classes*4,model.zdim).to(device), (torch.arange(num_classes*4) % num_classes).long().to(device))
                  generated = unroll_batch(generated, 4, num_classes)
                  plt.imshow(0.5*generated.to('cpu').permute(1,2,0) + 0.5); plt.show()
                  gimg = Image.fromarray(np.array((generated.to('cpu').permute(1,2,0)*255), dtype=np.uint8))
                  #wandb.log({'generated_images':wandb.Image(gimg)})
                  del(generated); del(gimg)
        
        train_dloss.append(np.mean(train_dloss_batch))
        val_dloss.append(np.mean(val_dloss_batch))
        train_gloss.append(np.mean(train_gloss_batch))

        train_dacc.append(np.mean(train_dacc_batch))
        val_dacc.append(np.mean(val_dacc_batch))
        train_gacc.append(np.mean(train_gacc_batch))

        print('Epoch', epoch+1, 'Train-D-Loss', f'{train_dloss[epoch]:.5f}',
              'Train-D-Acc', f'{train_dacc[epoch]:.5f}')
        print('Val-D-Loss', f'{val_dloss[epoch]:.5f}',
              'Val-D-Acc', f'{val_dacc[epoch]:.5f}')
        print('Train-G-Loss', f'{train_gloss[epoch]:.5f}',
              'Train-G-Acc', f'{train_gacc[epoch]:.5f}')
        #wandb.log({"train_loss_512Latent": np.mean(losses), "val_loss_512Latent": np.mean(val_losses)})
        if (epoch + 1 <= 5): 
            print_gpu_mem()

In [None]:
model = DeepGAN(numclasses=num_classes, ch=32)

In [None]:
##hyperparams
epochs = 1000
batch_size = 512
#lr = 0.001 - batchsize=1024, pretrainedresnet,sgd
Dlr = 2*1e-4
Glr = 5*1e-5

In [None]:
print_gpu_mem()

In [None]:
classes = [1, 7]
new_ds = []
for i in range(len(tds)):
    if tds[i][1] in classes:
        new_ds.append((tds[i][0], 0 if tds[i][1] == 1 else 1))

In [None]:
class Ds:
    def __init__(self, ds):
        self.ds = ds
    def __getitem__(self, i):
        return self.ds[i][0], self.ds[i][1]
    def __len__(self):
        return len(self.ds)

In [None]:
smallt, smallv = torch.utils.data.random_split(Ds(new_ds), [len(new_ds) - 520, 520])

In [None]:
train(model, Dlr, Glr, epochs, batch_size, smallt, smallv)

In [None]:
classes = ['plane','car','bird','cat','deer','dog','frog','horse','boat','truck']

In [None]:
model.G

In [None]:
torch.randint(low=1, high=5, size=(20,))

In [None]:
asdf = torch.randn(3,21)
b = torch.argmax(asdf, dim=1)
b

In [None]:
torch.save(model.state_dict(), 'horsecarembeddings.pt')