In [None]:
import warnings
warnings.filterwarnings('ignore')
import time
import numpy as np
import os
import shutil
import torch
import torchvision
import torch.backends.cudnn as cudnn
import torch.utils.data
from torchvision import transforms as T

In [None]:
from PIL import Image
from tqdm import tqdm_notebook as tqdm
from pprint import PrettyPrinter
from torch_source import *

import xml.etree.ElementTree as ET

In [None]:
import vision.torchvision.models.detection.faster_rcnn as MOD
import vision.torchvision.models.detection.backbone_utils as backbone_utils
from vision.torchvision.models import resnet

In [None]:
import pandas as pd

In [None]:
def train(train_loader, model, optimizer, epoch, on_dev):
    """
    One epoch's training.

    :param train_loader: DataLoader for training data
    :param model: model
    :param criterion: MultiBox loss
    :param optimizer: optimizer
    :param epoch: epoch number
    """
    model.train()  # training mode enables dropout

    batch_time = AverageMeter()  # forward prop. + back prop. time
    data_time = AverageMeter()  # data loading time
    losses = AverageMeter()  # loss

    start = time.time()

    # Batches
    for i, (images, annotations) in enumerate(tqdm(train_loader, 'Training')): # tqdm(train_loader, 'Training')
        data_time.update(time.time() - start)

        # Move to default device
        if on_dev:
            if isinstance(images, type(torch.tensor(1))):
                images = images.to(device)
            else:
                images = [img.to(device) for img in images]
            annotations = [{key: value.to(device) for key, value in annos.items()} for annos in annotations]
        else:
            images = images.cuda()
            annotations = [{key: value.cuda() for key, value in annos.items()} for annos in annotations]

        # Forward prop. Loss
        loss = model(images,annotations)
        total_los = 3*loss['loss_classifier' ] + \
                    loss['loss_box_reg'    ] + \
                    3*loss['loss_objectness' ] + \
                    loss['loss_rpn_box_reg']
        # total_los = total_los.sum()
            
        # Backward prop.
        optimizer.zero_grad()
        total_los.backward()

        # Update model
        optimizer.step()
        
        losses.update(total_los.item(), len(images)) # .size(0)
        batch_time.update(time.time() - start)

        start = time.time()
        
        # Print status
        if (i!=0) & (i % print_freq == 0):
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(epoch, i, 
                                                                  len(train_loader),
                                                                  batch_time = batch_time,
                                                                  data_time = data_time, 
                                                                  loss = losses))
        # input()
    del images, annotations  
    
    # free some memory since their histories may be stored
    
def validate(val_loader, model, on_dev):
    """
    One epoch's validation.

    :param val_loader: DataLoader for validation data
    :param model: model
    :param criterion: MultiBox loss
    :return: average validation loss
    """
    # model.eval()  # eval mode disables dropout
    batch_time = AverageMeter()
    losses = AverageMeter()

    start = time.time()

    # Prohibit gradient computation explicity because I had some problems with memory
    with torch.no_grad():
        # Batches
        for i, (images, annotations) in enumerate(tqdm(val_loader,'Validation')):
            # Move to default device
            if on_dev:
                if isinstance(images, type(torch.tensor(1))):
                    images = images.to(device)
                else:
                    images = [img.to(device) for img in images]
                annotations = [{key: value.to(device) for key, value in annos.items()} for annos in annotations]
            else:
                images = images.cuda()
                annotations = [{key: value.cuda() for key, value in annos.items()} for annos in annotations]
            # Loss
            loss = model(images,annotations)
#             print("loss['loss_classifier' ] = ",loss['loss_classifier' ],
#                   "loss['loss_box_reg'    ] = ",loss['loss_box_reg'    ],
#                   "loss['loss_objectness' ] = ",loss['loss_objectness' ],
#                   "loss['loss_rpn_box_reg'] = ",loss['loss_rpn_box_reg'])
            
            total_los = loss['loss_classifier' ] + \
                        loss['loss_box_reg'    ] + \
                        loss['loss_objectness' ] + \
                        loss['loss_rpn_box_reg']
            #total_los = total_los.sum()
            losses.update(total_los.item(), len(images)) # images.size(0)
            batch_time.update(time.time() - start)

            start = time.time()

            # Print status
            if (i!=0) & (i % print_freq == 0):
                print('[{0}/{1}]\t'
                      'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format(i, len(val_loader),
                                                                      batch_time=batch_time,
                                                                      loss=losses))

    print('\n * LOSS - {loss.avg:.3f}\n'.format(loss=losses))

    return losses.avg

