In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
import os
import pickle
import requests

import numpy as np

from tqdm import tqdm
from PIL import Image
from pathlib import Path
from matplotlib import pyplot as plt
from collections import namedtuple

import torch
import torch.nn as nn
from torch.utils import data
from torchvision import transforms, models
import torchvision.transforms as T
from torchvision.utils import save_image

from shutil import copyfile, copy
from IPython.display import clear_output

os.makedirs("models", exist_ok=True)

## Load data

#### Download required files from my drive

In [2]:
def download_file_from_google_drive(id, destination):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params = { 'id' : id }, stream = True)
    token = get_confirm_token(response)

    if token:
        params = { 'id' : id, 'confirm' : token }
        response = session.get(URL, params = params, stream = True)

    save_response_content(response, destination)    

def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value
    return None

def save_response_content(response, destination):
    CHUNK_SIZE = 32768

    with open(destination, "wb") as f:
        print('\nDownloading file...')
        for chunk in tqdm(response.iter_content(CHUNK_SIZE), unit="KB", leave=True):
            if chunk: # filter out keep-alive new chunks
                f.write(chunk)
        print('File was successfully downloaded!')

directory = 'models'
if not os.path.exists(directory):
    os.makedirs(directory)
file_id = '1w9r1NoYnn7tql1VYG3qDUzkbIks24RBQ'
destination = f'{directory}/decoder.pth'
download_file_from_google_drive(file_id, destination)

file_id = '1X7x314BlP4XwH76TV50hRjv0HyM9Xhhm'
destination = f'{directory}/vgg_relu4_1.pth'
download_file_from_google_drive(file_id, destination)

file_id = '1_NjdWZIv63Yb9uMn3XkZjCQP3Z5HjZtC'
destination = 'fruit_data.zip'
download_file_from_google_drive(file_id, destination)

file_id = '1jAFyliSAJuSvWIcG12GRjY3zUaTuDK5z'
destination = 'fruit_classifier.pth'
download_file_from_google_drive(file_id, destination)

file_id = '1bXB5cTFrmG96t6AAIGAQkwUGzVkjLYSy'
destination = 'id2class.pkl'
download_file_from_google_drive(file_id, destination)

file_id = '1Z_hsaSpXGJoHoA21BMgC6ci6GrunR09s'
destination = 'tests.zip'
download_file_from_google_drive(file_id, destination)

0KB [00:00, ?KB/s]


Downloading file...


428KB [00:00, 1437.62KB/s]


File was successfully downloaded!


0KB [00:00, ?KB/s]


Downloading file...


429KB [00:00, 1202.15KB/s]


File was successfully downloaded!


0KB [00:00, ?KB/s]


Downloading file...


447KB [00:00, 921.95KB/s]


File was successfully downloaded!


51KB [00:00, 3641.45KB/s]


Downloading file...
File was successfully downloaded!



1KB [00:00, 2050.00KB/s]


Downloading file...
File was successfully downloaded!



11KB [00:00, 1525.50KB/s]


Downloading file...
File was successfully downloaded!





In [0]:
!unzip fruit_data.zip 
!unzip tests.zip
clear_output()

In [4]:
os.makedirs('content', exist_ok=True)
os.makedirs('style', exist_ok=True)

for subdir in os.listdir('fruit_data'):
    for file in os.listdir('fruit_data/' + subdir):
        copyfile(f'fruit_data/{subdir}/{file}', f'content/{subdir}_{file}')

os.system('cp -r content/. style/')

0


## Implementation code

In [0]:
class Sampler(data.sampler.Sampler):
    def __init__(self, data_source):
        self.num_samples = len(data_source)

    def __iter__(self):
        return iter(self.__sampler(self.num_samples))

    def __len__(self):
        return 2 ** 31
    
    def __sampler(self, n):
        i = n - 1
        order = np.random.permutation(n)
        while True:
            yield order[i]
            i += 1
            if i >= n:
                np.random.seed()
                order = np.random.permutation(n)
                i = 0

