In [2]:
# Pseudo-Labeling-Noised-Student-Classifier

This Script trains a Teacher Net on labeled data. The teacher net is used to pseudo label some unlabeled data. On this expanded dataset the student net is trained. 

In [1]:
from __future__ import print_function
import numpy as np
import imageio
import os
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing

####################        Test if an GPU is available     #################################
if torch.cuda.is_available():     
    print('used GPU: ' + torch.cuda.get_device_name(0))
    dev = torch.device("cuda:0")
    kwar = {'num_workers': 8, 'pin_memory': True}
    cpu = torch.device("cpu")
    
else:
    print("Warning: CUDA not found, CPU only.")
    dev = torch.device("cpu")
    kwar = {}
    cpu = torch.device("cpu")

np.random.seed(551)



## Predefine some variables

In [2]:
####################        Predefine some Variables        #################################
dataDir = '../../resized'              # path to directory of medical MNIST images

validListStud = []                          # list for valid data
testListStud = []                           # list for test data
trainListStud = []                          # list for train data





#************ Teacher Net ****************************#
#**** Net Parameters ****#
numConvsTeach1   = 5                        # number of channels produced by the convolution
convSizeTeach1   = 7                        # size of the convolving kernel
numConvsTeach2   = 10                       # number of channels produced by the convolution
convSizeTeach2   = 7                        # size of the convolving kernel

fcSizeTeach1 = 400                          # size of sample
fcSizeTeach2 = 80                           # size of sample

#**** Training Parameters ****#
testSizeTeach= 0.8                          # fraction of images for test dataset

numEpochsTeach = 10                         # number of epochs
batchSizeTeach = 300                        # size of batches
lRateTeach = 0.001                          # learning rate of classifier
momentumTeach = 0.9                         # adds a proportion of the previous weight changes to the current weight changes

t2vRatioTeach = 1.2                         # Maximum allowed ratio of validation to training loss
t2vEpochsTeach = 3                          # Number of consecutive epochs before halting if validation loss exceeds above limit

#**** DataSet Parameters ****#
sameClassOccTeach = True                    # Every Class has same count of images in test and training set

#**** Pseudo labeling Parameters ****#
SmallerFactor = 0.6                         # Factor for calculating difference between first and second prediction 
pseudoX_train = []                          # list of train imagesdata
pseudoY_train = []                          # list of pseudo labeled train data
pseudoX_test = []
pseudoY_test = []



#************ Student Net ****************************#
#**** Net Parameters ****#
numConvsStud1   = 5                         # number of channels produced by the convolution
convSizeStud1   = 7                         # size of the convolving kernel
numConvsStud2   = 10                        # number of channels produced by the convolution
convSizeStud2   = 7                         # size of the convolving kernel

fcSizeStud1 = 400                           # size of sample
fcSizeStud2 = 80                            # size of sample

#**** Training Parameters ****#
testSizeStud = 0.8                          # fraction of images for test dataset

numEpochsStud = 10                          # number of epochs
batchSizeStud = 300                         # size of batches
lRateStud = 0.001                           # learning rate of classifier
momentumStud = 0.9                          # adds a proportion of the previous weight changes to the current weight changes

t2vRatioStud = 1.2                          # Maximum allowed ratio of validation to training loss
t2vEpochsStud = 3                           # Number of consecutive epochs before halting if validation loss exceeds above limit

#**** DataSet Parameters ****#
sameClassOccStud = True                     # Every Class has same count of images in test and training set


## Read and Scale and Prepare Images

In [3]:
####################        Read and Prepare Images         #################################
classNames = os.listdir(dataDir)                                            # Each type of image can be found in its own subdirectory
numClass = len(classNames)                                                  # Number of types = number of subdirectories
imageFiles = [[os.path.join(dataDir,classNames[i],x) for x in os.listdir(os.path.join(dataDir,classNames[i]))]
            for i in range(numClass)]                                       # nested list of filenames
numEach = [len(imageFiles[i]) for i in range(numClass)]                     # count of each type of image
imageFilesList = []                                                         # un-nested list of filenames
imageClass = []                                                             # The labels -- the type of each individual image in the list
if sameClassOccTeach == True:
    for i in range(numClass):
        imageFilesList.extend(imageFiles[i][:np.min(numEach)])
        imageClass.extend([i]*np.min(numEach))