In [None]:
from pprint import PrettyPrinter
pp = PrettyPrinter()

def measure_mAP(model,dataloader_test, data_folder, data_size):
    folder_path = os.path.abspath(os.getcwd())
    destin_path = 'chkpnts'
    destin_folder_path = os.path.join(folder_path, destin_path, data_size)

    if not os.path.exists(destin_folder_path):
        os.mkdir(destin_folder_path)
    
    try:
        database = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
    except:
        print('No such file')
        database = pd.DataFrame()

    with open(os.path.join(data_folder,'label_map.json'), 'r') as j:
            label_map = json.load(j)
    rev_label_map = {v: k for k, v in label_map.items()}
    
    files_to_mAP_calc = []
    for file in os.listdir():
        if ('BAGS_pretren' in file) & ('BEST_' not in file):
            files_to_mAP_calc.append(file)
    files_to_mAP_calc.sort()
    if len(files_to_mAP_calc)!=0:
        for checkpoint_filename in files_to_mAP_calc:

            shutil.move(os.path.join(folder_path, checkpoint_filename), 
                        os.path.join(destin_folder_path, checkpoint_filename))

            print('processing on file ', checkpoint_filename,'...')
            checkpoint = torch.load(os.path.join(destin_folder_path, checkpoint_filename),map_location='cpu')
            model.load_state_dict(checkpoint['model'])
            del checkpoint
            APs, mAP = evaluate(dataloader_test, model, data_folder=data_folder, device = device)
            
            mAP = pd.DataFrame([{'checkpoint':checkpoint_filename, 'mAP':mAP}])
            newline = pd.concat([mAP, pd.DataFrame([APs])], axis = 1)
            database = pd.concat([database, newline])
        database.to_pickle(os.path.join(data_folder,'mAPs.pkl'))
    return newline

