In [1]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

from resnet_features import resnet18_features, resnet34_features, resnet50_features, resnet101_features, resnet152_features
from densenet_features import densenet121_features, densenet161_features, densenet169_features, densenet201_features
from vgg_features import vgg11_features, vgg11_bn_features, vgg13_features, vgg13_bn_features, vgg16_features, vgg16_bn_features,\
                         vgg19_features, vgg19_bn_features

from receptive_field import compute_proto_layer_rf_info_v2

from settings import img_size

from PIL import Image
import numpy as np
import numpy.random as npr
import pandas as pd
import os
import matplotlib.pyplot as plt
import torch.utils.data
# import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim

import pickle as pkl
import skimage as sk
import skimage.io as skio
from preference_model import construct_PrefNet, paired_cross_entropy_loss, PrefNet

# book keeping namings and code
from settings import base_architecture, img_size, prototype_shape, num_classes, \
                     prototype_activation_function, add_on_layers_type, experiment_run

from preprocess import mean, std, preprocess_input_function
from tqdm import tqdm

In [2]:
base_architecture_to_features = {'resnet18': resnet18_features,
                                 'resnet34': resnet34_features,
                                 'resnet50': resnet50_features,
                                 'resnet101': resnet101_features,
                                 'resnet152': resnet152_features,
                                 'densenet121': densenet121_features,
                                 'densenet161': densenet161_features,
                                 'densenet169': densenet169_features,
                                 'densenet201': densenet201_features,
                                 'vgg11': vgg11_features,
                                 'vgg11_bn': vgg11_bn_features,
                                 'vgg13': vgg13_features,
                                 'vgg13_bn': vgg13_bn_features,
                                 'vgg16': vgg16_features,
                                 'vgg16_bn': vgg16_bn_features,
                                 'vgg19': vgg19_features,
                                 'vgg19_bn': vgg19_bn_features}


class PrefNet(nn.Module):

    def __init__(self, img_features, pattern_features, img_size, prototype_shape,
                 proto_layer_rf_info, num_classes, init_weights=False,
                 prototype_activation_function='log',
                 add_on_layers_type='bottleneck', 
                k = 1):

        super(PrefNet, self).__init__()
        self.img_size = img_size
        self.prototype_shape = prototype_shape
        self.num_prototypes = prototype_shape[0]
        self.num_classes = num_classes
        self.epsilon = 1e-4
        self.k = k
        
        # this has to be named features to allow the precise loading
        self.img_features = img_features
        self.pattern_features = pattern_features
        
        '''
        features_name = str(self.features).upper()
        if features_name.startswith('VGG') or features_name.startswith('RES'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.Conv2d)][-1].out_channels
        elif features_name.startswith('DENSE'):
            first_add_on_layer_in_channels = \
                [i for i in features.modules() if isinstance(i, nn.BatchNorm2d)][-1].num_features
        else:
            raise Exception('other base base_architecture NOT implemented')
        '''
        
        '''    
        if add_on_layers_type == 'bottleneck':
            add_on_layers = []
            current_in_channels = first_add_on_layer_in_channels
            while (current_in_channels > self.prototype_shape[1]) or (len(add_on_layers) == 0):
                current_out_channels = max(self.prototype_shape[1], (current_in_channels // 2))
                add_on_layers.append(nn.Conv2d(in_channels=current_in_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                add_on_layers.append(nn.ReLU())
                add_on_layers.append(nn.Conv2d(in_channels=current_out_channels,
                                               out_channels=current_out_channels,
                                               kernel_size=1))
                if current_out_channels > self.prototype_shape[1]:
                    add_on_layers.append(nn.ReLU())
                else:
                    assert(current_out_channels == self.prototype_shape[1])
                    add_on_layers.append(nn.Sigmoid())
                current_in_channels = current_in_channels // 2
            self.add_on_layers = nn.Sequential(*add_on_layers)
        else:
            self.add_on_layers = nn.Sequential(
                nn.Conv2d(in_channels=first_add_on_layer_in_channels, out_channels=self.prototype_shape[1], kernel_size=1),
                nn.ReLU(),
                nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
                nn.Sigmoid()
                )
        
        '''
                
