In [None]:
import os
import PIL.Image as Image
import matplotlib.pyplot as plt
import torch.cuda
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import torchvision
from LookGenerator.networks.clothes_feature_extractor import ClothingAutoEncoder
from LookGenerator.datasets.basic_dataset import BasicDataset
from LookGenerator.networks.losses import VAELoss
from LookGenerator.networks.utils import load_model
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm
from torchsummary import summary

In [None]:
transform_input = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [None]:
batch_size_train = 32
batch_size_val = 16
pin_memory = True
num_workers = 8

In [None]:
train_dataset = BasicDataset(
    root_dir="",
    dir_name="",
    transform_input=transform_input
)
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory,
    num_workers=num_workers
)

In [None]:
val_dataset = BasicDataset(
    root_dir="",
    dir_name="",
    transform_input=transform_input
)
val_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory,
    num_workers=num_workers
)

In [None]:
def fit(model, criterion, optimizer, device, train_loader, val_loader, epochs):

    train_loss=[]
    val_loss=[]

    criterion = criterion.to(device)
    model = model.to(device)

    for epoch in tqdm(range(epochs)):
        torch.cuda.empty_cache()
        model.train()

        train_epoch_loss = []
        for X_batch, _ in train_loader:
            X_batch = X_batch.to(device)
            optimizer.zero_grad()
            pred = model(X_batch)
            reconstructed, mu, log_var = pred
            loss = criterion(X_batch, mu, log_var, reconstructed)
            loss.backward()
            optimizer.step()

            train_epoch_loss.append(loss.item())

        train_loss.append(np.mean(train_epoch_loss))

        model.eval()
        val_loss_epoch = []
        with torch.no_grad():
          for X_val, _ in val_loader:
              X_val = X_val.to(device)
              pred = model(X_val)
              reconstructed, mu, log_var  = pred
              loss, _, _ = criterion(X_val, mu, log_var, reconstructed)
              val_loss_epoch.append(loss.item())
        val_loss.append(np.mean(val_loss_epoch))

        print("Epoch [{}/{}], train_loss: {:.3f}, val_loss: {:.3f}".format(
            epoch+1, epochs,
            train_loss[-1], val_loss[-1])
        )

    return train_loss, val_loss

In [None]:
criterion = VAELoss()
model = ClothingAutoEncoder()
optimizer = torch.optim.Adam(model.parameters(), lr = 10e-3)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
train_loss, val_loss = fit(model=model,
                           criterion=criterion,
                           optimizer=optimizer,
                           train_loader=train_dataloader,
                           val_loader=val_dataloader,
                           epochs=30)

In [None]:
plt.plot(train_loss, label='train')
plt.plot(val_loss, label='val')
plt.legend()
plt.show()

In [None]:
model.eval()
for X_val, _ in val_dataloader:
  reconstructed, mu, log_var = model(X_val.to(device))
  img = transforms.ToPILImage()(reconstructed[0]/2+0.5)
  cl = transforms.ToPILImage()(X_val[0])
  cl.show()
  img.show()
  break