In [1]:
import os
import random
import warnings
from sklearn.model_selection import KFold

import cv2
import numpy as np
import pandas as pd
import sklearn.metrics
import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from albumentations import Compose, ShiftScaleRotate, Resize
from albumentations.pytorch import ToTensorV2
from os import environ
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import socket

import fastai
from fastai.vision import *
from fastai.callbacks import SaveModelCallback
warnings.filterwarnings("ignore")

In [2]:
import sys
sys.path.append("../..")

In [3]:
from utils.csvlogger import *
from utils.radam import *

#from utils.arguments.train_arguments import *
from models import *
from utils.training import *

In [4]:
sz = 32
bs = 512

nfolds = 4  # keep the same split as the initial dataset
fold = 0
SEED = 2019

if environ.get('BENGALI_DATA_PATH') is not None:
    INPUT_PATH = environ.get('BENGALI_DATA_PATH')
    TRAIN_IMGS = "grapheme-imgs"
    LABELS = INPUT_PATH + "/train.csv"
else:
    assert False, "Please set the environment variable BENGALI_DATA_PATH. Read the README!"

df = pd.read_csv(LABELS)
nunique = list(df.nunique())[1:-1]

stats = ([0.0692], [0.2051])

In [5]:
data = (ImageList.from_df(df, path=INPUT_PATH, folder=TRAIN_IMGS, suffix='.png',
                          cols='image_id', convert_mode='L')
        .split_by_idx(range(fold * len(df) // nfolds, (fold + 1) * len(df) // nfolds))
        .label_from_df(cols=['grapheme_root', 'vowel_diacritic', 'consonant_diacritic'])
        .transform(data_augmentation_selector("da7"), size=sz, padding_mode='zeros')
        .databunch(bs=bs)).normalize(stats)

In [6]:
writer = SummaryWriter(log_dir='test/')

In [7]:
model = model_selector("densenet121", "initial_head", nunique)

# Train

In [8]:
learn = Learner(data, model, loss_func=Loss_combine(), opt_func=Over9000,
                metrics=[Metric_grapheme(), Metric_vowel(), Metric_consonant(), Metric_tot()])

In [9]:
logger = CSVLogger(learn, writer,f'test_log{fold}')
learn.clip_grad = 1.0
#learn.model = nn.DataParallel(learn.model, device_ids=[0, 1])
learn.split([model.head1])
learn.unfreeze()

In [10]:
learn.fit_one_cycle(
    1,
    max_lr=slice(0.2e-2, 1e-2),
    wd=[1e-3, 0.1e-1],
    pct_start=0.0,
    div_factor=100,
    callbacks=[
        logger,
        SaveModelCallback(learn, monitor='metric_tot', mode='max', name=f'test_model_{fold}')
    ]
)

epoch,train_loss,valid_loss,metric_idx,metric_idx.1,metric_idx.2,metric_tot,time
0,12.674604,10.531393,0.036853,0.519552,0.400729,0.248497,01:47


In [11]:
writer.close()

In [15]:
slack_message("[{}] - BENGALI - {}".format(socket.gethostname().upper(), getMetricTot(learn)), "experimentos")