In [1]:
from utils import *

df_slices = list(split_dataframe(pd.read_csv('data/small.csv'), (0.9, 0.01, 0.04, 0.05)))
submit_df = pd.read_csv('data/submit.csv')

In [2]:
"""
    https://towardsdatascience.com/a-friendly-introduction-to-siamese-networks-85ab17522942
"""

class SiameseNetworkClassifier(nn.Module):
    def __init__(self, latent_space_dim=50, dropout=0.6, device='mps'):
        super(SiameseNetworkClassifier, self).__init__()

        resnet = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        for param in resnet.parameters():
            param.requires_grad = False

        self.frozen = nn.Sequential(*list(resnet.children())[:-1])
        self.hot = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(resnet.fc.in_features, latent_space_dim)
        )

        self.threshold = torch.tensor(0.)
        self.device = torch.device(device)
        self.to(self.device)

    def forward(self, images1, images2):
        output1 = self.hot(self.frozen(images1))
        output2 = self.hot(self.frozen(images2))
        return F.pairwise_distance(output1, output2)
        
    def update_threshold(self, loader, max_batches=None):
        self.eval()
        with torch.no_grad():
            distances = []
            labels = []
            for images1, images2, equals in islice(loader, max_batches):
                distance = self.forward(images1.to(self.device), images2.to(self.device))
                distances.append(distance.cpu())
                labels.append(equals)
    
            distances = torch.cat(distances)
            labels = torch.cat(labels)
            log_reg = LogisticRegression(penalty=None)
            log_reg.fit(distances.reshape((-1, 1)), labels)
            self.threshold = (-log_reg.intercept_ / log_reg.coef_).item()

    # TODO refactor this method so we don't have to call .to(self.device) ? 
    def predict(self, images1, images2):
        self.eval()
        with torch.no_grad():
            images1 = images1.to(self.device)
            images2 = images2.to(self.device)
            distances = self.forward(images1, images2)
            return (distances < self.threshold).int().cpu()
        
class SiameseNetworkClassifier_TrainConv(nn.Module):
    def __init__(self, latent_space_dim=150, dropout=0.1, device='mps'):
        super(SiameseNetworkClassifier_TrainConv, self).__init__()

        resnet = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        for name, param in resnet.named_parameters():
            if not name.startswith('layer4.2') and not name.startswith('fc'):
                param.requires_grad = False

        self.layers = nn.Sequential(*list(resnet.children())[:-1], nn.Flatten(), nn.Dropout(dropout), nn.Linear(resnet.fc.in_features, latent_space_dim))
        
        self.threshold = torch.tensor(0.)
        self.device = torch.device(device)
        self.to(self.device)

    def forward(self, images1, images2):
        output1 = self.layers(images1)
        output2 = self.layers(images2)
        return F.pairwise_distance(output1, output2)
        
    def update_threshold(self, loader, max_batches=None):
        self.eval()
        with torch.no_grad():
            distances = []
            labels = []
            for images1, images2, equals in islice(loader, max_batches):
                distance = self.forward(images1.to(self.device), images2.to(self.device))
                distances.append(distance.cpu())
                labels.append(equals)
    
            distances = torch.cat(distances)
            labels = torch.cat(labels)
            log_reg = LogisticRegression(penalty=None)
            log_reg.fit(distances.reshape((-1, 1)), labels)
            self.threshold = (-log_reg.intercept_ / log_reg.coef_).item()

    # TODO refactor this method so we don't have to call .to(self.device) ? 
    def predict(self, images1, images2):
        self.eval()
        with torch.no_grad():
            images1 = images1.to(self.device)
            images2 = images2.to(self.device)
            distances = self.forward(images1, images2)
            return (distances < self.threshold).int().cpu()

