In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
import dlc_practical_prologue as prologue

mini_batch_size = 100
N = 1000
train_input, train_target, train_classes, test_input, test_target, test_classes = prologue.generate_pair_sets(N)

In [2]:
#normalize the input
train_input/=255
test_input/=255

In [3]:
print_shapes_Net = False

In [4]:
#CLASSES
class Net_classes(nn.Module):
    def __init__(self):
        super(Net_classes, self).__init__()
        nb_hidden = 200
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=2)
        self.fc1 = nn.Linear(64, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 10)

    def forward(self, x):
        if print_shapes_Net:
            print("initial shape", x.shape)
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2))
        if print_shapes_Net:
            print("1 conv",x.shape)
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2))
        if print_shapes_Net:
            print("2 conv", x.shape)
        x = F.relu(self.conv3(x))
        if print_shapes_Net:
            print("3 conv",x.shape)
        x = F.relu(self.fc1(x.view(-1, 64)))
        if print_shapes_Net:
            print("fc1",x.shape)
        x = self.fc2(x)
        if print_shapes_Net:
            print("final",x.shape)
        return x

######################################################################

In [5]:
#TARGET
class Net_targets(nn.Module):
    def __init__(self, nb_hidden):
        super(Net_targets, self).__init__()
        self.conv1 = nn.Conv1d(2, 32, kernel_size=3)
        self.conv2 = nn.Conv1d(16, 64, kernel_size=3)
        self.fc1 = nn.Linear(32, nb_hidden)
        self.fc2 = nn.Linear(nb_hidden, 2)

    def forward(self, x):
        if print_shapes_Net:
            print("initial", x.shape) #100 2 10
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2))
        if print_shapes_Net:
            print("conv1",x.shape) #100 16 4
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        if print_shapes_Net:
            print("conv2",x.shape)
        x = F.relu(self.fc1(x.view(-1, 32)))
        if print_shapes_Net:
            print("fc1",x.shape)
        x = self.fc2(x)
        if print_shapes_Net:
            print("final",x.shape) # 100 1 2 
        return x

In [31]:
def train_model(model_target, model_class, train_input, train_classes, train_target, mini_batch_size, lr, nb_epoch):
    criterion = nn.MSELoss()
    optimizer_class = optim.Adam(model_class.parameters(), lr)
    optimizer_target = optim.Adam(model_target.parameters(), lr)
    
    for e in range(nb_epoch):
        sum_loss = 0
        for b in range(0, train_input.size(0), mini_batch_size):
            #classes
            output_class = model_class(train_input.narrow(0, b, mini_batch_size))
            loss_class = criterion(output_class, train_classes.narrow(0, b, mini_batch_size))
            
            #target
            input_target = output_class.view([int(mini_batch_size/2), 2, 10]) #reshape the output to have 2 images at once
            output_target = model_target(input_target) #CHECK DIMENSION SI BUG
            loss_target = criterion(output_target, train_target.narrow(0, int(b/mini_batch_size), int(mini_batch_size/2)))
            
            model_class.zero_grad()
            model_target.zero_grad()
            loss = loss_class + loss_target
            loss.backward()
            
            optimizer_class.step()
            optimizer_target.step()
            
            sum_loss = sum_loss + loss.item()
        #print(e, sum_loss)

In [54]:
def compute_nb_errors_classes(model, input, target):
    nb_errors = 0

    output = model(input)
    _, predicted_classes = output.max(1)
    
    print(predicted_classes)

    for b in range(input.shape[0]):
        if target[b + k, predicted_classes[k]] <= 0:
            nb_errors = nb_errors + 1

    return nb_errors

In [56]:
def compute_nb_errors_targets(model_target, model_class, input, target):
    nb_errors = 0
    
    #need class ouput first, and reshape it
    output_class = model_class(input)
    input_target = output_class.view([1000,2,10])
    
    output = model_target(input_target)
    _, predicted_target = output.max(1) #digits - shape [1000]
    
    print(predicted_target.shape)
    #predicted_target = torch.empty(1000)
    
    for b in range(input_target.shape[0]):
        # print(predicted_classes[b], target[b], "b", b)
        if target[b, int(predicted_target[b])] <= 0:
            nb_errors = nb_errors + 1

    return nb_errors

