## Directory settings

In [1]:
# =======================================================
# Directory settings
# =======================================================
import os
import pandas as pd

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    
TRAIN_PATH = '/home/yuki/RANZCR/input/ranzcr-clip-catheter-line-classification/train'

## Data Loading

In [2]:
train = pd.read_csv('/home/yuki/RANZCR/input/ranzcr-clip-catheter-line-classification/train.csv')
test = pd.read_csv('/home/yuki/RANZCR/input/ranzcr-clip-catheter-line-classification/sample_submission.csv')
train_annotations = pd.read_csv('/home/yuki/RANZCR/input/ranzcr-clip-catheter-line-classification/train_annotations.csv')

# delete suspicious data
train = train[train['StudyInstanceUID'] != '1.2.826.0.1.3680043.8.498.93345761486297843389996628528592497280'].reset_index(drop=True)
train_annotations = train_annotations[train_annotations['StudyInstanceUID'] != '1.2.826.0.1.3680043.8.498.93345761486297843389996628528592497280'].reset_index(drop=True)

## CFG

In [3]:
# ====================================================
# CFG
# ====================================================
class CFG:
    debug=False
    device='TPU' # ['TPU', 'GPU']
    nprocs=1 # [1, 8]
    print_freq=100
    num_workers=4
    model_name='resnet200d_320'
    size=640
    scheduler='CosineAnnealingWarmRestarts' # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    teacher='/home/yuki/RANZCR/input/005-training-resnext-step1-data/resnet200d_320_fold0_best_loss_cpu.pth'
    student='/home/yuki/RANZCR/nb/nb006/ver22/resnet200d_320_fold0_best_loss.pth'
    epochs=5
    #factor=0.2 # ReduceLROnPlateau
    #patience=4 # ReduceLROnPlateau
    #eps=1e-6 # ReduceLROnPlateau
    T_max=5 # CosineAnnealingLR
    T_0=5 # CosineAnnealingWarmRestarts
    lr=2e-5 # 1e-4
    min_lr=1e-6
    batch_size=16 # 64
    weight_decay=1e-6
    gradient_accumulation_steps=1
    max_grad_norm=1000
    seed=416
    target_size=11
    target_cols=['ETT - Abnormal', 'ETT - Borderline', 'ETT - Normal',
                 'NGT - Abnormal', 'NGT - Borderline', 'NGT - Incompletely Imaged', 'NGT - Normal', 
                 'CVC - Abnormal', 'CVC - Borderline', 'CVC - Normal',
                 'Swan Ganz Catheter Present']
    n_fold=5
    trn_fold=[1] # [0, 1, 2, 3, 4]
    train=True
    
if CFG.debug:
    CFG.epochs = 3
    train = train.sample(n=3000, random_state=CFG.seed).reset_index(drop=True)

In [4]:
if CFG.device == 'TPU':
    import os
    os.system('curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py')
    os.system('python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev')
    os.system('export XLA_USE_BF16=1')
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    import torch_xla.distributed.xla_multiprocessing as xmp
    CFG.lr = CFG.lr * CFG.nprocs
    CFG.batch_size = CFG.batch_size // CFG.nprocs

## Library

In [5]:
# ====================================================
# Library
# ====================================================
import sys
sys.path.append('/home/yuki/RANZCR/input/pytorch-image-models/pytorch-image-models-master')

import os
import ast
import copy
import math
import time
import random
import shutil
from pathlib import Path
from contextlib import contextmanager
from collections import defaultdict, Counter

import scipy as sp
import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.metrics import roc_auc_score
from sklearn.utils import check_random_state
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, GroupKFold, KFold

from tqdm.auto import tqdm
from functools import partial

import cv2
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import torchvision.models as models
from torch.nn.parameter import Parameter
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau

