## Train Script for Bengali AI models
Load in model classes and processing scripts. Run training and evaluation here

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Packages
import torch
import torchvision
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import DataLoader
from torch.autograd import Variable


from model.modelBase import *
from model.wrapperModel import *
from utils.evalUtils import *
from ProcessAndAugment import *

In [3]:
# paths
datadir = "./data"
inputdir= datadir + "/raw"
outputdir= datadir + "/processed"

In [4]:
# Parameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size=128

n_grapheme = 168
n_vowel = 11
n_consonant = 7
n_total = n_grapheme + n_vowel + n_consonant
print('n_total', n_total)

n_total 186


In [5]:
# Model Selection

# core model
predictor = densenet(in_channels=1, out_dim=n_total).to(device)
predictor.requires_grad = True
print('predictor', type(predictor))

# select our wrapper class
classifier = BengaliClassifier(predictor).to(device)
classifier.requires_grad = True

print('classifier',type(classifier))

predictor <class 'model.modelBase.densenet'>
classifier <class 'model.wrapperModel.BengaliClassifier'>


In [6]:
# Model Parameters
epochs = 10
lr = .01 # TODO: starting with flat LR, but need to implement scheduler
bs = 64

optimizer = torch.optim.SGD(classifier.parameters(), lr=lr)


#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#    optimizer, mode='min', factor=0.7, patience=5, min_lr=1e-10)
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)

validate_every = 5 # TODO: validate every n batches or epochs
checkpoint_every = 5 # TODO: implement model checkpoints

## Prep Data
Utilizes our process and data augmentation script

In [7]:
# load train file and generate dataset
train = pd.read_csv(datadir+'/train.csv')
indices = [0] # just set to list of all indices when actually training
dataset, crop_rsz_img = genDataset(indices, inputdir, data_type = "train", train = train) # generates the dataset class

print(dataset.get_example(0))


image_df_list 1
~~Loaded Images~~
~~Standardized Images~~
(array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32), array([15,  9,  5], dtype=int64))


In [8]:
# our weights for the weighted random sampler for each epoch
consonant_weights = genWeightTensor("consonant_diacritic", train[:len(crop_rsz_img)])
root_weights = genWeightTensor("grapheme_root", train[:len(crop_rsz_img)])
vowel_weights = genWeightTensor("vowel_diacritic", train[:len(crop_rsz_img)])
grapheme_weights = genWeightTensor("grapheme", train[:len(crop_rsz_img)])

weights = {"consonant_diacritic": consonant_weights,
           "grapheme_root": root_weights,
           "vowel_diacritic": vowel_weights,
           "grapheme": grapheme_weights}

#weight_keys = list(weights.keys())
# can change the focus of the sampler like so
weight_keys = ['grapheme', 'grapheme_root', 'vowel_diacritic', 'consonant_diacritic', 'grapheme', 'grapheme_root']

## Training

In [25]:
# testing without sampler for now
train_loader = DataLoader(dataset, batch_size=bs, shuffle = True)
num_batches = len(train_loader)

for i, wkey in zip(range(epochs), itertools.cycle(weight_keys)):
    print(i, wkey)
    
    # generate sampler and loader specific to epoch
    wgt_val = weights[wkey]
    #sampler = WeightedRandomSampler(wgt_val, len(wgt_val))
    #train_loader = DataLoader(dataset, batch_size=bs, sampler=sampler)
    
    # init
    predictor.train()
    classifier.train()
    
    # store accuracy results
    acc_root = []
    acc_consonant = []
    acc_vowel = []
    running_loss = []
    
    for j, (images, labels) in enumerate(train_loader):
        images = Variable(images).to(device)
        labels = Variable(labels).to(device)
        
        # reset
        optimizer.zero_grad()
        
        # run model - requires 4d float input
        loss, metrics, pred = classifier(images.unsqueeze(1).float(), labels)
        
        # compute loss and step
        loss.backward()
        
        optimizer.step()
        
        # store metrics
        running_loss.append(loss.to("cpu").numpy())
        acc_root.append(metrics['acc_grapheme'].to("cpu").numpy())
        acc_consonant.append(metrics['acc_consonant'].to("cpu").numpy())
        acc_vowel.append(metrics['acc_vowel'].to("cpu").numpy())
        #print(metrics)
    
    print("Epoch Metrics")
    print(f"Epoch Loss: {np.mean(running_loss)}")
    print(f"grapheme root accuracy: {np.mean(acc_root)}")
    print(f"consonant diacritic accuracy: {np.mean(acc_consonant)}")
    print(f"vowel diacritic accuracy: {np.mean(acc_vowel)}")

0 grapheme
Epoch Metrics
grapheme root accuracy: 0.32192739844322205
consonant diacritic accuracy: 0.9101462960243225
vowel diacritic accuracy: 0.8853043913841248
1 grapheme_root
Epoch Metrics
grapheme root accuracy: 0.4919431209564209
consonant diacritic accuracy: 0.9326587915420532
vowel diacritic accuracy: 0.9258601665496826
2 vowel_diacritic
Epoch Metrics
grapheme root accuracy: 0.6107524633407593
consonant diacritic accuracy: 0.9444764256477356
vowel diacritic accuracy: 0.9446634650230408
3 consonant_diacritic
Epoch Metrics
grapheme root accuracy: 0.6825068593025208
consonant diacritic accuracy: 0.9524654746055603
vowel diacritic accuracy: 0.9572084546089172
4 grapheme
Epoch Metrics
grapheme root accuracy: 0.7354461550712585
consonant diacritic accuracy: 0.9628844857215881
vowel diacritic accuracy: 0.9637100100517273
5 grapheme_root
Epoch Metrics
grapheme root accuracy: 0.7700335383415222
consonant diacritic accuracy: 0.96763676404953
vowel diacritic accuracy: 0.9684997797012329
6