# Models for disbalanced datasets
Current notebook codes model training to predict heavily disbalanced joint type datasets (>.5 of all data belongs to one class).


In [None]:
import os
import torch
import numpy as np
import pandas as pd
import torchvision
import matplotlib.pyplot as plt
import copy
import shutil

from tqdm import tqdm
from torchvision.transforms import v2
from torchvision import models
from sklearn.metrics import accuracy_score, f1_score

In [None]:
def copy_data(main_class: str, inp_path: str, out_path: str) -> None:
    '''
    For groups with disbalanced classes.
    Create folders and copy data from minor classes to make minor model DataLoaders using ImageFolder.

    Args:
        main_class (str)
        inp_path (str): path to input data
        out_path (str): path to output data
    '''
    for cl in os.listdir(inp_path):
        if cl == main_class:
            continue
        os.makedirs(os.path.join(out_path, cl), exist_ok=True)
        for _, file_name in enumerate(os.listdir(os.path.join(inp_path, cl))):
            shutil.copy(os.path.join(inp_path, cl, file_name), os.path.join(out_path, cl, file_name))

def get_dataloaders(joint_type_and_param: str, main_class: str):
    '''
    Make 4 dataloaders for both binary and minor model
    '''
    # [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] Normalization values for RESNet & EfficientNet
    mean_list, std_list = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]

    train_transform = v2.Compose([
        v2.Resize((224, 224)),
        v2.RandomRotation(15),
        v2.RandomHorizontalFlip(p = 0.3),
        v2.RandomVerticalFlip(p = 0.3),
        v2.ToTensor(),
        v2.Normalize(mean_list, std_list)
    ])

    val_transform = v2.Compose([
        v2.Resize((224, 224)),
        v2.RandomRotation(15),
        v2.ToTensor(),
        v2.Normalize(mean_list, std_list)
    ])

    '''data_root = os.path.join('dataset', 'json_split')
    joint_type_and_param = 'DIP_jsn' # data for DIP joints, predicting JSN score
    train, val = os.path.join(data_root, 'train', joint_type_and_param), os.path.join(data_root, 'test', joint_type_and_param) '''

    data_root = os.path.join('dataset', 'custom_split_inv_clahe')
    train, val = os.path.join(data_root, joint_type_and_param, 'train'), os.path.join(data_root, joint_type_and_param, 'test')

    n_classes = max(len(os.listdir(train)), len(os.listdir(val)))

    train_minor, val_minor = os.path.join(data_root, f'{joint_type_and_param}_minor_classes', 'train'), \
                            os.path.join(data_root, f'{joint_type_and_param}_minor_classes', 'test')
    copy_data(main_class, train, train_minor)
    copy_data(main_class, val, val_minor)

    train_dataset_bin = torchvision.datasets.ImageFolder(train, train_transform)
    val_dataset_bin = torchvision.datasets.ImageFolder(val, val_transform)

    train_dataset_minor = torchvision.datasets.ImageFolder(train_minor, train_transform)
    val_dataset_minor = torchvision.datasets.ImageFolder(val_minor, val_transform)

    batch_size = 64
    train_dataloader_bin = torch.utils.data.DataLoader(
        train_dataset_bin, batch_size=batch_size, shuffle=True, num_workers=8)
    val_dataloader_bin = torch.utils.data.DataLoader(
        val_dataset_bin, batch_size=batch_size, shuffle=False, num_workers=8)

    train_dataloader_minor = torch.utils.data.DataLoader(
        train_dataset_minor, batch_size=batch_size, shuffle=True, num_workers=8)
    val_dataloader_minor = torch.utils.data.DataLoader(
        val_dataset_minor, batch_size=batch_size, shuffle=False, num_workers=8)

    print(len(train_dataloader_bin), len(val_dataloader_bin), len(train_dataloader_minor), len(val_dataloader_minor))
    return train_dataloader_bin, val_dataloader_bin, train_dataloader_minor, val_dataloader_minor, n_classes

In [2]:
def show_input(input_tensor, title=''):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array( [0.229, 0.224, 0.225])

    image = input_tensor.permute(1, 2, 0).numpy()
    image = std * image + mean
    plt.imshow(image.clip(0, 1))
    plt.title(title)
    plt.show()
    plt.pause(0.001)

def visualise_batch(train_dataloader):
    X_batch, y_batch = next(iter(train_dataloader))

    for x_item, y_item in zip(X_batch, y_batch):
        show_input(x_item, title={y_item})

