In [1]:
import os
import glob
import numpy as np
import pickle
# Torch Module
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as F
from torchvision.utils import save_image
# PIL Module
from PIL import Image, ImageFile
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from ACGAN import Generator, Discriminator

Make A Dictionary For Car

{"name.jpg" : (brand,year,color)}

In [2]:
# Only Audi Data
img_folder = os.path.join('.', 'confirmed_fronts')
brand_folder = os.listdir(img_folder)
image_path = []
search_path = os.path.join(img_folder,"Audi","*","*.jpg")
image_path += glob.glob(search_path) 
image_name = list(map(os.path.basename,image_path))
#print(image_name)

In [3]:
#Add the label element into list
brand_list = []
year_list = []
color_list = []
for i in range(len(image_name)):
    info = image_name[i].split("$$")
    brand = info[0]
    year = info[2]
    color = info[3]
    if brand not in brand_list:
        brand_list.append(brand)
    if year not in year_list:
        year_list.append(year)
    if color not in color_list:
        color_list.append(color)
brand_list.sort()
year_list.sort()
color_list.sort()
#Create a dictionary for label,idx
brand_dict = {}
year_dict = {}
color_dict = {}
for idx,key in enumerate(brand_list):
    brand_dict[key] = idx
for idx,key in enumerate(color_list):
    color_dict[key] = idx
for idx,key in enumerate(year_list):
    year_dict[key] = idx
# print(len(brand_dict),len(year_dict),len(color_dict))
# print(brand_dict)
# print(year_dict)
# print(color_dict)


In [4]:
#Create a dictionary for image,label
labels_dict = {}
for i in range(len(image_name)):
    info = image_name[i].split("$$")
    brand = info[0]
    year = info[2]
    color = info[3]
    labels_dict[image_name[i]] = (brand_dict[brand],year_dict[year],color_dict[color])
        

Define Some Useful Function

In [5]:
def denorm(img):
    """ Denormalize input image tensor. (From [0,1] -> [-1,1]) 
    
    Args:
        img: input image tensor.
    """
	
    output = img / 2 + 0.5
    return output.clamp(0, 1)


def save_model(model, optimizer, file_path):
    """ Save model checkpoints. """

    state = {'model' : model.state_dict(),
             'optim' : optimizer.state_dict(),
            }
    torch.save(state, file_path)
    return

def load_model(model, optimizer, file_path):
    """ Load previous checkpoints. """

    prev_state = torch.load(file_path)
    
    model.load_state_dict(prev_state['model'])
    optimizer.load_state_dict(prev_state['optim'])

    return model, optimizer

Define Dataset Class

In [6]:
class Dataset:
    def __init__(self, root, labels, class_num, transform):
        self.root = root
        self.img_folder = os.path.join(self.root, 'confirmed_fronts')
        path = os.path.join(self.img_folder,"Audi","*","*.jpg") #(folder/brand/year/name)
        self.img_files = glob.glob(path)
        self.labels = labels
        self.transform = transform
        self.class_num = class_num

    def color_transform(self, x):
        x = F.adjust_saturation(x, 2.5)
        x = F.adjust_gamma(x, 0.7)
        x = F.adjust_contrast(x, 1.2)
        return x
        
    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, idx):
        img = Image.open(self.img_files[idx])
        img = self.color_transform(img)
        img = self.transform(img)
        filename = os.path.basename(self.img_files[idx])
        label = self.labels[filename]
        
        one_hots = []
        for i, c in enumerate(self.class_num):
            #Traversal every class in label (brand,year,color) and change to one hot
            # e.g (1,5,10) in class_num = [1,14,18]
            # [1,| 0,0,0,0,1,0,0,0,0,0,0,0,0,0, | 0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0] 
            l = torch.zeros(c)
            l[label[i]] = 1
            one_hots.append(l)
        one_hots = torch.cat(one_hots, 0)
        return img, one_hots

def image_loader(root, labels, class_num, batch_size):
    
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p = 0.5),
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = Dataset(root, labels, class_num, transform)
    dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True, drop_last = True)
    return dataloader

Define Trainer Class