else:
    for i in range(numClass):
        imageFilesList.extend(imageFiles[i])
        imageClass.extend([i]*numEach[i])

numTotal = len(imageClass)                                                  # Total number of images
imageWidth, imageHeight = Image.open(imageFilesList[0]).size                # The dimensions of each image

print("There are",numTotal,"images in",numClass,"distinct categories")
print("Label names:",classNames)
print("Label counts:",numEach)
print("Image dimensions:",imageWidth,"x",imageHeight)




####################        Store and Rescale Images               #################################
toTensor = torchvision.transforms.ToTensor()
def scaleImage(x):                                                          # Pass a PIL image, return a tensor
    y = toTensor(x)
    if(y.min() < y.max()):                                                  # Assuming the image isn't empty, rescale so its values run from 0 to 1
        y = (y - y.min())/(y.max() - y.min()) 
    z = y - y.mean()                                                        # Subtract the mean value of the image
    return z

imageTensor = torch.stack([scaleImage(Image.open(x)) for x in imageFilesList])  # Create image (X) tensor
classTensor = torch.tensor(imageClass)                                          # Create label (Y) tensor  
print("Rescaled min pixel value = {:1.3}; Max = {:1.3}; Mean = {:1.3}"
        .format(imageTensor.min().item(),imageTensor.max().item(),imageTensor.mean().item()))




####################        Seperate DataSet to Train/Test          #################################
x_train, x_test, y_train, y_test = train_test_split(imageTensor, classTensor, test_size=testSizeTeach, random_state=4, shuffle=True, stratify=classTensor)

for i in range(numClass):
    print(sum(y_train==i))                                                # check the count of images in every class

There are 53724 images in 6 distinct categories
Label names: ['HeadCT', 'AbdomenCT', 'BreastMRI', 'Hand', 'ChestCT', 'CXR']
Label counts: [10000, 10000, 8954, 10000, 10000, 10000]
Image dimensions: 64 x 64
Rescaled min pixel value = -0.774; Max = 0.972; Mean = -2.9e-09
tensor(1791)
tensor(1790)
tensor(1790)
tensor(1791)
tensor(1791)
tensor(1791)


## Define the Neural Networks

In [4]:
####################        Define the teacher neural network       #################################
class TeacherNet(nn.Module):
    def __init__(self,xDim,yDim,numC):
        super(TeacherNet, self).__init__()

        self.conv1 = nn.Conv2d(1,numConvsTeach1,convSizeTeach1)                       # first convolutional layer
        #self.pool = nn.MaxPool2d(2,2)                                      # max pooling layer
        self.conv2 = nn.Conv2d(numConvsTeach1,numConvsTeach2, convSizeTeach2)              # second convolutional layer

        self.fc1 = nn.Linear(numConvsTeach2*(xDim-(convSizeTeach1-1)-(convSizeTeach2-1))*
                             (yDim-(convSizeTeach1-1)-(convSizeTeach2-1)), fcSizeTeach1)    # first fully connected layer
        self.fc2 = nn.Linear(fcSizeTeach1,fcSizeTeach2)                               # second fully connected layer
        self.fc3 = nn.Linear(fcSizeTeach2,numClass)                              # third fully connected layer

    def forward(self, x):
        # x = self.pool(F.relu(self.conv1(x)))                              # first conv layer with relu activation function
        # x = self.pool(F.relu(self.conv2(x)))                              # second conv layer with relu activation function
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))                                             # first fc layer with relu activation function
        x = F.relu(self.fc2(x))                                             # second fc layer with relu activation function
        x = self.fc3(x)                                                     # output layer
        return x

    def num_flat_features(self, x):                                         # Count the individual nodes in a layer
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