In [3]:
def out_transform(labels, output_transform):
    tr_type, majour_class = output_transform.split('_')
    majour_class = int(majour_class)
    if tr_type == 'bin':
        for i in range(labels.shape[0]):
            labels[i] = 1 if labels[i] == majour_class else 0
    elif tr_type == 'minor':
        for i in range(labels.shape[0]):
            labels[i] = labels[i] - 1  if labels[i] > majour_class else labels[i]
    else:
        raise ValueError(f'Invalid transform type {tr_type}')
    return labels

def train_model(model, train_dataloader, val_dataloader, loss, optimizer, num_epochs, device, output_transform = None):
    for epoch in tqdm(range(num_epochs)):
        if epoch == 20:
            for param in model.parameters():
                param.requires_grad = True

        for phase in ['train', 'val']:
            if phase == 'train':
                dataloader = train_dataloader
                #scheduler.step()
                model.train()
            else:
                dataloader = val_dataloader
                model.eval()

            running_loss = 0.
            running_acc = 0.

            y_preds, y_trues = [], []
            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    preds = model(inputs)
                    if output_transform:
                       labels = out_transform(labels, output_transform)
                    loss_value = loss(preds, labels)
                    distance_weight = torch.abs(preds.argmax(1) - labels) + 1
                    ordinal_ce_loss = torch.mean(distance_weight * loss_value)
                    preds_class = preds.argmax(dim=1)

                    if phase == 'train':
                        ordinal_ce_loss.backward()
                        optimizer.step()
                    else:
                        y_preds.extend(preds_class.detach().cpu().tolist())
                        y_trues.extend(labels.detach().cpu().tolist())

                running_loss += loss_value.item()
                running_acc += (preds_class == labels.data).float().mean()

            epoch_loss = running_loss / len(dataloader)
            epoch_acc = running_acc / len(dataloader)
            if phase == 'train':
                train_acc = epoch_acc
            else:
                val_acc = epoch_acc
                val_loss = epoch_loss

        #if epoch % 10 == 0:
         #   print('Epoch {},\n    train accuracy: {:.4f},\n    val accuracy: {:.4f}, f1: {:.4f}'.format(epoch+1, train_acc, val_acc, \
          #                                                                                               f1_score(y_trues, y_preds, average='micro')), flush = True)

        if epoch == 0:
            best_loss = val_loss
            best_model = model ########
        elif val_loss < best_loss:
            best_loss = val_loss
            best_model = copy.deepcopy(model)

    #print('Epoch {},\n    train accuracy: {:.4f},\n    val accuracy: {:.4f}, f1: {:.4f}'.format(epoch+1, train_acc, val_acc, \
     #                                                                                           f1_score(y_trues, y_preds, average='micro')), flush = True)
    return best_model

In [4]:
def test_time_augmentations(preds, inputs, model):
    inp_rot_p15 = v2.functional.rotate(inputs, 10)
    inp_rot_m15 = v2.functional.rotate(inputs, -10)

    with torch.no_grad():
        preds += model(v2.functional.horizontal_flip(inputs))
        preds += model(v2.functional.horizontal_flip(inp_rot_p15))
        preds += model(v2.functional.horizontal_flip(inp_rot_m15))
        preds += model(inp_rot_p15)
        preds += model(inp_rot_m15)
    return preds

def get_preds(preds, inputs, majour_class, minor_model):
    majour_class = int(majour_class)
    with torch.no_grad():
        y_pred = torch.empty((preds.shape[0],))
        for i, item in enumerate(inputs):
            if torch.argmax(preds[i]).item() == 1:
                y_pred[i] = majour_class
            else:
                pr = minor_model(item.unsqueeze(0))
                pred = torch.argmax(pr).item()
                y_pred[i] = pred if pred < majour_class else pred + 1
    return y_pred

def get_preds_tta(preds, inputs, majour_class, minor_model):
    majour_class = int(majour_class)
    with torch.no_grad():
        y_pred = torch.empty((preds.shape[0],))
        for i, item in enumerate(inputs):
            if torch.argmax(preds[i]).item() == 1:
                y_pred[i] = majour_class
            else:
                pr = minor_model(item.unsqueeze(0))
                pr = test_time_augmentations(pr, item.unsqueeze(0), minor_model)
                pred = torch.argmax(pr).item()
                y_pred[i] = pred if pred < majour_class else pred + 1
    return y_pred

