In [None]:
import wandb
import torchvision.models as torchmodels
import torch
from torch import nn
import json
import os
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
import torch.optim as optim
import datetime
import pandas as pd
import tensorflow as tf
from tqdm import tqdm
import time
from sklearn.metrics import confusion_matrix
import argparse
import random

In [None]:
# print the wandb version
print(wandb.__version__)

👆 check not version 0.13.10 as this falls over

In [1]:
# get wd
import os
wd = os.getcwd()
print(wd)

/Users/kgoldmann/Documents/Projects/AMBER/on_device_classifier/02_model_training/pytorch


In [None]:
from models.build_model import build_model
from data2.mothdataset import MOTHDataset
from training_params.loss import Loss
from training_params.optimizer import Optimizer
from data2 import dataloader
from evaluation.micro_accuracy_batch import MicroAccuracyBatch
from evaluation.micro_accuracy_batch import add_batch_microacc, final_microacc
from evaluation.macro_accuracy_batch import MacroAccuracyBatch
from evaluation.macro_accuracy_batch import add_batch_macroacc, final_macroacc, taxon_accuracy
from evaluation.confusion_matrix_data import confusion_matrix_data
from evaluation.confusion_data_conversion import ConfusionDataConvert

In [None]:
config_file = './configs/01_uk_moth_data_config.json'
dataloader_num_workers = 4
random_seed = 42

In [None]:
f             = open(config_file)
config_data   = json.load(f)
print(json.dumps(config_data, indent=3))

In [None]:
config_data['training']['wandb']['project']

In [None]:
wandb.init(project=config_data['training']['wandb']['project'], entity=config_data['training']['wandb']['entity'])
wandb.init(settings=wandb.Settings(start_method="fork"))

In [None]:
# Set up the data
image_resize  = config_data['training']['image_resize']
batch_size    = config_data['training']['batch_size']
label_list    = config_data['dataset']['label_info']
epochs        = config_data['training']['epochs']
loss_name     = config_data['training']['loss']['name']
early_stop    = config_data['training']['early_stopping']
start_val_los = config_data['training']['start_val_loss']

label_read    = json.load(open(label_list))
species_list  = label_read['species_list']
genus_list    = label_read['genus_list']
family_list   = label_read['family_list']

no_species_cl = config_data['model']['species_num_classes']
no_genus_cl   = config_data['model']['genus_num_classes']
no_family_cl  = config_data['model']['family_num_classes']
model_type    = config_data['model']['type']
preprocess_mode = config_data['model']['preprocess_mode']

opt_name      = config_data['training']['optimizer']['name']
learning_rate = config_data['training']['optimizer']['learning_rate']
momentum      = config_data['training']['optimizer']['momentum']

mod_save_pth  = config_data['training']['model_save_path']
mod_name      = config_data['training']['model_name']
mod_ver       = config_data['training']['version']
DTSTR         = datetime.datetime.now()
DTSTR         = DTSTR.strftime("%Y-%m-%d-%H-%M")
save_path     = mod_save_pth + mod_name + '_' + mod_ver + '_' + model_type + '_' + DTSTR + '.pt'

taxon_hierar  = config_data['dataset']['taxon_hierarchy']
label_info    = config_data['dataset']['label_info']

In [None]:
torch.cuda.is_available()

In [None]:
# Loading Model
# Get cpu or gpu device for training.
if (torch.cuda.is_available()) and (not torch.backends.mps.is_available()):
	device = "cuda" 
elif torch.backends.mps.is_available():
	device = "mps"
else: 
	device ="cpu"

print(device)

In [None]:
model = build_model(config_data)

# Making use of multiple GPUs
if device == "cuda" and torch.cuda.device_count() > 1:
	print("Let's use", torch.cuda.device_count(), "GPUs!")
	model = nn.DataParallel(model)

if device == "mps" and torch.cuda.device_count() > 1:
	print("Let's use", torch.cuda.device_count(), "GPUs!")
	model = nn.DataParallel(model)

model = model.to(device)

In [None]:
train_webdataset_url = "./data2/datasets/test/train/train-500-{000000..000002}.tar"
val_webdataset_url = "./data2/datasets/test/val/val-500-000000.tar"
test_webdataset_url = "./data2/datasets/test/test/test-500-000000.tar"

dataloader_num_workers = 4

In [None]:
# # Loading Data
# # Training data loader
train_dataloader = dataloader.build_webdataset_pipeline(
	sharedurl=train_webdataset_url,
	input_size=image_resize,
	batch_size=batch_size,
	is_training=True,
	num_workers=dataloader_num_workers,
	preprocess_mode=preprocess_mode)


