In [1]:
from generator import Generator, load_rgb_img
from ImageTransformationNN import ImageTransformationNN
from VGG16 import VGG16LossNN

import time, argparse

from torch.optim import Adam
from torch.nn import MSELoss

from torchvision import transforms

from torch.autograd import set_detect_anomaly, Variable

from torchsummary import summary

In [2]:
# To be moved to utils.py

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    res = (batch - mean) / std
    #print('r', res.size())
    return res
    

In [3]:
ALI_DATA_DIR=r"/Users/aliissaoui/Desktop/studies/IIT/spring/CS577/Project/PerceptualLossNetwork/dataset/train_images/"
ALI_STYLE_DIR=r"/Users/aliissaoui/Desktop/studies/IIT/spring/CS577/Project/PerceptualLossNetwork/dataset/train_styles"
ALI_STYLE_PATH = f"{ALI_STYLE_DIR}/mosaic.jpeg"
BATCH_SIZE = 4
DIM = 256

In [4]:
# Parser: Temporarily disabled for jupyter notebook
"""parser = argparse.ArgumentParser(description='Arguments for the network.')

parser.add_argument('learning_rate', type=float,
                    help='Learning rate')

parser.add_argument('epochs', type=int,
                    help='Number of epochs')

parser.add_argument('content_weights', type=int,
                    help='Content weights')

parser.add_argument('log_interval', type=int,
                    help='Integer for Log interval')

parser.add_argument('checkpoint_model_dir', type=str,
                    help='Path for checkpoints directory')  

parser.add_argument('checkpoint_interval', type=int,
                    help='Integer for checkpoints interval') 

args = parser.parse_args()"""

"parser = argparse.ArgumentParser(description='Arguments for the network.')\n\nparser.add_argument('learning_rate', type=float,\n                    help='Learning rate')\n\nparser.add_argument('epochs', type=int,\n                    help='Number of epochs')\n\nparser.add_argument('content_weights', type=int,\n                    help='Content weights')\n\nparser.add_argument('log_interval', type=int,\n                    help='Integer for Log interval')\n\nparser.add_argument('checkpoint_model_dir', type=str,\n                    help='Path for checkpoints directory')  \n\nparser.add_argument('checkpoint_interval', type=int,\n                    help='Integer for checkpoints interval') \n\nargs = parser.parse_args()"

In [5]:
class StyleTransferFactory(object):
    
    def __init__(self, DATA_DIR, STYLE_PATH, BATCH_SIZE):
        self.loss_net = VGG16LossNN()
        self.transformer = ImageTransformationNN()
        self.gen = Generator(DATA_DIR, BATCH_SIZE)
        self.style = load_rgb_img(STYLE_PATH)
        #print("shape:", self.style.shape)
    
    
    def train(self):
        
        '''
        Should train network and
        - save model file
        - save training info, validaiton loss, etc
        '''

        
        # Parameters ( to be placed in parser )
        learning_rate = 1e-3 
        epochs = 2
        content_weights = 1e5
        style_weights = 1e10
        log_interval = 1
        checkpoint_model_dir = None
        checkpoint_interval = 0
        #! 
        
        optimizer = Adam(self.transformer.parameters(), learning_rate)

        mse_loss = MSELoss()

        style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
        ])
        
        style = style_transform(self.style)
        style = style.repeat(BATCH_SIZE, 1, 1, 1)
        #print('style:', style.shape)
        
        loss_net_input = normalize_batch(style)
        
        #print('style:', style.shape)
        feature_style = self.loss_net.forward(normalize_batch(style))
        
        """Verification"""
        #for key, value in feature_style.items() :
        #    print(key)
        
        gram_style = [gram_matrix(y) for y in feature_style.values()]
        #print('GRAM:', gram_style[0].shape)
    
              
        for epoch in range(epochs):
            with set_detect_anomaly(True):
                print('epoch:', epoch)
                self.transformer.train()
                #print('-------------\n', self.transformer, '\n-------------')

                l_feat_total = 0.
                l_style_total = 0.
                count = 0

                # To revise
                for batch_id, (x, _) in enumerate(self.gen):

                    # Add the batch size
                    n_batch = len(x)
                    count += n_batch

                    # Adam
                    optimizer.zero_grad()

                    #print("XXXX", x.shape)
                    # Problem here !!! 
                    y = self.transformer(x)
                    
                    xc = Variable(x.clone())

                    # Normalize batch
                    y = normalize_batch(y)
                    xc = normalize_batch(xc)

                    # Features from the VGG16 network
                    #print('x size:', xc.shape)
                    #print('y size:', y.shape)

                    features_y = self.loss_net(y)
                    features_xc = self.loss_net(xc)
                    
                    f_xc_c = Variable(features_xc['relu2_2'], requires_grad=False)

                    # Update features reconstruction loss
                    mse = mse_loss(features_y['relu2_2'], f_xc_c)

                    l_feat = content_weights * mse


                    l_style = 0.
                    
                
                    for m, k in enumerate(features_y.keys()):
                        gram_s = Variable(gram_style[m].data, requires_grad=False)
                        gram_y = gram_matrix(features_y[k])
                        l_style += style_weights * mse_loss(gram_y, gram_s[:n_batch, :, :])

                    l_total = l_feat + l_style

                    l_total.backward()
                    optimizer.step()

                    l_feat_total += l_feat.item()
                    l_style_total += l_style.item()

                    if (batch_id + 1) % log_interval == 0:
                        msg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                            time.ctime(), epoch + 1, count, len(self.gen.train),
                                          l_feat_total / (batch_id + 1),
                                          l_style_total / (batch_id + 1),
                                          (l_feat_total + l_style_total) / (batch_id + 1)
                        )
                        print(msg)

                    if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
                        transformer.eval().cpu()
                        filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
                        path = os.path.join(checkpoint_model_dir, filename)
                        torch.save(transformer.state_dict(), path)
                        #transformer.to(device).train() ?

        transformer.eval().cpu()
        
        if checkpoint_model_dir is not None:
            filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
            path = os.path.join(checkpoint_model_dir, filename)
            torch.save(transformer.state_dict(), path)
    