def predict_val(bin_model, minor_model, val_dataloader_bin, device, majour_class):
    bin_model.eval()
    minor_model.eval()
    acc = 0
    acc_tta = 0

    y_preds, y_preds_tta, y_trues = [], [], []
    for inputs, labels in val_dataloader_bin:
        inputs = inputs.to(device)
        y_trues.extend(labels.detach().cpu().tolist())

        with torch.no_grad():
            preds = bin_model(inputs)
            
        y_pred = get_preds(preds, inputs, majour_class, minor_model)
        y_preds.extend(y_pred.detach().cpu().tolist())
        acc += (y_pred==labels.data).float().mean()

        y_pred_tta = get_preds_tta(test_time_augmentations(preds, inputs, bin_model), inputs, majour_class, minor_model)
        y_preds_tta.extend(y_pred_tta.detach().cpu().tolist())
        acc_tta += (y_pred_tta == labels.data).float().mean()

    f1 = f1_score(y_trues, y_preds, average='micro')
    f1_tta = f1_score(y_trues, y_preds_tta, average='micro')
    print('Test accuracy = {:.4f}, f1 = {:.4f}'.format(acc/len(val_dataloader_bin), f1))
    print('Test accuracy (with tta) = {:.4f}, f1 = {:.4f}'.format(acc_tta/len(val_dataloader_bin), f1_tta))
    return acc/len(val_dataloader_bin), acc_tta/len(val_dataloader_bin), f1, f1_tta

