## Train Script for Bengali AI models
Load in model classes and processing scripts. Run training and evaluation here

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Packages
import torch
import torchvision
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from torch.autograd import Variable


from model.modelBase import *
from model.wrapperModel import *
from utils.evalUtils import *
from ProcessAndAugment import *

In [3]:
# paths
datadir = "./data"
inputdir= datadir + "/raw"
outputdir= datadir + "/processed"

In [4]:
# Parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=128

n_grapheme = 168
n_vowel = 11
n_consonant = 7
n_total = n_grapheme + n_vowel + n_consonant
print('n_total', n_total)

n_total 186


In [5]:
# Model Selection

# core model
predictor = densenet(in_channels=1, out_dim=n_total).to(device)
print('predictor', type(predictor))

# select our wrapper class
classifier = BengaliClassifier(predictor).to(device)
classifier.requires_grad = True

print('classifier',type(classifier))

predictor <class 'model.modelBase.densenet'>
classifier <class 'model.wrapperModel.BengaliClassifier'>


In [6]:
# Model Parameters
epochs = 10
lr = .001 # TODO: starting with flat LR, but need to implement scheduler
bs = 128

optimizer = torch.optim.SGD(classifier.parameters(), lr=lr)


# ignoring scheduler for now until we have baseline
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-10)


validate_every = 5 # TODO: validate every n batches or epochs
checkpoint_every = 5 # TODO: implement model checkpoints

## Prep Data
Utilizes our process and data augmentation script

In [7]:
# load train file and generate dataset
train = pd.read_csv(datadir+'/train.csv')
indices = [0] # just set to list of all indices when actually training
dataset, crop_rsz_img = genDataset(indices, inputdir, data_type = "train", train = train) # generates the dataset class

print(dataset.get_example(0))


image_df_list 1
~~Loaded Images~~
~~Standardized Images~~
(array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([15,  9,  5], dtype=int64))


In [8]:
# our weights for the weighted random sampler for each epoch
consonant_weights = genWeightTensor("consonant_diacritic", train[:len(crop_rsz_img)])
root_weights = genWeightTensor("grapheme_root", train[:len(crop_rsz_img)])
vowel_weights = genWeightTensor("vowel_diacritic", train[:len(crop_rsz_img)])
grapheme_weights = genWeightTensor("grapheme", train[:len(crop_rsz_img)])

weights = {"consonant_diacritic": consonant_weights,
           "grapheme_root": root_weights,
           "vowel_diacritic": vowel_weights,
           "grapheme": grapheme_weights}

#weight_keys = list(weights.keys())
# can change the focus of the sampler like so
weight_keys = ['grapheme', 'grapheme_root', 'vowel_diacritic', 'consonant_diacritic', 'grapheme', 'grapheme_root']

## Training

In [None]:
# testing without sampler for now
train_loader = DataLoader(dataset, batch_size=bs, shuffle = True)
num_batches = len(train_loader)

for i, wkey in zip(range(epochs), itertools.cycle(weight_keys)):
    print(i, wkey)
    
    # generate sampler and loader specific to epoch
    wgt_val = weights[wkey]
    #sampler = WeightedRandomSampler(wgt_val, len(wgt_val))
    #train_loader = DataLoader(dataset, batch_size=bs, sampler=sampler)
    
    # init
    predictor.train()
    classifier.train()
    
    for j, (images, labels) in enumerate(train_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        
        # reset
        optimizer.zero_grad()
        
        # run model - requires 4d float input
        loss, metrics, pred = classifier(images.unsqueeze(1).float(), labels)
        
        # compute loss and step
        loss.backward()
        
        optimizer.step()
        
    print(metrics)

0 grapheme
{'loss': 5.858402729034424, 'loss_grapheme': 4.046599388122559, 'loss_vowel': 1.0486302375793457, 'loss_consonant': 0.7631732821464539, 'acc_grapheme': tensor(0.0882, device='cuda:0'), 'acc_vowel': tensor(0.5882, device='cuda:0'), 'acc_consonant': tensor(0.7353, device='cuda:0')}
1 grapheme_root
{'loss': 5.1883864402771, 'loss_grapheme': 3.2126145362854004, 'loss_vowel': 1.2205678224563599, 'loss_consonant': 0.7552042603492737, 'acc_grapheme': tensor(0.2353, device='cuda:0'), 'acc_vowel': tensor(0.5294, device='cuda:0'), 'acc_consonant': tensor(0.6765, device='cuda:0')}
2 vowel_diacritic
{'loss': 4.680876731872559, 'loss_grapheme': 3.026343822479248, 'loss_vowel': 0.9334233999252319, 'loss_consonant': 0.7211093902587891, 'acc_grapheme': tensor(0.2941, device='cuda:0'), 'acc_vowel': tensor(0.6471, device='cuda:0'), 'acc_consonant': tensor(0.7353, device='cuda:0')}
3 consonant_diacritic
{'loss': 4.369133949279785, 'loss_grapheme': 2.6501224040985107, 'loss_vowel': 0.8149975538