from albumentations import (
    Compose, OneOf, Normalize, Resize, RandomResizedCrop, RandomCrop, HorizontalFlip, VerticalFlip, 
    RandomBrightness, RandomContrast, RandomBrightnessContrast, Rotate, ShiftScaleRotate, Cutout, 
    IAAAdditiveGaussianNoise, Transpose, HueSaturationValue, CoarseDropout
    )
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import timm

if CFG.device == 'TPU':
    import ignite.distributed as idist
elif CFG.device == 'GPU':
    from torch.cuda.amp import autocast, GradScaler

import warnings 
warnings.filterwarnings('ignore')

## Utils

In [6]:
# ====================================================
# Utils
# ====================================================
def get_score(y_true, y_pred):
    scores = []
    for i in range(y_true.shape[1]):
        score = roc_auc_score(y_true[:, i], y_pred[:, i])
        scores.append(score)
    avg_score = np.mean(scores)
    return avg_score, scores

@contextmanager
def timer(name):
    t0 = time.time()
    LOGGER.info(f'[{name}] start')
    yield
    LOGGER.info(f'[{name}] done in {time.time() - t0:.0f}')
    
def init_logger(log_file=OUTPUT_DIR+'train.log'):
    from logging import getLogger, INFO, FileHandler,  Formatter,  StreamHandler
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=log_file)
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = init_logger()

def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

## CV splits

In [7]:
folds = train.copy()
Fold = GroupKFold(n_splits=CFG.n_fold)
groups = folds['PatientID'].values
for n, (train_index, val_index) in enumerate(Fold.split(folds, folds[CFG.target_cols], groups)):
    folds.loc[val_index, 'fold'] = int(n)
folds['fold'] = folds['fold'].astype(int)
print(folds.groupby('fold').size())

fold
0    6017
1    6017
2    6016
3    6016
4    6016
dtype: int64

In [8]:
# ====================================================
# Dataset
# ====================================================
COLOR_MAP = {'ETT - Abnormal': (255, 0, 0),
             'ETT - Borderline': (0, 255, 0),
             'ETT - Normal': (0, 0, 255),
             'NGT - Abnormal': (255, 255, 0),
             'NGT - Borderline': (255, 0, 255),
             'NGT - Incompletely Imaged': (0, 255, 255),
             'NGT - Normal': (128, 0, 0),
             'CVC - Abnormal': (0, 128, 0),
             'CVC - Borderline': (0, 0, 128),
             'CVC - Normal': (128, 128, 0),
             'Swan Ganz Catheter Present': (128, 0, 128),
            }

class TrainDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.file_names = df['StudyInstanceUID'].values
        self.labels = df[CFG.target_cols].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        file_path = f'{TRAIN_PATH}/{file_name}.jpg'
        image = cv2.imread(file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']
        label = torch.tensor(self.labels[idx]).float()
        return image, label

## Transforms

In [9]:
# ====================================================
# Transforms
# ====================================================
def get_transforms(*, data):
    
    if data == 'train':
        return Compose([
            #Resize(CFG.size, CFG.size),
            RandomResizedCrop(CFG.size, CFG.size, scale=(0.85, 1.0)),
            HorizontalFlip(p=0.5),
            RandomBrightnessContrast(p=0.2, brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2)),
            HueSaturationValue(p=0.2, hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2),
            ShiftScaleRotate(p=0.2, shift_limit=0.0625, scale_limit=0.2, rotate_limit=20),
            CoarseDropout(p=0.2),
            Cutout(p=0.2, max_h_size=16, max_w_size=16, fill_value=(0., 0., 0.), num_holes=16),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),
        ])
    
    elif data == 'valid':
        return Compose([
            Resize(CFG.size, CFG.size),
            Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
            ToTensorV2(),

        ])

## MODEL

