# Binary Classification Training

This Notebook contains the myDataset.py code and the training loop which were used to create the binary classification training.
The model which was used is a derivation of the UNet implementation by https://github.com/milesial/Pytorch-UNet

Unet parts, classes that are needed for the network architecture

In [4]:
import torch
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
from PIL import Image
from torchvision import transforms
import random
import numpy as np
import torch


In [5]:
""" Parts of the U-Net model """

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels

        conv_layer1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False)

        conv_layer2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False)


        self.double_conv = nn.Sequential(
            conv_layer1,# bias was false
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            conv_layer2,# bias was false
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

    def normal_init(m, mean, std):
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            m.weight.data.normal_(mean, std)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)



UNet architecture with initialization and the forward pass. As is visible, the decoder part was omitted in favour of a smaller model that yields a single digit output.


In [6]:

"""
Basis was a standard UNet as proposed by Ronneberger, Fischer and Brox.
To generate a single value decision, the decoder part was deleted in favour of a dense linear output layer, followed by a sigmoid function.
"""


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 3)
        self.down1 = Down(3, 6)
        self.down2 = Down(6, 12)
        self.lin = torch.nn.Linear(12*120*160, 1)


    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x3 = x3.flatten(start_dim=1)
        logit = self.lin(x3)
        return torch.nn.functional.sigmoid(logit)

The myDataset.py file is separated in myDataset_train and myDataset_val, as for the validation no data augmentation should be performed.

In [7]:


class myDataset_train:

    def __init__(self, data_path):

        self.data_path = data_path

    def __len__(self):

        return len(os.listdir(self.data_path))

    def __getitem__(self, i):

        # image call
        img = Image.open(os.path.join(self.data_path, 'img_{}.jpg'.format(i)))

        # transformation to torch.Tensor
        img_array = np.array(img)
        img_tensor = torch.tensor(img_array)
        img_tensor = torch.unsqueeze(img_tensor,0)
        #
        # # strong data augmentation to compensate for little training data
        degrees = random.uniform(-10.0, 10.0)
        translate_x = random.uniform(-1, 1)
        translate_y = random.uniform(-1, 1)
        scale = random.uniform(0.9, 1.1)

        img_tensor = transforms.functional.affine(img=img_tensor, shear=0, angle=degrees, translate=(translate_x, translate_y), scale=scale, interpolation=transforms.InterpolationMode.BILINEAR, fill=0)

        # ids of images with fractures
        fractures = range(12)

        # creation of target value
        if i in fractures:
            target = torch.tensor(1)
        else:
            target = torch.tensor(0)


        return img_tensor, target


class myDataset_val:

    def __init__(self, data_path):

        self.data_path = data_path

    def __len__(self):

        return len(os.listdir(self.data_path))

    def __getitem__(self, i):

        img = Image.open(os.path.join(self.data_path, 'img_{}.jpg'.format(i)))

        img_array = np.array(img)
        img_tensor = torch.tensor(img_array)
        img_tensor = torch.unsqueeze(img_tensor, 0)


        fractures = [0]

        if i in fractures:
            target = torch.tensor(1)
        else:
            target = torch.tensor(0)



        return img_tensor, target



The following cell contains the training loop for the classification with a final check on validation data

In [10]:

# file paths to data set

file_path = "./assets/hand_train"

file_path_val = "./assets/hand_train"

dataset = myDataset_train(file_path)

val_set = myDataset_val(file_path_val)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

dataloader_val = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=True)

images, target = next(iter(dataloader))


# call of model architecture
model = UNet(n_channels=1, n_classes=2)

# select training parameters
learning_rate = 1e-5
num_epochs = 5

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.999),
    )

loss_func = torch.nn.MSELoss()

loss_history = []


# training loop
for epoch in range(num_epochs):

    print('Epoch: {}'.format(epoch))

    train_loss_history = []

    for image, target in dataloader:

        image = image.float()

        target = target.float()

        optimizer.zero_grad()

        outputs = model(image)

        loss = loss_func(outputs, target)

        loss_value = loss.item()

        train_loss_history.append(loss_value)

        loss.backward()

        optimizer.step()

    epoch_loss = np.mean(train_loss_history)

    loss_history.append(epoch_loss)

    print(epoch_loss)

total = 0
correct = 0

# analysis of performance on training image after completed training
for image, target in dataloader:

    total += 1

    image = image.float()
    target = target.float()
    outputs = model(image)

    if outputs>0.5:
        outputs = 1
    else:
        outputs = 0

    if target == outputs:
        correct += 1
        print('Correct prediction {}'.format(outputs))
    else:
        print('Wrong. Predicted {}, should have been {}'.format(outputs, target))

print('{} out of {} train images were correctly classified.'.format(correct, total))


model_location = "./assets/models/" + 'classifier_weights'

torch.save(model.state_dict(), model_location)



# analysis on validation images to check for successful generalization
print('Validation Performance:')
for image, target in dataloader_val:

    image = image.float()
    target = target.float()
    outputs = model(image)

    if outputs>0.5:
        outputs = 1
    else:
        outputs = 0

    if target == outputs:
        correct += 1
        print('Correct prediction {}'.format(outputs))
    else:
        print('Wrong. Predicted {}, should have been {}'.format(outputs, target))








Epoch: 0


  return F.mse_loss(input, target, reduction=self.reduction)


0.29743240866144854
Epoch: 1
0.21588283529771227
Epoch: 2
0.22163926528633704
Epoch: 3
0.1807943430130503
Epoch: 4
0.14990449398090797
Correct prediction 1
Correct prediction 0
Wrong. Predicted 0, should have been tensor([1.])
Correct prediction 0
Correct prediction 0
Correct prediction 0
Wrong. Predicted 0, should have been tensor([1.])
Wrong. Predicted 0, should have been tensor([1.])
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 1
Wrong. Predicted 0, should have been tensor([1.])
Correct prediction 0
Correct prediction 0
Correct prediction 1
Correct prediction 1
Correct prediction 1
Correct prediction 0
Correct prediction 1
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 1
Correct prediction 0
Correct prediction 0
Correct prediction 0
Correct prediction 1
Correct prediction 0
Correct prediction 0
31 out of 35 t