In [21]:
from generator import Generator, load_rgb_img
from ImageTransformationNN import ImageTransformationNN
from VGG16 import VGG16LossNN

import torch, time, argparse, os

from torch.optim import Adam
from torch.nn import MSELoss

from torchvision import transforms

from torch.autograd import set_detect_anomaly, Variable

from torchsummary import summary
import numpy as np

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
torch.cuda.is_available()

True

In [4]:
torch.cuda.get_device_name(0)

'Tesla V100-SXM2-16GB'

In [5]:
# To be moved to utils.py

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    res = (batch - mean) / std
    #print('r', res.size())
    return res
    

In [6]:
ALI_DATA_DIR=r"/content/drive/MyDrive/Coco/Train_images_1/"
ALI_STYLE_DIR=r"/content/drive/MyDrive/Coco/styles/"
ALI_STYLE_PATH = f"{ALI_STYLE_DIR}/mosaic.jpeg"
BATCH_SIZE = 4
DIM = 256

In [7]:
# Parser: Temporarily disabled for jupyter notebook
"""parser = argparse.ArgumentParser(description='Arguments for the network.')

parser.add_argument('learning_rate', type=float,
                    help='Learning rate')

parser.add_argument('epochs', type=int,
                    help='Number of epochs')

parser.add_argument('content_weights', type=int,
                    help='Content weights')

parser.add_argument('log_interval', type=int,
                    help='Integer for Log interval')

parser.add_argument('checkpoint_model_dir', type=str,
                    help='Path for checkpoints directory')  

parser.add_argument('checkpoint_interval', type=int,
                    help='Integer for checkpoints interval') 

args = parser.parse_args()"""

"parser = argparse.ArgumentParser(description='Arguments for the network.')\n\nparser.add_argument('learning_rate', type=float,\n                    help='Learning rate')\n\nparser.add_argument('epochs', type=int,\n                    help='Number of epochs')\n\nparser.add_argument('content_weights', type=int,\n                    help='Content weights')\n\nparser.add_argument('log_interval', type=int,\n                    help='Integer for Log interval')\n\nparser.add_argument('checkpoint_model_dir', type=str,\n                    help='Path for checkpoints directory')  \n\nparser.add_argument('checkpoint_interval', type=int,\n                    help='Integer for checkpoints interval') \n\nargs = parser.parse_args()"

In [71]:
class StyleTransferFactory(object):
    
    def __init__(self, DATA_DIR, STYLE_PATH, BATCH_SIZE):
        self.loss_net = VGG16LossNN()
        self.transformer = ImageTransformationNN()
        self.gen = Generator(DATA_DIR, BATCH_SIZE)
        self.style = load_rgb_img(STYLE_PATH)    
    
    def train(self):
        
        '''
        Should train network and
        - save model file
        - save training info, validaiton loss, etc
        '''

        
        # Parameters ( to be placed in parser )
        learning_rate = 1e-3 
        epochs = 2
        content_weights = 1e5
        style_weights = 1e10
        log_interval = 1
        checkpoint_model_dir = "/content/drive/MyDrive/style_save/checkpoints"
        checkpoint_interval = 20
        #! 
        
        optimizer = Adam(self.transformer.parameters(), learning_rate)

        mse_loss = MSELoss()

        style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
        ])
        
        style = style_transform(self.style)
        style = style.repeat(BATCH_SIZE, 1, 1, 1)
        #print('style:', style.shape)
        
        loss_net_input = normalize_batch(style)
        
        #print('style:', style.shape)
        feature_style = self.loss_net.forward(normalize_batch(style))
        
        """Verification"""
        #for key, value in feature_style.items() :
        #    print(key)
        
        gram_style = [gram_matrix(y) for y in feature_style.values()]
        #print('GRAM:', gram_style[0].shape)
    
        l_total_history, l_feat_history, l_style_history = [], [], []
        l_feat_total_history, l_style_total_history, l_total_total_history = [], [], []     
                 
        for epoch in range(epochs):
            with set_detect_anomaly(True):
                print('epoch:', epoch)
                self.transformer.train()
                #print('-------------\n', self.transformer, '\n-------------')

                l_feat_total = 0.
                l_style_total = 0.
                count = 0

                # To revise
                for batch_id, (x, _) in enumerate(self.gen):

                    # Add the batch size
                    n_batch = len(x)
                    count += n_batch

                    # Adam
                    optimizer.zero_grad()

                    y = self.transformer(x)
                    
                    xc = Variable(x.clone())

                    # Normalize batch
                    y = normalize_batch(y)
                    xc = normalize_batch(xc)

                    # Features from the VGG16 network
                    features_y = self.loss_net(y)
                    features_xc = self.loss_net(xc)
                    
                    f_xc_c = Variable(features_xc['relu2_2'], requires_grad=False)

                    # Update features reconstruction loss
                    mse = mse_loss(features_y['relu2_2'], f_xc_c)

                    l_feat = content_weights * mse


                    l_style = 0.
                    
                
                    for m, k in enumerate(features_y.keys()):
                        gram_s = Variable(gram_style[m].data, requires_grad=False)
                        gram_y = gram_matrix(features_y[k])
                        l_style += style_weights * mse_loss(gram_y, gram_s[:n_batch, :, :])

                    l_total = l_feat + l_style



                    l_total.backward()
                    optimizer.step()

                    l_feat_total += l_feat.item()
                    l_style_total += l_style.item()


                    # Saving losses per item
                    l_feat_history.append(l_feat.item())
                    l_style_history.append(l_style.item())
                    l_total_history.append(l_total.item())

                    # Total
                    l_feat_total_history.append(l_feat_total / (batch_id + 1))
                    l_style_total_history.append(l_style_total / (batch_id + 1))
                    l_total_total_history.append((l_feat_total + l_style_total) / (batch_id + 1))

                    if (batch_id + 1) % log_interval == 0:
                        msg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                            time.ctime(), epoch + 1, count, len(self.gen.train),
                                          l_feat_total / (batch_id + 1),
                                          l_style_total / (batch_id + 1),
                                          (l_feat_total + l_style_total) / (batch_id + 1)
                        )
                        print(msg)


                    if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
                        self.transformer.eval().cpu()
                        filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
                        path = os.path.join(checkpoint_model_dir, filename)
                        print('\nSaving model:', path, '\n')
                        torch.save(self.transformer.state_dict(), path)
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_total_history.csv", l_total_history, delimiter=",")
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_feat_history.csv", l_feat_history, delimiter=",")
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_style_history.csv", l_style_history, delimiter=",")
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_total_total_history.csv", l_feat_total_history, delimiter=",")
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_feat_total_history.csv", l_feat_total_history, delimiter=",")
                        np.savetxt("/content/drive/MyDrive/style_save/losses/l_style_total_history.csv", l_style_total_history, delimiter=",")
                        #transformer.to(device).train() ?

        transformer.eval().cpu()
        
        if checkpoint_model_dir is not None:
            filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
            path = os.path.join(checkpoint_model_dir, filename)
            torch.save(transformer.state_dict(), path)
    