In [11]:
# ====================================================
# MODEL
# ====================================================
class CustomResNet200D(nn.Module):
    def __init__(self, model_name='resnet200d_320', pretrained=False):
        super().__init__()
        self.model = timm.create_model(model_name, pretrained=False)
        if pretrained:
            pretrained_path = '../input/resnet200d-pretrained-weight/resnet200d_ra2-bdba9bf9.pth'
            self.model.load_state_dict(torch.load(pretrained_path))
            print(f'load {model_name} pretrained model')
        n_features = self.model.fc.in_features
        self.model.global_pool = nn.Identity()
        self.model.fc = nn.Identity()
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(n_features, CFG.target_size)
        
    def forward(self, x):
        bs = x.size(0)
        features = self.model(x)
        pooled_features = self.pooling(features).view(bs, -1)
        output = self.fc(pooled_features)
        return features, pooled_features, output

## Helper functions

In [12]:
# ====================================================
# Helper functions
# ====================================================
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    if CFG.device == 'GPU':
        scaler = GradScaler()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to train mode
    model.train()
    start = end = time.time()
    global_step = 0
    for step, (images, labels) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        if CFG.device == 'GPU':
            with autocast():
                _, _, y_preds = model(images)
                loss = criterion(y_preds, labels)
                # record loss
                losses.update(loss.item(), batch_size)
                if CFG.gradient_accumulation_steps > 1:
                    loss = loss / CFG.gradient_accumulation_steps
                scaler.scale(loss).backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
                if (step + 1) % CFG.gradient_accumulation_steps == 0:
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    global_step += 1
                    
        elif CFG.device == 'TPU':
            _, _, y_preds = model(images)
            loss = criterion(y_preds, labels)
            # record loss
            losses.update(loss.item(), batch_size)
            if CFG.gradient_accumulation_steps > 1:
                loss = loss / CFG.gradient_accumulation_steps
            loss.backward()
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
            if (step + 1) % CFG.gradient_accumulation_steps == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()
                global_step += 1
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
                print('Epoch: [{0}][{1}/{2}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      'Grad: {grad_norm:.4f}  '
                      #'LR: {lr:.6f}  '
                      .format(
                       epoch+1, step, len(train_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(train_loader)),
                       grad_norm=grad_norm,
                       #lr=scheduler.get_lr()[0],
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
                xm.master_print('Epoch: [{0}][{1}/{2}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                'Grad: {grad_norm:.4f}  '
                                #'LR: {lr:.6f}  '
                                .format(
                                epoch+1, step, len(train_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(train_loader)),
                                grad_norm=grad_norm,
                                #lr=scheduler.get_lr()[0],
                                ))
    return losses.avg

def valid_fn(valid_loader, model, criterion, device):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # switch to evaluation mode
    model.eval()
    trues = []
    preds = []
    start = end = time.time()
    for step, (images, labels) in enumerate(valid_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        images = images.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        # compute loss
        with torch.no_grad():
            _, _, y_preds = model(images)
        loss = criterion(y_preds, labels)
        losses.update(loss.item(), batch_size)
        # record accuracy
        trues.append(labels.to('cpu').numpy())
        preds.append(y_preds.sigmoid().to('cpu').numpy())
        if CFG.gradient_accumulation_steps > 1:
            loss = loss / CFG.gradient_accumulation_steps
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if CFG.device == 'GPU':
            if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
                print('EVAL: [{0}/{1}] '
                      'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                      'Elapsed {remain:s} '
                      'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                      .format(
                       step, len(valid_loader), batch_time=batch_time,
                       data_time=data_time, loss=losses,
                       remain=timeSince(start, float(step+1)/len(valid_loader)),
                       ))
        elif CFG.device == 'TPU':
            if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
                xm.master_print('EVAL: [{0}/{1}] '
                                'Data {data_time.val:.3f} ({data_time.avg:.3f}) '
                                'Elapsed {remain:s} '
                                'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                                .format(
                                step, len(valid_loader), batch_time=batch_time,
                                data_time=data_time, loss=losses,
                                remain=timeSince(start, float(step+1)/len(valid_loader)),
                                ))
    trues = np.concatenate(trues)
    predictions = np.concatenate(preds)
    return losses.avg, predictions, trues

## Train loop

In [13]:
# ====================================================
# Train loop
# ====================================================
def train_loop(folds, fold):

    if CFG.device == 'GPU':
        LOGGER.info(f"========== fold: {fold} training ==========")
    elif CFG.device == 'TPU':
        if CFG.nprocs == 1:
            LOGGER.info(f"========== fold: {fold} training ==========")
        elif CFG.nprocs == 8:
            xm.master_print(f"========== fold: {fold} training ==========")
            
    # ====================================================
    # loader
    # ====================================================
    trn_idx = folds[folds['fold'] != fold].index
    val_idx = folds[folds['fold'] == fold].index

    train_folds = folds.loc[trn_idx].reset_index(drop=True)
    valid_folds = folds.loc[val_idx].reset_index(drop=True)
    valid_labels = valid_folds[CFG.target_cols].values

    train_dataset = TrainDataset(train_folds, 
                                 transform=get_transforms(data='train'))
    valid_dataset = TrainDataset(valid_folds, 
                                 transform=get_transforms(data='valid'))
    
    if CFG.device == 'GPU':
        train_loader = DataLoader(train_dataset, 
                                  batch_size=CFG.batch_size, 
                                  shuffle=True, 
                                  num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
        valid_loader = DataLoader(valid_dataset, 
                                  batch_size=CFG.batch_size * 2, 
                                  shuffle=False, 
                                  num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
        
    elif CFG.device == 'TPU':
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=True)
        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=CFG.batch_size,
                                                   sampler=train_sampler,
                                                   drop_last=True,
                                                   num_workers=CFG.num_workers)
        
        valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset,
                                                                        num_replicas=xm.xrt_world_size(),
                                                                        rank=xm.get_ordinal(),
                                                                        shuffle=False)
        valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                                   batch_size=CFG.batch_size * 2,
                                                   sampler=valid_sampler,
                                                   drop_last=False,
                                                   num_workers=CFG.num_workers)
        
    # ====================================================
    # scheduler 
    # ====================================================
    def get_scheduler(optimizer):
        if CFG.scheduler=='ReduceLROnPlateau':
            scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=CFG.factor, patience=CFG.patience, verbose=True, eps=CFG.eps)
        elif CFG.scheduler=='CosineAnnealingLR':
            scheduler = CosineAnnealingLR(optimizer, T_max=CFG.T_max, eta_min=CFG.min_lr, last_epoch=-1)
        elif CFG.scheduler=='CosineAnnealingWarmRestarts':
            scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG.T_0, T_mult=1, eta_min=CFG.min_lr, last_epoch=-1)
        return scheduler
    
    # ====================================================
    # model & optimizer
    # ====================================================
    if CFG.device == 'TPU':
        device = xm.xla_device()
    elif CFG.device == 'GPU':
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = CustomResNet200D(CFG.model_name, pretrained=False)
    model.load_state_dict(torch.load(CFG.student, map_location=torch.device('cpu'))['model'])
    model.to(device)

    optimizer = Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay, amsgrad=False)
    scheduler = get_scheduler(optimizer)
    
    # ====================================================
    # loop
    # ====================================================
    criterion = nn.BCEWithLogitsLoss()

    best_score = 0.
    best_loss = np.inf
    
    for epoch in range(CFG.epochs):
        
        start_time = time.time()
        
        # train
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
            elif CFG.nprocs == 8:
                para_train_loader = pl.ParallelLoader(train_loader, [device])
                avg_loss = train_fn(para_train_loader.per_device_loader(device), model, criterion, optimizer, epoch, scheduler, device)
        elif CFG.device == 'GPU':
            avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        
        # eval
        if CFG.device == 'TPU':
            if CFG.nprocs == 1:
                avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device)
            elif CFG.nprocs == 8:
                para_valid_loader = pl.ParallelLoader(valid_loader, [device])
                avg_val_loss, preds, valid_labels = valid_fn(para_valid_loader.per_device_loader(device), model, criterion, device)
                preds = idist.all_gather(torch.tensor(preds)).to('cpu').numpy()
                valid_labels = idist.all_gather(torch.tensor(valid_labels)).to('cpu').numpy()
        elif CFG.device == 'GPU':
            avg_val_loss, preds, _ = valid_fn(valid_loader, model, criterion, device)
            
        if isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(avg_val_loss)
        elif isinstance(scheduler, CosineAnnealingLR):
            scheduler.step()
        elif isinstance(scheduler, CosineAnnealingWarmRestarts):
            scheduler.step()
            
        # scoring
        score, scores = get_score(valid_labels, preds)

        elapsed = time.time() - start_time
        
        if CFG.device == 'GPU':
            LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
            LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
        elif CFG.device == 'TPU':
            if CFG.nprocs == 1:
                LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                LOGGER.info(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
            elif CFG.nprocs == 8:
                xm.master_print(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  avg_val_loss: {avg_val_loss:.4f}  time: {elapsed:.0f}s')
                xm.master_print(f'Epoch {epoch+1} - Score: {score:.4f}  Scores: {np.round(scores, decimals=4)}')
        
        if score > best_score:
            best_score = score
            if CFG.device == 'GPU':
                LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                torch.save({'model': model.state_dict(), 
                            'preds': preds},
                           OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                elif CFG.nprocs == 8:
                    xm.master_print(f'Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model')
                xm.save({'model': model, 
                         'preds': preds}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
                
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if CFG.device == 'GPU':
                LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                torch.save({'model': model.state_dict(), 
                            'preds': preds},
                           OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')
            elif CFG.device == 'TPU':
                if CFG.nprocs == 1:
                    LOGGER.info(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                elif CFG.nprocs == 8:
                    xm.master_print(f'Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model')
                xm.save({'model': model, 
                         'preds': preds}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')
        
        # inference用に全て保存しておく
        if CFG.device == 'TPU':
            xm.save({'model': model}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
        elif CFG.device == 'GPU':
            torch.save({'model': model.state_dict()}, OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
        
        if CFG.nprocs != 8:
            check_point = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
            for c in [f'pred_{c}' for c in CFG.target_cols]:
                valid_folds[c] = np.nan
            valid_folds[[f'pred_{c}' for c in CFG.target_cols]] = check_point['preds']

    return valid_folds

In [14]:
# ====================================================
# main
# ====================================================
def main():

    """
    Prepare: 1.train  2.folds
    """

    def get_result(result_df):
        preds = result_df[[f'pred_{c}' for c in CFG.target_cols]].values
        labels = result_df[CFG.target_cols].values
        score, scores = get_score(labels, preds)
        LOGGER.info(f'Score: {score:<.4f}  Scores: {np.round(scores, decimals=4)}')
        
    if CFG.train:
        # train 
        oof_df = pd.DataFrame()
        for fold in range(CFG.n_fold):
            if fold in CFG.trn_fold:
                _oof_df = train_loop(folds, fold)
                oof_df = pd.concat([oof_df, _oof_df])
                if CFG.nprocs != 8:
                    LOGGER.info(f"========== fold: {fold} result ==========")
                    get_result(_oof_df)
                    
        if CFG.nprocs != 8:
            # CV result
            LOGGER.info(f"========== CV ==========")
            get_result(oof_df)
            # save result
            oof_df.to_csv(OUTPUT_DIR+'oof_df.csv', index=False)

In [15]:
if __name__ == '__main__':
    if CFG.device == 'TPU':
        def _mp_fn(rank, flags):
            torch.set_default_tensor_type('torch.FloatTensor')
            a = main()
        FLAGS = {}
        xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=CFG.nprocs, start_method='fork')
    elif CFG.device == 'GPU':
        main()



Epoch: [1][0/1504] Data 1.994 (1.994) Elapsed 1m 19s (remain 1986m 53s) Loss: 0.0713(0.0713) Grad: 0.5428  
Epoch: [1][100/1504] Data 0.019 (0.046) Elapsed 4m 59s (remain 69m 24s) Loss: 0.1233(0.1285) Grad: 0.6004  
Epoch: [1][200/1504] Data 0.017 (0.037) Elapsed 7m 19s (remain 47m 28s) Loss: 0.1473(0.1196) Grad: 0.7587  
Epoch: [1][300/1504] Data 0.018 (0.034) Elapsed 9m 36s (remain 38m 25s) Loss: 0.1413(0.1171) Grad: 0.8039  
Epoch: [1][400/1504] Data 0.016 (0.032) Elapsed 11m 53s (remain 32m 43s) Loss: 0.0992(0.1176) Grad: 0.6312  
Epoch: [1][500/1504] Data 0.016 (0.031) Elapsed 14m 10s (remain 28m 23s) Loss: 0.0748(0.1173) Grad: 0.4402  
Epoch: [1][600/1504] Data 0.018 (0.030) Elapsed 16m 27s (remain 24m 44s) Loss: 0.0723(0.1168) Grad: 0.6617  
Epoch: [1][700/1504] Data 0.018 (0.030) Elapsed 18m 48s (remain 21m 32s) Loss: 0.0589(0.1173) Grad: 0.3692  
Epoch: [1][800/1504] Data 0.018 (0.030) Elapsed 21m 6s (remain 18m 31s) Loss: 0.1049(0.1174) Grad: 0.5952  
Epoch: [1][900/1504] Dat

Epoch 1 - avg_train_loss: 0.1148  avg_val_loss: 0.1167  time: 2576s
Epoch 1 - Score: 0.9677  Scores: [0.9804 0.9679 0.9932 0.9709 0.9769 0.9876 0.9889 0.949  0.8989 0.9312
 0.9995]
Epoch 1 - Save Best Score: 0.9677 Model


EVAL: [188/189] Data 0.025 (0.053) Elapsed 5m 39s (remain 0m 0s) Loss: 0.5681(0.1167) 


Epoch 1 - Save Best Loss: 0.1167 Model


Epoch: [2][0/1504] Data 2.008 (2.008) Elapsed 1m 43s (remain 2604m 28s) Loss: 0.0456(0.0456) Grad: 0.3178  
Epoch: [2][100/1504] Data 0.016 (0.048) Elapsed 4m 4s (remain 56m 38s) Loss: 0.1274(0.1152) Grad: 0.7037  
Epoch: [2][200/1504] Data 0.019 (0.038) Elapsed 6m 28s (remain 42m 0s) Loss: 0.1434(0.1083) Grad: 0.7529  
Epoch: [2][300/1504] Data 0.018 (0.035) Elapsed 8m 48s (remain 35m 12s) Loss: 0.1697(0.1078) Grad: 0.6808  
Epoch: [2][400/1504] Data 0.017 (0.033) Elapsed 11m 8s (remain 30m 39s) Loss: 0.1157(0.1085) Grad: 0.7742  
Epoch: [2][500/1504] Data 0.018 (0.032) Elapsed 13m 28s (remain 26m 58s) Loss: 0.0750(0.1086) Grad: 0.4279  
Epoch: [2][600/1504] Data 0.019 (0.032) Elapsed 15m 52s (remain 23m 50s) Loss: 0.0637(0.1086) Grad: 0.4494  
Epoch: [2][700/1504] Data 0.019 (0.031) Elapsed 18m 11s (remain 20m 50s) Loss: 0.0650(0.1093) Grad: 0.3652  
Epoch: [2][800/1504] Data 0.020 (0.031) Elapsed 20m 31s (remain 18m 1s) Loss: 0.0808(0.1097) Grad: 0.4941  
Epoch: [2][900/1504] Data 0

Epoch 2 - avg_train_loss: 0.1077  avg_val_loss: 0.1185  time: 2524s
Epoch 2 - Score: 0.9679  Scores: [0.9859 0.9675 0.9936 0.9718 0.9753 0.9872 0.9881 0.9499 0.8976 0.9305
 0.9994]
Epoch 2 - Save Best Score: 0.9679 Model


EVAL: [188/189] Data 0.026 (0.054) Elapsed 4m 57s (remain 0m 0s) Loss: 0.2353(0.1185) 
Epoch: [3][0/1504] Data 2.138 (2.138) Elapsed 0m 3s (remain 93m 12s) Loss: 0.0321(0.0321) Grad: 0.2430  
Epoch: [3][100/1504] Data 0.017 (0.048) Elapsed 2m 23s (remain 33m 17s) Loss: 0.1032(0.1108) Grad: 0.8187  
Epoch: [3][200/1504] Data 0.020 (0.039) Elapsed 4m 47s (remain 31m 2s) Loss: 0.2013(0.1052) Grad: 1.0057  
Epoch: [3][300/1504] Data 0.018 (0.035) Elapsed 7m 7s (remain 28m 27s) Loss: 0.1666(0.1030) Grad: 0.7161  
Epoch: [3][400/1504] Data 0.019 (0.034) Elapsed 9m 27s (remain 26m 1s) Loss: 0.0839(0.1036) Grad: 0.5545  
Epoch: [3][500/1504] Data 0.019 (0.032) Elapsed 11m 47s (remain 23m 36s) Loss: 0.0596(0.1038) Grad: 0.3640  
Epoch: [3][600/1504] Data 0.029 (0.032) Elapsed 14m 9s (remain 21m 16s) Loss: 0.0873(0.1036) Grad: 0.7952  
Epoch: [3][700/1504] Data 0.020 (0.031) Elapsed 16m 29s (remain 18m 53s) Loss: 0.0579(0.1040) Grad: 0.4743  
Epoch: [3][800/1504] Data 0.018 (0.031) Elapsed 18m 4

Epoch 3 - avg_train_loss: 0.1022  avg_val_loss: 0.1193  time: 2421s
Epoch 3 - Score: 0.9677  Scores: [0.9848 0.969  0.9934 0.9719 0.9752 0.9873 0.988  0.948  0.8965 0.9309
 0.9996]


EVAL: [188/189] Data 0.027 (0.053) Elapsed 5m 0s (remain 0m 0s) Loss: 0.3385(0.1193) 
Epoch: [4][0/1504] Data 2.069 (2.069) Elapsed 0m 3s (remain 87m 43s) Loss: 0.0356(0.0356) Grad: 0.3557  
Epoch: [4][100/1504] Data 0.017 (0.047) Elapsed 2m 22s (remain 33m 3s) Loss: 0.1101(0.1061) Grad: 0.8982  
Epoch: [4][200/1504] Data 0.017 (0.038) Elapsed 4m 45s (remain 30m 51s) Loss: 0.1318(0.0997) Grad: 0.6972  
Epoch: [4][300/1504] Data 0.019 (0.034) Elapsed 7m 4s (remain 28m 17s) Loss: 0.1474(0.0980) Grad: 0.8164  
Epoch: [4][400/1504] Data 0.020 (0.033) Elapsed 9m 24s (remain 25m 52s) Loss: 0.0850(0.0985) Grad: 0.5454  
Epoch: [4][500/1504] Data 0.019 (0.032) Elapsed 11m 44s (remain 23m 29s) Loss: 0.0628(0.0988) Grad: 0.4476  
Epoch: [4][600/1504] Data 0.019 (0.031) Elapsed 14m 3s (remain 21m 7s) Loss: 0.0619(0.0981) Grad: 0.6097  
Epoch: [4][700/1504] Data 0.020 (0.031) Elapsed 16m 26s (remain 18m 49s) Loss: 0.0560(0.0986) Grad: 0.4391  
Epoch: [4][800/1504] Data 0.020 (0.031) Elapsed 18m 46

Epoch 4 - avg_train_loss: 0.0966  avg_val_loss: 0.1202  time: 2428s
Epoch 4 - Score: 0.9673  Scores: [0.9857 0.9689 0.9934 0.9704 0.9749 0.9872 0.9881 0.9467 0.8958 0.9301
 0.9994]


EVAL: [188/189] Data 0.028 (0.060) Elapsed 5m 17s (remain 0m 0s) Loss: 0.3220(0.1202) 
Epoch: [5][0/1504] Data 2.093 (2.093) Elapsed 0m 4s (remain 108m 41s) Loss: 0.0286(0.0286) Grad: 0.3638  
Epoch: [5][100/1504] Data 0.018 (0.048) Elapsed 2m 27s (remain 34m 12s) Loss: 0.1217(0.1029) Grad: 0.9825  
Epoch: [5][200/1504] Data 0.018 (0.037) Elapsed 4m 49s (remain 31m 19s) Loss: 0.1780(0.0956) Grad: 0.8052  
Epoch: [5][300/1504] Data 0.021 (0.035) Elapsed 7m 16s (remain 29m 2s) Loss: 0.1412(0.0932) Grad: 1.2147  
Epoch: [5][400/1504] Data 0.019 (0.033) Elapsed 9m 38s (remain 26m 30s) Loss: 0.0886(0.0938) Grad: 0.6692  
Epoch: [5][500/1504] Data 0.019 (0.032) Elapsed 12m 0s (remain 24m 3s) Loss: 0.0704(0.0942) Grad: 0.5629  
Epoch: [5][600/1504] Data 0.019 (0.032) Elapsed 14m 23s (remain 21m 37s) Loss: 0.0655(0.0942) Grad: 0.8421  
Epoch: [5][700/1504] Data 0.019 (0.032) Elapsed 16m 49s (remain 19m 16s) Loss: 0.0440(0.0949) Grad: 0.3364  
Epoch: [5][800/1504] Data 0.019 (0.031) Elapsed 19m

Epoch 5 - avg_train_loss: 0.0921  avg_val_loss: 0.1189  time: 2475s
Epoch 5 - Score: 0.9677  Scores: [0.9871 0.9689 0.9933 0.9706 0.9752 0.9871 0.9881 0.9473 0.8964 0.9316
 0.9994]


EVAL: [188/189] Data 0.028 (0.059) Elapsed 5m 13s (remain 0m 0s) Loss: 0.3016(0.1189) 


Score: 0.9679  Scores: [0.9859 0.9675 0.9936 0.9718 0.9753 0.9872 0.9881 0.9499 0.8976 0.9305
 0.9994]
Score: 0.9679  Scores: [0.9859 0.9675 0.9936 0.9718 0.9753 0.9872 0.9881 0.9499 0.8976 0.9305
 0.9994]


In [16]:
# save as cpu
if CFG.device == 'TPU':
    for fold in range(CFG.n_fold):
        if fold in CFG.trn_fold:
            # best score
            state = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score.pth')
            torch.save({'model': state['model'].to('cpu').state_dict(), 
                        'preds': state['preds']}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_score_cpu.pth')
            # best loss
            state = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss.pth')
            torch.save({'model': state['model'].to('cpu').state_dict(), 
                        'preds': state['preds']}, 
                        OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_best_loss_cpu.pth')
            
            for epoch in range(CFG.epochs):
                state = torch.load(OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_epoch{epoch+1}.pth')
                torch.save({'model': state['model'].to('cpu').state_dict()}, 
                            OUTPUT_DIR+f'{CFG.model_name}_fold{fold}_{epoch+1}_cpu.pth')