In [32]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from dlc_practical_prologue import generate_pair_sets

In [155]:
train_input, train_target, train_classes, test_input, test_target, test_classes = generate_pair_sets(1000)
train_classes.size()

torch.Size([1000, 2])

In [162]:
class PairModel(nn.Module):
    def __init__(self, nbch1=32, nbch2=64, nbfch=256, batch_norm=True):
        super(PairModel, self).__init__()
        self.nbch1 = nbch1
        self.nbch2 = nbch2
        self.nbfch = nbfch
        self.batch_norm = batch_norm
        self.conv1 = nn.Conv2d(2, nbch1, 3)
        if batch_norm:
            self.bn1 = nn.BatchNorm2d(nbch1)
        self.conv2 = nn.Conv2d(nbch1, nbch2, 6)
        if batch_norm:
            self.bn2 = nn.BatchNorm2d(nbch2)
        self.fc1 = nn.Linear(nbch2, nbfch)
        self.fc2 = nn.Linear(nbfch, 2)
    def forward(self, x):
        if self.batch_norm:
            x = F.relu(F.max_pool2d(self.bn1(self.conv1(x)), 2))
            x = F.relu(self.bn2(self.conv2(x)))
        else:
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
            x = F.relu(self.conv2(x))
        x = F.relu(self.fc1(x.view(-1, self.nbch2)))
        return F.relu(self.fc2(x))

In [174]:
def train_pair_model(num_epochs=25, lr=0.1, mini_batch_size=100):
    model = PairModel()
    num_samples = train_input.size(0)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for e in range(num_epochs):
        sum_loss = 0
        for b in range(0, num_samples, mini_batch_size):
            input_mini_batch = train_input[b:b + mini_batch_size]
            target_mini_batch = train_target[b:b + mini_batch_size]
            model.zero_grad()
            prediction = model(input_mini_batch)
            loss = criterion(prediction, target_mini_batch)
            sum_loss += loss.item()
            loss.backward()
            optimizer.step()
        print(e, sum_loss)
    return model

In [183]:
class SiameseModel(nn.Module):
    def __init__(self, nbch1=32, nbch2=64, nbfch=256, batch_norm=True):
        super(SiameseModel, self).__init__()
        self.nbch1 = nbch1
        self.nbch2 = nbch2
        self.nbfch = nbfch
        self.batch_norm = batch_norm
        self.conv1 = nn.Conv2d(1, nbch1, 3)
        if batch_norm:
            self.bn1 = nn.BatchNorm2d(nbch1)
        self.conv2 = nn.Conv2d(nbch1, nbch2, 6)
        if batch_norm:
            self.bn2 = nn.BatchNorm2d(nbch2)
        self.fc1 = nn.Linear(nbch2, nbfch)
        self.fc2 = nn.Linear(nbfch, 10)
        self.fc3 = nn.Linear(20, 2)
    def forward(self, x):
        if self.batch_norm:
            x1, x2 = x[:, 0], x[:, 1]
            x1 = x1.reshape(-1, 1, 14, 14)
            x2 = x2.reshape(-1, 1, 14, 14)
            x1 = F.relu(F.max_pool2d(self.bn1(self.conv1(x1)), 2))
            x1 = F.relu(self.bn2(self.conv2(x1)))
            x2 = F.relu(F.max_pool2d(self.bn1(self.conv1(x2)), 2))
            x2 = F.relu(self.bn2(self.conv2(x2)))
        else:
            x1 = F.relu(F.max_pool2d(self.conv1(x1), 2))
            x1 = F.relu(self.conv2(x1))
            x2 = F.relu(F.max_pool2d(self.conv1(x2), 2))
            x2 = F.relu(self.conv2(x2))
        x1 = F.relu(self.fc1(x1.view(-1, self.nbch2)))
        x1 = F.relu(self.fc2(x1))
        x2 = F.relu(self.fc1(x2.view(-1, self.nbch2)))
        x2 = F.relu(self.fc2(x2))
        x = torch.cat((x1, x2), axis=1)
        return F.relu(self.fc3(x)), (x1, x2)

