# Load data

In [1]:
import PIL.Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
import time
import copy

In [2]:
torch.__version__

In [3]:
torch.cuda.is_available()

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# torch.set_default_tensor_type(torch.cuda.FloatTensor)
torch.backends.cudnn.benchmark=True

## Get CSVs

In [3]:
train_df = pd.read_csv('../data/csvs/train.csv')
# test_df = pd.read_csv('../data/csvs/test.csv')

In [4]:
train_df.head(5)

Unnamed: 0.1,Unnamed: 0,file,lvl_three,lvl_one,lvl_two
0,0,1_220_F_83683073_O4yJOnarzTjKXuUBAgkAifmiC8d0I...,1,0,3
1,1,20_220_F_5292725_818KTy3xv82nEkNolcs2m37MOV86s...,20,1,5
2,2,20_220_F_47187567_lwYwc9UQtBK5Be6v4P7HNsCc4Hhr...,20,1,5
3,3,1_220_F_38932828_Osns7NBWCq8AhJonYpQArrToDLLhT...,1,0,3
4,4,1_220_F_97168737_y0VWy7kLMby9BO6lHDfpyfNpW9o0S...,1,0,3


# Split data

Scikit-learn definitely takes the cake for ease in stratified splitting. A helper function splits the training and validation sets for ease of input into a Pytorch dataset class.

In [9]:
def train_val_split(files, target, test_size, stratify=True):
    """
    
    """
    if stratify:
        X_train, X_val, y_train, y_val = train_test_split(files, target, test_size=test_size, stratify=target)
    else:
        X_train, X_val, y_train, y_val = train_test_split(files, target, test_size=test_size)
    train_split = pd.concat([X_train, y_train], axis = 1)
    val_split = pd.concat([X_val, y_val], axis = 1) 
    return train_split, val_split

In [10]:
train_split, val_split = train_val_split(train_df['file'], train_df['lvl_one'], test_size=0.2)

# Utils

In [81]:
# import numpy as np
# import torch

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience.
    
    Developed by: https://github.com/Bjarten/early-stopping-pytorch
    
    """
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
#             self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
#             self.save_checkpoint(val_loss, model)
            self.counter = 0

#     def save_checkpoint(self, val_loss, model):
#         '''Saves model when validation loss decrease.'''
#         if self.verbose:
#             self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
#         torch.save(model.state_dict(), self.path)
#         self.val_loss_min = val_loss

In [82]:
def train_model(model, dataloader, criterion, optimizer, num_epochs=25):
    since = time.time()
    
    train_loss_history = []
    train_acc_history = []
    val_loss_history = []
    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=5, verbose=True)
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        if early_stopping.early_stop:
                    print("Early stopping")
                    break

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloader[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloader[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
#                 best_model_wts = copy.deepcopy(model.state_dict())
                print('saving model')
                save_model('best_model.tar', epoch, model, optimizer, epoch_loss)
            # save metrics
            if phase == 'train':
                train_acc_history.append(epoch_acc)
                train_loss_history.append(epoch_loss)
            if phase == 'val':
                val_acc_history.append(epoch_acc)
                val_loss_history.append(epoch_loss)
                early_stopping(epoch_loss, model)

                if early_stopping.early_stop:
                    print("Early stopping")
                    break
            #early stopping
            
        print()
        

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, val_acc_history

In [70]:
def save_model(PATH, epoch, model, optimizer, loss):
    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
                }, PATH)

# Dataset & Dataloaders

In [11]:
class ImgDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, df, root_dir, percent_sample=None, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.img_df=df
        self.root_dir=root_dir
        self.transform=transform
        self.percent_sample=percent_sample

    def __len__(self):
        if self.percent_sample:
            assert self.percent_sample > 0.0, 'Percentage to sample must be >= 0 and <= 1.'
            assert self.percent_sample <= 1.0, 'Percentage to sample must be >= 0 and <= 1.'
            return int(np.floor(len(self.img_df) * self.percent_sample))
        else:
            return len(self.img_df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
            
        img_name = os.path.join(self.root_dir,
                                self.img_df.iloc[idx, 0])
        X = PIL.Image.open(img_name).convert('RGB') #Some images in greyscale, so converting to ensure 3 channels - 1 causes issues in transformers
        y = self.img_df.iloc[idx, 1]
        
        if self.transform:
            X = self.transform(X)

        return X, y

In [53]:
train_transforms = transforms.Compose([ 
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], #OG means/sds from imagenet
                         std=[0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([ 
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [54]:
train_dataset = ImgDataset(df=train_split,
                           root_dir='../data/images/train',
                           percent_sample=.0001,
                           transform=train_transforms
                          )
val_dataset = ImgDataset(df=val_split,
                           root_dir='../data/images/train',
                           percent_sample=.0001,
                           transform=val_transforms
                          )

In [55]:
train_loader = DataLoader(train_dataset,
                          batch_size=6, 
                          shuffle=True,
                          pin_memory=True,
                          num_workers=12)
val_loader = DataLoader(val_dataset, 
                        batch_size=6,
                        pin_memory=True,
                        num_workers=12)

In [56]:
loaders_dict = {'train': train_loader, 
                'val': val_loader}

# Training

In [16]:
num_classes = 2
batch_size = 128
num_epochs = 30

In [66]:
model_ft = models.resnet50(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, num_classes))
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer_ft = torch.optim.SGD(model_ft.parameters(), 
                               lr=0.01, 
                               momentum=0.9, 
                               weight_decay=0.0001)

In [None]:
# model_ft.fc = nn.Sequential(nn.Linear(num_ftrs, 100), #num_ftrs = OG classifier input size
#                             nn.ReLU(),
#                             nn.Linear(100, num_classes)
#                            )

In [20]:
torch.cuda.empty_cache()

In [75]:
# Run the functions and save the best model in the function model_ft.
model_ft, val_acc = train_model(model_ft, loaders_dict, criterion, optimizer_ft,
                       num_epochs)

Epoch 0/19
----------
train Loss: 0.8767 Acc: 0.7647
val Loss: 85.7873 Acc: 0.5000
saving model
Validation loss decreased (inf --> 85.787308).  Saving model ...

Epoch 1/19
----------
train Loss: 1.2866 Acc: 0.5294
val Loss: 172.1930 Acc: 0.5000
EarlyStopping counter: 1 out of 5

Epoch 2/19
----------
train Loss: 0.3869 Acc: 0.8235
val Loss: 98.7412 Acc: 0.5000
EarlyStopping counter: 2 out of 5

Epoch 3/19
----------
train Loss: 1.7515 Acc: 0.5294
val Loss: 25.8205 Acc: 0.5000
Validation loss decreased (85.787308 --> 25.820492).  Saving model ...

Epoch 4/19
----------


KeyboardInterrupt: 

In [78]:
# checkpoint = torch.load('best_model.tar')
# model_ft.load_state_dict(checkpoint['model_state_dict'])
# # optimizer_ft.load_state_dict(checkpoint['optimizer_state_dict'])
# epoch = checkpoint['epoch']
# loss = checkpoint['loss']