In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox, TextArea
import matplotlib
import numpy as np
import pandas as pd
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset
from torch.utils import data
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import save_image
from sklearn.model_selection import train_test_split
from tqdm import tqdm, tqdm_notebook
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Training on',DEVICE)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [None]:
batch_size = 128
learning_rate = 0.005
image_shape = 100
input_size = image_shape * image_shape
hidden_size = 1000
epoch = 600
labels_length = 2

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Grayscale(num_output_channels=1)
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
dataset_path = "/content/drive/MyDrive/VKR/dataset_1200_100"
dataset = ImageFolder(root=dataset_path, transform=transform)

In [None]:
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])

In [None]:
train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
val_dataset = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
#helper functions
def one_hot(x, max_x):
    return torch.eye(max_x + 1)[x]

def plot_gallery(images, h, w, n_row=3, n_col=6):
    plt.figure(figsize=(2 * n_col, 2 * n_row))
    for i in range(n_row * n_col):
        plt.subplot(n_row, n_col, i + 1)
        plt.axis("off")
        plt.imshow(images[i].reshape(h, w), cmap = matplotlib.cm.binary)
    plt.show()
    
def plot_loss(history):
    loss, val_loss = zip(*history)
    plt.figure(figsize=(15, 9))
    plt.plot(loss, label="train_loss")
    plt.plot(val_loss, label="val_loss")
    plt.legend(loc='best')
    plt.xlabel("epochs")
    plt.ylabel("loss")
    plt.show()

In [None]:
class CVAE(nn.Module):
    def __init__(self, input_size, hidden_size=1000):
        super(CVAE, self).__init__()
        input_size_with_label = input_size + labels_length
        hidden_size += labels_length
        
        self.fc1 = nn.Linear(input_size_with_label,2048)
        self.fc2 = nn.Linear(2048,1024)
        self.fc23 = nn.Linear(1024,512)
        self.fc21 = nn.Linear(512, hidden_size)
        self.fc22 = nn.Linear(512, hidden_size)
        
        self.relu = nn.ReLU()
        
        self.fc3 = nn.Linear(hidden_size, 512)
        self.fc31 = nn.Linear(512, 1024)
        self.fc32 = nn.Linear(1024, 2048)
        self.fc4 = nn.Linear(2048, input_size)
    
    def encode(self, x, labels):
        x = x.view(-1, 1*image_shape*image_shape)
        x = torch.cat((x, labels), 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc23(x))
        return self.fc21(x), self.fc22(x)
        
    def decode(self, z, labels):
        torch.cat((z, labels), 1)
        z = self.relu(self.fc3(z))
        z = self.relu(self.fc31(z))
        z = self.relu(self.fc32(z))
        return torch.sigmoid(self.fc4(z))
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 *logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
        
    def forward(self,x, labels):
        #targets = one_hot(targets,labels_length-1).float().to(DEVICE)
        mu, logvar = self.encode(x, labels)
        z = self.reparameterize(mu, logvar)
        x = self.decode(z, labels)
        return x, mu, logvar

def train_cvae(net, dataloader, test_dataloader, flatten=True, epochs=20):
    validation_losses = []
    optim = torch.optim.Adam(net.parameters())

    log_template = "\nEpoch {ep:03d} val_loss {v_loss:0.4f}"
    with tqdm(desc="epoch", total=epochs) as pbar_outer:  
        for i in range(epochs):
            for batch, labels in dataloader:
                batch = batch.to(DEVICE)
                labels = one_hot(labels,2).to(DEVICE)

                if flatten:
                    batch = batch.view(batch.size(0), image_shape*image_shape)

                optim.zero_grad()
                x,mu,logvar = net(batch, labels)
                loss = vae_loss_fn(batch, x[:, :image_shape*image_shape], mu, logvar) # 784
                loss.backward()
                optim.step()
            evaluate(validation_losses, net, test_dataloader, flatten=True)
            pbar_outer.update(1)
            tqdm.write(log_template.format(ep=i+1, v_loss=validation_losses[i]))
    plt.show()
    return validation_losses

In [None]:
cvae = CVAE(image_shape*image_shape, hidden_size).to(DEVICE)

