### Imports

In [2]:
from __future__ import print_function
from __future__ import division

import os
import shutil
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import models
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import time
from tensorboardX import SummaryWriter
from glob import glob
from util import *
import numpy as np
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

from vae import VAE, ShallowVAE

### Variables globales (argumentos) ((cambiarlo porque es un notebook))

In [3]:
torch.manual_seed(1)
if torch.cuda.is_available():
    torch.cuda.manual_seed(1)
    is_cuda = True
else:
    is_cuda = False

BATCH_SIZE = 128
EPOCH = 20
LOG_INTERVAL=1
path = 'PetImages/'
kwargs = {'num_workers': 3, 'pin_memory': True} if is_cuda else {}

In [14]:
# Editar estructura del dataset
pathCat = path + 'Cat/'
pathDog = path + 'Dog/'
#os.mkdir(pathCat + 'valid')
#os.mkdir(pathCat + 'train')
#os.mkdir(pathDog + 'valid')
#os.mkdir(pathDog + 'train')


#for i in range(1250, 12500):
#    shutil.copy(pathDog+'train/'+str(i)+'.jpg', path+'train/d' + str(i) + ".jpg")
#    shutil.copy(pathCat+'train/'+str(i)+'.jpg', path+'train/c' + str(i) + ".jpg")


for i in range(1250):
    shutil.copy(pathDog+'valid/'+str(i)+'.jpg', path+'valid/d' + str(i) + ".jpg")
    shutil.copy(pathCat+'valid/'+str(i)+'.jpg', path+'valid/c' + str(i) + ".jpg")



### Carga de datos

In [16]:
simple_transform = transforms.Compose([transforms.Resize((224,224))
                                       ,transforms.ToTensor(), transforms.Normalize([0.48829153, 0.45526633, 0.41688013],[0.25974154, 0.25308523, 0.25552085])])
train = ImageFolder(path+'train',simple_transform)
valid = ImageFolder(path+'valid',simple_transform)
train_data_gen = torch.utils.data.DataLoader(train,shuffle=True,batch_size=BATCH_SIZE,num_workers=kwargs['num_workers'])
valid_data_gen = torch.utils.data.DataLoader(valid,batch_size=BATCH_SIZE,num_workers=kwargs['num_workers'])

dataset_sizes = {'train':len(train_data_gen.dataset),'valid':len(valid_data_gen.dataset)}
dataloaders = {'train':train_data_gen,'valid':valid_data_gen}


### Modelo

In [17]:
model = ShallowVAE(latent_variable_size=500, nc=3, ngf=224, ndf=224, is_cuda=is_cuda)

#model = VAE(BasicBlock, [2, 2, 2, 2], latent_variable_size=500, nc=3, ngf=224, ndf=224, is_cuda=is_cuda)

if is_cuda:
    model.cuda()
    
reconstruction_function = nn.MSELoss()
reconstruction_function.size_average = False

optimizer = optim.Adam(model.parameters(), lr=1e-4)

### Funcion de perdida

In [18]:
def loss_function(recon_x, x, mu, logvar):

    MSE = reconstruction_function(recon_x, x)

    # https://arxiv.org/abs/1312.6114 (Appendix B)
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return MSE + KLD


### Entrenamiento

In [26]:
def train(epoch):

    model.train()
    train_loss = 0
    batch_idx = 1
    for data in dataloaders['train']:
        # get the inputs
        inputs, _ = data

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs = Variable(inputs.cuda())
        else:
            inputs = Variable(inputs)
        
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(inputs)
        #print(inputs.data.size())
        inputs.data = unnormalize(inputs.data,[0.48829153, 0.45526633, 0.41688013],[0.25974154, 0.25308523, 0.25552085])

        #print("input max/min"+str(inputs.max())+"  "+str(inputs.min()))
        #print("recon input max/min"+str(recon_batch.max())+"  "+str(recon_batch.min()))
        loss = loss_function(recon_batch, inputs, mu, logvar)
        loss.backward()
        train_loss += loss.data
        optimizer.step()

        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(inputs), (len(dataloaders['train'])*128),
                100. * batch_idx / len(dataloaders['train']),
                loss.data / len(inputs)))
        batch_idx+=1

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / (len(dataloaders['train'])*BATCH_SIZE)))
    return train_loss / (len(dataloaders['train'])*BATCH_SIZE)

### Test

In [25]:
def test(epoch):
    model.eval()
    test_loss = 0
    for data in dataloaders['valid']:
        # get the inputs
        inputs, _ = data

        # wrap them in Variable
        if torch.cuda.is_available():
            inputs = Variable(inputs.cuda())
        else:
            inputs = Variable(inputs)
        recon_batch, mu, logvar = model(inputs)
        inputs.data = unnormalize(inputs.data,[0.48829153, 0.45526633, 0.41688013],[0.25974154, 0.25308523, 0.25552085])
        test_loss += loss_function(recon_batch, inputs, mu, logvar).data
        if((epoch+1)%10==0):
            torchvision.utils.save_image(inputs.data, './imgs/Epoch_{}_data.jpg'.format(epoch), nrow=8, padding=2)
            torchvision.utils.save_image(recon_batch.data, './imgs/Epoch_{}_recon.jpg'.format(epoch), nrow=8, padding=2)

    test_loss /= (len(dataloaders['valid'])*128)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss



In [27]:
writer = SummaryWriter('runs/exp-1')
since = time.time()
for epoch in range(EPOCH):
    train_loss = train(epoch)
    test_loss = test(epoch)
    writer.add_scalar('train_loss', train_loss, epoch)
    writer.add_scalar('test_loss',test_loss, epoch)
    torch.save(model.state_dict(), './models/Epoch_{}_Train_loss_{:.4f}_Test_loss_{:.4f}.pth'.format(epoch, train_loss, test_loss))
time_elapsed = time.time() - since    
print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))



UnidentifiedImageError: Caught UnidentifiedImageError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 232, in __getitem__
    sample = self.loader(path)
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 269, in default_loader
    return pil_loader(path)
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/torchvision/datasets/folder.py", line 250, in pil_loader
    img = Image.open(f)
  File "/home/pavlo/anaconda3/envs/AEenv/lib/python3.9/site-packages/PIL/Image.py", line 3030, in open
    raise UnidentifiedImageError(
PIL.UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='PetImages/train/Dog/11702.jpg'>
