In [None]:
from utils import *

In [None]:
"""
    https://towardsdatascience.com/a-friendly-introduction-to-siamese-networks-85ab17522942
"""

class SiameseNetworkClassifier(nn.Module):
    def __init__(self, latent_space_dim=50, dropout=0.6, device='mps'):
        super(SiameseNetworkClassifier, self).__init__()

        resnet = resnet152(weights=ResNet152_Weights.IMAGENET1K_V1)
        for param in resnet.parameters():
            param.requires_grad = False

        self.frozen = nn.Sequential(*list(resnet.children())[:-1]) 
        self.hot = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(resnet.fc.in_features, latent_space_dim)
        )

        self.threshold = torch.tensor(0.)
        self.device = torch.device(device)
        self.to(self.device)

    def forward(self, images1, images2):
        output1 = self.hot(self.frozen(images1))
        output2 = self.hot(self.frozen(images2))
        return F.pairwise_distance(output1, output2)
        
    def update_threshold(self, loader, max_batches=None):
        self.eval()
        with torch.no_grad():
            distances = []
            labels = []
            for images1, images2, equals in islice(loader, max_batches):
                distance = self.forward(images1.to(self.device), images2.to(self.device))
                distances.append(distance.cpu())
                labels.append(equals)
    
            distances = torch.cat(distances)
            labels = torch.cat(labels)
            log_reg = LogisticRegression(penalty=None)
            log_reg.fit(distances.reshape((-1, 1)), labels)
            self.threshold = (-log_reg.intercept_ / log_reg.coef_).item()

    # TODO refactor this method so we don't have to call .to(self.device) ? 
    def predict(self, images1, images2):
        self.eval()
        with torch.no_grad():
            images1 = images1.to(self.device)
            images2 = images2.to(self.device)
            distances = self.forward(images1, images2)
            return (distances < self.threshold).int().cpu()

In [None]:
"""
    Create instances of the model and loaders
    TODO move params batch_size, image_size, crop_size, paths to the separate cell
"""

batch_size = 64

transform = T.Compose([
    T.Resize(300),
    T.CenterCrop(280),
    T.Lambda(lambda x: T.functional.equalize(x)),
    T.ToTensor(),
    T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

df_slices = list(split_dataframe(pd.read_csv('data/small.csv'), (0.9, 0.01, 0.04, 0.05)))
datasets = [ImageDataset(df, transform=transform) for df in df_slices]
loaders = [DataLoader(dataset, batch_size=batch_size, shuffle=(i == 0)) for i, dataset in enumerate(datasets)]

train_data, train_threshold_data, valid_data, test_data = df_slices
train_dataset, train_threshold_dataset, valid_dataset, test_dataset = datasets
train_loader, train_threshold_loader, valid_loader, test_loader = loaders

submit_df = pd.read_csv('data/submit.csv')
submit_dataset = ImageDataset(submit_df, transform=transform)
submit_loader = DataLoader(submit_dataset, batch_size=batch_size, shuffle=False)

model = SiameseNetworkClassifier(latent_space_dim=100, dropout=0.8)

In [None]:
"""
    Train model
"""

train(model, *loaders, epochs=5, max_batches=10, verbose=True)

In [None]:
"""
    Check on what's wrong with our model
"""

mislabeled(model, valid_loader)

In [None]:
"""
    Save predictions for the submission
"""

max_submit_id = 22661
save_submission(model, submit_loader, max_submit_id)