In [3]:
def objective(trial):
    latent_space_dim = trial.suggest_int("latent_space_dim", 130, 170)
    latent_space_dim = 150
    dropout = trial.suggest_float("dropout", 0, 0.2)
    
    lr = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-4, 1e-2, log=True)

    batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
    image_size = trial.suggest_int("image_size", 300, 400)
    
    # Random transform parameters for training set
    random_crop_scale_min = trial.suggest_float('random_crop_scale_min', 0.5, 1)
    random_crop_scale_max = trial.suggest_float('random_crop_scale_max', 0.5, 1)
    if random_crop_scale_min > random_crop_scale_max:
        random_crop_scale_min, random_crop_scale_max = random_crop_scale_max, random_crop_scale_min
        # trial.set_user_attr('random_crop_scale_min', random_crop_scale_min)
        # trial.set_user_attr('random_crop_scale_max', random_crop_scale_max)
    downsample_ratio = trial.suggest_float('downsample_ratio', 0.6, 1)
    pad_ratio = trial.suggest_float('pad_ratio', 0, 0.3)
    use_color_jitter = trial.suggest_categorical('use_color_jitter', [True, False])
    random_apply_prob = trial.suggest_float('random_apply_prob', 0, 0.3)
    
    train_conv = trial.suggest_categorical("train_conv", [True, False])
    
    def get_transform(is_train):
        transforms = [
            T.Resize(image_size),
            T.CenterCrop(image_size),
            T.Lambda(lambda x: T.functional.equalize(x)),
        ]

        if is_train:
            transforms += [
                T.RandomApply([T.RandomResizedCrop(image_size, (random_crop_scale_min, random_crop_scale_max), (1, 1))], p=random_apply_prob),
                T.RandomApply([T.Resize(int(image_size * downsample_ratio)), T.Resize(image_size)], p=random_apply_prob),
                T.RandomApply([T.Pad(int(image_size * pad_ratio)), T.Resize(image_size)], p=random_apply_prob),
            ]
            if use_color_jitter:
                transforms.append(T.ColorJitter(0.1, 0.1, 0.1, 0.1))

        transforms += [
            T.ToTensor(),
            T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]

        return T.Compose(transforms)

    datasets = [ImageDataset(df, transform=get_transform(i == 0)) for i, df in enumerate(df_slices)]
    loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=(i == 0)) for i, dataset in enumerate(datasets)]
    
    train_data, train_threshold_data, valid_data, test_data = df_slices
    train_dataset, train_threshold_dataset, valid_dataset, test_dataset = datasets
    train_loader, train_threshold_loader, valid_loader, test_loader = loaders
    
    if train_conv:
        model = SiameseNetworkClassifier_TrainConv(latent_space_dim=latent_space_dim, dropout=dropout)
    else:
        model = SiameseNetworkClassifier(latent_space_dim=latent_space_dim, dropout=dropout)
    train(model, *loaders, epochs=1, max_batches= 5000 // batch_size + 1, lr=lr, weight_decay=weight_decay, verbose=False)
    return evaluate(model, test_loader, max_batches= 2500 // batch_size + 1)

In [4]:
study = optuna.create_study(
    study_name = "augment_and_train_conv_study",
    storage = "sqlite:///augment_and_train_conv_study.db",
    direction = "maximize",
    load_if_exists=True
)

n_trials = 500

study.optimize(objective, n_trials=n_trials, gc_after_trial=True, show_progress_bar=True, catch=Exception)

[I 2023-07-09 00:06:01,091] A new study created in RDB with name: augment_and_train_conv_study


  0%|          | 0/500 [00:00<?, ?it/s]

Epoch 0:   0%|          | 0/313 [00:00<?, ?it/s]

In [None]:
optuna.visualization.plot_param_importances(study)

In [None]:
optuna.visualization.plot_contour(study, params=['image_size', 'latent_space_dim'])

In [None]:
study.trials_dataframe().sort_values('value', ascending=False).to_csv('params.csv')

In [None]:
pd.read_csv('params.csv')

Unnamed: 0.1,Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_batch_size,params_downsample_ratio,params_dropout,params_image_size,params_latent_space_dim,params_learning_rate,params_pad_ratio,params_random_apply_prob,params_random_crop_scale_max,params_random_crop_scale_min,params_use_color_jitter,params_weight_decay,state
0,12,12,0.997761,2023-07-08 22:01:16.520658,2023-07-08 22:07:35.269168,0 days 00:06:18.748510,16,0.616866,0.121133,378,152,0.000193,0.177446,0.222577,0.514853,0.856183,True,0.002508,COMPLETE
1,19,19,0.995422,2023-07-08 22:25:53.457407,2023-07-08 22:30:43.333673,0 days 00:04:49.876266,8,0.646529,0.00934,324,153,0.00054,0.128549,0.268075,0.518787,0.791927,False,0.003059,COMPLETE
2,18,18,0.995422,2023-07-08 22:21:44.017694,2023-07-08 22:25:53.270446,0 days 00:04:09.252752,8,0.652126,0.002524,307,178,0.000206,0.135649,0.268082,0.500464,0.903891,False,0.002293,COMPLETE
3,11,11,0.993304,2023-07-08 21:54:35.909763,2023-07-08 22:01:15.725787,0 days 00:06:39.816024,16,0.605947,0.114329,378,154,7.7e-05,0.171543,0.2464,0.500797,0.867172,True,0.000597,COMPLETE
4,13,13,0.993304,2023-07-08 22:07:36.312466,2023-07-08 22:09:40.962596,0 days 00:02:04.650130,16,0.752348,0.148938,202,136,0.000438,0.214437,0.18851,0.607612,0.832451,True,0.002601,COMPLETE
5,5,5,0.993304,2023-07-08 21:37:41.160038,2023-07-08 21:39:37.450105,0 days 00:01:56.290067,16,0.724953,0.186027,170,173,0.000153,0.157382,0.225157,0.654646,0.883331,True,0.001414,COMPLETE
6,15,15,0.993282,2023-07-08 22:12:15.511579,2023-07-08 22:19:17.598372,0 days 00:07:02.086793,16,0.758245,0.077192,394,139,0.000958,0.139047,0.245953,0.600036,0.513665,True,0.001638,COMPLETE
7,0,0,0.991086,2023-07-08 21:25:43.924422,2023-07-08 21:28:54.957578,0 days 00:03:11.033156,16,0.923019,0.081242,315,42,1.8e-05,0.169548,0.263186,0.548217,0.960422,True,0.000593,COMPLETE
8,14,14,0.990873,2023-07-08 22:09:41.130931,2023-07-08 22:12:15.344812,0 days 00:02:34.213881,8,0.608587,0.150787,256,194,9.4e-05,0.298321,0.192838,0.509831,0.880298,True,0.003606,COMPLETE
9,1,1,0.988875,2023-07-08 21:28:55.122478,2023-07-08 21:34:02.730265,0 days 00:05:07.607787,32,0.692378,0.219454,350,38,3.5e-05,0.203514,0.221769,0.765473,0.959567,False,0.005732,COMPLETE