In [6]:
def main():
    '''  
    Parse args, such as data dir, style image, etc.
    Call factory object and train
    '''
    
    #parser = argparse.ArgumentParser(description='Arguments for the training.')

    model = StyleTransferFactory(ALI_DATA_DIR, ALI_STYLE_PATH, BATCH_SIZE)
    model.train()

In [7]:
main()

epoch: 0
block 4 done
Fri Apr 30 11:21:06 2021	Epoch 1:	[4/100]	content: 880166.812500	style: 19317.386719	total: 899484.199219
block 4 done
Fri Apr 30 11:21:20 2021	Epoch 1:	[8/100]	content: 839642.187500	style: 17016.434570	total: 856658.622070
block 4 done
Fri Apr 30 11:21:33 2021	Epoch 1:	[12/100]	content: 756327.916667	style: 20179.652995	total: 776507.569661
block 4 done
Fri Apr 30 11:21:46 2021	Epoch 1:	[16/100]	content: 734669.312500	style: 18750.042969	total: 753419.355469
block 4 done
Fri Apr 30 11:21:59 2021	Epoch 1:	[20/100]	content: 731987.775000	style: 20480.549609	total: 752468.324609
block 4 done
Fri Apr 30 11:22:14 2021	Epoch 1:	[24/100]	content: 713772.677083	style: 19667.689616	total: 733440.366699
block 4 done
Fri Apr 30 11:22:27 2021	Epoch 1:	[28/100]	content: 726499.464286	style: 18984.761998	total: 745484.226283
block 4 done
Fri Apr 30 11:22:42 2021	Epoch 1:	[32/100]	content: 714966.226562	style: 18702.573242	total: 733668.799805
block 4 done
Fri Apr 30 11:22:56 

KeyboardInterrupt: 

# SCRATCH

In [None]:
loss_net = VGG16LossNN()
transformer = ImageTransformationNN()
gen = Generator(ALI_DATA_DIR, BATCH_SIZE)
style = load_rgb_img(ALI_STYLE_PATH)

In [None]:
summary(loss_net, (3, 391, 470))

In [None]:
summary(transformer, (3, 256, 256))

In [None]:
# Parameters ( to be placed in parser )
learning_rate = 1e-3 
epochs = 2
content_weights = 1e5
style_weights = 1e10
log_interval = 1
checkpoint_model_dir = None
checkpoint_interval = 0
#! 


In [None]:
optimizer = Adam(self.transformer.parameters(), learning_rate)

mse_loss = MSELoss()

style_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])



In [None]:
            
style = style_transform(self.style)
style = style.repeat(BATCH_SIZE, 1, 1, 1)

loss_net_input = normalize_batch(style)

feature_style = self.loss_net.forward(normalize_batch(style))

"""Verification"""
for key, value in feature_style.items() :
    print(key)

gram_style = [gram_matrix(y) for y in feature_style.values()]
print('GRAM:', gram_style[0].shape)
    