#         self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape),
#                                               requires_grad=True)

#         # do not make this just a tensor,
#         # since it will not be moved automatically to gpu
#         self.ones = nn.Parameter(torch.ones(self.prototype_shape),
#                                  requires_grad=False)

#         self.last_layer = nn.Linear(self.num_prototypes, self.num_classes,
#                                     bias=False) # do not use bias
    

        self.img_conv = nn.Sequential(
            nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3),
            nn.Sigmoid()
            )
        
        self.pattern_conv = nn.Sequential(
            nn.Conv2d(in_channels=2048, out_channels=512, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3),
            nn.Sigmoid()
            )
        
        self.final_fc = nn.Sequential(
            nn.Linear(256, 32),
            nn.ReLU(),
            nn.Sigmoid(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        
            )
        
        #self.fc1 = nn.Linear(6400, 512)
        #self.fc2 = nn.Linear(512, 32)
        #self.fc3 = nn.Linear(32, 1)
        #self.fc1 = nn.Linear(64, 16)
        

        if init_weights:
            self._initialize_weights()
            
            
    def conv_features(self, x):
        '''
        the feature input to prototype layer
        '''
        # Insert k and then img size
        x = self.features(x)
        #print("base features: ", x.shape)
        #x = self.add_on_layers(x)
        return x
    
    def forward(self, x, p):
        # (N, 512, 7, 7)
        #x = self.conv_features(x)
        x = self.img_features(x)
        #x = self.add_on_layers(x)
        x = self.img_conv(x)
        #print("conv out:", x.shape)
        p = self.pattern_features(p)
        #p = self.conv_features(p)
        #p = self.add_on_layers(p)
        p = self.pattern_conv(p)
        
        out = torch.cat((x, p), dim=1)
        #print("cat out shape: ", out.shape)
        out = torch.flatten(out, 1) 
        #print("flatten out shape: ", out.shape)
        
        out = self.final_fc(out)
        #out = torch.sigmoid(self.fc1(out))
        #out = torch.sigmoid(self.fc2(out))
        #out = self.fc3(out)
        return out

    
    def _initialize_weights(self):
        for m in self.img_conv.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
        for m in self.pattern_conv.modules():
            if isinstance(m, nn.Conv2d):
                # every init technique has an underscore _ in the name
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)




            
def construct_PrefNet(base_architecture, pretrained=True, img_size=224,
                    prototype_shape=(1000, 128, 1, 1), num_classes=200,
                    prototype_activation_function='log',
                    add_on_layers_type='bottleneck',
                    k = 1):
    img_features = base_architecture_to_features[base_architecture](pretrained=pretrained)
    pattern_features = base_architecture_to_features[base_architecture](pretrained=pretrained)
    layer_filter_sizes, layer_strides, layer_paddings = img_features.conv_info()
    proto_layer_rf_info = compute_proto_layer_rf_info_v2(img_size=img_size,
                                                         layer_filter_sizes=layer_filter_sizes,
                                                         layer_strides=layer_strides,
                                                         layer_paddings=layer_paddings,
                                                         prototype_kernel_size=prototype_shape[2])
    return PrefNet(img_features=img_features,
                   pattern_features=pattern_features,
                 img_size=img_size,
                 prototype_shape=prototype_shape,
                 proto_layer_rf_info=proto_layer_rf_info,
                 num_classes=num_classes,
                 init_weights=True,
                 prototype_activation_function=prototype_activation_function,
                 add_on_layers_type=add_on_layers_type,
                 k = k)


def paired_cross_entropy_loss(out1, out2, targets):
    
    total_loss = 0
    for i in range(len(targets)):
        
        if targets[i] == -1:
            p1 = torch.exp(out1[i])/(torch.exp(out1[i]) + torch.exp(out2[i]))
            loss = - torch.log(p1)
        elif targets[i] == 1:
            p2 = torch.exp(out2[i])/(torch.exp(out1[i]) + torch.exp(out2[i]))
            loss = - torch.log(p2)

        else:
            p1 = torch.exp(out1[i])/(torch.exp(out1[i]) + torch.exp(out2[i]))
            p2 = torch.exp(out2[i])/(torch.exp(out1[i]) + torch.exp(out2[i]))
            loss = - (0.5*torch.log(p1) + 0.5*torch.log(p2))
            
        total_loss += loss
    return total_loss


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [4]:
normalize = transforms.Normalize(mean=mean,
                                 std=std)