In [178]:
def train_siamese_model(num_epochs=25, lr=0.1, mini_batch_size=100, loss_weights=(1, 1)):
    model = SiameseModel()
    num_samples = train_input.size(0)
    optimizer = optim.SGD(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    for e in range(num_epochs):
        sum_loss = 0
        for b in range(0, num_samples, mini_batch_size):
            input_mini_batch = train_input[b:b + mini_batch_size]
            target_mini_batch = train_target[b:b + mini_batch_size]
            classes_mini_batch = train_classes[b:b + mini_batch_size]
            model.zero_grad()
            prediction_2, (prediction_10_1, prediction_10_2) = model(input_mini_batch)
            loss_2 = criterion(prediction_2, target_mini_batch)
            loss_10_1 = criterion(prediction_10_1, classes_mini_batch[:, 0])
            loss_10_2 = criterion(prediction_10_2, classes_mini_batch[:, 1])
            loss_10 = loss_10_1 + loss_10_2
            total_loss = loss_weights[0] * loss_2 + loss_weights[1] * loss_10
            total_loss.backward()
            sum_loss += total_loss.item()
            optimizer.step()
        print(e, sum_loss)
    return model

In [157]:
def test_model(model):
    num_samples = test_input.size(0)
    prediction = model(test_input)
    predicted_class = torch.argmax(prediction, axis=1)
    accuracy = torch.sum(predicted_class == test_target).float() / num_samples
    return accuracy

In [150]:
trained_pair_model = train_pair_model(lr=1e0)

0 6.338943183422089
1 5.194023281335831
2 4.665180027484894
3 4.361992746591568
4 4.107274085283279
5 3.9422988891601562
6 3.3555320352315903
7 2.6736945509910583
8 1.6262310519814491
9 0.9279737286269665
10 1.3092701062560081
11 1.5165461488068104
12 0.33970838133245707
13 0.06681734882295132
14 0.029906735755503178
15 0.020450577372685075
16 0.0154412139672786
17 0.012318151071667671
18 0.010172725014854223
19 0.008625649730674922
20 0.00746550690382719
21 0.006563328322954476
22 0.005844814004376531
23 0.005257547978544608
24 0.004770920728333294


In [151]:
test_model(trained_pair_model)

tensor(0.8260)

In [197]:
trained_siamese_model = train_siamese_model(lr=1, loss_weights=(1.5, 0.25))

0 18.91983151435852
1 12.93790578842163
2 8.380740702152252
3 6.232254892587662
4 5.683494448661804
5 4.304722011089325
6 3.3988476544618607
7 3.291791617870331
8 2.376682296395302
9 1.5707224905490875
10 2.0111621767282486
11 2.6959504559636116
12 3.1543694511055946
13 2.5462238863110542
14 0.8878308944404125
15 0.593383114784956
16 0.32002727687358856
17 0.2417464954778552
18 0.20591144356876612
19 0.18339358922094107
20 0.1670431005768478
21 0.15462171332910657
22 0.14506843546405435
23 0.1377294142730534
24 0.132113853469491


In [189]:
def test_siamese_model(model):
    num_samples = test_input.size(0)
    prediction_2, (prediction_10_1, prediction_10_2) = model(test_input)
    predicted_class_2 = torch.argmax(prediction_2, axis=1)
    predicted_class_10_1 = torch.argmax(prediction_10_1, axis=1)
    predicted_class_10_2 = torch.argmax(prediction_10_2, axis=1)
    predicted_class_10 = predicted_class_10_1 <= predicted_class_10_2
    accuracy_2 = torch.sum(predicted_class_2 == test_target).float() / num_samples
    accuracy_10 = torch.sum(predicted_class_10 == test_target).float() / num_samples
    return accuracy_2, accuracy_10

In [198]:
test_siamese_model(trained_siamese_model)

(tensor(0.9280), tensor(0.9690))

 ### ToDo
 - gradient norm over depth vs usage of batch_norm
 - cross-val of learn_rate, siamese_loss_weights, num_channels, num_linear_hidden
 - one loss test accuracy vs two losses(augmented) test accuracy
 - possibly, 20->2 probs with linear vs non-linear transform
 - impact of initializations, normalizing input, adding skip connections, dropout...?