In [None]:
def vae_loss_fn(x, recon_x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

def evaluate(losses, autoencoder, dataloader, flatten=True):
    model = lambda x, y: autoencoder(x, y)[0]    
    loss_sum = []
    inp, out = [],[]
    loss_fn = nn.MSELoss()
    for inputs, labels in dataloader:
        inputs = inputs.to(DEVICE)
        labels = one_hot(labels,1).to(DEVICE)

        if flatten:
            inputs = inputs.view(inputs.size(0), image_shape*image_shape)

        outputs = model(inputs, labels)
        loss = loss_fn(inputs, outputs)            
        loss_sum.append(loss)
        inp = inputs
        out = outputs

    with torch.set_grad_enabled(False):
        plot_gallery([inp[0].detach().cpu(),out[0].detach().cpu()],image_shape,image_shape,1,2)    

    losses.append((sum(loss_sum)/len(loss_sum)).item())

In [None]:
def train_cvae(net, dataloader, test_dataloader, flatten=True, epochs=50):
    validation_losses = []
    optim = torch.optim.Adam(net.parameters())

    log_template = "\nEpoch {ep:03d} val_loss {v_loss:0.4f}"
    with tqdm(desc="epoch", total=epochs) as pbar_outer:  
        for i in range(epochs):
            for batch, labels in dataloader:
                batch = batch.to(DEVICE)
                labels = one_hot(labels,labels_length-1).to(DEVICE)
                
                if flatten:
                    batch = batch.view(batch.size(0), image_shape*image_shape)

                optim.zero_grad()
                x,mu,logvar = net(batch, labels)
                loss = vae_loss_fn(batch, x[:, :image_shape*image_shape], mu, logvar)
                loss.backward()
                optim.step()
            evaluate(validation_losses, net, test_dataloader, flatten=True)
            pbar_outer.update(1)
            tqdm.write(log_template.format(ep=i+1, v_loss=validation_losses[i]))
    plt.show()
    return validation_losses

In [None]:
history = train_cvae(cvae, train_dataset, val_dataset, epochs=epoch)

In [None]:
val_loss = history
plt.figure(figsize=(15, 9))
plt.plot(val_loss, label="val_loss")
plt.legend(loc='best')
plt.xlabel("epochs")
plt.ylabel("loss")
plt.show()

## CVAE convolutional 

In [None]:
def load_data():
    transform = torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               transforms.Grayscale(num_output_channels=1)])
    dataset_path = "D:/VKR/dataset/union_smaller_100"
    dataset = ImageFolder(root=dataset_path, transform=transform)
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_data, test_data = torch.utils.data.random_split(dataset, [train_size, test_size])
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=True, drop_last=True)

    
    return train_loader, test_loader

In [None]:
train_loader, test_loader = load_data()

