In [1]:
from generator import Generator, load_rgb_img
from ImageTransformationNN import ImageTransformationNN
from VGG16 import VGG16LossNN

import time, argparse

from torch.optim import Adam
from torch.nn import MSELoss

from torchvision import transforms



In [2]:
# To be moved to utils.py

def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

def normalize_batch(batch):
    # normalize using imagenet mean and std
    mean = batch.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
    std = batch.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
    batch = batch.div_(255.0)
    res = (batch - mean) / std
    print('r', res.size())
    return res
    

In [3]:
ALI_DATA_DIR=r"/Users/aliissaoui/Desktop/studies/IIT/spring/CS577/Project/PerceptualLossNetwork/dataset/train_images/"
ALI_STYLE_DIR=r"/Users/aliissaoui/Desktop/studies/IIT/spring/CS577/Project/PerceptualLossNetwork/dataset/train_styles"
ALI_STYLE_PATH = f"{ALI_STYLE_DIR}/mosaic.jpeg"
BATCH_SIZE = 4
DIM = 256

In [4]:
# Parser: Temporarily disabled for jupyter notebook
"""parser = argparse.ArgumentParser(description='Arguments for the network.')

parser.add_argument('learning_rate', type=float,
                    help='Learning rate')

parser.add_argument('epochs', type=int,
                    help='Number of epochs')

parser.add_argument('content_weights', type=int,
                    help='Content weights')

parser.add_argument('log_interval', type=int,
                    help='Integer for Log interval')

parser.add_argument('checkpoint_model_dir', type=str,
                    help='Path for checkpoints directory')  

parser.add_argument('checkpoint_interval', type=int,
                    help='Integer for checkpoints interval') 

args = parser.parse_args()"""

"parser = argparse.ArgumentParser(description='Arguments for the network.')\n\nparser.add_argument('learning_rate', type=float,\n                    help='Learning rate')\n\nparser.add_argument('epochs', type=int,\n                    help='Number of epochs')\n\nparser.add_argument('content_weights', type=int,\n                    help='Content weights')\n\nparser.add_argument('log_interval', type=int,\n                    help='Integer for Log interval')\n\nparser.add_argument('checkpoint_model_dir', type=str,\n                    help='Path for checkpoints directory')  \n\nparser.add_argument('checkpoint_interval', type=int,\n                    help='Integer for checkpoints interval') \n\nargs = parser.parse_args()"

In [5]:
class StyleTransferFactory(object):
    
    def __init__(self, DATA_DIR, STYLE_PATH, BATCH_SIZE):
        self.loss_net = VGG16LossNN()
        self.transformer = ImageTransformationNN()
        self.gen = Generator(DATA_DIR, BATCH_SIZE)
        self.style = load_rgb_img(STYLE_PATH)
        #print("shape:", self.style.shape)
    
    
    def train(self):
        
        '''
        Should train network and
        - save model file
        - save training info, validaiton loss, etc
        '''

        
        # Parameters ( to be placed in parser )
        learning_rate = 1e-3 
        epochs = 2
        content_weights = 1e5
        style_weights = 1e10
        log_interval = 1
        checkpoint_model_dir = None
        checkpoint_interval = 0
        #! 
        
        optimizer = Adam(self.transformer.parameters(), learning_rate)

        mse_loss = MSELoss()

        style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
        ])
        
        style = style_transform(self.style)
        style = style.repeat(BATCH_SIZE, 1, 1, 1)
        
        loss_net_input = normalize_batch(style)
        
        feature_style = self.loss_net.forward(normalize_batch(style))
        
        """Verification"""
        for key, value in feature_style.items() :
            print(key)
        
        gram_style = [gram_matrix(y) for y in feature_style.values()]
        print('GRAM:', gram_style[0].shape)
    
              
        for epoch in range(epochs):
            print('epoch:', epoch)
            self.transformer.train()
            print('-------------\n', self.transformer, '\n-------------')

            l_feat_total = 0.
            l_style_total = 0.
            count = 0
            
            # To revise
            for batch_id, (x, _) in enumerate(self.gen):
                
                # Add the batch size
                n_batch = len(x)
                count += n_batch
                
                # Adam
                optimizer.zero_grad()
                
                # Problem here !!! 
                y = self.transformer(x)
                input()
                
                # Normalize batch
                y = normalize_batch(y)
                x = normalize_batch(x)
                
                # Features from the VGG16 network
                features_y = self.loss_net(y)
                features_x = self.loss_net(x)
                
                # Update features reconstruction loss
                l_feat = content_weights * mse_loss(features_y['relu2_2'], features_x['relu2_2'])
                
                
                l_style = 0.
                
                for f_y, gram_s in zip(features_y, gram_style):
                    gram_y = gram_matrix(f_y)
                    l_style += mse_loss(gram_y, gram_s[:n_batch, :, :])
                    
                l_style *= style_weights
                
                l_total = l_feat + l_style
                
                l_total.backward()
                optimizer.step()
                
                l_feat_total += l_feat.item()
                l_style_total += l_style.item()
                
                if (batch_id + 1) % log_interval == 0:
                    msg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                        time.ctime(), epoch + 1, count, len(train_dataset),
                                      l_feat_total / (batch_id + 1),
                                      l_style_total / (batch_id + 1),
                                      (l_feat_total + l_style_total) / (batch_id + 1)
                    )
                    print(msg)
                    
                if checkpoint_model_dir is not None and (batch_id + 1) % checkpoint_interval == 0:
                    transformer.eval().cpu()
                    filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
                    path = os.path.join(checkpoint_model_dir, filename)
                    torch.save(transformer.state_dict(), path)
                    #transformer.to(device).train() ?
                    
        transformer.eval().cpu()
        
        if checkpoint_model_dir is not None:
            filename = "check_epoch_" + str(epoch) + "batch_id" + str(batch_id + 1) + '.pth'
            path = os.path.join(checkpoint_model_dir, filename)
            torch.save(transformer.state_dict(), path)
    


In [6]:
def main():
    '''  
    Parse args, such as data dir, style image, etc.
    Call factory object and train
    '''
    
    #parser = argparse.ArgumentParser(description='Arguments for the training.')

    model = StyleTransferFactory(ALI_DATA_DIR, ALI_STYLE_PATH, BATCH_SIZE)
    model.train()

In [7]:
main()

r torch.Size([4, 3, 391, 470])
r torch.Size([4, 3, 391, 470])
relu1_2
relu2_2
relu3_3
relu4_3
GRAM: torch.Size([4, 64, 64])
epoch: 0
-------------
 ImageTransformationNN(
  (down_sample): DownSampleConv(
    (conv2d1): Conv2d(3, 32, kernel_size=(9, 9), stride=(1, 1))
    (norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (relu1): ReLU()
    (conv2d2): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2))
    (norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (relu2): ReLU()
    (conv2d3): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2))
    (norm3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (relu3): ReLU()
  )
  (res): ResidualNet(
    (block1): RBlock(
      (conv2d1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      (norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (relu): ReLU()
      (conv

RuntimeError: Given groups=1, weight of size [128, 3, 3, 3], expected input[4, 128, 59, 59] to have 3 channels, but got 128 channels instead