In [5]:
####################        Define the student neural network       #################################
class StudentNet(nn.Module):
    def __init__(self,xDim,yDim,numC):
        super(StudentNet, self).__init__()

        self.conv1 = nn.Conv2d(1,numConvsStud1,convSizeStud1)               # first convolutional layer
        #self.pool = nn.MaxPool2d(2,2)                                      # max pooling layer
        self.conv2 = nn.Conv2d(numConvsStud1,numConvsStud2, convSizeStud2)  # second convolutional layer

        self.fc1 = nn.Linear(numConvsStud2*(xDim-(convSizeStud1-1)-(convSizeStud2-1))*(yDim-(convSizeStud1-1)-(convSizeStud2-1)), fcSizeStud1)    # first fully connected layer
        self.fc2 = nn.Linear(fcSizeStud1,fcSizeStud2)                       # second fully connected layer
        self.fc3 = nn.Linear(fcSizeStud2,numClass)                          # third fully connected layer

    def forward(self, x):
        # x = self.pool(F.relu(self.conv1(x)))                              # first conv layer with relu activation function
        # x = self.pool(F.relu(self.conv2(x)))                              # second conv layer with relu activation function
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))                                             # first fc layer with relu activation function
        x = F.relu(self.fc2(x))                                             # second fc layer with relu activation function
        x = self.fc3(x)                                                     # output layer
        return x

    def num_flat_features(self, x):                                         # Count the individual nodes in a layer
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


## Train the teacher model and pseudolabel dataset

In [6]:
####################        Training the teacher neural network       #################################
teachernet = TeacherNet(imageWidth,imageHeight,numClass).cuda()                    # create neural network
criterion = nn.CrossEntropyLoss()                                           # set criterion
optimizer = optim.SGD(teachernet.parameters(), lr=lRateTeach, momentum=momentumTeach) # set optimizer to squared gradient decent