In [None]:
class Model(nn.Module):
    def __init__(self,latent_size=32,num_classes=2, image_shape=100):
        super(Model,self).__init__()
        self.latent_size = latent_size
        self.num_classes = num_classes
        self.original_image_shape = image_shape
        self.image_shape = image_shape
        self.kernal_size = 5
        self.stride = 2

        # For encode
        sh = (self.kernal_size + (self.kernal_size - 1) * self.stride)
        self.conv1 = nn.Conv2d(2, 16, kernel_size=5, stride=2) # image_shape - 13 + 1
        self.sig = nn.Sigmoid()
        self.image_shape = self.image_shape - sh + 1
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.image_shape = self.image_shape - sh + 1
        self.linear1 = nn.Linear(22*22*32,300) # 4*4*32
        self.mu = nn.Linear(300, self.latent_size)
        self.logvar = nn.Linear(300, self.latent_size)

        # For decoder
        self.linear2 = nn.Linear(self.latent_size + self.num_classes, 300)
        self.linear3 = nn.Linear(300,10*10*32)
        self.conv3 = nn.ConvTranspose2d(32, 16, kernel_size=5,stride=2) # H_out ​= (H_in​−1)*stride[0] − 2×padding[0] + dilation[0]×(kernel_size[0]−1) + output_padding[0] + 1
                                                                        # W_out​ = (Win​−1)×stride[1] − 2×padding[1] + dilation[1]×(kernel_size[1]−1) + output_padding[1] + 1
        self.conv4 = nn.ConvTranspose2d(16, 1, kernel_size=5, stride=2)
        self.conv5 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2)

    def encoder(self,x,y):
        y = torch.argmax(y, dim=1).reshape((y.shape[0],1,1,1))
        y = torch.ones(x.shape).to(device)*y
        t = torch.cat((x,y),dim=1)
        #print(t.shape)
        t = F.relu(self.conv1(t))
        #print(t.shape)
        t = F.relu(self.conv2(t))
        #print(t.shape)
        t = t.reshape((x.shape[0], -1))
        #print(t.shape)
        #print("linear: ", self.image_shape)
        t = F.relu(self.linear1(t))
        mu = self.mu(t)
        logvar = self.logvar(t)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std).to(device)
        return eps*std + mu
    
    def unFlatten(self, x, channels, new_is):
        return x.reshape((x.shape[0], channels, new_is, new_is))

    def decoder(self, z):
        #print("z: ", z.shape)
        t = F.relu(self.linear2(z))
        #print("t_linear1: ", t.shape)
        t = F.relu(self.linear3(t))
        #print("t_linear2: ", t.shape)
        t = self.unFlatten(t, 32, 10)
        #print("t_unflatten: ", t.shape)
        t = F.relu(self.conv3(t))
        #print("t_conv1: ", t.shape)
        t = F.relu(self.conv4(t))
       # print("t_conv2: ", t.shape)
        t = F.relu(self.conv5(t))
        #print("t_conv3: ", t.shape)
        return self.sig(t)


    def forward(self, x, y):
        mu, logvar = self.encoder(x,y)
        z = self.reparameterize(mu,logvar)

        # Class conditioning
        z = torch.cat((z, y.float()), dim=1)
        pred = self.decoder(z)
        return pred, mu, logvar


def plot(epoch, pred, y,name='test_'):
    if not os.path.isdir('./images'):
        os.mkdir('./images')
    fig = plt.figure(figsize=(16,16))
    for i in range(2):
        ax = fig.add_subplot(1,2,i+1)
        ax.imshow(pred[i,0],cmap='gray')
        ax.axis('off')
        ax.title.set_text(str(y[i]))
    plt.savefig("./images/{}epoch_{}.jpg".format(name, epoch))
    # plt.figure(figsize=(10,10))
    # plt.imsave("./images/pred_{}.jpg".format(epoch), pred[0,0], cmap='gray')
    plt.close()


def loss_function(x, pred, mu, logvar):
    recon_loss = F.binary_cross_entropy(pred, x, reduction='sum') #F.binary_cross_entropy mse_loss
    kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss, kld


def train(epoch, model, train_loader, optim):
    reconstruction_loss = 0
    kld_loss = 0
    total_loss = 0
    for i,(x,y) in enumerate(train_loader):
        try:
            label = np.zeros((x.shape[0], num_classes))
            label[np.arange(x.shape[0]), y] = 1
            label = torch.tensor(label)

            #print(label)
            
            optim.zero_grad()   
            pred, mu, logvar = model(x.to(device),label.to(device))
            
            recon_loss, kld = loss_function(x.to(device),pred, mu, logvar)
            loss = recon_loss + kld
            loss.backward()
            optim.step()
            
            total_loss += loss.cpu().data.numpy()*x.shape[0]
            reconstruction_loss += recon_loss.cpu().data.numpy()*x.shape[0]
            kld_loss += kld.cpu().data.numpy()*x.shape[0]
            if i == 0:
                print("Gradients")
                for name,param in model.named_parameters():
                    if "bias" in name:
                        print(name,param.grad[0],end=" ")
                    else:
                        print(name,param.grad[0,0],end=" ")
                    print()
        except Exception as e:
            traceback.print_exe()
            torch.cuda.empty_cache()
            continue
    
    reconstruction_loss /= len(train_loader.dataset)
    kld_loss /= len(train_loader.dataset)
    total_loss /= len(train_loader.dataset)
    return total_loss, kld_loss,reconstruction_loss

