In [None]:
from torchsummary import summary
from model import AE
from torchvision.datasets import ImageFolder
import torchvision.transforms as transform
import torch
from torch.utils.data import DataLoader 
from torchvision.datasets import CelebA, CIFAR10, Caltech256
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from PIL import Image
import pandas as pd

In [None]:
data_transform = transform.Compose([ 
    transform.CenterCrop(128),
    transform.Resize((128, 128)),
    transform.ToTensor()
])

train_data = CelebA(r"../data", download=True, transform=data_transform)
# train_data = Caltech256(r"../data", download=True, transform=data_transform)


train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)

In [None]:
plt.imshow(train_data[0][0].permute(1, 2, 0))

In [None]:
device = 'cuda' if torch.cuda.is_available() else "cpu"

print("device = ", device)

model = AE(256).to(device)

critrion = torch.nn.BCELoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)

lr_sheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")

In [None]:
def train():
    model.train()
    loss_train = []
    count = 0
    for x in train_loader:
        x = x[0].to(device)
        optimizer.zero_grad()

        prediction, latent = model(x)
        loss = critrion(prediction, x)
        loss.backward()
        

        optimizer.step()
        
        if count % 50 == 0:
            loss_train.append(loss.item())
            print("current loss = ", loss.item())

            if count % 150 == 0:
                lr_sheduler.step(loss.item())

        count += 1


def val():
    pass

In [None]:
# train()
# model.eval()
# torch.save(model.state_dict(), "../weights/AE.pth")

model.load_state_dict(torch.load(".\weights\AE.pth"))

In [None]:
img = train_data[np.random.randint(0, 10000)][0]


fig, (ax1, ax2) = plt.subplots(1, 2)
ax1.imshow(img.permute(1, 2, 0))

reconstruct, _ = model(torch.unsqueeze(img, 0).to(device))
ax2.imshow(torch.squeeze(reconstruct.cpu().detach()).permute(1, 2, 0))

In [None]:
txt_file = open("C:\Visual_Studio\ML\data\celeba\list_attr_celeba.txt", "r")
csv_file = open("C:\Visual_Studio\ML\data\celeba\list_attr_celeba.csv", "w")

txt_file.readline() # reduce first line
csv_file.write("img_name," + txt_file.read().replace("  ", " ").replace(" ", ","))

txt_file.close()
csv_file.close()

In [None]:
df = pd.read_csv("C:\Visual_Studio\ML\data\celeba\list_attr_celeba.csv")
df = df.drop(columns="Unnamed: 41")
df

In [None]:

list_smile = df.sort_values(by="Smiling", key=lambda x: x != 1)[["img_name", "Smiling"]][0:30]['img_name'].values
list_sad = df.sort_values(by="Smiling")[["img_name", "Smiling"]][0:30]['img_name'].values

list_smile

In [None]:
smile_vector = torch.zeros([1, 256])

tmp_smile = torch.zeros([1, 256])
model.eval()
for i in range(30):
    img = Image.open(f"C:\Visual_Studio\ML\data\celeba\img_align_celeba\{list_smile[i]}")
    img = data_transform(img)

    tmp_smile += model.cpu().encode(torch.unsqueeze(img, 0))

smile_vector / 30


tmp_sad = torch.zeros([1, 256])
for i in range(30):
    img = Image.open(f"C:\Visual_Studio\ML\data\celeba\img_align_celeba\{list_sad[i]}")
    img = data_transform(img)

    tmp_sad += model.cpu().encode(torch.unsqueeze(img, 0))

smile_vector = (tmp_smile / 30) - (tmp_sad / 30)

smile_vector

In [None]:
test = train_data[np.random.randint(0, 10000)][0]
# test = img

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 10))
ax1.imshow(test.permute(1, 2, 0))

_, latent = model(torch.unsqueeze(test, 0))
ax2.imshow(torch.squeeze(_.cpu().detach()).permute(1, 2, 0))

pred = model.decode(latent + smile_vector * 1)
ax3.imshow(torch.squeeze(pred.cpu().detach()).permute(1, 2, 0))