In [72]:
def main():
    '''  
    Parse args, such as data dir, style image, etc.
    Call factory object and train
    '''
    
    #parser = argparse.ArgumentParser(description='Arguments for the training.')

    model = StyleTransferFactory(ALI_DATA_DIR, ALI_STYLE_PATH, BATCH_SIZE)
    model.train()

In [73]:
main()

epoch: 0
Sat May  1 16:05:25 2021	Epoch 1:	[4/16958]	content: 880297.000000	style: 24060.605469	total: 904357.605469
Sat May  1 16:05:36 2021	Epoch 1:	[8/16958]	content: 840231.937500	style: 18853.376953	total: 859085.314453
Sat May  1 16:05:46 2021	Epoch 1:	[12/16958]	content: 784633.020833	style: 19888.275391	total: 804521.296224
Sat May  1 16:05:57 2021	Epoch 1:	[16/16958]	content: 726705.390625	style: 22702.447266	total: 749407.837891
Sat May  1 16:06:07 2021	Epoch 1:	[20/16958]	content: 727115.812500	style: 21483.453906	total: 748599.266406
Sat May  1 16:06:18 2021	Epoch 1:	[24/16958]	content: 726448.197917	style: 21370.490560	total: 747818.688477
Sat May  1 16:06:28 2021	Epoch 1:	[28/16958]	content: 713729.258929	style: 20865.926618	total: 734595.185547
Sat May  1 16:06:39 2021	Epoch 1:	[32/16958]	content: 714504.500000	style: 20374.876465	total: 734879.376465
Sat May  1 16:06:49 2021	Epoch 1:	[36/16958]	content: 703643.583333	style: 19559.160048	total: 723202.743381
Sat May  1 1

KeyboardInterrupt: ignored

# SCRATCH

In [None]:
loss_net = VGG16LossNN()
transformer = ImageTransformationNN()
gen = Generator(ALI_DATA_DIR, BATCH_SIZE)
style = load_rgb_img(ALI_STYLE_PATH)

In [None]:
summary(loss_net, (3, 391, 470))

In [None]:
summary(transformer, (3, 256, 256))

In [None]:
# Parameters ( to be placed in parser )
learning_rate = 1e-3 
epochs = 2
content_weights = 1e5
style_weights = 1e10
log_interval = 1
checkpoint_model_dir = "model/"
checkpoint_interval = 0
#! 


In [None]:
optimizer = Adam(self.transformer.parameters(), learning_rate)

mse_loss = MSELoss()

style_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.mul(255))
])



In [None]:
            
style = style_transform(self.style)
style = style.repeat(BATCH_SIZE, 1, 1, 1)

loss_net_input = normalize_batch(style)

feature_style = self.loss_net.forward(normalize_batch(style))

"""Verification"""
for key, value in feature_style.items() :
    print(key)

gram_style = [gram_matrix(y) for y in feature_style.values()]
print('GRAM:', gram_style[0].shape)
    

In [25]:
a = np.array([12, 13, 42, 53])

In [44]:
np.savetxt("/content/drive/MyDrive/style_save/checkpoints/foo.csv", a, delimiter=",")


In [43]:
np.loadtxt(open("/content/drive/MyDrive/style_save/losses/l_total_history.csv"", "rb"), delimiter=",")


array([12., 13., 42., 53.])