In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import os
import torch
import torch.nn as nn
from dataset import BengaliTrainDataset
from tqdm import tqdm

# Model Configuration 

In [None]:
DEVICE = "cuda"
TRAINING_FOLDS_CSV = "../input/train_folds.csv"
EPOCHS = 10

TRAIN_BATCH_SIZE = 32
TEST_BATCH_SIZE = 8

In [None]:
def loss_fn(outputs, targets):
    o1, o2, o3 = outputs
    t1, t2, t3 = targets
    l1 = nn.CrossEntropyLoss()(o1, t1)
    l2 = nn.CrossEntropyloss()(o2, t2)
    l3 = nn.CrossEntropyloss()(o3, t3)
    return (l1 + l2 + l3) / 3

def train(dataset, data_loader, model, optimizer, device):
    model.train()

    for batch, dataset in tqdm(enumerate(data_loader), total=int(len(dataset)/data_loader.batch_size)):
        image = dataset["image"]
        grapheme_root = dataset["grapheme_root"]
        vowel_diacritic = dataset["vowel_diacritic"]
        consonant_diacritic = dataset["consonant_diacritic"]

        image = image.to(device, dtype = torch.float) # move data to cuda
        grapheme_root = grapheme_root.to(DEVICE, dtype = torch.long)
        vowel_diacritic = vowel_diacritic.to(DEVICE, dtype = torch.long)
        consonant_diacritic = consonant_diacritic.to(DEVICE, dtype = torch.long)

        outputs = model(image)
        targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def evaluate(dataset, data_loader, model):
    model.eval()
    final_loss = 0
    counter = 0
    for batch, dataset in tqdm(enumerate(data_loader), total=int(len(dataset)/data_loader.batch_size)):
        counter = counter + 1
        image = dataset["image"]
        grapheme_root = dataset["grapheme_root"]
        vowel_diacritic = dataset["vowel_diacritic"]
        consonant_diacritic = dataset["consonant_diacritic"]

        image = image.to(DEVICE, dtype = torch.float) # move data to cuda
        grapheme_root = grapheme_root.to(DEVICE, dtype = torch.long)
        vowel_diacritic = vowel_diacritic.to(DEVICE, dtype = torch.long)
        consonant_diacritic = consonant_diacritic.to(DEVICE, dtype = torch.long)

        outputs = model(image)
        targets = (grapheme_root, vowel_diacritic, consonant_diacritic)
        loss = loss_fn(outputs, targets)
        final_loss += loss
    
    return final_loss / counter

def main(training_folds, validation_folds, model, device, train_batch_size, test_batch_size, epochs):
    model.to(device)

    def train_transforms():
        return A.Compose([
            A.Resize(137, 236, always_apply = True),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=5, p=0.9),
            A.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225), always_apply = True)])

    def valid_transforms():
        return A.Compose([
            A.Resize(137, 236, always_apply = True),
            A.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225), always_apply = True)])

    train_dataset = BengaliTrainDataset(folds=training_folds, transforms = train_transforms())
    valid_dataset = BengaliTrainDataset(folds=validation_folds, transforms = valid_transforms())

    trainloader = DataLoader(
    dataset = train_dataset,
    batch_size=train_batch_size,
    shuffle=True,
    num_workers=4
        )

    validloader = DataLoader(
    dataset = train_dataset,
    batch_size=test_batch_size,
    shuffle=False,
    num_workers=4
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optmizer, mode = 'min', patience = 5, factor = 0.3)

    # if torch.cuda.device_count() > 1:
    #     model = nn.DataParallel(model)
    
    for epoch in range(epochs):
        train(train_dataset, trainloader, model, optimizer)
        val_score = evaluate(valid_dataset, validloader, model)
        scheduler.step(val_score)
        torch.save(model.state_dict(), f"Models/resnet_34_fold{VALIDATION_FOLDS[0]}.bin")

# Training

In [None]:
TRAINING_FOLDS = "(0,1,2,3)"
VALIDATION_FOLDS = "(4,)"

In [None]:
TRAINING_FOLDS = "(0,1,2,4)"
VALIDATION_FOLDS = "(3,)"

In [None]:
TRAINING_FOLDS = "(0,1,4,3)"
VALIDATION_FOLDS = "(2,)"

In [None]:
TRAINING_FOLDS = "(0,4,2,3)"
VALIDATION_FOLDS = "(1,)"

In [None]:
TRAINING_FOLDS = "(4,1,2,3)"
VALIDATION_FOLDS = "(0,)"