trans = transforms.Compose([
    transforms.Resize(size=(img_size, img_size)),
    transforms.ToTensor(),
    normalize
])

In [5]:
k = 1
csv_name = "./human_comparisons/rating_s=5_k=1_700_random_1.csv"
if os.path.exists(csv_name):
    comp_df = pd.read_csv(csv_name)

In [6]:
split = 0.7
df_len = len(comp_df)
train_set = []
test_set = []
split_idx = int(df_len*split)
for i in range(split_idx):
    for j in range(i+1, split_idx):
        if comp_df.iloc[i]['rating'] > comp_df.iloc[j]['rating']:
            train_set.append([i, j, -1])
        elif comp_df.iloc[i]['rating'] < comp_df.iloc[j]['rating']:
            train_set.append([i, j, 1])
            
for i in range(split_idx, df_len):
    for j in range(i+1, df_len):
        if comp_df.iloc[i]['rating'] > comp_df.iloc[j]['rating']:
            test_set.append([i, j, -1])
        elif comp_df.iloc[i]['rating'] < comp_df.iloc[j]['rating']:
            test_set.append([i, j, 1])
print(len(train_set))
print(len(test_set))

71875
13831


In [7]:
images = []
patterns = []
for i in range(df_len):
    img = './human_comparisons/feedback_images/k=1_random/original_imgs/' + comp_df.iloc[i]['imgid'] + '.png'
    img = plt.imread(img)[:, :, :3]
    img = np.transpose(img, (2, 0, 1))
    images.append(torch.from_numpy(np.array([img])))
    pattern = './human_comparisons/feedback_images/k=1_random/patterns/' + comp_df.iloc[i]['imgid'] + '.npy'
    pattern = np.load(pattern)
    pattern = np.array([pattern, pattern, pattern])
    patterns.append(torch.from_numpy(np.array([pattern])))
print(len(images))
print(images[100].shape)
print(patterns[100].shape)

700
torch.Size([1, 3, 224, 224])
torch.Size([1, 3, 224, 224])


In [39]:
prefnet = construct_PrefNet("resnet50")
prefnet.to(device)
prefnet.train()

for p in prefnet.final_fc.parameters():
    p.requires_grad = True
for p in prefnet.img_features.parameters():
    p.requires_grad = True
for p in prefnet.pattern_features.parameters():
    p.requires_grad = True
for p in prefnet.img_conv.parameters():
    p.requires_grad = True
for p in prefnet.pattern_conv.parameters():
    p.requires_grad = True

pref_optimizer = optim.Adam([{'params': prefnet.img_features.parameters(), 'lr': 1e-4, 'weight_decay':1e-4}, {'params': prefnet.pattern_features.parameters(), 'lr': 1e-4, 'weight_decay':1e-4},
                             #{'params': prefnet.add_on_layers.parameters(), 'lr': 1e-4}, 
                             {'params': prefnet.img_conv.parameters(), 'lr': 1e-4, 'weight_decay':1e-4}, {'params': prefnet.pattern_conv.parameters(), 'lr': 1e-4, 'weight_decay':1e-4}, 
                             {'params': prefnet.final_fc.parameters(), 'lr': 1e-4, 'weight_decay':1e-4}
                             #{'params': prefnet.fc1.parameters(), 'lr': 1e-4}, {'params': prefnet.fc2.parameters(), 'lr': 1e-4}, {'params': prefnet.fc3.parameters(), 'lr': 1e-4}
                             ])