In [None]:
def main(dev = None):
    """
    Training and validation.
    """
    global epochs_since_improvement, start_epoch, label_map, best_loss, epoch, checkpoint
    # Initialize model or load checkpoint
    model_code = '''
        backbone_fpn = backbone_utils.resnet_fpn_backbone('trident_resnet50',pretrained=False)
        model = MOD.FasterRCNN(backbone = backbone_fpn, num_classes= 21)'''
    rpn_pre_nms_top_n_train=12000 
    rpn_pre_nms_top_n_test=12000
    rpn_post_nms_top_n_train=500 
    rpn_post_nms_top_n_test=500
    box_detections_per_img = 20
    backbone_fpn = backbone_utils.resnet_fpn_backbone('trident_resnet50', 
                                                      pretrained=False, 
                                                      input_channels = 3)#, progress = True)
    model = MOD.FasterRCNN(backbone = backbone_fpn, num_classes = 21, 
                          rpn_pre_nms_top_n_train  = rpn_pre_nms_top_n_train,
                          rpn_pre_nms_top_n_test   = rpn_pre_nms_top_n_test,
                          rpn_post_nms_top_n_train = rpn_post_nms_top_n_train,
                          rpn_post_nms_top_n_test  = rpn_post_nms_top_n_test,
                          box_detections_per_img = box_detections_per_img
                          )
    if checkpoint is None:
        model_dict  = model.state_dict()
        print('No Checkpoint')
        try:
            df = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
            data = df[df.mAP == df.mAP.max()]
            checkpoint = os.path.join(data_size,data.iloc[0][0])
            
            print('checkpoint - ', checkpoint)
            BEST_VOC = 0
            if not os.path.exists('chkpnts',data_size, checkpoint):
                print('file is not exists! I"m going to use BEST_VOC0712.pth.tar')
                checkpoint = 'BEST_VOC0712.pth.tar'
                BEST_VOC = 1
            else:
                print("I'm going to use ",checkpoint)
        except:
            checkpoint = 'BEST_VOC0712.pth.tar'
            BEST_VOC = 1
            
        faster_dict = torch.load(os.path.join('chkpnts',checkpoint), map_location='cpu')['model']
        for key_z in model_dict:
            if key_z in faster_dict:
                model_dict[key_z] = faster_dict[key_z]
        model.load_state_dict(model_dict)
        if BEST_VOC ==1:
            model.roi_heads.box_predictor.cls_score = torch.nn.Linear(
                                    in_features = model.roi_heads.box_predictor.cls_score.in_features, 
                                    out_features=n_classes)
            model.roi_heads.box_predictor.bbox_pred = torch.nn.Linear(
                                    in_features = model.roi_heads.box_predictor.bbox_pred.in_features, 
                                    out_features=n_classes * 4)
        biases = list()
        not_biases = list()
        for param_name, param in model.named_parameters():
            if param.requires_grad:
                if param_name.endswith('.bias'):
                    biases.append(param)
                else:
                    not_biases.append(param)
        optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': lr}, {'params': not_biases}],
                                    lr=lr, momentum=momentum, weight_decay=weight_decay)
        learning_rate = lr
    else:
        
        checkpoint = torch.load(checkpoint, map_location='cpu')
        start_epoch = checkpoint['epoch'] 
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        best_loss = checkpoint['best_loss']
        print('\nLoaded checkpoint from epoch %d. Best loss is %.3f.\n' % (start_epoch, best_loss))
        print('Do you want to load LR from checkpoint? 1 - Yes, 2 - No')
        decision = int(input())
        if decision == 1:
            learning_rate = checkpoint['lr'];
        else:
            learning_rate = lr
        
        optimizer = checkpoint['optimizer']
        model.roi_heads.box_predictor.cls_score = torch.nn.Linear(
                                    in_features = model.roi_heads.box_predictor.cls_score.in_features, 
                                    out_features=n_classes)
        model.roi_heads.box_predictor.bbox_pred = torch.nn.Linear(
                                    in_features = model.roi_heads.box_predictor.bbox_pred.in_features, 
                                    out_features=n_classes * 4)        
        model.load_state_dict(checkpoint['model'])
        del checkpoint

 
    
    # Epochs
    
    num_to_change = epochs/10
    prev_esi = ''
    
    for i, epoch in enumerate(range(start_epoch, epochs)):
        # choose best checkpoint
        
        dataset_train = PascalVOCDataset(data_folder = data_folder, 
                                 split = 'train',
                                 transforms = get_transform(train=True), max_size = 400) 
    
        dataset_val   = PascalVOCDataset(data_folder = './data/scans/FINAL_TEST/he/',
                                         split = 'test',
                                         # data_folder = data_folder, 
                                         # split = 'val',
                                         transforms = get_transform(train=False), max_size = 400) 

        dataset_test  = PascalVOCDataset(data_folder = './data/scans/FINAL_TEST/he/',
                                         # data_folder = data_folder, 
                                         split = 'test',
                                         transforms = get_transform(train=False))

        dataloader_train = torch.utils.data.DataLoader(dataset_train, 
                                                       batch_size = batch_size,
                                                       shuffle=True, 
                                                       num_workers=4, 
                                                       collate_fn=collate_fn) 

        dataloader_val   = torch.utils.data.DataLoader(dataset_val, 
                                                       batch_size = batch_size,
                                                       shuffle=True, 
                                                       num_workers=4,
                                                       collate_fn=collate_fn)

        dataloader_test   = torch.utils.data.DataLoader(dataset_test, 
                                                       batch_size = test_batch_size, 
                                                       shuffle=False, 
                                                       num_workers=4,
                                                       collate_fn=collate_fn) 
        
        ch_changer = 0
        if i%(num_to_change/5)==0:
            ch_changer  = 1
            print('Is the way exists? - ',os.path.exists(os.path.join(data_folder,'mAPs.pkl')))
            if os.path.exists(os.path.join(data_folder,'mAPs.pkl')):
                database = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
                chpt = database[database.mAP==database.mAP.max()].checkpoint[0]
                print('chpt = ',chpt)
                # reload model weights
                checkpoint = torch.load(os.path.join('chkpnts',data_size,chpt), map_location='cpu')

                diff_obj = {}
                rez = database[database.mAP == database.mAP.max()][database.columns[2:]].values[0]
                for i, f in enumerate(rez):
                    diff_obj[i+1] = f

                start_epochs             = checkpoint['epoch']
                epochs_since_improvement = checkpoint['epochs_since_improvement']
                best_losses              = checkpoint['best_loss']
    #             learning_rate            = checkpoint['lr'];     
                model.load_state_dict(checkpoint['model'])
                del checkpoint
                print('Loaded BEST checkpoint from epoch %d. Best loss is %.3f.\n' % (start_epochs, best_losses))  
        lr_changer = 0
        paramparam =  epoch - start_epoch   
        if (paramparam!=0)&(paramparam % num_to_change == 0):
            lr_changer = 1
            learning_rate = learning_rate*0.75;   print('PRES_learning_rate = ', learning_rate)
            print('PAST_learning_rate = ', learning_rate)
        if (lr_changer == 1)|(ch_changer == 1):
            biases = list()
            not_biases = list()
            for param_name, param in model.named_parameters():
                if param.requires_grad:
                    if param_name.endswith('.bias'):
                        biases.append(param)
                    else:
                        not_biases.append(param)
            optimizer = torch.optim.SGD(params=[{'params': biases, 'lr': learning_rate}, {'params': not_biases}],
                                    lr=learning_rate, momentum=momentum, weight_decay=weight_decay)

        if dev is not None:
            device = dev
            on_dev = True
            model = model.to(device)
        else:
            print('There is NOT any device!')
            on_dev = False
            cudnn.benchmark = True
            net = torch.nn.DataParallel(model) 
            model = net.cuda()

        # One epoch's training
        train(train_loader=dataloader_train, model=model, optimizer=optimizer, epoch=epoch, on_dev = on_dev)
        # One epoch's validation
        val_loss = validate(val_loader=dataloader_val,model=model, on_dev = on_dev)
        
        # Did validation loss improve?
        is_best = val_loss < best_loss
        best_loss = min(val_loss, best_loss)

        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement,))

        else:
            epochs_since_improvement = 0
            
        pp = PrettyPrinter()
        APs, mAP = 1, 1
        # Save checkpoint
        model_Odict = []
        for name in model.state_dict():
            tensor = model.state_dict()[name].to(torch.device('cpu'))
            model_Odict.append((name, tensor))

        save_checkpoint(epoch, epochs_since_improvement, model_code, OrderedDict(model_Odict), 
                    optimizer, learning_rate, val_loss, best_loss, is_best, checkpoint_filename)
        del model_Odict
        
        # measure mAP
        last_rez = measure_mAP(model,dataloader_test, data_folder, data_size)
        
        # dataset resampling
        if os.path.exists(os.path.join(data_folder,'mAPs.pkl')):
            tmp_df = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
            flag = 1
            div_coeff = float(data_size.split('_')[-1])
            if float(last_rez.mAP[0])>.5:
                resample_dataset(0.97, last_rez, div_coeff, wtf=False, random_choice = True)  
                
            df = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
            df.sort_values(by='mAP', ascending=False, inplace=True)
            df.head(1).to_pickle(os.path.join(data_folder,'mAPs.pkl'))
            df = pd.read_pickle(os.path.join(data_folder,'mAPs.pkl'))
            # remove checkpoints with low mAP
            for file in os.listdir('chkpnts/'+data_size):
                if file not in df.checkpoint.tolist():
                    os.remove(os.path.join('chkpnts',data_size,file))