# Validation data loader
val_dataloader = dataloader.build_webdataset_pipeline(
	sharedurl=val_webdataset_url,
	input_size=image_resize,
	batch_size=batch_size,
	is_training=False,
	num_workers=dataloader_num_workers,
	preprocess_mode=preprocess_mode)

# Testing data loader
test_dataloader = dataloader.build_webdataset_pipeline(
	sharedurl=test_webdataset_url,
	input_size=image_resize,
	batch_size=batch_size,
	is_training=False,
	num_workers=dataloader_num_workers,
	preprocess_mode=preprocess_mode)

In [None]:
# Loading Loss function and Optimizer
loss_func = Loss(loss_name).func()
optimizer = Optimizer(opt_name, model, learning_rate, momentum).func()

# Model Training
lowest_val_loss = start_val_los
early_stp_count = 0

In [None]:
# first image_batch in train_dataloader
image_batch, label_batch = next(iter(train_dataloader))
#image_batch.shape, label_batch.shape

image_batch, label_batch = image_batch.to(device, non_blocking=True), label_batch.to(device, non_blocking=True)

In [None]:
import timm

timm.create_model('tf_efficientnetv2_b3', pretrained=True, num_classes=5)

Below takes 1-2 hours to run

In [None]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

from evaluation.confusion_matrix_data import confusion_matrix_data
from evaluation.confusion_data_conversion import ConfusionDataConvert
from evaluation.micro_accuracy_batch import MicroAccuracyBatch
from evaluation.micro_accuracy_batch import add_batch_microacc, final_microacc
from evaluation.macro_accuracy_batch import MacroAccuracyBatch
from evaluation.macro_accuracy_batch import add_batch_macroacc, final_macroacc, taxon_accuracy

for epoch in tqdm(range(epochs)):
	train_loss      = 0
	train_batch_cnt = 0
	val_loss        = 0
	val_batch_cnt   = 0
	s_time          = time.time()
	
	global_microacc_data_train = None
	global_microacc_data_val   = None
	
	# model training on training dataset
	model.train()                      
	for image_batch, label_batch in train_dataloader: 
		#print(image_batch)
		#print(label_batch)   
		#print(image_batch.to(device, non_blocking=True))
		image_batch,  label_batch = image_batch.to(device, non_blocking=True), label_batch.to(device, non_blocking=True)          
	
		optimizer.zero_grad()
		# forward + backward + optimize
		outputs   = model(image_batch)      
		t_loss    = loss_func(outputs, label_batch)
		t_loss.backward()
		optimizer.step()        
		train_loss += t_loss.item()
		
		# micro-accuracy calculation
		micro_accuracy_train          = MicroAccuracyBatch(outputs, label_batch, label_info, taxon_hierar).batch_accuracy()   
		global_microacc_data_train    = add_batch_microacc(global_microacc_data_train, micro_accuracy_train)
		train_batch_cnt += 1
	train_loss = train_loss/train_batch_cnt


	# model evaluation on validation dataset
	model.eval()                      
	for image_batch, label_batch in val_dataloader:
		image_batch, label_batch = image_batch.to(device, non_blocking=True), label_batch.to(device, non_blocking=True)        
	
		outputs   = model(image_batch)        
		v_loss    = loss_func(outputs, label_batch)
		val_loss += v_loss.item()    
	
		# micro-accuracy calculation
		micro_accuracy_val          = MicroAccuracyBatch(outputs, label_batch, label_info, taxon_hierar).batch_accuracy()   
		global_microacc_data_val    = add_batch_microacc(global_microacc_data_val, micro_accuracy_val)
		val_batch_cnt += 1
	val_loss = val_loss/val_batch_cnt

	if val_loss<lowest_val_loss:
		if torch.cuda.device_count() > 1:
			torch.save({
				'epoch': epoch,
				'model_state_dict': model.module.state_dict(),
				'optimizer_state_dict': optimizer.state_dict(),
				'train_loss': train_loss,
				'val_loss':val_loss}, 
				save_path)   
		else:
			torch.save({
				'epoch': epoch,
				'model_state_dict': model.state_dict(),
				'optimizer_state_dict': optimizer.state_dict(),
				'train_loss': train_loss,
				'val_loss':val_loss}, 
				save_path)    
			
		lowest_val_loss = val_loss
		early_stp_count = 0
	else:
		early_stp_count += 1

	# logging metrics
	wandb.log({'training loss': train_loss, 'validation loss': val_loss, 'epoch': epoch})

	final_micro_accuracy_train = final_microacc(global_microacc_data_train)
	final_micro_accuracy_val   = final_microacc(global_microacc_data_val) 
	wandb.log({'train_micro_species_top1': final_micro_accuracy_train['micro_species_top1'], 
			'train_micro_genus_top1': final_micro_accuracy_train['micro_genus_top1'],
			'train_micro_family_top1': final_micro_accuracy_train['micro_family_top1'],
			'val_micro_species_top1': final_micro_accuracy_val['micro_species_top1'], 
			'val_micro_genus_top1': final_micro_accuracy_val['micro_genus_top1'],
			'val_micro_family_top1': final_micro_accuracy_val['micro_family_top1'],
			'epoch': epoch   
			})   

	e_time = (time.time()-s_time)/60   # time taken in minutes    
	wandb.log({'time per epoch': e_time, 'epoch': epoch})

	if early_stp_count >= early_stop:
		break    