In [10]:
#RESHAPE THE INPUTS AND CLASSES

new_train_input = train_input.view([2000,1,14,14])
new_test_input = test_input.view([2000,1,14,14])

train_classes = train_classes.view(2000)
test_classes = test_classes.view(2000)

#transfor classes in a 2000 * 10 
train_classes2 = torch.empty(2000,10)
test_classes2 = torch.empty(2000,10)
for i in range(2000):
    for j in range(10):
        if train_classes[i] == j:
            train_classes2[i,j] = 1
        else:
            train_classes2[i,j] = 0
        if test_classes[i] == j:
            test_classes2[i,j] = 1
        else:
            test_classes2[i,j] = 0
            
            
train_input = new_train_input
test_input = new_test_input
train_classes = train_classes2
test_classes = test_classes2


# RESHAPE TARGETS

#train_target[1000,1]

new_train_target = torch.empty(1000,2)
new_test_target = torch.empty(1000,2)
for i in range(1000):
    if train_target[i] == 1 :
        new_train_target[i,0] = 0
        new_train_target[i,1] = 1
        
    else:
        new_train_target[i,0] = 1
        new_train_target[i,1] = 0
        
    if test_target[i] == 1:
        new_test_target[i,0] = 0
        new_test_target[i,1] = 1
        
    else:
        new_test_target[i,0] = 1
        new_test_target[i,1] = 0
        
train_target = new_train_target
test_target = new_test_target

In [36]:
####predict class of each digit
for k in range(1):
    model_class = Net_classes()
    model_target = Net_targets(200)
    lr = 0.005
    nb_epoch = 25
    train_model(model_target, model_class, train_input, train_classes, train_target, mini_batch_size, lr, nb_epoch)
    
    '''nb_train_errors_class = compute_nb_errors_classes(model_class, train_input, train_classes, mini_batch_size)
    print('train error Net_classes {:0.2f}% {:d}/{:d}'.format((100 * nb_train_errors_class) / train_input.size(0),
                                                      nb_train_errors_class, train_input.size(0)))
    nb_test_errors_class = compute_nb_errors_classes(model_class, test_input, test_classes, mini_batch_size)
    print('test error Net_classes {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors_class) / test_input.size(0),
                                                    nb_test_errors_class, test_input.size(0)))
    
    nb_train_errors_target = compute_nb_errors_targets(model_target, train_input, train_classes)
    print('train error Net_target {:0.2f}% {:d}/{:d}'.format((100 * nb_train_errors_target) / train_input.size(0),
                                                      nb_train_errors_target, train_input.size(0)))
    nb_test_errors_target = compute_nb_errors_targets(model_target, test_input, test_classes)
    print('test error Net_target {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors_target) / test_input.size(0),
                                                    nb_test_errors_target, test_input.size(0)))'''
    
    

In [62]:
nb_train_errors_class = compute_nb_errors_classes(model_class, train_input, train_classes)
print('train error Net_classes {:0.2f}% {:d}/{:d}'.format((100 * nb_train_errors_class) / train_input.size(0),
                                                  nb_train_errors_class, train_input.size(0)))
nb_test_errors_class = compute_nb_errors_classes(model_class, test_input, test_classes)
print('test error Net_classes {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors_class) / test_input.size(0),
                                                nb_test_errors_class, test_input.size(0)))

nb_train_errors_target = compute_nb_errors_targets(model_target, model_class, train_input, train_classes)
print('train error Net_target {:0.2f}% {:d}/{:d}'.format((100 * nb_train_errors_target) / (train_input.size(0)/2),
                                                  nb_train_errors_target, int((train_input.size(0))/2)))
nb_test_errors_target = compute_nb_errors_targets(model_target, model_class, test_input, test_classes)
print('test error Net_target {:0.2f}% {:d}/{:d}'.format((100 * nb_test_errors_target) / int(test_input.size(0)/2),
                                                nb_test_errors_target, int(test_input.size(0)/2)))

tensor([9, 3, 3,  ..., 5, 1, 1])
train error Net_classes 90.20% 1804/2000
tensor([0, 3, 4,  ..., 6, 7, 4])
test error Net_classes 90.15% 1803/2000
torch.Size([1000])
train error Net_target 89.80% 898/1000
torch.Size([1000])
test error Net_target 89.60% 896/1000