In [0]:
class ImageDataset(data.Dataset):
    def __init__(self, root, transform):
        super(ImageDataset, self).__init__()
        self.root = root
        self.paths = list(Path(self.root).glob('*'))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(str(path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'ImageDataset'

#### Model structure

In [0]:
# decoder structure
decoder = nn.Sequential(
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(512, 256, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 256, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(256, 128, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 128, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(128, 64, (3, 3)),
    nn.ReLU(),
    nn.Upsample(scale_factor=2, mode='nearest'),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 64, (3, 3)),
    nn.ReLU(),
    nn.ReflectionPad2d((1, 1, 1, 1)),
    nn.Conv2d(64, 3, (3, 3)),
)

vgg = torch.load('models/vgg_relu4_1.pth')

In [0]:
def get_mean_std(features):
    """Compute mean and std for a given features data"""
    dims = [*features.size()[:2], 1, 1]
    features_var = features.view(*dims[:2], -1).var(dim=2) + 1e-5
    features_std = torch.sqrt(features_var).view(dims)
    features_mean = torch.mean(features.view(*dims[:2], -1), 2).view(dims)
    return features_mean, features_std


def adaIN(content_f, style_f):
    """Implementation of adapive instance normalization from the paper"""
    dims = content_f.size()
    s_mean, s_std = get_mean_std(style_f)
    c_mean, c_std = get_mean_std(content_f)

    normalized_features = (content_f - c_mean.expand(dims)) / c_std.expand(dims)
    return normalized_features * s_std.expand(dims) + s_mean.expand(dims)

In [0]:
class StyleModel(nn.Module):
    def __init__(self, vgg, decoder):
        super(StyleModel, self).__init__()
        self.decoder = decoder
        vgg_layers = list(vgg.children())
        self.relu1_1 = nn.Sequential(*vgg_layers[:4])
        self.relu2_1 = nn.Sequential(*vgg_layers[4:11])
        self.relu3_1 = nn.Sequential(*vgg_layers[11:18])
        self.relu4_1 = nn.Sequential(*vgg_layers[18:31])

        self.mse_loss = nn.MSELoss()

        # do not change weights for pretrained vgg
        for name in ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1']:
            for param in getattr(self, name).parameters():
                param.requires_grad = False

    def encode(self, source, intermediate=False):
        """Extract relu1_1, relu2_1, relu3_1, relu4_1 from source image
        if intermediate is True, otherwise return result of relu4_1"""
        r1 = self.relu1_1(source)
        r2 = self.relu2_1(r1)
        r3 = self.relu3_1(r2)
        r4 = self.relu4_1(r3)
        return [r1, r2, r3, r4] if intermediate else r4
    
    def compute_style_loss(self, source, target):
        """Compute style loss between source and target"""
        source_mean, source_std = get_mean_std(source)
        target_mean, target_std = get_mean_std(target)
        mean_loss = self.mse_loss(source_mean, target_mean)
        std_loss = self.mse_loss(source_std, target_std)
        return mean_loss + std_loss

    def forward(self, content, style):
        # get features for style and content
        style_features = self.encode(style, intermediate=True)
        content_features = self.encode(content)

        # apply adaptive instance normalization
        t = adaIN(content_features, style_features[-1])

        g_t = self.decoder(t)
        g_t_features = self.encode(g_t, intermediate=True)

        # compute content loss with last feature representation
        loss_content = self.mse_loss(g_t_features[-1], t)
        
        # compute style loss across all feature representations
        loss_style = 0
        for i in range(4):
            loss_style += self.compute_style_loss(
                g_t_features[i], style_features[i]
            )
        
        return loss_content, loss_style

#### Training

In [0]:
def adjust_learning_rate(lr, lr_decay, optimizer, iteration_count):
    """Adjust learning rate for the current iteration"""
    lr = lr / (1.0 + lr_decay * iteration_count)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def train_model(vgg, decoder, content_path, style_path, lr,
                max_iter=10000, style_weight=10, save_iter=1000):
    device = torch.device('cuda')

    # Make directory to save models
    save_dir = "/content/drive/My Drive/CV_project/model_versions"
    os.makedirs(save_dir, exist_ok=True)

    model = StyleModel(vgg, decoder)
    model.train()
    model.to(device)

    transform = transforms.Compose([
         transforms.Resize(size=(512, 512)),
         transforms.RandomCrop(256),
         transforms.ToTensor()
    ])

    content_dataset = ImageDataset(content_path, transform)
    style_dataset = ImageDataset(style_path, transform)

    content_iter = iter(data.DataLoader(content_dataset, batch_size=8,
                                        sampler=Sampler(content_dataset),
                                        num_workers=16))

    style_iter = iter(data.DataLoader(style_dataset, batch_size=8,
                                      sampler=Sampler(style_dataset),
                                      num_workers=16))

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for i in range(1, max_iter + 1):
        # adjust learning rate with each iteration
        adjust_learning_rate(lr, 5e-5, optimizer, iteration_count=i)

        content_images = next(content_iter).to(device)
        style_images = next(style_iter).to(device)

        # compute losses
        loss_c, loss_s = model(content_images, style_images)
        loss = loss_c + style_weight * loss_s

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'\rIteration [{i}/{max_iter}]', end='')

        # Save model with specific interval
        if (i + 1) % save_iter == 0 or (i + 1) == max_iter:
            state_dict = decoder.state_dict()
            for key in state_dict.keys():
                state_dict[key] = state_dict[key].to(torch.device('cpu'))
            torch.save(state_dict, f'{save_dir}/decoder_iter_{i + 1}.pth')

In [1]:
train_model(vgg, decoder, './content', './style', lr=1e-4,
            max_iter=10000, style_weight=12, save_iter=2500)

Iteration [10000/10000]


#### Testing

In [0]:
def test_model(vgg, decoder, decoder_path, test_paths, show_result=True):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    output_dir = Path(test_paths.output_path)
    output_dir.mkdir(exist_ok=True, parents=True)

    content_dir = Path(test_paths.content_path)
    content_paths = [f for f in content_dir.glob('*')]

    style_path = Path(test_paths.style)

    decoder.eval()
    vgg.eval()

    decoder.load_state_dict(torch.load(decoder_path))

    vgg.to(device)
    decoder.to(device)

    transform = transforms.Compose([transforms.Resize(100),
                                    transforms.ToTensor()])

    for content_path in content_paths:
        content = Image.open(str(content_path))
        if show_result:
            print('Content:')
            display(content)
        content = transform(content)
        
        style = Image.open(str(style_path))
        if show_result:
            print('Style:')
            display(style)
        style = transform(style)

        style = style.to(device).unsqueeze(0)
        content = content.to(device).unsqueeze(0)

        # transfer style to content image
        with torch.no_grad():
            content_feat = vgg(content)
            fs = adaIN(content_feat, vgg(style))
            output = decoder(fs)

        output = output.cpu()
        output_name = output_dir / '{:s}_style_{:s}.jpg'.format(
            content_path.stem, style_path.stem)
        
        save_image(output, str(output_name))
        if show_result:
            print('Result:')
            display(Image.open(str(output_name))) 

#### Classification


In [0]:
class FruitClassifier(nn.Module):

    def __init__(self, classes, id2class):
        super(FruitClassifier, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 64, kernel_size=7, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(3),
            nn.Conv2d(64, 64, kernel_size=7),
            nn.ReLU(),
            nn.MaxPool2d(5),
            nn.Flatten(),
            nn.Linear(64, 100),
            nn.ReLU(),
            nn.Linear(100, classes)
        )
        self.id2class = id2class

    def forward(self, x):
        return self.model(x)

    def load_weights(self, path):
        self.model.load_state_dict(torch.load(path))

    def predict(self, x):
        return self.id2class[self.model(x).max(1)[1].cpu().item()]


def load_image(filename, device='cpu'):
    transform = T.Compose([T.ToTensor(),
                           T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
    img = Image.open(filename)
    img_t = transform(img)
    return img_t.unsqueeze(0).to(device)


with open('id2class.pkl', 'rb') as fd:
    id2class = pickle.load(fd)

clf = FruitClassifier(5, id2class)
clf.load_weights('fruit_classifier.pth')

In [0]:
TestPaths = namedtuple('TestPaths', 
                       field_names=['content_path', 'style', 'output_path'])

def generate_test_results(decoder_path):
    os.system('rm -rd test_results/')
    fruit_dirs = os.listdir('tests/')
    for style_fruit_dir in fruit_dirs:
        style = sorted(os.listdir('tests/' + style_fruit_dir))[0]
        style_im = 'tests/' + style_fruit_dir + '/' + style
        for content_fruit_dir in fruit_dirs:
            if (style_fruit_dir != content_fruit_dir):
                content_path = 'tests/' + content_fruit_dir
                output_path = 'test_results/' + content_fruit_dir + \
                              '_2_' + style_fruit_dir
                
                paths = TestPaths(content_path, style_im, output_path)
                
                test_model(vgg, decoder, decoder_path, paths, show_result=False)

In [0]:
def eval_test_results(clf, path='test_results'):
    total_checks = 0
    style_transfer_score = 0
    tricked_cnt = {'Apple': 0, 'Banana': 0, 'Cocos': 0, 'Lemon': 0, 'Orange': 0}
    for result_dir in os.listdir(path):
        _, style = result_dir.split('_2_')
        for result in os.listdir(path + '/' + result_dir):
            clf_class = clf.predict(load_image(path + '/' + result_dir + '/' + result))
            if style == clf_class:
                style_transfer_score += 1
                tricked_cnt[style] += 1
            total_checks += 1
    
    accuracy = float(style_transfer_score) / float(total_checks)
    tricked_acc = dict((k, v / 60) for k, v in tricked_cnt.items())
    return accuracy, tricked_acc, tricked_cnt

In [16]:
decoder_path = "models/decoder.pth"

print('Generating test results...', end='')
generate_test_results(decoder_path)
print(' -> Done!')
print('Evaluating test results...', end='')

res = eval_test_results(clf)
print(' -> Done!')
print('Total accuracy:', res[0])
print('Per fruit accuracy:', res[1])
print('Classified counts:', res[2])

Generating test results... -> Done!
Evaluating test results... -> Done!
Total accuracy: 0.46
Per fruit accuracy: {'Apple': 0.4666666666666667, 'Banana': 0.2, 'Cocos': 0.016666666666666666, 'Lemon': 0.7333333333333333, 'Orange': 0.8833333333333333}
Classified counts: {'Apple': 28, 'Banana': 12, 'Cocos': 1, 'Lemon': 44, 'Orange': 53}