In [5]:
def get_bin_minor_models(n_classes):
    #model = models.resnet34(weights = models.ResNet34_Weights.DEFAULT)
    bin_model = models.efficientnet_b4(weights = models.EfficientNet_B4_Weights.DEFAULT)
    minor_model = models.efficientnet_b4(weights = models.EfficientNet_B4_Weights.DEFAULT)
    
    for param in bin_model.parameters():
        param.requires_grad = False
    for param in minor_model.parameters():
        param.requires_grad = False

    #for ResNet
    '''model.fc = torch.nn.Sequential(torch.nn.Dropout(p = 0.2, inplace=True),
                                    torch.nn.Linear(model.fc.in_features, model.fc.in_features//2),
                                    torch.nn.Dropout(p = 0.2, inplace=True),
                                    torch.nn.LeakyReLU(),
                                    torch.nn.Linear(model.fc.in_features//2, n_classes)) '''
    # For EfficientNet
    bin_model.classifier = torch.nn.Sequential(torch.nn.Dropout(p = 0.2, inplace=True),
                                        torch.nn.Linear(bin_model.classifier[1].in_features, bin_model.classifier[1].in_features//2),
                                        torch.nn.Dropout(p = 0.2, inplace=True),
                                        torch.nn.LeakyReLU(),
                                        torch.nn.Linear(bin_model.classifier[1].in_features//2, 2))
    minor_model.classifier = torch.nn.Sequential(torch.nn.Dropout(p = 0.2, inplace=True),
                                        torch.nn.Linear(minor_model.classifier[1].in_features, minor_model.classifier[1].in_features//2),
                                        torch.nn.Dropout(p = 0.2, inplace=True),
                                        torch.nn.LeakyReLU(),
                                        torch.nn.Linear(minor_model.classifier[1].in_features//2, n_classes-1))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    #print(f'Using {device}')
    minor_model = minor_model.to(device)
    bin_model = bin_model.to(device)

    loss_bin = torch.nn.CrossEntropyLoss()
    loss_minor = torch.nn.CrossEntropyLoss()
    bin_optimizer = torch.optim.Adam(bin_model.parameters(), lr=3.0e-4)
    minor_optimizer = torch.optim.Adam(minor_model.parameters(), lr=3.0e-4)

    # Decay LR by a factor of 0.1 every 7 epochs
    #currently the scheduler is disabled in 'train_model' function
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    return bin_model, loss_bin, bin_optimizer, minor_model, loss_minor, minor_optimizer, device

In [None]:
def count_disbalance(df, param):
    '''
    Decide if classes in group are disbalanced or not

    Args:
        df (pd.DataFrame): erosion or jsn dataframe
        param (str): 'jsn' or 'erosion'
    '''
    d = {}
    for column in df.columns:
        vals = df[column].to_list()
        s = sum(vals)
        for i, v in enumerate(vals):
            if v > s * .5:
                d[f'{column}_{param}'] = str(i)
                break
        else:
            d[f'{column}_{param}'] = 'balanced'
    return d

erosion_df = pd.read_csv(os.path.join('dataset', 'non-sorted', 'erosion_data_counts.csv'), header = 0, index_col=0)
jsn_df = pd.read_csv(os.path.join('dataset', 'non-sorted', 'jsn_data_counts.csv'), header = 0, index_col=0)
er_d, jsn_d = count_disbalance(erosion_df, 'erosion'), count_disbalance(jsn_df, 'jsn')
print(er_d)
print(jsn_d)

{'DIP_erosion': '0', 'CMC_erosion': '0', 'wrist_erosion': 'balanced', 'RC_erosion': '0', 'ulna_erosion': '0', 'PIP_erosion': '0', 'MCP_erosion': '0'}
{'DIP_jsn': 'balanced', 'CMC_jsn': 'balanced', 'wrist_jsn': 'balanced', 'RC_jsn': 'balanced', 'ulna_jsn': '0', 'PIP_jsn': 'balanced', 'MCP_jsn': 'balanced'}


In [7]:
full_dict = dict(list(er_d.items()) + list(jsn_d.items()))

for key in full_dict:
    if key == 'ulna_jsn': #consisted only of class '0'
        continue
    if full_dict[key] != 'balanced':
        majour_class = full_dict[key]
        print(f'\n {key}, majour class - {majour_class}')

        train_dataloader_bin, val_dataloader_bin, train_dataloader_minor, val_dataloader_minor, n_classes = get_dataloaders(key, full_dict[key])
        bin_model, loss_bin, bin_optimizer, minor_model, loss_minor, minor_optimizer, device = get_bin_minor_models(n_classes)

        bin_model = train_model(bin_model, train_dataloader_bin, val_dataloader_bin, loss_bin, bin_optimizer, num_epochs = 80,
                                 device = device, output_transform = f'bin_{majour_class}')
        print('Finished training bin model')
        minor_model = train_model(minor_model, train_dataloader_minor, val_dataloader_minor, loss_minor, minor_optimizer, num_epochs = 80,
                                 device = device, output_transform = f'minor_{majour_class}')
        print('Finished training minor model')

        val_acc, val_acc_tta, f1, f1_tta = predict_val(bin_model, minor_model, val_dataloader_bin, device, majour_class)
        
        #check for model name
        torch.save(bin_model, os.path.join('models', \
                        'bin_effNetb4_{}_{:.3f}_tta_{:.3f}.json'.format(key, val_acc, val_acc_tta)))
        torch.save(minor_model, os.path.join('models', \
                        'minor_effNetb4_{}_{:.3f}_tta_{:.3f}.json'.format(key, val_acc, val_acc_tta)))


 DIP_erosion, majour class - 0
40 8 1 1


100%|██████████| 80/80 [39:39<00:00, 29.75s/it]


Finished training bin model


100%|██████████| 80/80 [09:14<00:00,  6.93s/it]


Finished training minor model
Test accuracy = 0.9799, f1 = 0.9821
Test accuracy (with tta) = 0.9777, f1 = 0.9802

 CMC_erosion, majour class - 0
32 7 1 1


100%|██████████| 80/80 [40:34<00:00, 30.43s/it]


Finished training bin model


100%|██████████| 80/80 [09:12<00:00,  6.90s/it]


Finished training minor model
Test accuracy = 0.9603, f1 = 0.9876
Test accuracy (with tta) = 0.9683, f1 = 0.9900

 RC_erosion, majour class - 0
8 2 2 2


100%|██████████| 80/80 [12:20<00:00,  9.26s/it]


Finished training bin model


100%|██████████| 80/80 [10:04<00:00,  7.55s/it]


Finished training minor model
Test accuracy = 0.7081, f1 = 0.7745
Test accuracy (with tta) = 0.7368, f1 = 0.8039

 ulna_erosion, majour class - 0
8 2 2 2


100%|██████████| 80/80 [12:36<00:00,  9.46s/it]


Finished training bin model


100%|██████████| 80/80 [09:58<00:00,  7.48s/it]


Finished training minor model
Test accuracy = 0.6891, f1 = 0.7500
Test accuracy (with tta) = 0.6969, f1 = 0.7596

 PIP_erosion, majour class - 0
32 7 2 2


100%|██████████| 80/80 [39:49<00:00, 29.87s/it]


Finished training bin model


100%|██████████| 80/80 [09:42<00:00,  7.28s/it]


Finished training minor model
Test accuracy = 0.8535, f1 = 0.9429
Test accuracy (with tta) = 0.8580, f1 = 0.9479

 MCP_erosion, majour class - 0




40 8 7 7


100%|██████████| 80/80 [45:34<00:00, 34.19s/it]


Finished training bin model


100%|██████████| 80/80 [12:12<00:00,  9.15s/it]


Finished training minor model
Test accuracy = 0.7966, f1 = 0.8088
Test accuracy (with tta) = 0.8048, f1 = 0.8167
