# Get model to use as encoder and comparator

In [None]:

from torchvision.datasets import ImageFolder
from src.utils.config import RESNET34_FULL, RESNET18_FULL, Config, GAMfig
from src.utils.config import BEETLE_DATASET, DEFAULT_TEST_PATH
from src.models import download_model, load_model_weights_and_metrics
import torchvision
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from src.GAM import GAM_fit, GAM_fit_save
import numpy as np
import copy
from src.utils.dataset import dataset_to_dataloaders, dataset_to_dataloaders_2
from src.utils.transforms import *

In [None]:
batch_size = 16

In [None]:
config = Config()

In [None]:
dataset_config = BEETLE_DATASET
model_config = RESNET34_FULL

In [None]:
GAM_DATASET = BEETLE_DATASET
#GAM_DATASET['image_folder_path'] = './data/beetles/images_subset'
GAM_DATASET['data_augmentations'] = [ 
        Resize((224, 448)),
        ToTensor(),
        Normalize(0.5, 0.5)
    ]
GAM_DATASET['batch_size'] = batch_size
GAM_DATASET['training_data_ratio'] = 0.8
GAM_DATASET['validation_data_ratio'] = 0.5

In [None]:
model = download_model(model_config, dataset_config)
_ = load_model_weights_and_metrics(model, model_config)
model = model.eval()

# Data for Generator

In [None]:
transformer = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 448)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(0.5, 0.5)])


In [None]:
data_loaders, dataset_sizes = dataset_to_dataloaders_2(GAM_DATASET)

In [None]:
gamset_sub = ImageFolder('./data/beetles/images_subset/', transform = transformer)
dataloader_sub = DataLoader(gamset_sub, batch_size = 8, num_workers = 12, shuffle=True)

In [None]:
gamset = ImageFolder(dataset_config['image_folder_path'], transform = transformer)
dataloader = DataLoader(gamset, batch_size = 16, num_workers = 12, shuffle=True)

# Discriminator and generator

In [None]:
#Remember to set this to the correct number of input channels
input_channels = 1000

number_epochs = 20

In [None]:
class D_block(nn.Module):
    def __init__(self, out_channels, in_channels=3, kernel_size=4, strides=2,
                padding=1, alpha=0.2, **kwargs):
        super(D_block, self).__init__(**kwargs)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size,
                                strides, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.LeakyReLU(alpha, inplace=True)

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d(X)))

In [None]:
n_D = 64
net_D = nn.Sequential(
    D_block(n_D),  # Output: (64, 32, 32)
    D_block(in_channels=n_D, out_channels=n_D*2),  # Output: (64 * 2, 16, 16)
    D_block(in_channels=n_D*2, out_channels=n_D*4),  # Output: (64 * 4, 8, 8)
    D_block(in_channels=n_D*4, out_channels=n_D*8),  # Output: (64 * 8, 4, 4)
    D_block(in_channels=n_D*8, out_channels=n_D*16),  # Output: (64 * 8, 4, 4)
    nn.Conv2d(in_channels=n_D*16, out_channels=1,
              kernel_size=(7,14), bias=False))  # Output: (1, 1, 1)

In [None]:
class G_block(nn.Module):
    def __init__(self, out_channels, in_channels=3, kernel_size=(4,4), strides=(2,2),
                 padding=(1,1), **kwargs):
        super(G_block, self).__init__(**kwargs)
        self.conv2d_trans = nn.ConvTranspose2d(in_channels, out_channels,
                                kernel_size, strides, padding, bias=False)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()

    def forward(self, X):
        return self.activation(self.batch_norm(self.conv2d_trans(X)))

In [None]:
n_G = 64
net_G = nn.Sequential(
    G_block(in_channels=input_channels, out_channels=n_G*16, kernel_size=(7,14),
            strides=1, padding=0),                  # Output: (64 * 16, 7, 14)
    G_block(in_channels=n_G*16, out_channels=n_G*8), # Output: (64 * 8, 14, 28)
    G_block(in_channels=n_G*8, out_channels=n_G*4), # Output: (64 * 4, 28, 56)
    G_block(in_channels=n_G*4, out_channels=n_G*2), # Output: (64 * 2, 56, 112)
    G_block(in_channels=n_G*2, out_channels=n_G),   # Output: (64, 112, 224)
    nn.ConvTranspose2d(in_channels=n_G, out_channels=3,
                       kernel_size=4, stride=2, padding=1, bias=False),
    nn.Tanh())  # Output: (3, 224, 448)

