In [None]:
import pandas as pd
from PIL import Image
import numpy as np
import torch
from torchvision import transforms, models

In [None]:
# define global constants

# link to the input csv / xlsx
g_InputSummaryPath_str = r'Database\EgoHand_Separated_PAIs\Label.xlsx'

# how many epochs would be trained
# note: for each epoch, 1/n of data would be used for validation (different part for each epoch), and the rest data would be used for training
g_NumEpoch_int = 5

# since the data are not "nature" objects, pre-calulate the mean and std for normalization
# see the notebook NormalizationSupport
g_NormMean_lst = [0.7406573632924783, 0.7232460663149365, 0.6959515512913943]
g_NormStd_lst = [0.22704669694242244, 0.24139483459696767, 0.2690579683428062]

# average size of input data; will be used for input data scalling; depending on your screen size!
# see the notebook NormalizationSupport
g_AvgW_int = 83
g_AvgH_int = 136
# scale factor and final input image size which fulfill the 224x224 requirement
g_ScaleFactor_int = max(224//g_AvgW_int, 224//g_AvgH_int)+1 # in case original input size already big enough, the factor would be one
g_UpSampledW_int = g_ScaleFactor_int*g_AvgW_int
g_UpSampledH_int = g_ScaleFactor_int*g_AvgH_int

In [None]:
def MyLoader(f_Path_str):
    global g_ScaleFactor_int
    
    # open the picture with PIL; rescale it to [0, 1] and convert to tensor
    # also do the upsample in this part, since each picture has different size
    
    # open the picture
    l_im_img = Image.open(f_Path_str)
    
    # upsample
    if g_ScaleFactor_int > 1: 
        l_im_img = l_im_img.resize((l_im_img.size[0]*g_ScaleFactor_int, l_im_img.size[1]*g_ScaleFactor_int), resample=Image.NEAREST)
    
    # to tensor (implicated also rescale to [0, 1])
    # note: ToTensor is a class!
    l_im_tsr = transforms.ToTensor()(l_im_img)

    # close the image
    l_im_img.close()
    
    # return
    return l_im_tsr

In [None]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, imgpaths, labels, loader=MyLoader, transform=None):
        self.images = imgpaths
        self.target = labels
        self.loader = loader
        self.transform = transform

    def __getitem__(self, index):
        path = self.images[index]
        img = self.loader(path)
        target = self.target[index]
        if self.transform: img = self.transform(img)
        
        return img,target

    def __len__(self):
        return len(self.images)

In [None]:
# main part

# set random seed
torch.manual_seed(7)

# set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# initialze the model
# get the ResNet18 pretrained model as base and replace the last fully connected layer with custom output
l_Net = models.resnet18(pretrained=True, progress=False)
#l_Net = models.resnet18(pretrained=False, progress=False)
in_ftr  = l_Net.fc.in_features
out_ftr = 34 # 34 different PAIs available
l_Net.fc = torch.nn.Linear(in_ftr,out_ftr,bias=True)
#l_Net.load_state_dict(torch.load(r'Model\20210317_Trial1.pth'))
l_Net.to(device)

# set loss function and optimizer
l_LossFcn = torch.nn.CrossEntropyLoss()
l_Optim = torch.optim.SGD(l_Net.parameters(), lr=0.001, momentum=0.9)
#l_Optim = torch.optim.SGD(l_Net.parameters(), lr=0.00001, momentum=0.9)

# open the summary sheet
l_Summary_df = pd.read_excel(g_InputSummaryPath_str, index_col=None)
# get all file path and corresponding label
l_FilePath_lst = l_Summary_df['File_Name'].tolist()
l_Label_lst = l_Summary_df['Label_ID'].tolist()

# loop over all epoch
for l_Epoch_int in range(g_NumEpoch_int):
    # get the training set and validation set for this epoch
    l_TrainingData_lst = l_FilePath_lst.copy()
    l_TrainingLabel_lst = l_Label_lst.copy()
    l_ValidationData_lst = []
    l_ValidationLabel_lst = []
    l_NumValidationData_int = len(l_FilePath_lst) // g_NumEpoch_int
    l_StartIndexValidationData_int = l_Epoch_int * l_NumValidationData_int
    for i in range(l_NumValidationData_int):
        if l_StartIndexValidationData_int < len(l_TrainingData_lst): # still data left
            l_ValidationData_lst.append(l_TrainingData_lst.pop(l_StartIndexValidationData_int))
            l_ValidationLabel_lst.append(l_TrainingLabel_lst.pop(l_StartIndexValidationData_int))
    
    # initialize training dataset
    l_TrainingData_DaSt = MyDataset(l_TrainingData_lst, l_TrainingLabel_lst, \
                                    transform = transforms.Compose([transforms.CenterCrop((g_UpSampledH_int, g_UpSampledW_int)), \
                                                                    transforms.Normalize(g_NormMean_lst, g_NormStd_lst) \
                                                                    ]))
    # put into data loader
    l_TrainingData_DaLder = torch.utils.data.DataLoader(l_TrainingData_DaSt, batch_size=4, shuffle=False)

    # start training
    running_loss = 0.0
    for i, data in enumerate(l_TrainingData_DaLder, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        l_Optim.zero_grad()

        # forward + backward + optimize
        outputs = l_Net(inputs)
        loss = l_LossFcn(outputs, labels)
        loss.backward()
        l_Optim.step()

        # print statistics
        running_loss += loss.item()
        if i % 100 == 99:
            print(running_loss / 100)
            running_loss = 0.0

In [None]:
#torch.save(l_Net.state_dict(), r'Model\*****.pth')