def test(epoch, model, test_loader):
    reconstruction_loss = 0
    kld_loss = 0
    total_loss = 0
    with torch.no_grad():
        for i,(x,y) in enumerate(test_loader):
            try:
                label = np.zeros((x.shape[0], num_classes))
                label[np.arange(x.shape[0]), y] = 1
                label = torch.tensor(label)

                pred, mu, logvar = model(x.to(device),label.to(device))
                recon_loss, kld = loss_function(x.to(device),pred, mu, logvar)
                loss = recon_loss + kld

                total_loss += loss.cpu().data.numpy()*x.shape[0]
                reconstruction_loss += recon_loss.cpu().data.numpy()*x.shape[0]
                kld_loss += kld.cpu().data.numpy()*x.shape[0]
                if i == 0:
                    # print("gr:", x[0,0,:5,:5])
                    # print("pred:", pred[0,0,:5,:5])
                    plot(epoch, pred.cpu().data.numpy(), y.cpu().data.numpy())
            except Exception as e:
                traceback.print_exe()
                torch.cuda.empty_cache()
                continue
    reconstruction_loss /= len(test_loader.dataset)
    kld_loss /= len(test_loader.dataset)
    total_loss /= len(test_loader.dataset)
    return total_loss, kld_loss,reconstruction_loss        



def generate_image(epoch,z, y, model):
    with torch.no_grad():
        label = np.zeros((y.shape[0], num_classes))
        label[np.arange(z.shape[0]), y] = 1
        label = torch.tensor(label)

        pred = model.decoder(torch.cat((z.to(device),label.float().to(device)), dim=1))
        plot(epoch, pred.cpu().data.numpy(), y.cpu().data.numpy(),name='Eval_')
        print("data Plotted")


def save_model(model, epoch):
    if not os.path.isdir("./checkpoints"):
        os.mkdir("./checkpoints")
    file_name = './checkpoints/model_{}.pt'.format(epoch)
    torch.save(model.state_dict(), file_name)



if __name__ == "__main__":
    print("dataloader created")
    model = Model().to(device)
    print("model created")
    
    if load_epoch > 0:
        model.load_state_dict(torch.load('./checkpoints/model_{}.pt'.format(load_epoch), map_location=torch.device('cpu')))
        print("model {} loaded".format(load_epoch))

    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0.001)


    train_loss_list = []
    test_loss_list = []
    for i in tqdm(range(load_epoch+1, max_epoch)):
        model.train()
        train_total, train_kld, train_loss = train(i, model, train_loader, optimizer)
        with torch.no_grad():
            model.eval()
            test_total, test_kld, test_loss = test(i, model, test_loader)
            if generate:
                z = torch.randn(2, 32).to(device)
                y = torch.tensor([1,2]) - 1
                generate_image(i,z, y, model)
            
        print("Epoch: {}/{} Train loss: {}, Train KLD: {}, Train Reconstruction Loss:{}".format(i, max_epoch,train_total, train_kld, train_loss))
        print("Epoch: {}/{} Test loss: {}, Test KLD: {}, Test Reconstruction Loss:{}".format(i, max_epoch, test_loss, test_kld, test_loss))

        save_model(model, i)
        train_loss_list.append([train_total, train_kld, train_loss])
        test_loss_list.append([test_total, test_kld, test_loss])
        np.save("train_loss", np.array(train_loss_list))
        np.save("test_loss", np.array(test_loss_list))

In [None]:
torch.save(model.state_dict(), "D:/VKR/CVAE/src/models/cvae_conv_bce_200.pt")

In [None]:
with torch.no_grad():
  model.eval()
  #test_total, test_kld, test_loss = test(i, model, test_loader)
  if generate:
    z = torch.randn(1, 32).to(device)
    y = torch.tensor([1])
    label = np.zeros((y.shape[0], num_classes))
    label[np.arange(z.shape[0]), y] = 1
    label = torch.tensor(label)

    pred = model.decoder(torch.cat((z.to(device),label.float().to(device)), dim=1))
    pred = np.transpose(pred.detach().cpu().numpy().reshape(pred.shape[0], pred.shape[2], pred.shape[3]), (1, 2, 0))
    plt.imshow(pred, cmap="gray")