This script demostrate the training (i.e., finetuning) of the Fatal Crash Risk Estiamtion model

In [None]:
import utilities as UT
from mydataset import MyDataset_MultiScale_SameAug as MyDataset_MultiScale
from myops import Train_FT_MTSL_MTHD_V2, Test_FT_MTSL_MTHD_V2
import myaugmentation as MyAug
import mymodels

import os 
import tqdm
import numpy as np

import torch 
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from torchvision import models
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from torch.cuda.amp import GradScaler#, autocast
import torch.nn.parallel as parallel

In [None]:
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")

    # Get the names of the available GPUs
    gpu_names = [torch.cuda.get_device_name(i) for i in range(num_gpus)]
    print("Available GPUs:")
    for i, gpu_name in enumerate(gpu_names):
        print(f"GPU {i}: {gpu_name}")
else:
    print("CUDA is not available on this system.") 

In [None]:
# Hyperparameters
epochs = 40
batch_size = 64
num_workers = 4
num_classes = 2

pretrained = True 
learning_rate = 1e-4
class_weights = [1.370, 3.681] # inverse frequency for crossview
optimizer_name = "AdamW"

model_name = 'Res50-FineTune-MTSL-MTHD-V2'+'_'+optimizer_name+'_'+str(learning_rate)+'_'+str(batch_size)
training_trail = 'CrossView-Train_PreTrain-MultiScale-24'

ckpt_name = "ResNet50_Pre-Train_MultiScale/ckpt/ResNet50_Pre-Train_MultiScale_24_0.017479911147395014.pth"

project_directory = "../experimental_results"
metadataTrain = "../metadata/train.csv" # list of training samples
metadataVal = "../metadata/val_crossview.csv" # list of validation samples
metadataTest = "../metadata/test_crossview.csv" # list of test samples

output_path = os.path.join(project_directory, model_name.split("_")[0], str(training_trail))
ckpt_path = os.path.join(project_directory, ckpt_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_dataset = MyDataset_MultiScale(if_test=False, 
                                     basic_transform=MyAug.base_aug(), 
                                     metadata=metadataTrain)
val_dataset = MyDataset_MultiScale(if_test=True, 
                                   basic_transform=MyAug.base_aug(), 
                                   metadata=metadataVal)
test_dataset = MyDataset_MultiScale(if_test=True, 
                                    basic_transform=MyAug.base_aug(), 
                                    metadata=metadataTest)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 
                              num_workers=num_workers, shuffle=True)

val_dataloader = DataLoader(val_dataset, batch_size=batch_size, 
                            num_workers=num_workers, shuffle=False)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, 
                             num_workers=num_workers, shuffle=False)

In [None]:
# load pretrained weights
state_dict = torch.load(ckpt_path)

# create a mapping of old key names to new key names
key_map = {}
for key in state_dict.keys():
    new_key = key.replace("module.", "")  # remove "model." prefix
    key_map[key] = new_key

# rename the keys in the state dictionary
renamed_state_dict = {}
for key, value in state_dict.items():
    renamed_state_dict[key_map[key]] = value
    
# Load the model    
model = mymodels.ResNet50_MultiScale_MultiHead_V2(pretrain_weights=renamed_state_dict)
model = parallel.DataParallel(model)
model.to(device)

class_weights = torch.tensor(class_weights).to(device) 
criterion = nn.CrossEntropyLoss(class_weights)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
scaler = GradScaler() # the scaler for mixed precision training

writer = SummaryWriter(log_dir=os.path.join(output_path, "log", model_name))

In [None]:
# Train the model
for epoch in tqdm.tqdm(range(epochs)):
    
    model, epoch_loss, train_acc = Train_FT_MTSL_MTHD_V2(epoch, train_dataloader, model, criterion, optimizer, 
                                            scheduler, scaler, device, writer)        
    
    _, val_acc = Test_FT_MTSL_MTHD_V2(epoch, val_dataloader, model, criterion, device, writer, mode="Val")

    _, test_acc = Test_FT_MTSL_MTHD_V2(epoch, test_dataloader, model, criterion, device, writer, mode="Test")

    torch.save(model.state_dict(), 
               os.path.join(output_path, model_name+f'_{epoch}_{train_acc}_{val_acc}_{test_acc}.pth'))
    
    print("Epoch: %d\t Train [loss/acc]: [%.4f/%.4f]\t Val/Test Acc: %.4f/%.4f" 
          %(epoch, epoch_loss, train_acc, val_acc, test_acc))
    
    