# Notebook to train binary classifier (before, after growth plate)

In [1]:
import sys
import os
sys.path.append('..')

from multiprocessing import freeze_support
from torchvision.transforms import v2
from data_utils.dataset import BoneSlicesDatasetPrev
from training_utils import train_one_epoch
from validation_metrics import calculate_metric
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import precision_score
from pathlib import Path
from datetime import datetime
from time import time
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn
import pandas as pd

## Augmentations (for future)

In [2]:
#v2.Resize(size=(224, 224), antialias=True),
transforms = v2.Compose([
    v2.RandomHorizontalFlip(0.5),
    v2.RandomVerticalFlip(0.5),
    v2.RandomRotation(degrees=180),
    v2.ToDtype(torch.float32, scale=False)
])

## Training parameters

In [3]:
torch.manual_seed(0)

TIMESTAMP = datetime.now().strftime('%Y%m%d_%H%M%S')
DEVICE = 'cuda' if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else 'cpu'
DEVIE = 'cpu'
BASE_DIR = 'training/resnet_18_6_all/'
WRITER_DIR = BASE_DIR + "logs"
MODEL_PATH = BASE_DIR + "saved_models/"
BATCH_SIZE = 64
NUM_WORKERS = 4
LEARNING_RATE = 0.0001
EPOCHS = 20

# os.chdir('..')
print("--------------------")
print(f"TIMESTAMP: {TIMESTAMP}")

print(f"CURRENT WORKING DIR: {os.getcwd()}") 
print(f"DEVICE: {DEVICE}")
print(f"BASE_DIR: {BASE_DIR}")
print(f"WRITER_DIR: {WRITER_DIR}")
print(f"MODEL_PATH: {MODEL_PATH}")
print(f"BATCH_SIZE: {BATCH_SIZE}")
print(f"NUM_WORKERS: {NUM_WORKERS}")
print(f"LEARNING_RATE: {LEARNING_RATE}")
print(f"EPOCHS: {EPOCHS}")
print("--------------------")

if not os.path.exists(MODEL_PATH):
    os.makedirs(MODEL_PATH)

--------------------
TIMESTAMP: 20240520_093559
CURRENT WORKING DIR: /home/ec2-user/SageMaker/code-warriors-imgcollab/training
DEVICE: cuda
BASE_DIR: training/resnet_18_6_all/
WRITER_DIR: training/resnet_18_6_all/logs
MODEL_PATH: training/resnet_18_6_all/saved_models/
BATCH_SIZE: 64
NUM_WORKERS: 4
LEARNING_RATE: 0.0001
EPOCHS: 20
--------------------


## Create dataset and dataloaders

In [4]:
os.chdir('..')
print(os.getcwd())

/home/ec2-user/SageMaker/code-warriors-imgcollab


In [5]:
METADATA_PATH = os.path.join('train.csv')
metadata_df = pd.read_csv(METADATA_PATH)

In [6]:
def get_balanced_dataset(meta_df, column, fraction=0.2):
    validation_examples = []
    for val in meta_df[column].unique():
        subdf = meta_df[meta_df[column] == val]
        validation_examples += list(subdf.sample(frac=fraction)['Image Name'])
    training_examples = set(meta_df['Image Name'].unique()).difference(set(validation_examples))
    return validation_examples, list(training_examples)

In [7]:
# validation_examples, training_examples = get_balanced_dataset(metadata_df, 'STUDY ID', fraction=0.2)

In [8]:
# VALIDATION_EXAMPLES_FILE = '10folds/5_fold/validation_examples.csv'
# TRAINING_EXAMPLES_FILE = '10folds/5_fold/training_examples.csv'
# # val_df = pd.DataFrame(validation_examples, columns=['Image Name'])
# # train_df = pd.DataFrame(training_examples, columns=['Image Name'])
# # val_df.to_csv(VALIDATION_EXAMPLES_FILE)
# # train_df.to_csv(TRAINING_EXAMPLES_FILE)
# validation_examples = list(pd.read_csv(VALIDATION_EXAMPLES_FILE)['Image Name'])
# training_examples = list(pd.read_csv(TRAINING_EXAMPLES_FILE)['Image Name'])

In [9]:
train_ds = BoneSlicesDatasetPrev(json_config_filepath = 'data_utils/config_binary_z.json', transform=transforms)
valid_ds = BoneSlicesDatasetPrev(json_config_filepath = 'data_utils/config_binary_z.json')
# train_ds.subset_by_image_name(training_examples)
# valid_ds.subset_by_image_name(validation_examples)
print(len(valid_ds.metadata['Image Name']))

print(len(train_ds.metadata['Image Name']))

44940
44940


In [10]:
valid_ds.metadata.head()