In [56]:
lr = 1e-5
weight_decay = 1e-5
pref_optimizer = optim.Adam([{'params': prefnet.img_features.parameters(), 'lr': lr, 'weight_decay':weight_decay}, {'params': prefnet.pattern_features.parameters(), 'lr': lr, 'weight_decay':weight_decay},
                             #{'params': prefnet.add_on_layers.parameters(), 'lr': 1e-4}, 
                             {'params': prefnet.img_conv.parameters(), 'lr': lr, 'weight_decay':weight_decay}, {'params': prefnet.pattern_conv.parameters(), 'lr': lr, 'weight_decay':weight_decay}, 
                             {'params': prefnet.final_fc.parameters(), 'lr': lr, 'weight_decay':weight_decay}
                             #{'params': prefnet.fc1.parameters(), 'lr': 1e-4}, {'params': prefnet.fc2.parameters(), 'lr': 1e-4}, {'params': prefnet.fc3.parameters(), 'lr': 1e-4}
                             ])

In [57]:
epochs = 5
batch_size = 16

In [58]:
prefnet(images[9].cuda(), patterns[9].cuda())

tensor([[0.0273]], device='cuda:0', grad_fn=<SigmoidBackward0>)

In [59]:
def test_reward_model(prefnet, test_set, images, patterns, batch_size):
    acc = []
    error_count = 0
    all_idx = np.arange(len(test_set))
    for batch_i in tqdm(range(len(test_set)//batch_size + 1)):
        prefnet.eval()
        
        idx = all_idx[batch_i*batch_size:(batch_i+1)*batch_size]
        
        left_imgs = torch.zeros((batch_size, 3, 224, 224))
        right_imgs = torch.zeros((batch_size, 3, 224, 224))
        left_patterns = torch.zeros((batch_size, 3, 224, 224))
        right_patterns = torch.zeros((batch_size, 3, 224, 224))
        targets = []
        for i in range(len(idx)):
            index = idx[i]
            left_imgs[i] = images[test_set[index][0]][0]
            right_imgs[i] = images[test_set[index][1]][0]
            targets.append(test_set[index][2])
            left_patterns[i] = patterns[test_set[index][0]][0]
            right_patterns[i] = patterns[test_set[index][1]][0]
        
        targets = torch.tensor(targets).cuda().float()
        
        out1 = prefnet(left_imgs.cuda().float(), left_patterns.cuda().float())
        out2 = prefnet(right_imgs.cuda().float(), right_patterns.cuda().float())

        
        
        for i in range(len(targets)):

            if out1[i] > out2[i]:
                y_pred = -1

            else:
                y_pred = 1

            
            if y_pred == targets[i]:
                acc.append(1)
        else:
            
            error_count += 1
            acc.append(0)
            
    return np.mean(acc), error_count

In [60]:
for epoch in range(epochs):
    shuffled_idx = np.random.permutation(len(train_set))
    for batch_i in range(len(train_set)//batch_size + 1):
        prefnet.train()
        if batch_i % 100 == 0:
            last_100_losses = []
            last_100_error_count = 0
        idx = shuffled_idx[batch_i*batch_size:(batch_i+1)*batch_size]
        
        left_imgs = torch.zeros((batch_size, 3, 224, 224))
        right_imgs = torch.zeros((batch_size, 3, 224, 224))
        left_patterns = torch.zeros((batch_size, 3, 224, 224))
        right_patterns = torch.zeros((batch_size, 3, 224, 224))
        targets = []
        for i in range(len(idx)):
            index = idx[i]
            left_imgs[i] = images[train_set[index][0]][0]
            right_imgs[i] = images[train_set[index][1]][0]
            targets.append(train_set[index][2])
            left_patterns[i] = patterns[train_set[index][0]][0]
            right_patterns[i] = patterns[train_set[index][1]][0]
        
        targets = torch.tensor(targets).cuda().float()
        
        out1 = prefnet(left_imgs.cuda().float(), left_patterns.cuda().float())
        out2 = prefnet(right_imgs.cuda().float(), right_patterns.cuda().float())

        
        pref_optimizer.zero_grad()   
        
        for i in range(len(targets)):
            if out1[i] > out2[i] and targets[i] == 1:
                last_100_error_count += 1

            elif out1[i] < out2[i] and targets[i] == -1:
                last_100_error_count += 1
                
        loss = paired_cross_entropy_loss(out1, out2, targets)
        
        loss.backward()
        pref_optimizer.step()   
        
        last_100_losses.append(loss.data.cpu().numpy()[0])
        
        if batch_i % 100 == 0:
            print(epoch, batch_i, np.sum(last_100_losses))
        if batch_i % 100 == 99:
            print(epoch, batch_i, last_100_error_count)
        if batch_i % 2000 == 1999:
            
            test_acc, test_error_count = test_reward_model(prefnet, test_set, images, patterns, batch_size)
            print(epoch, batch_i, test_acc, test_error_count)

0 0 7.7368217
0 99 98
0 100 6.6764493
0 199 95
0 200 7.047304
0 299 101
0 300 5.9509435
0 399 93
0 400 6.566382
0 499 92
0 500 6.3128304
0 599 90
0 600 6.678241
0 699 104
0 700 6.6762857
0 799 90
0 800 5.957803
0 899 89
0 900 7.04094
0 999 109
0 1000 6.305628
0 1099 108
0 1100 7.4081235
0 1199 109
0 1200 6.694401
0 1299 105
0 1300 5.564393
0 1399 105
0 1400 6.668639
0 1499 94
0 1500 7.006999
0 1599 115
0 1600 6.6676965
0 1699 99
0 1700 6.664954
0 1799 105
0 1800 7.4009433
0 1899 123
0 1900 6.294656
0 1999 103


100%|██████████| 865/865 [00:46<00:00, 18.63it/s]


0 1999 0.9142800515310673 865
0 2000 7.0306573
0 2099 120
0 2100 6.2919474
0 2199 101
0 2200 7.029771
0 2299 96
0 2300 6.6597095
0 2399 102
0 2400 6.289434
0 2499 100
0 2500 6.6577916
0 2599 104
0 2600 5.916775
0 2699 106
0 2700 6.6560087
0 2799 126
0 2800 6.28539
0 2899 100
0 2900 6.2849874
0 2999 69
0 3000 5.913496
0 3099 91
0 3100 6.282184
0 3199 112
0 3200 6.6511188
0 3299 105
0 3300 6.49632
0 3399 94
0 3400 6.279418
0 3499 96
0 3500 5.9084477
0 3599 104
0 3600 6.647893
0 3699 114
0 3700 6.56309
0 3799 124
0 3800 5.5354056
0 3899 114
0 3900 7.3863378


KeyboardInterrupt: 

In [61]:
torch.save(prefnet, './human_comparisons/pref_model_700_random_rating_split0.7_acc0.915.pth')

In [65]:
for i in range(df_len):
    score = prefnet(images[i].cuda(), patterns[i].cuda()).data
    print(i, comp_df.iloc[i]['rating'], score)

0 5.0 tensor([[0.9832]], device='cuda:0')
1 5.0 tensor([[0.9832]], device='cuda:0')
2 5.0 tensor([[0.9832]], device='cuda:0')
3 3.0 tensor([[0.0178]], device='cuda:0')
4 4.0 tensor([[0.9792]], device='cuda:0')
5 4.0 tensor([[0.0396]], device='cuda:0')
6 2.0 tensor([[0.0178]], device='cuda:0')
7 5.0 tensor([[0.9831]], device='cuda:0')
8 5.0 tensor([[0.9831]], device='cuda:0')
9 1.0 tensor([[0.0187]], device='cuda:0')
10 1.0 tensor([[0.0194]], device='cuda:0')
11 4.0 tensor([[0.1502]], device='cuda:0')
12 5.0 tensor([[0.9832]], device='cuda:0')
13 4.0 tensor([[0.5484]], device='cuda:0')
14 4.0 tensor([[0.0179]], device='cuda:0')
15 5.0 tensor([[0.9831]], device='cuda:0')
16 5.0 tensor([[0.9831]], device='cuda:0')
17 5.0 tensor([[0.9831]], device='cuda:0')
18 5.0 tensor([[0.9831]], device='cuda:0')
19 1.0 tensor([[0.0179]], device='cuda:0')
20 3.0 tensor([[0.0184]], device='cuda:0')
21 5.0 tensor([[0.9832]], device='cuda:0')
22 2.0 tensor([[0.0178]], device='cuda:0')
23 2.0 tensor([[0.017