In [None]:
wandb.log_artifact(save_path, name=mod_name, type='models')

In [None]:
model.eval()                                          
global_microacc_data     = None
global_macroacc_data     = None
global_confusion_data_sp = None
global_confusion_data_g  = None
global_confusion_data_f  = None

print("Prediction on test data started ...")

with torch.no_grad():                                 
	for image_batch, label_batch in test_dataloader:  
		image_batch, label_batch = image_batch.to(device), label_batch.to(device)
		predictions              = model(image_batch)
	
		# micro-accuracy calculation
		micro_accuracy           = MicroAccuracyBatch(predictions, label_batch, label_info, taxon_hierar).batch_accuracy()   
		global_microacc_data     = add_batch_microacc(global_microacc_data, micro_accuracy)
	
		# macro-accuracy calculation
		macro_accuracy           = MacroAccuracyBatch(predictions, label_batch, label_info, taxon_hierar).batch_accuracy()
		global_macroacc_data     = add_batch_macroacc(global_macroacc_data, macro_accuracy) 

		# confusion matrix
		sp_label_batch, sp_predictions, g_label_batch, g_predictions, f_label_batch, f_predictions = ConfusionDataConvert(predictions, label_batch, label_info, taxon_hierar).converted_data()   
	
		global_confusion_data_sp = confusion_matrix_data(global_confusion_data_sp, [sp_label_batch, sp_predictions])
		global_confusion_data_g  = confusion_matrix_data(global_confusion_data_g, [g_label_batch, g_predictions])
		global_confusion_data_f  = confusion_matrix_data(global_confusion_data_f, [f_label_batch, f_predictions])        

final_micro_accuracy            = final_microacc(global_microacc_data)
final_macro_accuracy, taxon_acc = final_macroacc(global_macroacc_data)
tax_accuracy                    = taxon_accuracy(taxon_acc, label_read)

# saving evaluation data to file
confdata_pd_f  = pd.DataFrame({'F_Truth': global_confusion_data_f[0].reshape(-1), 'F_Prediction': global_confusion_data_f[1].reshape(-1)})
confdata_pd_g  = pd.DataFrame({'G_Truth': global_confusion_data_g[0].reshape(-1), 'G_Prediction': global_confusion_data_g[1].reshape(-1)})
confdata_pd_sp = pd.DataFrame({'S_Truth': global_confusion_data_sp[0].reshape(-1), 'S_Prediction': global_confusion_data_sp[1].reshape(-1)})
confdata_pd    = pd.concat([confdata_pd_f, confdata_pd_g, confdata_pd_sp], axis=1)
confdata_pd.to_csv(mod_save_pth + mod_ver + '_confusion-data.csv', index=False)

with open(mod_save_pth + mod_name + '_' + mod_ver + '_micro-accuracy.json', 'w') as outfile:
	json.dump(final_micro_accuracy, outfile)

with open(mod_save_pth + mod_name + '_' + mod_ver + '_macro-accuracy.json', 'w') as outfile:
	json.dump(final_macro_accuracy, outfile)

with open(mod_save_pth + mod_name + '_' + mod_ver + '_taxon-accuracy.json', 'w') as outfile:
	json.dump(tax_accuracy, outfile)

wandb.log({'final micro accuracy' : final_micro_accuracy})
wandb.log({'final macro accuracy' : final_macro_accuracy})
wandb.log({'configuration' : config_data})
wandb.log({'tax accuracy' : tax_accuracy})

label_f = tf.keras.utils.to_categorical(global_confusion_data_f[0], num_classes=no_family_cl)
pred_f  = tf.keras.utils.to_categorical(global_confusion_data_f[1], num_classes=no_family_cl)

wandb.finish()