Unnamed: 0.1,index,Unnamed: 0,STUDY ID,Bone ID,Image Name,Growth Plate Index,Slice Index,img_file_name,axis,img_path
0,0,0,0,4,32c1aa1bcd,168,0,32c1aa1bcd_0.npy,z,sliced_data_new/z_axis/32c1aa1bcd_0.npy
1,1,1,0,4,32c1aa1bcd,168,1,32c1aa1bcd_1.npy,z,sliced_data_new/z_axis/32c1aa1bcd_1.npy
2,2,2,0,4,32c1aa1bcd,168,2,32c1aa1bcd_2.npy,z,sliced_data_new/z_axis/32c1aa1bcd_2.npy
3,3,3,0,4,32c1aa1bcd,168,3,32c1aa1bcd_3.npy,z,sliced_data_new/z_axis/32c1aa1bcd_3.npy
4,4,4,0,4,32c1aa1bcd,168,4,32c1aa1bcd_4.npy,z,sliced_data_new/z_axis/32c1aa1bcd_4.npy


In [11]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE)

## Training

In [None]:
for it in range(1):#range(5):
        print(f"\n##############################################")
        print(f"Iteration: {it + 1}")
        print(f"##############################################\n")

        # Model, optimizer, tensorboard setup
        resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        # Changing last classificator layer from 1000 classes to 2
        resnet.fc = nn.Linear(512, 2)
        # Changing 3 channels into 1 (monochromatic image)
        resnet.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        resnet.to(DEVICE)

        loss_fn = torch.nn.CrossEntropyLoss()
        # optimizer = torch.optim.Adam(resnet.parameters(), lr=LEARNING_RATE)
        optimizer = torch.optim.AdamW(resnet.parameters())

        writer = SummaryWriter(f'{WRITER_DIR}/Iteration_{it + 1}')
        epoch_number = 0
        best_vloss = sys.float_info.max

        start_time = time()

        if not os.path.exists(MODEL_PATH + f'/Iteration_{it + 1}'):
            os.makedirs(MODEL_PATH + f'/Iteration_{it + 1}')

        for epoch in range(EPOCHS):
            print('EPOCH {}:'.format(epoch_number + 1))
            epoch_start_time = time()

            resnet.train(True)
            avg_loss = train_one_epoch(epoch_number, writer, train_dl, optimizer, loss_fn, resnet, DEVICE, start_time)

            resnet.eval()
            running_vloss = 0.0

            with torch.no_grad():
                for i, vdata in enumerate(val_dl):
                    vinputs, vlabels = vdata
                    voutputs = resnet(vinputs.to(DEVICE))
                    vloss = loss_fn(voutputs, vlabels.to(DEVICE))
                    running_vloss += vloss

            avg_vloss = running_vloss / (i + 1)

            print("#############################################################")
            print("Epoch results:")
            print(f'Loss train {avg_loss} valid loss: {avg_vloss}')
            validation_precision_score = calculate_metric(resnet, val_dl, device=DEVICE,
                                                          metric=lambda x, y: precision_score(x, y, average='macro'))
            print(f'Validation macro average precision: {validation_precision_score}')
            print(f'Epoch execution time {time() - epoch_start_time}')
            print("#############################################################\n\n")

            writer.add_scalars('Training vs. Validation Loss',
                               {'Training': avg_loss, 'Validation': avg_vloss}, epoch_number + 1)

            writer.add_scalars('Macro_averaged_precision_score',
                               {'Validation': validation_precision_score}, epoch_number + 1)

            writer.flush()

            best_vloss = avg_vloss
            model_path = f'model_{TIMESTAMP}_{epoch_number}'

            torch.save(resnet.state_dict(), MODEL_PATH + f"Iteration_{it + 1}/" + model_path)

            epoch_number += 1

        writer.close()



##############################################
Iteration: 1
##############################################

EPOCH 1:
Loss after batch 10: 0.2524; time elapsed: 9.97
Loss after batch 20: 0.1245; time elapsed: 18.44
Loss after batch 30: 0.0986; time elapsed: 27.39
Loss after batch 40: 0.0872; time elapsed: 35.95
Loss after batch 50: 0.0658; time elapsed: 43.97
Loss after batch 60: 0.0368; time elapsed: 52.15
Loss after batch 70: 0.0355; time elapsed: 60.54
Loss after batch 80: 0.0301; time elapsed: 68.85
Loss after batch 90: 0.0353; time elapsed: 76.90
Loss after batch 100: 0.0365; time elapsed: 84.97
Loss after batch 110: 0.0384; time elapsed: 93.14
Loss after batch 120: 0.0458; time elapsed: 101.26
Loss after batch 130: 0.0304; time elapsed: 109.26
Loss after batch 140: 0.0257; time elapsed: 117.24
Loss after batch 150: 0.0361; time elapsed: 125.56
Loss after batch 160: 0.0348; time elapsed: 133.84
Loss after batch 170: 0.0581; time elapsed: 141.97
Loss after batch 180: 0.0402; time e