In [7]:
class Trainer:
    def __init__(self, labels):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.run_dir = os.path.join('training')
        self.data_root = "."
        
        self.lr = 0.0002
        self.beta = 0.5
        
        self.labels = labels
        self.classes = {"brand":1,"year":14,"color":18}
        self.class_num = tuple([value for key,value in self.classes.items()])
        
        self.batch_size = 64
        self.epochs = 22
        self.print_n_iter = 10
        self.sample_n_iter = 50 # sample generated image to save to file
        self.save_n_epoch = 1
        self.itr = 0
        
        self.dataloader = image_loader(self.data_root,self.labels,self.class_num, self.batch_size)
        self.steps_per_epoch = int(np.ceil(self.dataloader.dataset.__len__() * 1.0 / self.batch_size))
            
        self.input_size = (64,64)
        self.noise_dim = 100
        self.class_dim = sum(self.class_num)
        
        self.G = Generator(self.noise_dim, self.class_dim).to(self.device)
        self.D = Discriminator(self.class_dim).to(self.device)
        
        self.dis_crit = nn.BCELoss() # discriminator criterion
        self.aux_crit = nn.BCELoss() # classifier criterion
        
        
        self.G_optim = optim.Adam(self.G.parameters(), lr = self.lr, betas = [self.beta, 0.999])
        self.D_optim = optim.Adam(self.D.parameters(), lr = self.lr, betas = [self.beta, 0.999])
        
        self.D_loss_list = []
        self.G_loss_list = []
        self.cls_loss_list = []
        
    def generate_class(self, batch_size):
        labels = []
        
        for c in self.class_num:
            label = torch.LongTensor(batch_size, 1).random_() % c
            one_hot = torch.zeros(batch_size, c).scatter(1, label, 1)
            labels.append(one_hot)
            
        labels = torch.cat(labels, 1)
        return labels
        
    def generate_image(self,sample=10):
        for i in range(sample):
            z = torch.randn(1, self.noise_dim).to(self.device)
            c = self.generate_class(1).to(self.device)
            class_img = denorm(self.G(z, c))
            os.makedirs(os.path.join('output','Audi'),exist_ok=True)
            save_image(class_img, os.path.join('.','output','Audi',f'{i+1}.jpg'))
        
    def load_model(self, G_file_path, D_file_path, iter):
        self.itr = iter
        load_model(self.G, self.G_optim, G_file_path)
        load_model(self.D, self.D_optim, D_file_path)

    def load_log(self):
        with open("D_loss.pickle","rb") as f:
            self.D_loss_list = pickle.load(f)
        with open("G_loss.pickle","rb") as f:
            self.G_loss_list = pickle.load(f)
        with open("class_loss.pickle","rb") as f:
            self.cls_loss_list = pickle.load(f)
    
    def train_step(self,real_img,real_class):
        self.G.train()
            
        fake_label = torch.zeros(self.batch_size).to(self.device)
        
        real_img = real_img.to(self.device)
        real_class = real_class.to(self.device)
        
        # Train D
        real_label = torch.empty(self.batch_size).uniform_(0.9, 1).to(self.device)
        fake_class = self.generate_class(self.batch_size).to(self.device)
        z = torch.randn(self.batch_size, self.noise_dim).to(self.device)

        fake_img = self.G(z, fake_class).to(self.device)
        
        real_score, real_pred = self.D(real_img)
        fake_score, fake_pred = self.D(fake_img)
        
        real_dis_loss = self.dis_crit(real_score, real_label)
        fake_dis_loss = self.dis_crit(fake_score, fake_label)
        dis_loss = (real_dis_loss + fake_dis_loss) * 0.5
        
        real_aux_loss = self.aux_crit(real_pred, real_class)
        
        D_loss = real_aux_loss + dis_loss 
        
        self.D_optim.zero_grad()
        D_loss.backward()
        self.D_optim.step()
        
        # Train G
        real_label = torch.ones(self.batch_size).to(self.device)
        fake_class = self.generate_class(self.batch_size).to(self.device)
        z = torch.randn(self.batch_size, self.noise_dim).to(self.device)

        fake_img = self.G(z, fake_class).to(self.device)
        fake_score, fake_pred = self.D(fake_img)
        
        fake_dis_loss = self.dis_crit(fake_score, real_label)
        fake_aux_loss = self.aux_crit(fake_pred, fake_class)
        G_loss = fake_aux_loss + fake_dis_loss 
        
        cls_loss = (fake_aux_loss + real_aux_loss) * 0.5
        
        self.G_optim.zero_grad()
        G_loss.backward()
        self.G_optim.step()
        
        return D_loss,G_loss,cls_loss
                
    def start(self):
        
        for e in range(self.epochs):
            
            for i, (real_img, real_class) in enumerate(self.dataloader):
                
                self.itr += 1
                
                D_loss, G_loss, cls_loss = self.train_step(real_img,real_class)
                
                if self.itr % self.print_n_iter == 0 or self.itr == 1 :
                    log_text = f"| Iter {self.itr} | Epoch {e + 1} | {i + 1} / {self.steps_per_epoch} | D_loss: {D_loss.item()} | G_loss: {G_loss.item()} | cls_loss: {cls_loss.item()}" 
                    print(log_text)
                    self.D_loss_list.append(D_loss.item())
                    self.G_loss_list.append(G_loss.item())
                    self.cls_loss_list.append(cls_loss.item())
                    with open("log.txt","a") as log_file:
                        log_file.write(log_text+'\n')
                    
                
                if self.itr % self.sample_n_iter == 0 or self.itr == 1:
                    self.G.eval()
                    
                    z = torch.randn(self.batch_size, self.noise_dim).to(self.device)
                    c = self.generate_class(1).repeat(self.batch_size, 1).to(self.device)
                    class_img = denorm(self.G(z, c))
                    
                    os.makedirs(os.path.join(self.run_dir, 'images'),exist_ok=True)
                    save_image(class_img, os.path.join(self.run_dir, 'images', '{}.png'.format(self.itr)))

            if (e + 1) % self.save_n_epoch == 0:
                os.makedirs( os.path.join(self.run_dir, 'ckpt'),exist_ok=True)
                os.makedirs( os.path.join(self.run_dir, 'ckpt'),exist_ok=True)
                with open("D_loss.pickle","wb") as f:
                    pickle.dump(self.D_loss_list,f)
                with open("G_loss.pickle","wb") as f:
                    pickle.dump(self.G_loss_list,f)
                with open("class_loss.pickle","wb") as f:
                    pickle.dump(self.cls_loss_list,f)
                save_model(self.G, self.G_optim, os.path.join(self.run_dir, 'ckpt', 'G_{}.pth'.format(self.itr)))
                save_model(self.D, self.D_optim, os.path.join(self.run_dir, 'ckpt', 'D_{}.pth'.format(self.itr)))
                
    def p(self,id):
        '''Debuging'''
        pass
                

Start Training

In [8]:
trainer = Trainer(labels_dict)
ckpt_dir = os.path.join('training','ckpt')

In [9]:
load_iter = 3564
trainer.load_model(os.path.join(ckpt_dir,f'G_{load_iter}.pth'),os.path.join(ckpt_dir,f'D_{load_iter}.pth'),load_iter)
trainer.load_log()

In [None]:
trainer.start()

In [32]:
# Generate image
trainer.generate_image()