trainBats = x_train.size()[0] // batchSizeTeach                             # Number of training batches per epoch. Round down to simplify last batch
testBats = -(-x_test.size()[0] // batchSizeTeach)                           # Testing batches. Round up to include all

for epoch in range(numEpochsTeach):
    epochLoss = 0
    # X is a torch Variable
    permutation = torch.randperm(x_train.size()[0])

    for i in range(0,x_train.size()[0], batchSizeTeach):
        optimizer.zero_grad()

        indices = permutation[i:i+batchSizeTeach]
        batch_x, batch_y = x_train[indices], y_train[indices]

        # in case you wanted a semi-full example
        outputs = teachernet.forward(batch_x.to(dev))
        loss = F.cross_entropy(outputs,batch_y.to(dev))
        epochLoss += loss.item()                                            # Add loss
        
        loss.backward()
        optimizer.step()

    print("Epoch = {:-3}; Training loss = {:.4f}".format(epoch,epochLoss))

print('Finished Training')

Epoch =   0; Training loss = 64.4446
Epoch =   1; Training loss = 64.1013
Epoch =   2; Training loss = 63.0299
Epoch =   3; Training loss = 55.7356
Epoch =   4; Training loss = 27.8125
Epoch =   5; Training loss = 13.6037
Epoch =   6; Training loss = 9.5318
Epoch =   7; Training loss = 7.9737
Epoch =   8; Training loss = 7.0403
Epoch =   9; Training loss = 6.3113
Finished Training


In [7]:
####################        Evaluate the teacher Net and Pseudo Label Data        #################################
confuseMtx = np.zeros((numClass,numClass),dtype=int)
evalMat = []

counter = 0
for j in range(len(x_test)):
    pred = teachernet(x_test[j].reshape(1,1,imageWidth,imageHeight).to(dev))
    
    # Is-Class - Predicted-CLass - Prediction-Value
    evalMat.append([int(y_test[j].cpu().numpy()), pred.max(1)[1].cpu().numpy()[0], pred.max(1)[0].cpu().detach().numpy()[0]])
    confuseMtx[int(y_test[j].cpu().numpy()),pred.max(1)[1].cpu().numpy()[0]] += 1
    secElemSmallEnough = np.sort(pred.cpu().detach().numpy())[0][-2]< np.sort(pred.cpu().detach().numpy())[0][-1]*SmallerFactor
    
    
    if (pred.max(1)[0].cpu().detach().numpy()[0] >= 10.0) and secElemSmallEnough:
        #print(str(counter) + ": "+ str(pred.max(1)[0].cpu().detach().numpy()[0]))
        pseudoX_train.append(x_test[j].numpy())
        pseudoY_train.append(pred.max(1)[1].cpu().numpy()[0])
        counter += 1

    else:
        pseudoX_test.append(x_test[j].numpy())
        pseudoY_test.append(y_test[j])

correct = sum([confuseMtx[i,i] for i in range(numClass)])   # Sum over diagonal elements to count correct predictions
percentage = correct/len(x_test)*100
print("Correct predictions: ",correct,"of",len(x_test),"(",percentage,"%)")
print(confuseMtx)    
#print(evalMat)

Correct predictions:  40905 of 42980 ( 95.1721731037692 %)
[[6642  108    1   33  379    0]
 [ 107 6607   86    0  364    0]
 [   0   22 7106    0   36    0]
 [ 161   18   11 6927   28   18]
 [  14  373   87   26 6663    0]
 [  14   16    9  124   40 6960]]


## Prepare test and training dataset for student training

In [8]:
####################        Noise the pseudo labeled data        #################################
pseudoX_test = torch.tensor(pseudoX_test)
pseudoY_test = torch.LongTensor(pseudoY_test)
pseudoX_train = torch.cat((x_train, torch.Tensor(pseudoX_train)), 0)
pseudoY_train = torch.cat((y_train, torch.LongTensor(pseudoY_train)), 0)

colorjitter = torchvision.transforms.ColorJitter(hue=.05, saturation=.05)
horizontalFlip = torchvision.transforms.RandomHorizontalFlip()
randomRotation = torchvision.transforms.RandomRotation(20, resample=Image.BILINEAR)


for img in range(len(pseudoX_train)):
    random = np.random.rand(1)
    if random <= 0.25:
        pass
        #pseudoX_train[img] = colorjitter(pseudoX_train[img])
    elif random <= 0.5 and random > 0.25:
        pseudoX_train[img] = horizontalFlip(pseudoX_train[img])
    elif random <= 0.75 and random >0.5:
        pseudoX_train[img] = randomRotation(pseudoX_train[img])
    else:
        pass

print("pseudoX_test: " + str(x_test.size()) + " Type: " + str(pseudoX_test.type()))
print("x_test: " + str(x_test.size()) + " Type: " + str(x_test.type()))
print("pseudoY_test: " + str(pseudoY_test.size()) + " Type: " + str(pseudoY_test.type()))
print("y_test: " + str(y_test.size()) + " Type: " + str(y_test.type()))

print("pseudoX_train: " + str(x_train.size()) + " Type: " + str(pseudoX_train.type()))
print("x_train: " + str(x_train.size()) + " Type: " + str(x_train.type()))
print("pseudoY_train: " + str(pseudoY_train.size()) + " Type: " + str(pseudoY_train.type()))
print("y_train: " + str(y_train.size()) + " Type: " + str(y_train.type()))

pseudoX_test: torch.Size([42980, 1, 64, 64]) Type: torch.FloatTensor
x_test: torch.Size([42980, 1, 64, 64]) Type: torch.FloatTensor
pseudoY_test: torch.Size([26078]) Type: torch.LongTensor
y_test: torch.Size([42980]) Type: torch.LongTensor
pseudoX_train: torch.Size([10744, 1, 64, 64]) Type: torch.FloatTensor
x_train: torch.Size([10744, 1, 64, 64]) Type: torch.FloatTensor
pseudoY_train: torch.Size([27646]) Type: torch.LongTensor
y_train: torch.Size([10744]) Type: torch.LongTensor


<br>
<br>

# Train the Student Network with Pseudo-Label DataSet
<br>
<br>

In [9]:
####################        Training the student neural network       #################################
studentnet = StudentNet(imageWidth,imageHeight,numClass).cuda()                           # create neural network
criterionStud = nn.CrossEntropyLoss()                                           # set criterion
optimizerStud = optim.SGD(studentnet.parameters(), lr=lRateStud, momentum=momentumStud)        # set optimizer to squared gradient decent

trainBats = pseudoX_train.size()[0] // batchSizeStud                                             # Number of training batches per epoch. Round down to simplify last batch
testBats = -(-pseudoX_test.size()[0] // batchSizeStud)                                           # Testing batches. Round up to include all

for epoch in range(numEpochsStud):
    epochLoss = 0
    # X is a torch Variable
    permutation = torch.randperm(pseudoX_train.size()[0])

    for i in range(0,pseudoX_train.size()[0], batchSizeStud):
        optimizerStud.zero_grad()

        indices = permutation[i:i+batchSizeStud]
        batch_x_stud, batch_y_stud = pseudoX_train[indices], pseudoY_train[indices]
        
        # in case you wanted a semi-full example
        outputsStud = studentnet.forward(batch_x_stud.to(dev))
        loss = F.cross_entropy(outputsStud,batch_y_stud.to(dev))
        epochLoss += loss.item()                                            # Add loss
        
        loss.backward()
        optimizerStud.step()


    print("Epoch = {:-3}; Training loss = {:.4f}".format(epoch,epochLoss))

print('Finished Training')

Epoch =   0; Training loss = 131.6592
Epoch =   1; Training loss = 23.5390
Epoch =   2; Training loss = 12.4433
Epoch =   3; Training loss = 9.6472
Epoch =   4; Training loss = 8.2877
Epoch =   5; Training loss = 7.3583
Epoch =   6; Training loss = 6.7528
Epoch =   7; Training loss = 6.1330
Epoch =   8; Training loss = 5.6397
Epoch =   9; Training loss = 5.2098
Finished Training


In [12]:
####################        Evaluate the student network on bad predicted test set       #################################
confuseMtx = np.zeros((numClass,numClass),dtype=int)
evalMat = []
for j in range(len(pseudoX_test)):
    #np.shape(x_test[j].reshape(1,1,64,64))
    #plt.figure()
    #plt.imshow((x_test[j]).reshape(64,64))
    #print(str(y_test[j]) + ": " + str(classNames[y_test[j]]))
    pred = studentnet(pseudoX_test[j].reshape(1,1,imageWidth,imageHeight).to(dev))
    #print(pred.max(1,keepdim=True))
    
    # Is-Class - Predicted-CLass - Prediction-Value
    evalMat.append([int(pseudoY_test[j].cpu().numpy()), pred.max(1)[1].cpu().numpy()[0], pred.max(1)[0].cpu().detach().numpy()[0]])
    confuseMtx[int(pseudoY_test[j].cpu().numpy()),pred.max(1)[1].cpu().numpy()[0]] += 1
correct = sum([confuseMtx[i,i] for i in range(numClass)])   # Sum over diagonal elements to count correct predictions
percentage = correct/len(pseudoX_test)*100
print("Correct predictions: ",correct,"of",len(pseudoX_test),"(",percentage,"%)")
print(confuseMtx)

Correct predictions:  24455 of 26078 ( 93.7763632180382 %)
[[4323   96    0   67  364    8]
 [  10 6784   64    5  117    0]
 [   0   14 2564    0   21    8]
 [  65   44   12 1869   25   20]
 [   0  450  105   24 6584    0]
 [   4   17    7   50   26 2331]]


In [13]:
####################        Evaluate the student net on whole test dataset       #################################
confuseMtx = np.zeros((numClass,numClass),dtype=int)
evalMat = []
for j in range(len(x_test)):

    pred = studentnet(x_test[j].reshape(1,1,imageWidth,imageHeight).to(dev))
    #print(pred.max(1,keepdim=True))
    
    # Is-Class - Predicted-CLass - Prediction-Value
    evalMat.append([int(y_test[j].cpu().numpy()), pred.max(1)[1].cpu().numpy()[0], pred.max(1)[0].cpu().detach().numpy()[0]])
    confuseMtx[int(y_test[j].cpu().numpy()),pred.max(1)[1].cpu().numpy()[0]] += 1
correct = sum([confuseMtx[i,i] for i in range(numClass)])   # Sum over diagonal elements to count correct predictions
percentage = correct/len(x_test)*100
print("Correct predictions: ",correct,"of",len(x_test),"(",percentage,"%)")
print(confuseMtx)

Correct predictions:  41342 of 42980 ( 96.18892508143323 %)
[[6623  101    0   67  364    8]
 [  10 6968   64    5  117    0]
 [   0   14 7121    0   21    8]
 [  66   46   15 6991   25   20]
 [   0  450  105   24 6584    0]
 [   4   21    7   50   26 7055]]


In [14]:
torch.save(teachernet, 'MedNIST_PseudoLabel_Teachernet')
torch.save(studentnet, 'MedNIST_PseudoLabel_Studentnet')