In [None]:
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

In [None]:
net_G = net_G.to('cuda')
net_D = net_D.to('cuda')

In [None]:
enc = copy.deepcopy(model).to('cuda')

In [None]:
import types

In [None]:
# truncate the model at avgpool
def _new_forward_impl(self, x: torch.Tensor, not_test: bool = True) -> torch.Tensor:
    # See note [TorchScript super()]
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avgpool(x)
    x = self.fc(x.flatten(1,-1))
    return x
model._forward_impl = types.MethodType(_new_forward_impl, model)

## Lambda values

In [None]:
# lambda equal test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_lambda_equal'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [12,1/4000, 1/300]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

In [None]:
# lambda img test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_lambda_img'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [12,3/4000, 1/300]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

In [None]:
# lambda feat test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_lambda_feat'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [12,1/4000, 1/100]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

In [None]:
# lambda adv test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_lambda_adv'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [36,1/4000, 1/300]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

## latent code vs one-hot

In [None]:
# latent one hot test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_one_hot'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [15,1/1000, 1/300]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = 'soft_35'
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

In [None]:
# latent base test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_latent'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [15,1/1000, 1/300]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'fc'
#GAM_fit_save(gamfig)

## Test comparitor structure

In [None]:
# truncate the model at avgpool
def _new_forward_impl(self, x: torch.Tensor, not_test: bool = True) -> torch.Tensor:
    # See note [TorchScript super()]
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.avgpool(x)
    if not_test:
        return x
    x = self.fc(x.flatten(1,-1))
    return x

In [None]:
# comp layer 3 test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_comp_maxpool'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [15,1/1000, 1/100]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'maxpool'
#model._forward_impl = types.MethodType(_new_forward_impl, model)
#GAM_fit_save(gamfig)

In [None]:
# truncate the model at layer 3
def _new_forward_impl(self, x: torch.Tensor, not_test: bool = True) -> torch.Tensor:
    # See note [TorchScript super()]
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    if not_test:
        return x
    x = self.layer4(x)
    x = self.avgpool(x)
    x = self.fc(x.flatten(1,-1))
    return x

In [None]:
# comp layer 3 test
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)

gamfig = GAMfig
gamfig['name'] = 'static_comp_layer3'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['static'] = True
gamfig['encoder'] = enc
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [15,1/1000, 1/2000]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'layer3'
#model._forward_impl = types.MethodType(_new_forward_impl, model)
#GAM_fit_save(gamfig)

# Trained encoder

In [None]:
import torchvision.models as models

In [None]:
enc_test = models.resnet18(pretrained=False).to('cuda')

In [None]:
# truncate the model at layer 3
def _new_forward_impl(self, x: torch.Tensor, not_test: bool = True) -> torch.Tensor:
    # See note [TorchScript super()]
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    if not_test:
        return x
    x = self.layer4(x)
    x = self.avgpool(x)
    x = self.fc(x.flatten(1,-1))
    return x
model._forward_impl = types.MethodType(_new_forward_impl, model)

In [None]:
#[5,1/1000, 1/2000]

In [None]:
for w in net_D.parameters():
    nn.init.normal_(w, 0, 0.02)
for w in net_G.parameters():
    nn.init.normal_(w, 0, 0.02)
    
gamfig = GAMfig
gamfig['name'] = 'current_best_v1_feat'
gamfig['num_epochs'] = number_epochs
gamfig['generator'] = net_G
gamfig['discriminator'] = net_D
gamfig['comparator'] = model
gamfig['encoder'] = enc_test
gamfig['static'] = False
gamfig['latent_dim'] = input_channels
gamfig['lambdas'] = [0,0, 1/2000]
gamfig['datasizes'] = dataset_sizes
gamfig['dataloaders'] = data_loaders
gamfig['latent_aug_name'] = None
gamfig['comp_layer'] = 'layer3'
GAM_fit_save(gamfig)