In [None]:
# dataset selection
folder = './data/scans/'
dataFolder_list = os.listdir(folder)
dataFolder_list.sort()
for i, element in enumerate(dataFolder_list):
    print(i, element)
    
flag = 0
avalible_indexes = [i for i in range(len(dataFolder_list))]
while flag !=1:
    try:
        num = int(input())
        if num in avalible_indexes:
            flag = 1
        else:
            print('Wrong! try again!')
    except:
        print('Wrong! try again!')
print('\nDataset',dataFolder_list[num],' selected!')

data_size = dataFolder_list[num]
print('data_size = ', data_size)
data_folder = folder+data_size  # folder with data files(PATHs to images and annotations)
print('data_folder = ', data_folder)

free_device = 0
device = torch.device('cuda:{}'.format(free_device))

# NUM classes in dataset
n_classes = 9

# checking for pretrained models or checkpoints
checkpoint_filename = 'BAGS_pretren_bc_'+data_size+'.pth.tar'
if os.path.exists('BEST_'+ checkpoint_filename):
    checkpoint = 'BEST_'+ checkpoint_filename
    print('checkpoint exists')
else:
    checkpoint = None         
    print("there isn't any checkpoint")

In [None]:
batch_size = 1 #---------------> batch size
test_batch_size = 8 #----------> test_batch_size
start_epoch = 0 #--------------> start at this epoch
epochs = 300 #-----------------> number of epochs to run without early-stopping
epochs_since_improvement = 0 #-> number of epochs since there was an improvement in the validation metric
best_loss = 100 #--------------> assume a high loss at first
workers = 4 #------------------> number of workers for loading data in the DataLoader
print_freq = 200 #-------------> print training or validation status every __ batches
lr = 1e-3 #--------------------> learning rate
momentum = 0.9 #---------------> momentum
weight_decay = 5e-4 #----------> weight decay
keep_difficult = True 

In [None]:
if __name__ == '__main__':    
    main(device)